Native JAX Training#
Who this is for: users who write training loops directly in JAX
rather than going through Keras. Kinetic runs your JAX code on cloud
TPUs and GPUs the same way it runs Keras code — wrap the function in
@kinetic.run() and call it. JAX-specific details (multi-device
parallelism, dependency filtering, multi-host coordination) are covered
below.
A first run#
import kinetic
@kinetic.run(accelerator="tpu-v5litepod-8")
def jax_computation():
import jax
import jax.numpy as jnp
print(f"Devices: {jax.devices()}")
x = jnp.ones((1000, 1000))
result = jnp.dot(x, x)
return float(result[0, 0])
print(jax_computation()) # 1000.0
A standard JAX training loop with jax.grad runs without modification:
@kinetic.run(accelerator="tpu-v6e-8")
def train():
import jax
import jax.numpy as jnp
def loss_fn(params, x, y):
pred = x @ params["w"] + params["b"]
return jnp.mean((pred - y) ** 2)
grad_fn = jax.grad(loss_fn)
key = jax.random.PRNGKey(0)
params = {"w": jax.random.normal(key, (10, 1)), "b": jnp.zeros(1)}
x = jax.random.normal(key, (512, 10))
y = x @ jnp.ones((10, 1)) + 0.1 * jax.random.normal(key, (512, 1))
lr = 0.01
for step in range(200):
grads = grad_fn(params, x, y)
params = {k: params[k] - lr * grads[k] for k in params}
if step % 50 == 0:
print(f"step {step}: loss={loss_fn(params, x, y):.4f}")
return float(loss_fn(params, x, y))
Imports for jax, jaxlib, and any other heavy library go inside
the decorated function so the remote worker uses its accelerator-tuned
install.
How to think about it#
JAX needs the right jaxlib and the right accelerator runtime
(libtpu, CUDA) to be installed in the container. Kinetic handles this
for you:
Bundled and prebuilt images ship with JAX matched to the accelerator type. You don’t need to pin
jax,jaxlib, orlibtpuinrequirements.txt.JAX packages in your
requirements.txtare filtered out before install so they don’t shadow the accelerator-correct copy in the image. See Dependencies for the filter behavior.
Inside the function, jax.devices() returns whatever the pod sees: an
8-chip TPU slice for tpu-v6e-8, an 8-device array for
tpu-v5litepod-8, a single GPU for l4, etc.
Single-host parallelism#
Use jax.pmap (or jax.sharding) to spread computation across all
devices on a single host:
@kinetic.run(accelerator="tpu-v5litepod-8")
def parallel_computation():
import jax
import jax.numpy as jnp
n_devices = jax.local_device_count()
print(f"Running on {n_devices} devices")
@jax.pmap
def parallel_matmul(x):
return jnp.dot(x, x.T)
data = jnp.ones((n_devices, 256, 256))
result = parallel_matmul(data)
return float(result[0, 0, 0])
Scaling beyond a single host#
For multi-host slices (e.g., tpu-v5litepod-2x4) JAX needs a coordination
runtime to set up cross-host collectives. Kinetic provides this through
the Pathways backend:
@kinetic.run(accelerator="tpu-v5litepod-2x4", backend="pathways")
def train_distributed():
import jax
# jax.process_count() > 1 here; pmap/sharding work cross-host.
...
Without backend="pathways", multi-host JAX collectives won’t have a
working coordinator. See Distributed Training
for the full multi-host setup.
Data#
To pass a dataset into a remote JAX function, construct a
kinetic.Data(...) object at the call site in your local script and
pass it as an argument. Kinetic uploads (or mounts) the source and
delivers a plain filesystem path to the remote function. The decorated
function only ever sees a str path:
import kinetic
from kinetic import Data
@kinetic.run(accelerator="tpu-v6e-8")
def train(data_dir):
# `data_dir` is a local filesystem path on the remote pod.
import os
files = os.listdir(data_dir)
...
# Local directory:
train(Data("./my_dataset/"))
# Existing GCS bucket:
train(Data("gs://my-bucket/dataset/"))
# Large GCS dataset, streamed on demand via FUSE:
train(Data("gs://my-bucket/large/", fuse=True))
Data accepts both local paths and gs:// URIs. See Data
for the decision matrix between downloaded, FUSE-mounted, and direct
access patterns.
Next steps#
Distributed Training — multi-host JAX with Pathways.
Checkpointing — Orbax checkpoint patterns under
KINETIC_OUTPUT_DIR.