Source code for kinetic.core.core

from __future__ import annotations

import functools
import os
import sys
import warnings
from typing import Any, Callable

from kinetic.backend.execution import (
  GKEBackend,
  JobContext,
  PathwaysBackend,
  submit_remote,
)
from kinetic.cli.profiles import resolve_infra
from kinetic.collections import BatchHandle
from kinetic.collections import map as collections_map
from kinetic.core import accelerators
from kinetic.data import Data
from kinetic.debug import cleanup_port_forward
from kinetic.jobs import JobHandle


def _validate_volumes(volumes):
  """Validate the optional volumes mapping."""
  if volumes is None:
    return
  if not isinstance(volumes, dict):
    raise TypeError(f"volumes must be a dict, got {type(volumes).__name__}")
  for mount_path, data_obj in volumes.items():
    if not isinstance(mount_path, str) or not mount_path.startswith("/"):
      raise ValueError(
        f"Volume mount path must be an absolute path "
        f"(start with '/'), got: {mount_path!r}"
      )
    if not isinstance(data_obj, Data):
      raise TypeError(
        f"Volume value for {mount_path!r} must be a Data "
        f"instance, got {type(data_obj).__name__}"
      )


def _capture_env(capture_env_vars):
  """Capture requested environment variables for remote execution."""
  env_vars = {}
  if not capture_env_vars:
    return env_vars

  for pattern in capture_env_vars:
    if pattern.endswith("*"):
      prefix = pattern[:-1]
      env_vars.update(
        {k: v for k, v in os.environ.items() if k.startswith(prefix)}
      )
    elif pattern in os.environ:
      env_vars[pattern] = os.environ[pattern]
  return env_vars


def _require_interactive_terminal():
  """Raise if stdin is not a TTY and KINETIC_NO_TTY_DEBUG is not set.

  ``run(debug=True)`` blocks waiting for a VS Code debugger to attach.
  Without a TTY (CI, cron, nohup, piped input), no one can attach and
  the job hangs for ``DEBUG_WAIT_TIMEOUT`` before falling through.
  Fail fast with a clear message instead.  Set
  ``KINETIC_NO_TTY_DEBUG=1`` to override (useful for automated tests).
  """
  if os.environ.get("KINETIC_NO_TTY_DEBUG") == "1":
    return
  if not sys.stdin.isatty():
    raise RuntimeError(
      "debug=True requires an interactive terminal but stdin is not a TTY. "
      "Either remove debug=True, or call func.run_async() and attach with "
      "handle.debug_attach() from an interactive session, or set "
      "KINETIC_NO_TTY_DEBUG=1 to override."
    )


def _resolve_backend_name(accelerator, backend, spot=False):
  """Resolve the backend from explicit config or accelerator type."""
  if backend is not None:
    return backend

  try:
    accel_config = accelerators.parse_accelerator(accelerator, spot=spot)
    if (
      isinstance(accel_config, accelerators.TpuConfig)
      and accel_config.num_nodes > 1
    ):
      return "pathways"
  except ValueError:
    pass
  return "gke"


def _make_decorator(
  accelerator,
  container_image,
  base_image_repo,
  zone,
  project,
  capture_env_vars,
  cluster,
  backend,
  namespace,
  volumes,
  spot,
  sync,
  output_dir,
  debug,
):
  """Build a decorator that submits the wrapped function for remote execution.

  Args:
    sync: If True, block on result (`run()` semantics).
      If False, return a `JobHandle` immediately (`run_async()` semantics).
    debug: If True, enable debugpy remote debugging.
  """

  def decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
      env_vars = _capture_env(capture_env_vars)
      resolved_backend = _resolve_backend_name(accelerator, backend, spot=spot)

      if resolved_backend not in ("gke", "pathways"):
        raise ValueError(
          f"Unknown backend: {resolved_backend}. "
          "Use 'gke', 'pathways', or None for auto-detection"
        )

      infra = resolve_infra(
        project=project, zone=zone, cluster=cluster, namespace=namespace
      )

      ctx = JobContext.from_params(
        func,
        args,
        kwargs,
        accelerator,
        container_image,
        infra["zone"],
        infra["project"],
        env_vars,
        cluster_name=infra["cluster"],
        volumes=volumes,
        spot=spot,
        debug=debug,
        output_dir=output_dir,
        base_image_repo=base_image_repo,
      )

      if resolved_backend == "pathways":
        backend_inst = PathwaysBackend(
          cluster=infra["cluster"], namespace=infra["namespace"]
        )
      else:
        backend_inst = GKEBackend(
          cluster=infra["cluster"], namespace=infra["namespace"]
        )

      handle = submit_remote(ctx, backend_inst)

      if sync:
        if debug:
          _require_interactive_terminal()
          pf_proc = handle.debug_attach(working_dir=ctx.working_dir)
          try:
            return handle.result(stream_logs=False, cleanup=False)
          finally:
            cleanup_port_forward(pf_proc)
        return handle.result(stream_logs=True)
      return handle

    return wrapper

  return decorator


