Checkpointing and Outputs#

Long jobs need somewhere durable to write to. Pods come and go — when your training script exits, anything that lived only inside the pod’s filesystem is gone. Kinetic gives you KINETIC_OUTPUT_DIR: a per-job GCS prefix that survives the pod, so your checkpoints, logs, and final artifacts are still there when you come back.

This page covers what to write where, how Orbax (or any other library) plugs into it, and how cleanup and TTLs work.

A first checkpointed job#

Inside the pod, KINETIC_OUTPUT_DIR is already set. Read it and write under it. Fall back to a local path when the variable is not present so that the same function works when you exercise it locally:

import os

import kinetic

@kinetic.run(accelerator="cpu")
def train():
    # Remote: KINETIC_OUTPUT_DIR resolves to gs://.../outputs/<job_id>.
    # Local: fall back to a filesystem path under /tmp so the same code
    # works when you run the function directly for testing.
    output_dir = os.environ.get("KINETIC_OUTPUT_DIR", "/tmp/local_checkpoints")
    # ... train and write checkpoints/artifacts under output_dir ...
    return f"saved to {output_dir}"

For full Orbax-managed auto-resume with JAX or Keras, the canonical runnable examples live in the repo:

Outputs and checkpoints#

A Kinetic job produces three distinct kinds of artifact, each with its own storage location and lifecycle:

Artifact

What it is

Where it lives

Job return value

The Python value your function returns

Persisted to gs://{bucket}/{job_id}/result.pkl, then downloaded to your local process

Durable outputs

Files you wrote during the run

KINETIC_OUTPUT_DIR (GCS)

Resumable checkpoints

Periodic state snapshots for restart

KINETIC_OUTPUT_DIR/<your-subdir> (GCS)

The return value is the right channel for small results: a final loss, a metric dict, a path string. Large files belong on the output dir; checkpoints belong on a stable subpath under the output dir so restarts can find them.

KINETIC_OUTPUT_DIR is set automatically when the job starts. By default it resolves to the jobs bucket for your cluster:

gs://{project}-kn-{cluster}-jobs/outputs/{job_id}

{project} is your GCP project (from KINETIC_PROJECT) and {cluster} is the Kinetic cluster name (from KINETIC_CLUSTER, defaulting to kinetic-cluster). The bucket is created by kinetic up and reused across all jobs submitted to that cluster.

You can override it per job by passing output_dir= to the decorator, setting KINETIC_OUTPUT_DIR in your local environment before submission, or (when inspecting an existing job from the CLI) passing --output-dir to the relevant kinetic jobs subcommand. See the precedence table in Configuration for how these resolution paths combine.

TTL and retention#

By default the GCS bucket Kinetic creates has a 30-day TTL on its contents. Anything written to KINETIC_OUTPUT_DIR is auto-deleted after 30 days. That’s the right default for ephemeral training, but if you want a checkpoint to outlive a month:

  • Copy it to a bucket with no lifecycle policy (gsutil cp or the GCS client library).

  • Or set output_dir= to a bucket you manage yourself, with whatever lifecycle rules you want.

JobHandle.cleanup(gcs=True) removes the per-job artifacts under the GCS prefix used for code and result payloads — it does not touch files you wrote under KINETIC_OUTPUT_DIR. Outputs survive cleanup.

Copy-paste checklist#

A short checklist for any long-running job that you don’t want to redo from scratch:

  • [ ] Read KINETIC_OUTPUT_DIR inside the function and write everything durable under it.

  • [ ] Write checkpoints to a stable subdirectory (e.g. $KINETIC_OUTPUT_DIR/checkpoints/) so the resume path is predictable.

  • [ ] Choose a checkpoint cadence that bounds how much work a restart would lose (every N steps, or every M minutes).

  • [ ] Verify resume works locally before the long run — submit the same function twice with the same output_dir and confirm the second call picks up where the first left off.

  • [ ] If the run is critical, copy the final artifacts to a bucket without the 30-day TTL after success.

JAX example#

import os

# Set backend to JAX before any keras imports
os.environ["KERAS_BACKEND"] = "jax"

import kinetic


