Tunix SFT Example

Tunix SFT Example#

# -*- coding: utf-8 -*-
"""Tunix SFT Guide Script

Adapted for local execution outside of Google Colab and launched on remote TPU v6e-8 slice via Kinetic.
"""

import json
import logging
import os

import etils.epath as _epath
import jax
import nest_asyncio
import optax
import qwix
import wandb
from dotenv import load_dotenv
from flax import nnx
from huggingface_hub import snapshot_download
from tunix.examples.data import translation_dataset as data_lib
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma3_model_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.sft import metrics_logger, peft_trainer, utils
from tunix.sft.utils import show_hbm_usage

import kinetic
import kinetic.credentials

load_dotenv()
kinetic.credentials.ensure_credentials = lambda *args, **kwargs: None

# Monkey-patch etils.epath to ignore mode argument in mkdir.
# This is a workaround for permission issues in some environments when creating directories.
_orig_mkdir = _epath.Path.mkdir


def safe_mkdir(self, mode=0o777, parents=False, exist_ok=False):
  return _orig_mkdir(self, parents=parents, exist_ok=exist_ok)


_epath.Path.mkdir = safe_mkdir

nest_asyncio.apply()

if "WANDB_API_KEY" in os.environ and os.environ["WANDB_API_KEY"]:
  wandb.login(key=os.environ["WANDB_API_KEY"])
else:
  os.environ["WANDB_MODE"] = "disabled"
  logging.info("WANDB_API_KEY not found. Running wandb in disabled mode.")

if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  logging.info(
    "KAGGLE credentials not found. Skipping interactive kagglehub login."
  )

if "HF_TOKEN" in os.environ and os.environ["HF_TOKEN"]:
  logging.info("HF_TOKEN found in environment.")
else:
  logging.info("HF_TOKEN not found. Ensure Hugging Face is authenticated.")

logger = logging.getLogger()
logger.setLevel(logging.INFO)

model_id = "google/gemma-3-270m-it"
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"

# Data
BATCH_SIZE = 32  # Adjust based on TPU memory & model size.
MAX_TARGET_LENGTH = 256  # Adjusted based on your TPU memory and model size.

# Model Setup
# Adjust mesh based on your TPU memory and model size.
# MESH_COUNTS defines the number of devices along each axis of the mesh.
# The axes are named ("fsdp", "tp") where fsdp is Fully Sharded Data Parallel
# and tp is Tensor Parallel.
NUM_TPUS = len(jax.devices())
if NUM_TPUS == 8:
  MESH_COUNTS = (1, 4)  # 1 for fsdp, 4 for tp
elif NUM_TPUS == 4:
  MESH_COUNTS = (1, 4)  # Use all 4 devices for tensor parallel
elif NUM_TPUS == 1:
  MESH_COUNTS = (1, 1)
else:
  raise ValueError(f"Unsupported number of TPUs: {NUM_TPUS}")

MESH = [
  MESH_COUNTS,
  ("fsdp", "tp"),
]

# LoRA/QLoRA Configuration
USE_QUANTIZATION = True  # Set to True for QLoRA, False for LoRA
RANK = 16
ALPHA = float(2 * RANK)

# Train
MAX_STEPS = 100
EVAL_EVERY_N_STEPS = 20
NUM_EPOCHS = 3

# Checkpoint saving
FULL_CKPT_DIR = "/tmp/content/full_ckpts/"
LORA_CKPT_DIR = "/tmp/content/lora_ckpts/"
PROFILING_DIR = "/tmp/content/profiling/"


def create_dir(path):
  try:
    os.makedirs(path, exist_ok=True)
    logging.info(f"Created dir: {path}")
  except OSError as e:
    logging.error(f"Error creating directory '{path}': {e}")


