Distributed Training#

Who this is for: users whose model or batch size has outgrown a single TPU host. Kinetic + the Pathways backend lets you treat a multi-host TPU slice as one logical machine, without writing your own multi-process JAX coordination.

For single-host slices (everything that fits on one TPU node like tpu-v5litepod-8), you don’t need this page — your existing JAX or Keras code already uses every chip on the node.

A first multi-host run#

Pick a multi-host accelerator:

import kinetic

@kinetic.run(accelerator="tpu-v6e-16")
def train_distributed():
    import jax
    print(f"Total devices across all hosts: {jax.device_count()}")
    print(f"This host: {jax.process_index()} of {jax.process_count()}")
    # ... your training code ...

Whether a slice is multi-host depends on the topology and the per-VM chip count, not on the accelerator string alone. For example, tpu-v5litepod-2x2 (4 chips on one VM) and tpu-v5litepod-2x4 (8 chips on one VM via ct5lp-hightpu-8t) are both single-host, while tpu-v5litepod-16 (4×4 across four 4-chip VMs) and tpu-v6e-16 (4×4 across four 4-chip VMs) are multi-host. See Accelerators for the full topology table.

Pathways is auto-selected for multi-host slices — Kinetic resolves backend="pathways" whenever the accelerator’s topology spans more than one node, so the example above doesn’t need to set it explicitly. You only need to pass backend="pathways" yourself if you want to develop against the Pathways code path on a single-host slice — handy for shortening the iteration loop before you scale up.

Data parallelism with Keras#

@kinetic.run(accelerator="tpu-v6e-16")
def train_data_parallel():
    import keras

    devices = keras.distribution.list_devices()
    device_mesh = keras.distribution.DeviceMesh(
        shape=(len(devices),),
        axis_names=["batch"],
        devices=devices,
    )
    keras.distribution.set_distribution(
        keras.distribution.DataParallel(device_mesh=device_mesh)
    )

    model = keras.Sequential([...])
    model.compile(...)
    model.fit(...)

For a richer end-to-end example using a real model, see pathways_example.py and gemma_sft_pathways_distributed.py.

How to think about it#

Each host runs its own copy of your function. JAX collectives (jax.lax.psum, sharding, pmap) and Keras’s distribution APIs handle the actual cross-host communication. Kinetic’s job is to:

  • Schedule the slice as a single logical job that the autoscaler treats atomically (no split brain).

  • Run your function on every host with the right JAX_* env vars set.

  • Stream stdout from the leader pod (process index 0) back to your local terminal. Other hosts’ stdout is not aggregated; if you need it, fetch it directly from the per-host pods (see “Debugging distributed jobs” below).

  • Return only the leader process’s (jax.process_index() == 0) value to your local machine, so you don’t get N copies of the result.

When a host throws, Kinetic catches the exception and re-raises it locally with the failing host’s traceback attached.

Warning

When not to use this: if your model and batch fit on a single TPU host, stay there. Multi-host adds startup latency, requires Pathways, and a single host failure fails the whole slice. Move to multi-host only when you’ve outgrown one node.

Failure modes and recovery#

Multi-host jobs fail differently from single-host jobs. The most common ones, with what to actually do:

  • Slow startup (5–10 minutes for the first multi-host run). A fresh TPU multi-host slice has to provision multiple VMs and boot Pathways. This is expected; don’t kill the job thinking it’s stuck. If startup consistently exceeds 10 minutes, run kinetic init and choose troubleshoot, and check your TPU quota.

  • Topology mismatch. Your code’s expected device count doesn’t match jax.device_count() on the slice. Symptom: shape errors deep in pmap or sharding. Fix: compute mesh shapes from jax.device_count() and jax.process_count() instead of hardcoding.

  • One host hangs, the slice times out. A single host that fails collective communication takes the slice with it. JAX raises a collective timeout on every host. Fix: read logs from every host — Kinetic interleaves them — and look for the divergent one. Common causes are uneven data loading or a Python exception on one host before the collective.

  • Spot preemption. Multi-host slices on spot capacity die together if any one host is preempted. Fix: don’t use spot for multi-host unless you can absorb full restarts (and have checkpoints).

  • Quota exhaustion mid-run. A scheduled slice can be delayed indefinitely if regional quota is full. Symptom: job stuck in PENDING for > 10 min on a multi-host accelerator. Fix: check Cloud Console quota for your accelerator type; consider switching zones.

Recommended checkpoint frequency: for any multi-host run, write a checkpoint at least every 10 minutes of wall time. The base rate of preemption, quota issues, and slice-wide failures is high enough that unbounded loss windows are not worth the throughput. See Checkpointing for the API.

Debugging distributed jobs#

kinetic jobs logs <id> (and --follow while the job is running) returns the leader pod’s stdout, which is what print() calls on process index 0 produce. To gate output to that one process, guard print statements with jax.process_index():

import jax

if jax.process_index() == 0:
    print(f"epoch {epoch}: loss={loss}")

For non-leader hosts, fetch logs directly from the per-host pods. kubectl get pods -n <namespace> | grep <job-id> lists every pod in the slice; kubectl logs <pod-name> then returns that host’s stdout. Cloud Logging in the GCP Console offers the same view through a UI filter on the job name.

If a job fails on any host, Kinetic catches the exception and re-raises it locally with that host’s stack trace, so you usually do not need to inspect non-leader logs to diagnose a crash.