Distributed Gemma 2B fine-tune

Distributed Gemma 2B fine-tune#

import os

# JAX must be set as the backend before importing Keras
os.environ["KERAS_BACKEND"] = "jax"

import jax
import keras
import keras_hub

import kinetic


@kinetic.run(
  accelerator="tpu-v5litepod-2x4",
  backend="pathways",
  capture_env_vars=["KAGGLE_USERNAME", "KAGGLE_KEY"],
)
def finetune_gemma_distributed():
  """
  Distributed Finetuning of Gemma using Kinetic + ML Pathways on TPUs.
  This mirrors the Kaggle DataParallel fine-tuning methodology, but executed remotely.
  """
  print(
    f"Number of JAX devices available across all hosts: {jax.device_count()}"
  )

  # 1. Setup Data Parallel Distribution using a DeviceMesh
  # We construct a 1D mesh wrapping all available TPUs in the Pathways cluster
  devices = keras.distribution.list_devices()
  device_mesh = keras.distribution.DeviceMesh(
    shape=(len(devices),),
    axis_names=["batch"],
    devices=devices,
  )

  # Set global distribution to DataParallel
  keras.distribution.set_distribution(
    keras.distribution.DataParallel(device_mesh=device_mesh)
  )
  print(f"Global distribution set to DataParallel across {len(devices)} TPUs.")

  # 2. Load Gemma Model via Keras Hub
  print("Loading Gemma 2B model...")
  gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")

  # 3. Enable LoRA (Low-Rank Adaptation)
  # This drastically reduces the number of trainable parameters
  print("Enabling LoRA (rank=4)...")
  gemma_lm.backbone.enable_lora(rank=4)
  gemma_lm.summary()

  # 4. Prepare Fine-tuning Data
  # In a real environment, you might load TFRecords from GCS here.
  data = [
    "Patient: I have a sore throat and slight fever. Doctor: You might have a mild infection. Make sure to rest and drink fluids.",
    "Patient: My ankle hurts when I put weight on it. Doctor: It sounds like a sprain. Try keeping it elevated and apply ice.",
    "Patient: I've been feeling very tired lately. Doctor: Fatigue can be caused by many things. Are you getting enough sleep?",
    "Patient: I have a rash on my arm. Doctor: Is it itchy or painful? I can prescribe a topical cream.",
  ]
  # Duplicate data to fill batches across multiple TPUs
  train_data = data * 8

  # 5. Compile the Model
  optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
  )
  # Exclude LayerNorms and biases from weight decay per standard practice
  optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

  gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
  )

  # 6. Run Distributed Fine-tuning
  print("\nStarting Distributed Fine-Tuning across Pathways Cluster...")
  gemma_lm.fit(train_data, batch_size=len(devices), epochs=3)
  print("\nTraining completed successfully!")

  # 7. Local Evaluation Test
  prompt = (
    "Patient: I've had a headache that won't go away for two days. Doctor: "
  )
  output = gemma_lm.generate([prompt], max_length=64)

  gen_output = f"\nGenerative Output Test:\n{output[0]}"
  print(gen_output)

  return gen_output


if __name__ == "__main__":
  # Make sure you have exported your Kaggle credentials locally so they are sent to the cluster:
  # export KAGGLE_USERNAME="your_username"
  # export KAGGLE_KEY="your_key"

  print("Submitting Gemma Distributed Finetuning job to Pathways Backend...")
  result = finetune_gemma_distributed()
  print("Job Finished With Result:", result)