@kinetic.run(
  accelerator="tpu-v5litepod",
  capture_env_vars=["KAGGLE_USERNAME", "KAGGLE_KEY", "HF_TOKEN", "WANDB_MODE"],
)
def run_tuning():
  create_dir(FULL_CKPT_DIR)
  create_dir(LORA_CKPT_DIR)
  create_dir(PROFILING_DIR)

  ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
  ]
  logging.info(f"Downloading {model_id} from Hugging Face...")
  local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns
  )
  logging.info(f"Model successfully downloaded to: {local_model_path}")

  eos_tokens = []
  generation_config_path = os.path.join(
    local_model_path, "generation_config.json"
  )
  if os.path.exists(generation_config_path):
    with open(generation_config_path, "r") as f:
      generation_configs = json.load(f)
    eos_tokens = generation_configs.get("eos_token_id", [])
    logging.info(f"Using EOS token IDs: {eos_tokens}")

  logging.info("\n--- HBM Usage BEFORE Model Load ---")
  show_hbm_usage()

  model_cp_path = local_model_path

  if "gemma-3-270m" in model_id:
    model_config = gemma3_model_lib.ModelConfig.gemma3_270m()
  elif "gemma-3-1b" in model_id:
    model_config = gemma3_model_lib.ModelConfig.gemma3_1b_it()
  else:
    raise ValueError(f"Unsupported model: {model_id}")

  mesh = jax.make_mesh(
    *MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
  )
  with mesh:
    base_model = params_safetensors_lib.create_model_from_safe_tensors(
      model_cp_path, (model_config), mesh
    )

  tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
  if tokenizer.eos_id() not in eos_tokens:
    eos_tokens.append(tokenizer.eos_id())
    logging.info(f"Using EOS token IDs: {eos_tokens}")

  sampler = sampler_lib.Sampler(
    transformer=base_model,
    tokenizer=tokenizer if "gemma" in model_id else tokenizer.tokenizer,
    cache_config=sampler_lib.CacheConfig(
      cache_size=256,
      num_layers=model_config.num_layers,
      num_kv_heads=model_config.num_kv_heads,
      head_dim=model_config.head_dim,
    ),
  )

  input_batch = [
    "Translate this into French:\nHello, my name is Morgane.\n",
    "Translate this into French:\nThis dish is delicious!\n",
    "Translate this into French:\nI am a student.\n",
    "Translate this into French:\nHow's the weather today?\n",
  ]

  out_data = sampler(
    input_strings=input_batch,
    max_generation_steps=10,  # The number of steps performed when generating a response.
    eos_tokens=eos_tokens,
  )

  for input_string, out_string in zip(input_batch, out_data.text, strict=True):
    logging.info("----------------------")
    logging.info(f"Prompt:\n{input_string}")
    logging.info(f"Output:\n{out_string}")

  # Define a helper function to apply LoRA (or QLoRA) to the model.
  # This uses the 'qwix' library to inject low-rank adapters into specified layers.
  def get_lora_model(base_model, mesh, quantize=False):
    if quantize:
      # QLoRA uses 4-bit NormalFloat (nf4) quantization for base model weights
      lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
        rank=RANK,
        alpha=ALPHA,
        weight_qtype="nf4",
        tile_size=128,
      )
    else:
      # Standard LoRA keeps weights in original precision
      lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
        rank=RANK,
        alpha=ALPHA,
      )

    model_input = base_model.get_model_input()
    # Apply LoRA to the base model
    lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
    )

    # Ensure the LoRA model parameters are sharded according to the mesh
    with mesh:
      state = nnx.state(lora_model)
      pspecs = nnx.get_partition_spec(state)
      sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
      nnx.update(lora_model, sharded_state)

    return lora_model

  # Create LoRA or QLoRA model based on USE_QUANTIZATION hyperparameter
  lora_model = get_lora_model(base_model, mesh=mesh, quantize=USE_QUANTIZATION)

  logging.info(f"Using {'QLoRA' if USE_QUANTIZATION else 'LoRA'} model")

  # Loads the training and validation datasets
  train_ds, validation_ds = data_lib.create_datasets(
    dataset_name="mtnt/en-fr",
    global_batch_size=BATCH_SIZE,
    max_target_length=MAX_TARGET_LENGTH,
    num_train_epochs=NUM_EPOCHS,
    tokenizer=tokenizer,
  )

  def gen_model_input_fn(x: peft_trainer.TrainingInput):
    pad_mask = x.input_tokens != tokenizer.pad_id()
    positions = utils.build_positions_from_mask(pad_mask)
    attention_mask = utils.make_causal_attn_mask(pad_mask)
    return {
      "input_tokens": x.input_tokens,
      "input_mask": x.input_mask,
      "positions": positions,
      "attention_mask": attention_mask,
    }

  full_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/tensorboard/full", flush_every_n_steps=20
  )

  training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    metrics_logging_options=full_logging_options,
    checkpoint_root_directory=FULL_CKPT_DIR,
  )

  # Initialize the PeftTrainer.
  # We pass `lora_model` which has the LoRA adapters applied.
  trainer = peft_trainer.PeftTrainer(
    lora_model, optax.adamw(1e-5), training_config
  ).with_gen_model_input_fn(gen_model_input_fn)

  logging.info("Starting fine-tuning...")
  # Run the training loop within the mesh context for distributed execution
  with mesh:
    trainer.train(train_ds, validation_ds)

  if "WANDB_API_KEY" in os.environ and os.environ["WANDB_API_KEY"]:
    wandb.init()
    logging.info("Weights & Biases initialized successfully.")


if __name__ == "__main__":
  run_tuning()