@kinetic.run(accelerator="cpu")
def train_with_checkpoints():
  """Demo function showing Orbax checkpointing with Kinetic and Auto-Resume."""
  import jax.numpy as jnp
  import orbax.checkpoint as ocp

  output_dir = os.environ.get("KINETIC_OUTPUT_DIR")
  print(f"\n--- Kinetic Output Dir: {output_dir} ---")

  if not output_dir:
    # Fallback for local testing if run without kinetic context
    output_dir = "/tmp/local_checkpoints"
    print(f"No KINETIC_OUTPUT_DIR found, using local: {output_dir}")

  # Initialize Orbax CheckpointManager
  options = ocp.CheckpointManagerOptions(max_to_keep=2)
  mngr = ocp.CheckpointManager(
    output_dir, ocp.StandardCheckpointer(), options=options
  )

  # Orbax handles discovery + restore natively
  latest = mngr.latest_step()
  if latest is not None:
    print(f"Found latest checkpoint for step {latest}. Restoring...")
    state = mngr.restore(latest)
    start_step = latest + 1
  else:
    print("No checkpoint found. Starting from scratch (step 0).")
    state = {
      "step": 0,
      "weights": jnp.ones((10, 10)),
      "bias": jnp.zeros((10,)),
    }
    start_step = 0

  print(f"--- Starting from step: {start_step} ---\n")

  # Simulated training loop (run 3 steps)
  end_step = start_step + 3
  print(f"Will run steps from {start_step} to {end_step - 1}")

  for step in range(start_step, end_step):
    print(f"\n>> Simulating Step {step}...")
    state["step"] = step
    # Change weights so we can see they resume correctly
    state["weights"] = jnp.ones((10, 10)) * (step + 1)

    print(f"Saving checkpoint at step {step}...")
    mngr.save(step, state)
    mngr.wait_until_finished()
    print(f"Checkpoint for step {step} saved successfully.")

  # Verify by restoring the latest step
  latest_step = mngr.latest_step()
  print(f"\nVerifying by restoring latest step ({latest_step})...")
  if latest_step is not None:
    restored_state = mngr.restore(latest_step)
    assert restored_state["step"] == latest_step
    print(f"Verified: Restored state step matches latest step {latest_step}!")

  return True


if __name__ == "__main__":
  print("Starting Orbax checkpointing demo...")
  success = train_with_checkpoints()
  print(f"Demo run success: {success}")

After the snippet:

  • The function reads KINETIC_OUTPUT_DIR and points Orbax’s CheckpointManager at it.

  • Calling the function a second time picks up from the latest step rather than restarting from scratch.

Keras example#

import os

# Set backend to JAX before any keras imports
os.environ["KERAS_BACKEND"] = "jax"

import kinetic


@kinetic.run(accelerator="cpu")
def train_keras_with_checkpoints():
  """Demo function showing Orbax checkpointing with a Keras model and Auto-Resume."""
  import keras
  import numpy as np
  import orbax.checkpoint as ocp

  output_dir = os.environ.get("KINETIC_OUTPUT_DIR")
  print(f"\n--- Kinetic Output Dir: {output_dir} ---")

  if not output_dir:
    # Fallback for local testing if run without kinetic context
    output_dir = "/tmp/local_keras_checkpoints"
    print(f"No KINETIC_OUTPUT_DIR found, using local: {output_dir}")

  # Define a simple Keras model
  model = keras.Sequential(
    [
      keras.layers.Input(shape=(10,)),
      keras.layers.Dense(32, activation="relu"),
      keras.layers.Dense(1),
    ]
  )
  model.compile(optimizer="adam", loss="mse")

  # Initialize Orbax CheckpointManager
  options = ocp.CheckpointManagerOptions(max_to_keep=2)
  mngr = ocp.CheckpointManager(
    output_dir, ocp.StandardCheckpointer(), options=options
  )

  # Orbax handles discovery + restore natively. model.get_weights() returns a
  # list of numpy arrays, which Orbax treats as a PyTree.
  latest = mngr.latest_step()
  if latest is not None:
    print(f"Found latest checkpoint for epoch {latest}. Restoring...")
    state = mngr.restore(latest)
    model.set_weights(state["weights"])
    start_epoch = latest + 1
  else:
    print("No checkpoint found. Starting from scratch (epoch 0).")
    start_epoch = 0

  print(f"--- Starting from epoch: {start_epoch} ---\n")

  # Dummy data
  x_train = np.random.randn(256, 10).astype("float32")
  y_train = np.random.randn(256, 1).astype("float32")

  # Simulated training loop (run 3 epochs)
  end_epoch = start_epoch + 3
  print(f"Will run epochs from {start_epoch} to {end_epoch - 1}")

  for epoch in range(start_epoch, end_epoch):
    print(f"\n>> Training epoch {epoch}...")
    history = model.fit(x_train, y_train, epochs=1, verbose=0)
    loss = history.history["loss"][-1]
    print(f"epoch {epoch}: loss={loss:.4f}")

    state = {
      "epoch": epoch,
      "weights": model.get_weights(),
    }

    print(f"Saving checkpoint at epoch {epoch}...")
    mngr.save(epoch, state)
    mngr.wait_until_finished()
    print(f"Checkpoint for epoch {epoch} saved successfully.")

  # Verify by restoring the latest step
  latest_step = mngr.latest_step()
  print(f"\nVerifying by restoring latest epoch ({latest_step})...")
  if latest_step is not None:
    restored_state = mngr.restore(latest_step)
    model.set_weights(restored_state["weights"])
    assert restored_state["epoch"] == latest_step
    print(f"Verified: Restored model weights match epoch {latest_step}!")

  return True


if __name__ == "__main__":
  print("Starting Keras + Orbax checkpointing demo...")
  success = train_keras_with_checkpoints()
  print(f"Demo run success: {success}")

After the snippet:

  • model.get_weights() produces a PyTree of NumPy arrays that Orbax knows how to save.

  • model.set_weights() restores them on resume.