Single-TPU Gemma 3 fine-tune

Single-TPU Gemma 3 fine-tune#

import os

# JAX must be set as the backend before importing Keras
os.environ["KERAS_BACKEND"] = "jax"

import keras_hub

import kinetic


@kinetic.run(
  accelerator="tpu-v5litepod-1",
  capture_env_vars=["KAGGLE_USERNAME", "KAGGLE_KEY"],
)
def train_gemma():
  # Data for SFT
  print("Starting Gemma 3 SFT training...")
  features = {
    "prompts": ["Capital of India?", "Capital of South Africa?"],
    "responses": ["New Delhi", "Pretoria"],
  }
  print("Data prepared.")
  gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_1b")
  print("Model initialized.")
  # Fine-tune
  gemma_lm.fit(x=features, batch_size=1)

  print("Gemma 3 SFT training done")


if __name__ == "__main__":
  train_gemma()