Training Keras Models#
Who this is for: anyone with a working Keras training script who wants
it to run on a cloud TPU or GPU without standing up infrastructure.
Kinetic ships your existing model.compile() / model.fit() code to a
remote accelerator with a single decorator change. You don’t need to
restructure your training loop.
A first run#
import kinetic
@kinetic.run(accelerator="tpu-v6e-8")
def train_model():
import keras
import numpy as np
model = keras.Sequential([
keras.layers.Dense(64, activation="relu", input_shape=(10,)),
keras.layers.Dense(1),
])
model.compile(optimizer="adam", loss="mse")
x_train = np.random.randn(1000, 10)
y_train = np.random.randn(1000, 1)
history = model.fit(x_train, y_train, epochs=5, verbose=0)
return history.history["loss"][-1]
final_loss = train_model()
print(f"Final loss: {final_loss}")
A few things to note:
Imports for
keras,jax, etc. live inside the function so the remote worker uses its hardware-tuned install.The return value is serialized back to your local process. Keep it small — a final metric, a path under
KINETIC_OUTPUT_DIR, a dict of numbers. Don’t return the model object itself.accelerator="tpu-v6e-8"picks an 8-chip TPU v6e slice. Usecpuwhile iterating; switch when you’re ready for hardware. See Accelerators.
For the canonical end-to-end example with a real dataset, see
fashion_mnist.py (first entry under Quickstart).
How to think about it#
Your decorated function runs in a fresh process inside a container on a remote node. That has two practical consequences:
No local state crosses the boundary. Anything the function needs must either be passed as an argument, captured by closure, or shipped via
kinetic.Data. Locally-loaded variables that you reference by global name will not be there on the remote.The Keras backend is whatever the remote has installed. By default Kinetic’s prebuilt and bundled images use JAX. Set
KERAS_BACKENDif you need otherwise:@kinetic.run(accelerator="tpu-v6e-8", capture_env_vars=["KERAS_BACKEND"]) def train(): ...
Scaling beyond a single host#
For multi-host TPU slices like tpu-v5litepod-2x4, switch to the Pathways
backend so Keras’s distribution strategies have a working multi-host
runtime to talk to:
@kinetic.run(accelerator="tpu-v5litepod-2x4", backend="pathways")
def train_distributed():
...
See Distributed Training for the full multi-host setup, and LLM Fine-tuning for a concrete Gemma example.
Data#
Pulling NumPy arrays from inside the function works for tiny datasets,
but breaks down quickly. For real data, construct a
kinetic.Data(...) object at the call site in your local script
and pass it as an argument. Kinetic uploads (or mounts) the source and
delivers a plain filesystem path to the remote function. The decorated
function only ever sees a str path:
import kinetic
from kinetic import Data
@kinetic.run(accelerator="tpu-v6e-8")
def train(data_dir):
# `data_dir` is a local filesystem path on the remote pod.
import keras
...
# Local directory:
train(Data("./my_dataset/"))
# Existing GCS bucket:
train(Data("gs://my-bucket/dataset/"))
# Large GCS dataset, streamed on demand via FUSE:
train(Data("gs://my-bucket/large/", fuse=True))
Data accepts both local paths and gs:// URIs. See Data
for the decision matrix between downloaded, FUSE-mounted, and direct
access patterns.
Next steps#
fashion_mnist.py— full working example with a real dataset (first entry under Quickstart).Checkpointing — persist model weights and resume across runs.