Resumable Keras training

Resumable Keras training#

import os

# Set backend to JAX before any keras imports
os.environ["KERAS_BACKEND"] = "jax"

import kinetic


@kinetic.run(accelerator="cpu")
def train_keras_with_checkpoints():
  """Demo function showing Orbax checkpointing with a Keras model and Auto-Resume."""
  import keras
  import numpy as np
  import orbax.checkpoint as ocp

  output_dir = os.environ.get("KINETIC_OUTPUT_DIR")
  print(f"\n--- Kinetic Output Dir: {output_dir} ---")

  if not output_dir:
    # Fallback for local testing if run without kinetic context
    output_dir = "/tmp/local_keras_checkpoints"
    print(f"No KINETIC_OUTPUT_DIR found, using local: {output_dir}")

  # Define a simple Keras model
  model = keras.Sequential(
    [
      keras.layers.Input(shape=(10,)),
      keras.layers.Dense(32, activation="relu"),
      keras.layers.Dense(1),
    ]
  )
  model.compile(optimizer="adam", loss="mse")

  # Initialize Orbax CheckpointManager
  options = ocp.CheckpointManagerOptions(max_to_keep=2)
  mngr = ocp.CheckpointManager(
    output_dir, ocp.StandardCheckpointer(), options=options
  )

  # Orbax handles discovery + restore natively. model.get_weights() returns a
  # list of numpy arrays, which Orbax treats as a PyTree.
  latest = mngr.latest_step()
  if latest is not None:
    print(f"Found latest checkpoint for epoch {latest}. Restoring...")
    state = mngr.restore(latest)
    model.set_weights(state["weights"])
    start_epoch = latest + 1
  else:
    print("No checkpoint found. Starting from scratch (epoch 0).")
    start_epoch = 0

  print(f"--- Starting from epoch: {start_epoch} ---\n")

  # Dummy data
  x_train = np.random.randn(256, 10).astype("float32")
  y_train = np.random.randn(256, 1).astype("float32")

  # Simulated training loop (run 3 epochs)
  end_epoch = start_epoch + 3
  print(f"Will run epochs from {start_epoch} to {end_epoch - 1}")

  for epoch in range(start_epoch, end_epoch):
    print(f"\n>> Training epoch {epoch}...")
    history = model.fit(x_train, y_train, epochs=1, verbose=0)
    loss = history.history["loss"][-1]
    print(f"epoch {epoch}: loss={loss:.4f}")

    state = {
      "epoch": epoch,
      "weights": model.get_weights(),
    }

    print(f"Saving checkpoint at epoch {epoch}...")
    mngr.save(epoch, state)
    mngr.wait_until_finished()
    print(f"Checkpoint for epoch {epoch} saved successfully.")

  # Verify by restoring the latest step
  latest_step = mngr.latest_step()
  print(f"\nVerifying by restoring latest epoch ({latest_step})...")
  if latest_step is not None:
    restored_state = mngr.restore(latest_step)
    model.set_weights(restored_state["weights"])
    assert restored_state["epoch"] == latest_step
    print(f"Verified: Restored model weights match epoch {latest_step}!")

  return True


if __name__ == "__main__":
  print("Starting Keras + Orbax checkpointing demo...")
  success = train_keras_with_checkpoints()
  print(f"Demo run success: {success}")