Checkpointing and Outputs#
Long jobs need somewhere durable to write to. Pods come and go — when
your training script exits, anything that lived only inside the pod’s
filesystem is gone. Kinetic gives you KINETIC_OUTPUT_DIR: a per-job
GCS prefix that survives the pod, so your checkpoints, logs, and final
artifacts are still there when you come back.
This page covers what to write where, how Orbax (or any other library) plugs into it, and how cleanup and TTLs work.
A first checkpointed job#
Inside the pod, KINETIC_OUTPUT_DIR is already set. Read it and write
under it. Fall back to a local path when the variable is not present so
that the same function works when you exercise it locally:
import os
import kinetic
@kinetic.run(accelerator="cpu")
def train():
# Remote: KINETIC_OUTPUT_DIR resolves to gs://.../outputs/<job_id>.
# Local: fall back to a filesystem path under /tmp so the same code
# works when you run the function directly for testing.
output_dir = os.environ.get("KINETIC_OUTPUT_DIR", "/tmp/local_checkpoints")
# ... train and write checkpoints/artifacts under output_dir ...
return f"saved to {output_dir}"
For full Orbax-managed auto-resume with JAX or Keras, the canonical runnable examples live in the repo:
examples/example_checkpoint.py— JAX + Orbax with auto-resume.examples/example_keras_checkpoint.py— same pattern usingmodel.get_weights()/set_weights().
Outputs and checkpoints#
A Kinetic job produces three distinct kinds of artifact, each with its own storage location and lifecycle:
Artifact |
What it is |
Where it lives |
|---|---|---|
Job return value |
The Python value your function returns |
Persisted to |
Durable outputs |
Files you wrote during the run |
|
Resumable checkpoints |
Periodic state snapshots for restart |
|
The return value is the right channel for small results: a final loss, a metric dict, a path string. Large files belong on the output dir; checkpoints belong on a stable subpath under the output dir so restarts can find them.
KINETIC_OUTPUT_DIR is set automatically when the job starts. By
default it resolves to the jobs bucket for your cluster:
gs://{project}-kn-{cluster}-jobs/outputs/{job_id}
{project} is your GCP project (from KINETIC_PROJECT) and {cluster}
is the Kinetic cluster name (from KINETIC_CLUSTER, defaulting to
kinetic-cluster). The bucket is created by kinetic up and reused
across all jobs submitted to that cluster.
You can override it per job by passing output_dir= to the decorator,
setting KINETIC_OUTPUT_DIR in your local environment before
submission, or (when inspecting an existing job from the CLI) passing
--output-dir to the relevant kinetic jobs subcommand. See the
precedence table in Configuration for how these
resolution paths combine.
Recommended directory layout#
A simple convention that scales from one job to many:
$KINETIC_OUTPUT_DIR/
├── checkpoints/ # Orbax / model.save_weights — periodic snapshots
├── logs/ # extra logs your code writes (stdout already streams)
├── metrics/ # tensorboard / json metric dumps
└── final/ # post-training artifacts: exported model, eval results
Use whichever subdirectories make sense for your workflow. The point is that the layout is yours to control — Kinetic only cares that you write under the prefix it gave you.
TTL and retention#
By default the GCS bucket Kinetic creates has a 30-day TTL on its
contents. Anything written to KINETIC_OUTPUT_DIR is auto-deleted
after 30 days. That’s the right default for ephemeral training, but if
you want a checkpoint to outlive a month:
Copy it to a bucket with no lifecycle policy (
gsutil cpor the GCS client library).Or set
output_dir=to a bucket you manage yourself, with whatever lifecycle rules you want.
JobHandle.cleanup(gcs=True) removes the per-job artifacts under the
GCS prefix used for code and result payloads — it does not touch
files you wrote under KINETIC_OUTPUT_DIR. Outputs survive cleanup.
Copy-paste checklist#
A short checklist for any long-running job that you don’t want to redo from scratch:
[ ] Read
KINETIC_OUTPUT_DIRinside the function and write everything durable under it.[ ] Write checkpoints to a stable subdirectory (e.g.
$KINETIC_OUTPUT_DIR/checkpoints/) so the resume path is predictable.[ ] Choose a checkpoint cadence that bounds how much work a restart would lose (every N steps, or every M minutes).
[ ] Verify resume works locally before the long run — submit the same function twice with the same
output_dirand confirm the second call picks up where the first left off.[ ] If the run is critical, copy the final artifacts to a bucket without the 30-day TTL after success.
JAX example#
import os
# Set backend to JAX before any keras imports
os.environ["KERAS_BACKEND"] = "jax"
import kinetic
@kinetic.run(accelerator="cpu")
def train_with_checkpoints():
"""Demo function showing Orbax checkpointing with Kinetic and Auto-Resume."""
import jax.numpy as jnp
import orbax.checkpoint as ocp
output_dir = os.environ.get("KINETIC_OUTPUT_DIR")
print(f"\n--- Kinetic Output Dir: {output_dir} ---")
if not output_dir:
# Fallback for local testing if run without kinetic context
output_dir = "/tmp/local_checkpoints"
print(f"No KINETIC_OUTPUT_DIR found, using local: {output_dir}")
# Initialize Orbax CheckpointManager
options = ocp.CheckpointManagerOptions(max_to_keep=2)
mngr = ocp.CheckpointManager(
output_dir, ocp.StandardCheckpointer(), options=options
)
# Orbax handles discovery + restore natively
latest = mngr.latest_step()
if latest is not None:
print(f"Found latest checkpoint for step {latest}. Restoring...")
state = mngr.restore(latest)
start_step = latest + 1
else:
print("No checkpoint found. Starting from scratch (step 0).")
state = {
"step": 0,
"weights": jnp.ones((10, 10)),
"bias": jnp.zeros((10,)),
}
start_step = 0
print(f"--- Starting from step: {start_step} ---\n")
# Simulated training loop (run 3 steps)
end_step = start_step + 3
print(f"Will run steps from {start_step} to {end_step - 1}")
for step in range(start_step, end_step):
print(f"\n>> Simulating Step {step}...")
state["step"] = step
# Change weights so we can see they resume correctly
state["weights"] = jnp.ones((10, 10)) * (step + 1)
print(f"Saving checkpoint at step {step}...")
mngr.save(step, state)
mngr.wait_until_finished()
print(f"Checkpoint for step {step} saved successfully.")
# Verify by restoring the latest step
latest_step = mngr.latest_step()
print(f"\nVerifying by restoring latest step ({latest_step})...")
if latest_step is not None:
restored_state = mngr.restore(latest_step)
assert restored_state["step"] == latest_step
print(f"Verified: Restored state step matches latest step {latest_step}!")
return True
if __name__ == "__main__":
print("Starting Orbax checkpointing demo...")
success = train_with_checkpoints()
print(f"Demo run success: {success}")
After the snippet:
The function reads
KINETIC_OUTPUT_DIRand points Orbax’sCheckpointManagerat it.Calling the function a second time picks up from the latest step rather than restarting from scratch.
Keras example#
import os
# Set backend to JAX before any keras imports
os.environ["KERAS_BACKEND"] = "jax"
import kinetic
@kinetic.run(accelerator="cpu")
def train_keras_with_checkpoints():
"""Demo function showing Orbax checkpointing with a Keras model and Auto-Resume."""
import keras
import numpy as np
import orbax.checkpoint as ocp
output_dir = os.environ.get("KINETIC_OUTPUT_DIR")
print(f"\n--- Kinetic Output Dir: {output_dir} ---")
if not output_dir:
# Fallback for local testing if run without kinetic context
output_dir = "/tmp/local_keras_checkpoints"
print(f"No KINETIC_OUTPUT_DIR found, using local: {output_dir}")
# Define a simple Keras model
model = keras.Sequential(
[
keras.layers.Input(shape=(10,)),
keras.layers.Dense(32, activation="relu"),
keras.layers.Dense(1),
]
)
model.compile(optimizer="adam", loss="mse")
# Initialize Orbax CheckpointManager
options = ocp.CheckpointManagerOptions(max_to_keep=2)
mngr = ocp.CheckpointManager(
output_dir, ocp.StandardCheckpointer(), options=options
)
# Orbax handles discovery + restore natively. model.get_weights() returns a
# list of numpy arrays, which Orbax treats as a PyTree.
latest = mngr.latest_step()
if latest is not None:
print(f"Found latest checkpoint for epoch {latest}. Restoring...")
state = mngr.restore(latest)
model.set_weights(state["weights"])
start_epoch = latest + 1
else:
print("No checkpoint found. Starting from scratch (epoch 0).")
start_epoch = 0
print(f"--- Starting from epoch: {start_epoch} ---\n")
# Dummy data
x_train = np.random.randn(256, 10).astype("float32")
y_train = np.random.randn(256, 1).astype("float32")
# Simulated training loop (run 3 epochs)
end_epoch = start_epoch + 3
print(f"Will run epochs from {start_epoch} to {end_epoch - 1}")
for epoch in range(start_epoch, end_epoch):
print(f"\n>> Training epoch {epoch}...")
history = model.fit(x_train, y_train, epochs=1, verbose=0)
loss = history.history["loss"][-1]
print(f"epoch {epoch}: loss={loss:.4f}")
state = {
"epoch": epoch,
"weights": model.get_weights(),
}
print(f"Saving checkpoint at epoch {epoch}...")
mngr.save(epoch, state)
mngr.wait_until_finished()
print(f"Checkpoint for epoch {epoch} saved successfully.")
# Verify by restoring the latest step
latest_step = mngr.latest_step()
print(f"\nVerifying by restoring latest epoch ({latest_step})...")
if latest_step is not None:
restored_state = mngr.restore(latest_step)
model.set_weights(restored_state["weights"])
assert restored_state["epoch"] == latest_step
print(f"Verified: Restored model weights match epoch {latest_step}!")
return True
if __name__ == "__main__":
print("Starting Keras + Orbax checkpointing demo...")
success = train_keras_with_checkpoints()
print(f"Demo run success: {success}")
After the snippet:
model.get_weights()produces a PyTree of NumPy arrays that Orbax knows how to save.model.set_weights()restores them on resume.