Submit, monitor, and reattach

Submit, monitor, and reattach#

"""
Example: Async Jobs with Kinetic

This demonstrates the run_async/attach/list workflow for detached execution.
Instead of blocking until the remote function finishes,
run_async returns a JobHandle immediately so you can monitor,
reattach from another session, or manage multiple jobs concurrently.

Prerequisites:
1. A GKE cluster with CPU node pool (default setup works)
2. kubectl configured to access the cluster
3. KINETIC_PROJECT environment variable set

Workflow overview:
    1. run_async() → fire-and-forget, get a JobHandle back instantly
    2. status()   → poll the job without blocking
    3. logs()     → fetch or stream logs
    4. result()   → block until completion and collect the return value
    5. attach()   → reattach to a job from a different session using its ID
    6. list_jobs()→ discover all live jobs on the cluster
"""

import os
import time

os.environ["KERAS_BACKEND"] = "jax"

import keras
import numpy as np

import kinetic

# ---------------------------------------------------------------------------
# 1. Define functions using @kinetic.run
# ---------------------------------------------------------------------------


@kinetic.run(accelerator="cpu")
def train_model_a():
  """Train a small dense model — returns final loss."""
  model = keras.Sequential(
    [
      keras.layers.Dense(64, activation="relu", input_shape=(10,)),
      keras.layers.Dense(64, activation="relu"),
      keras.layers.Dense(1),
    ]
  )
  model.compile(optimizer="adam", loss="mse")

  x = np.random.randn(1000, 10)
  y = np.random.randn(1000, 1)

  print("Training model A...")
  history = model.fit(x, y, epochs=5, batch_size=32, verbose=1)
  final_loss = history.history["loss"][-1]
  print(f"Model A done — loss: {final_loss}")
  return final_loss


@kinetic.run(accelerator="cpu")
def train_model_b():
  """Train a slightly larger model — returns final loss."""
  model = keras.Sequential(
    [
      keras.layers.Dense(128, activation="relu", input_shape=(20,)),
      keras.layers.Dense(128, activation="relu"),
      keras.layers.Dense(1),
    ]
  )
  model.compile(optimizer="adam", loss="mse")

  x = np.random.randn(2000, 20)
  y = np.random.randn(2000, 1)

  print("Training model B...")
  history = model.fit(x, y, epochs=10, batch_size=64, verbose=1)
  final_loss = history.history["loss"][-1]
  print(f"Model B done — loss: {final_loss}")
  return final_loss


# ---------------------------------------------------------------------------
# 2. Submit both jobs (non-blocking)
# ---------------------------------------------------------------------------


def demo_submit_and_monitor():
  """Submit two jobs and monitor them until completion."""
  print("=" * 60)
  print("Submitting two training jobs...")
  print("=" * 60)

  job_a = train_model_a.run_async()
  job_b = train_model_b.run_async()

  print(f"\nJob A: id={job_a.job_id}")
  print(f"Job B: id={job_b.job_id}")

  # -----------------------------------------------------------------
  # 3. Poll status until both are running
  # -----------------------------------------------------------------
  print("\n--- Polling status ---")
  for _ in range(30):
    sa = job_a.status()
    sb = job_b.status()
    print(f"  A: {sa.value}  |  B: {sb.value}")
    if sa.value in ("RUNNING", "SUCCEEDED") and sb.value in (
      "RUNNING",
      "SUCCEEDED",
    ):
      break
    time.sleep(5)

  # -----------------------------------------------------------------
  # 4. Tail recent logs
  # -----------------------------------------------------------------
  print("\n--- Last 20 log lines from Job A ---")
  print(job_a.tail(n=20))

  # -----------------------------------------------------------------
  # 5. Collect results (blocks until done)
  # -----------------------------------------------------------------
  print("\n--- Collecting results ---")
  loss_a = job_a.result(cleanup=False)
  loss_b = job_b.result(cleanup=False)
  print(f"Job A loss: {loss_a}")
  print(f"Job B loss: {loss_b}")

  return job_a, job_b


# ---------------------------------------------------------------------------
# 6. Reattach from a "different session"
# ---------------------------------------------------------------------------


def demo_reattach(job_id: str):
  """Simulate reattaching to a completed job from another session."""
  print("\n" + "=" * 60)
  print(f"Reattaching to job {job_id} ...")
  print("=" * 60)

  job = kinetic.attach(job_id)
  print(f"Status: {job.status().value}")
  print(f"Function: {job.func_name}")
  print(f"Backend:  {job.backend}")


# ---------------------------------------------------------------------------
# 7. List all live jobs
# ---------------------------------------------------------------------------


def demo_list_jobs():
  """List every kinetic job currently visible on the cluster."""
  print("\n" + "=" * 60)
  print("Listing live jobs...")
  print("=" * 60)

  jobs = kinetic.list_jobs()
  if not jobs:
    print("  (no live jobs)")
  for j in jobs:
    print(f"  {j.job_id}  {j.func_name:30s}  {j.accelerator}")


# ---------------------------------------------------------------------------
# 8. Cleanup
# ---------------------------------------------------------------------------


def demo_cleanup(job_a, job_b):
  """Clean up k8s resources and GCS artifacts for both jobs."""
  print("\n" + "=" * 60)
  print("Cleaning up resources...")
  print("=" * 60)

  job_a.cleanup(k8s=True, gcs=True)
  print(f"  Cleaned up Job A ({job_a.job_id})")
  job_b.cleanup(k8s=True, gcs=True)
  print(f"  Cleaned up Job B ({job_b.job_id})")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

if __name__ == "__main__":
  job_a, job_b = demo_submit_and_monitor()
  demo_reattach(job_a.job_id)
  demo_list_jobs()
  demo_cleanup(job_a, job_b)
  print("\nDone!")