class RemoteCallable:
  """Wrapper class returned by @kinetic.run to handle sync and async calls.

  Supports instance methods via the descriptor protocol (__get__).
  """

  def __init__(self, func, sync_wrapper, async_wrapper):
    self._func = func
    self._sync_wrapper = sync_wrapper
    self._async_wrapper = async_wrapper
    functools.update_wrapper(self, func)

  def __call__(self, *args, **kwargs):
    """Synchronous execution (blocks)."""
    return self._sync_wrapper(*args, **kwargs)

  def run_async(self, *args, **kwargs) -> JobHandle:
    """Asynchronous execution (returns JobHandle)."""
    return self._async_wrapper(*args, **kwargs)

  def run_async_map(self, inputs, **kwargs) -> BatchHandle:
    """Fan out across accelerators."""

    return collections_map(self._async_wrapper, inputs, **kwargs)

  def __get__(self, instance, owner):
    if instance is None:
      return self
    return _BoundRemoteCallable(self, instance)


class _BoundRemoteCallable:
  """Proxy for RemoteCallable bound to an instance."""

  def __init__(self, callable_, instance):
    self._c = callable_
    self._instance = instance

  def __call__(self, *args, **kwargs):
    return self._c(self._instance, *args, **kwargs)

  def run_async(self, *args, **kwargs) -> JobHandle:
    return self._c.run_async(self._instance, *args, **kwargs)

  def run_async_map(self, inputs, **kwargs) -> BatchHandle:
    def bound_async_wrapper(*a, **kw):
      return self._c._async_wrapper(self._instance, *a, **kw)

    from kinetic.collections import map as collections_map

    return collections_map(bound_async_wrapper, inputs, **kwargs)


[docs] def run( accelerator: str = "tpu-v5e-1", container_image: str | None = None, base_image_repo: str | None = None, zone: str | None = None, project: str | None = None, capture_env_vars: list[str] | None = None, cluster: str | None = None, backend: str | None = None, namespace: str | None = None, volumes: dict[str, Data] | None = None, spot: bool = False, output_dir: str | None = None, debug: bool = False, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Execute function on remote TPU/GPU. Args: accelerator: TPU/GPU type (e.g., 'tpu-v3-8', 'tpu-v5litepod-4', 'gpu-l4', 'gpu-a100') container_image: Controls the container image used for execution. `None` or `"bundled"` (default) builds a custom image with all dependencies baked in via Cloud Build. `"prebuilt"` uses a prebuilt base image and installs user requirements at pod startup via `uv pip install`. Any other string is treated as a custom container image URI. base_image_repo: Docker Hub repository for prebuilt base images (e.g., `"mycompany/kinetic"`). Defaults to `KINETIC_BASE_IMAGE_REPO` env var, then `"kinetic"`. Only used when `container_image` is `"prebuilt"`. zone: GCP zone. Falls back to KINETIC_ZONE, then the active profile's zone (from ~/.kinetic/profiles.json), then 'us-central1-a'. project: GCP project. Falls back to KINETIC_PROJECT, then the active profile's project, then GOOGLE_CLOUD_PROJECT. capture_env_vars: List of environment variable names or patterns (ending in `*`) to propagate to the remote environment. Defaults to None. cluster: GKE cluster name. Falls back to KINETIC_CLUSTER, then the active profile's cluster, then the built-in default. backend: Backend to use ('gke' or 'pathways') namespace: Kubernetes namespace. Falls back to KINETIC_NAMESPACE, then the active profile's namespace, then 'default'. volumes: Dict mapping absolute mount paths to Data objects, e.g. `{"/data": Data("./dataset/")}`. Data is downloaded to these paths on the pod before function execution. spot: If True, use preemptible/spot VMs for the job. output_dir: GCS directory where job outputs should be saved. Propagated to the remote worker as the `KINETIC_OUTPUT_DIR` environment variable. Defaults to `gs://{bucket_name}/outputs/{job_id}`. debug: If True, enable debugpy remote debugging. The pod will start a debugpy server and wait for a VS Code debugger to attach before executing the function. Port-forwarding is set up automatically. Returns: A decorator that returns a wrapper function. When called, the wrapper executes the function remotely and blocks until completion (sync mode). The wrapper also has the following methods: - run_async(*args, **kwargs): Submits the job for remote execution and returns a JobHandle immediately (async mode). - run_async_map(inputs, **kwargs): Fans out across accelerators for a collection of inputs, returning a BatchHandle. """ _validate_volumes(volumes) if debug and spot: warnings.warn( "debug=True with spot=True is not recommended — your debug " "session may be interrupted by preemption.", stacklevel=3, ) def decorator(func): # Create the sync wrapper sync_decorator = _make_decorator( accelerator, container_image, base_image_repo, zone, project, capture_env_vars, cluster, backend, namespace, volumes, spot, sync=True, output_dir=output_dir, debug=debug, ) sync_wrapper = sync_decorator(func) # Create the async wrapper async_decorator = _make_decorator( accelerator, container_image, base_image_repo, zone, project, capture_env_vars, cluster, backend, namespace, volumes, spot, sync=False, output_dir=output_dir, debug=debug, ) async_wrapper = async_decorator(func) return RemoteCallable(func, sync_wrapper, async_wrapper) return decorator