Source code for kinetic.collections

"""Async collection orchestration for Kinetic.

Provides `map()` for job-array-style fan-out, `BatchHandle` for
observing and collecting collection results, and `attach_batch()`
for cross-session reattachment.
"""

from __future__ import annotations

import collections
import threading
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Iterator

from absl import logging
from google.api_core import exceptions as google_exceptions

from kinetic.cli.profiles import resolve_infra
from kinetic.collections_helpers import (
  append_child_to_manifest,
  build_initial_manifest,
  call_with_input,
)
from kinetic.constants import build_bucket_name
from kinetic.job_status import JobStatus
from kinetic.jobs import _TERMINAL_STATUSES, JobHandle
from kinetic.utils import storage

_DEFAULT_MAX_CONCURRENT = 64
_STATUS_POLL_INTERVAL = 5.0
_MANIFEST_POLL_INTERVAL = 10.0


def _resolve_bucket(
  project: str | None, cluster: str | None
) -> tuple[str, str]:
  """Return `(resolved_project, bucket_name)`.

  Resolution follows the standard chain: explicit kwarg > KINETIC_* env var
  > active profile field > built-in default.
  """
  infra = resolve_infra(project=project, cluster=cluster)
  return infra["project"], build_bucket_name(infra["project"], infra["cluster"])


[docs] class BatchError(Exception): """Raised when a batch collection has failed children. Attributes: group_id: The collection's group identifier. failures: List of JobHandles for failed children. partial_results: List where successful positions contain the result and failed positions contain `None`. """
[docs] def __init__( self, group_id: str, failures: list[JobHandle], partial_results: list[Any], ): self.group_id = group_id self.failures = failures self.partial_results = partial_results n_failed = len(failures) n_total = len(partial_results) super().__init__(f"Batch {group_id}: {n_failed} of {n_total} jobs failed")
[docs] @dataclass class BatchHandle: """Handle for a collection of submitted jobs. Created by `run_async_map()` or reconstructed by `kinetic.attach_batch()`. Provides collection-level observation, result gathering, and cleanup. """ group_id: str name: str | None tags: dict[str, str] jobs: list[JobHandle | None] # Bucket / project derived from eager resolution in map(). _bucket_name: str = field(default="", repr=False, compare=False) _project: str = field(default="", repr=False, compare=False) # Internal state for background submission. _submission_complete: threading.Event = field( default_factory=threading.Event, repr=False, compare=False ) _submission_error: BaseException | None = field( default=None, repr=False, compare=False ) _lock: threading.Lock = field( default_factory=threading.Lock, repr=False, compare=False ) # Per-index submission errors (index -> exception). _submission_errors: dict[int, Exception] = field( default_factory=dict, repr=False, compare=False ) # Cached failure list populated by results() so that failures() # remains accurate after cleanup deletes K8s resources. _cached_failures: list[JobHandle] | None = field( default=None, repr=False, compare=False )
[docs] def statuses(self) -> list[tuple[int, JobStatus]]: """Return `(index, status)` for each submitted job.""" return [ (i, job.status()) for i, job in enumerate(self.jobs) if job is not None ]
[docs] def status_counts(self) -> dict[str, int]: """Return a count of jobs in each status.""" return dict(collections.Counter(s.value for _, s in self.statuses()))
def _all_accounted_for(self, seen: set[int]) -> bool: """True when every job slot is either seen-terminal or a submission error.""" if not self._submission_complete.is_set(): return False with self._lock: total_submitted = sum(1 for j in self.jobs if j is not None) total_errors = len(self._submission_errors) return len(seen) >= total_submitted and ( len(seen) + total_errors >= len(self.jobs) )
[docs] def wait(self, *, timeout: float | None = None) -> None: """Block until all jobs reach a terminal state.""" deadline = None if timeout is None else time.monotonic() + timeout # Wait for background submission to finish first. if not self._submission_complete.is_set(): remaining = ( None if deadline is None else max(0, deadline - time.monotonic()) ) if not self._submission_complete.wait(timeout=remaining): raise TimeoutError( f"Timed out waiting for submission to complete " f"for batch {self.group_id}" ) if self._submission_error is not None: raise self._submission_error # Poll until every submitted job is terminal. while True: if all( job.status() in _TERMINAL_STATUSES for job in self.jobs if job is not None ): break if deadline is not None and time.monotonic() >= deadline: raise TimeoutError( f"Timed out waiting for batch {self.group_id} after {timeout}s" ) time.sleep(_STATUS_POLL_INTERVAL) if self._submission_errors: logging.warning( "Batch %s: %d input(s) failed at submission time. " "Use handle.submission_failures to inspect.", self.group_id, len(self._submission_errors), )
[docs] def as_completed( self, *, poll_interval: float = 5.0, timeout: float | None = None, ) -> Iterator[JobHandle]: """Yield jobs as they reach terminal states, in completion order. Unlike the simple approach of waiting for all submissions first, this streams results as soon as each job reaches a terminal state — even while more inputs are still being submitted. Args: poll_interval: Seconds between status polls. timeout: Maximum seconds to wait. Raises `TimeoutError` if exceeded. """ deadline = None if timeout is None else time.monotonic() + timeout seen: set[int] = set() while True: # Snapshot current jobs (slots may be filled by the submission thread). with self._lock: current_jobs = list(enumerate(self.jobs)) newly_done = [] for i, job in current_jobs: if i in seen or job is None: continue if job.status() in _TERMINAL_STATUSES: newly_done.append(i) for i in newly_done: seen.add(i) yield self.jobs[i] # type: ignore[misc] if self._all_accounted_for(seen): break if deadline is not None and time.monotonic() >= deadline: raise TimeoutError( f"as_completed() timed out after {timeout}s for batch {self.group_id}" ) if not newly_done: time.sleep(poll_interval)
[docs] def results( self, *, timeout: float | None = None, ordered: bool = True, cleanup: bool = True, return_exceptions: bool = False, ) -> list[Any]: """Collect results from all jobs. Args: timeout: Maximum seconds to wait for all jobs. ordered: If *True*, return in input order. If *False*, return in completion order. cleanup: If *True*, clean up each child's K8s and GCS resources (the group manifest is preserved). Note that cleaning up causes `failures()` to return an empty list as job statuses become `NOT_FOUND`. return_exceptions: If *True*, failed positions contain the exception object. If *False*, raise `BatchError` on any failure. Returns: List of results (input order when *ordered=True*, completion order otherwise). """ if ordered: results_list, failures = self._results_ordered( timeout=timeout, cleanup=cleanup, return_exceptions=return_exceptions ) else: results_list, failures = self._results_completion_order( timeout=timeout, cleanup=cleanup, return_exceptions=return_exceptions ) if failures and not return_exceptions: raise BatchError( group_id=self.group_id, failures=failures, partial_results=results_list, ) return results_list
def _results_ordered( self, *, timeout: float | None, cleanup: bool, return_exceptions: bool, ) -> tuple[list[Any], list[JobHandle]]: """Collect results in input order (waits for all jobs first).""" self.wait(timeout=timeout) failures: list[JobHandle] = [] results_list: list[Any] = [None] * len(self.jobs) for i, job in enumerate(self.jobs): if job is None: if i in self._submission_errors: exc = self._submission_errors[i] if return_exceptions: results_list[i] = exc failures.append(None) # type: ignore[arg-type] continue try: results_list[i] = job.result(cleanup=cleanup) except Exception as exc: if return_exceptions: results_list[i] = exc failures.append(job) with self._lock: self._cached_failures = [f for f in failures if f is not None] return results_list, failures def _results_completion_order( self, *, timeout: float | None, cleanup: bool, return_exceptions: bool, ) -> tuple[list[Any], list[JobHandle]]: """Collect results in completion order, streaming as they arrive.""" failures: list[JobHandle] = [] results_list: list[Any] = [] for job in self.as_completed(timeout=timeout): try: results_list.append(job.result(cleanup=cleanup)) except Exception as exc: if return_exceptions: results_list.append(exc) failures.append(job) for idx in sorted(self._submission_errors): exc = self._submission_errors[idx] if return_exceptions: results_list.append(exc) failures.append(None) # type: ignore[arg-type] with self._lock: self._cached_failures = [f for f in failures if f is not None] return results_list, failures
[docs] def failures(self) -> list[JobHandle]: """Return handles for jobs that failed. Only includes jobs whose status is `FAILED`. Jobs that are `NOT_FOUND` (e.g. after cleanup) are excluded because the status is ambiguous — use `statuses()` for finer control. After `results()` has been called, this returns the cached failure list from that collection pass, so it remains accurate even if cleanup has deleted K8s resources. See Also: `submission_failures`: returns per-input errors for inputs that failed at submission time (`jobs[idx]` is `None`). """ with self._lock: if self._cached_failures is not None: return list(self._cached_failures) return [ job for job in self.jobs if job is not None and job.status() == JobStatus.FAILED ]
@property def submission_failures(self) -> dict[int, Exception]: """Return a copy of per-input submission errors (index -> exception). These are inputs where the submission itself failed (e.g. validation error, network error). The corresponding `jobs[idx]` slot is `None`. These errors are included in `results()` output but are **not** reflected by `failures()` which only inspects live job statuses. """ with self._lock: return dict(self._submission_errors)
[docs] def cancel(self) -> None: """Cancel all non-terminal jobs in the collection.""" for job in self.jobs: if job is None: continue try: if job.status() not in _TERMINAL_STATUSES: job.cancel() except RuntimeError: logging.warning("Failed to cancel job %s", job.job_id)
[docs] def cleanup(self, *, k8s: bool = True, gcs: bool = True) -> None: """Clean up all jobs and optionally the group manifest. Args: k8s: Delete K8s resources for each child. gcs: Delete GCS artifacts for each child **and** the group manifest. """ for job in self.jobs: if job is None: continue try: job.cleanup(k8s=k8s, gcs=gcs) except (RuntimeError, google_exceptions.GoogleAPIError): logging.warning("Failed to clean up job %s", job.job_id) if gcs: bucket = self._bucket_name project = self._project if not bucket and self.jobs: first = next((j for j in self.jobs if j is not None), None) if first is not None: bucket = first.bucket_name project = first.project if bucket: try: storage.cleanup_manifest(bucket, self.group_id, project=project) except google_exceptions.GoogleAPIError: logging.warning( "Failed to clean up manifest for group %s", self.group_id )
def _load_child_handle( bucket_name: str, child: dict, total_expected: int, project: str, ) -> tuple[int, JobHandle] | None: """Download and reconstruct a single child handle. Returns `(group_index, handle)` on success, or `None` if the child has an invalid index or the download fails. """ idx = child["group_index"] if not isinstance(idx, int) or idx < 0 or idx >= total_expected: logging.warning( "Invalid child index %r (total_expected=%d); skipping", idx, total_expected, ) return None try: payload = storage.download_handle( bucket_name, child["job_id"], project=project ) return idx, JobHandle.from_dict(payload) except (google_exceptions.GoogleAPIError, KeyError, ValueError): logging.warning( "Could not load handle for child job %s (index %d); skipping", child["job_id"], idx, ) return None def _manifest_poll_loop( handle: BatchHandle, bucket_name: str, group_id: str, project: str, total_expected: int, poll_interval: float, timeout: float | None, ) -> None: """Poll GCS manifest until all children appear, then set `_submission_complete`. Used by `attach_batch()` when the manifest shows fewer children than `total_expected`, indicating the original `map()` is still submitting. """ deadline = None if timeout is None else time.monotonic() + timeout try: while True: if deadline is not None and time.monotonic() >= deadline: logging.warning( "Timed out polling manifest for batch %s (%d/%d children)", group_id, sum(1 for j in handle.jobs if j is not None), total_expected, ) break time.sleep(poll_interval) try: manifest = storage.download_manifest( bucket_name, group_id, project=project ) except google_exceptions.GoogleAPIError: logging.warning("Failed to poll manifest for batch %s", group_id) continue for child in manifest.get("children", []): idx = child.get("group_index") if not isinstance(idx, int) or idx < 0 or idx >= total_expected: continue with handle._lock: if handle.jobs[idx] is not None: continue result = _load_child_handle(bucket_name, child, total_expected, project) if result is not None: loaded_idx, job_handle = result with handle._lock: handle.jobs[loaded_idx] = job_handle loaded = sum(1 for j in handle.jobs if j is not None) if loaded >= total_expected: break finally: handle._submission_complete.set() def _cancel_active(handle: BatchHandle, active_indices: set[int]) -> None: """Best-effort cancel of all active jobs.""" for idx in list(active_indices): job = handle.jobs[idx] if job is None: continue try: job.cancel() except RuntimeError: logging.warning("Failed to cancel job at index %d", idx) @dataclass class _SubmissionState: """Groups the mutable state tracked by the submission loop. Provides named predicates so the main loop reads as a clear sequence of phases rather than a tangle of flags and counters. """ handle: BatchHandle manifest: dict submit_fn: Any inputs: list input_mode: str max_concurrent: int | None max_attempts: int fail_fast: bool cancel_running_on_fail: bool attempt_counts: list[int] = field(init=False) pending: collections.deque = field(init=False) active: set[int] = field(default_factory=set, init=False) stop_launching: bool = field(default=False, init=False) def __post_init__(self): self.attempt_counts = [0] * len(self.inputs) self.pending = collections.deque(range(len(self.inputs))) @property def has_work(self) -> bool: """True while jobs remain to be submitted or are still running.""" return bool(self.pending) or bool(self.active) def can_submit_more(self) -> bool: """True when the next pending job is allowed to launch.""" if not self.pending or self.stop_launching: return False return self.max_concurrent is None or len(self.active) < self.max_concurrent def needs_active_polling(self) -> bool: """True when the loop must poll active jobs itself. When all jobs are submitted with no retries and `fail_fast` is off, the caller uses `wait()`/`results()` to observe terminal states, so the submission loop can exit early. """ if not self.active: return False return bool(self.pending) or self.max_attempts > 1 or self.fail_fast def trigger_fail_fast(self) -> None: """Stop launching new jobs and optionally cancel siblings.""" self.stop_launching = True if self.cancel_running_on_fail: _cancel_active(self.handle, self.active) def _submit_available(state: _SubmissionState) -> None: """Submit pending jobs up to the concurrency limit. On per-input errors the exception is recorded in `handle._submission_errors` and, when `fail_fast` is set, `trigger_fail_fast` is called. """ handle = state.handle while state.can_submit_more(): idx = state.pending.popleft() state.attempt_counts[idx] += 1 # attempt submission try: job_handle = call_with_input( state.submit_fn, state.inputs[idx], state.input_mode ) except Exception as exc: logging.error("Submission failed for index %d: %s", idx, exc) with handle._lock: handle._submission_errors[idx] = exc if state.fail_fast: state.trigger_fail_fast() continue # tag with group metadata and persist job_handle.group_id = handle.group_id job_handle.group_kind = state.manifest["group_kind"] job_handle.group_index = idx try: storage.upload_handle( job_handle.bucket_name, job_handle.job_id, job_handle.to_dict(), project=job_handle.project, ) except google_exceptions.GoogleAPIError: logging.warning( "Failed to re-upload handle with group fields for %s", job_handle.job_id, ) # register in handle and manifest with handle._lock: handle.jobs[idx] = job_handle state.active.add(idx) append_child_to_manifest( state.manifest, idx, job_handle.job_id, state.attempt_counts[idx] ) try: storage.upload_manifest( handle._bucket_name, handle.group_id, state.manifest, project=handle._project, ) except google_exceptions.GoogleAPIError: logging.warning( "Failed to update manifest after submitting index %d", idx ) if state.stop_launching: state.pending.clear() def _poll_and_handle_terminal(state: _SubmissionState) -> None: """Poll active jobs for terminal states; retry or trigger fail_fast.""" handle = state.handle # Collect all newly-terminal jobs in one pass. newly_terminal: list[tuple[int, JobStatus, JobHandle]] = [] for idx in list(state.active): job = handle.jobs[idx] if job is None: continue try: status = job.status() if status in _TERMINAL_STATUSES: newly_terminal.append((idx, status, job)) except (RuntimeError, google_exceptions.GoogleAPIError): logging.warning("Failed to poll status for index %d", idx) for idx, status, job in newly_terminal: state.active.discard(idx) if status not in (JobStatus.FAILED, JobStatus.NOT_FOUND): continue if state.attempt_counts[idx] < state.max_attempts: # Retry: clean up previous attempt's K8s resources and re-queue. try: job.cleanup(k8s=True, gcs=False) except RuntimeError: logging.warning("Failed to clean up before retry for index %d", idx) state.pending.append(idx) elif state.fail_fast: state.trigger_fail_fast() def _submission_loop( submit_fn, inputs: list, input_mode: str, manifest: dict, handle: BatchHandle, max_concurrent: int | None, retries: int, fail_fast: bool, cancel_running_on_fail: bool, ) -> None: """Core submission and retry loop. Mutates *handle.jobs* and *manifest* in place. Runs in the calling thread (`max_concurrent=None` and `retries=0`) or in a background thread otherwise. Each iteration follows three phases: 1. **Submit** — launch pending jobs up to the concurrency limit. 2. **Poll** — check active jobs for terminal states, retry or trigger `fail_fast` as needed. 3. **Sleep** — back off before the next poll cycle. """ state = _SubmissionState( handle=handle, manifest=manifest, submit_fn=submit_fn, inputs=inputs, input_mode=input_mode, max_concurrent=max_concurrent, max_attempts=1 + retries, fail_fast=fail_fast, cancel_running_on_fail=cancel_running_on_fail, ) try: while state.has_work: _submit_available(state) if not state.needs_active_polling(): break _poll_and_handle_terminal(state) if state.has_work: time.sleep(_STATUS_POLL_INTERVAL) except BaseException as exc: handle._submission_error = exc logging.error("Submission loop error: %s", exc) finally: handle._submission_complete.set() def map( submit_fn, inputs, *, input_mode: str = "auto", max_concurrent: int | None = _DEFAULT_MAX_CONCURRENT, retries: int = 0, fail_fast: bool = False, cancel_running_on_fail: bool = False, name: str | None = None, tags: dict[str, str] | None = None, project: str | None = None, cluster: str | None = None, ) -> BatchHandle: """Launch many independent jobs over a set of inputs. `submit_fn` must be a function obtained from `func.run_async` where `func` is decorated with `@kinetic.run(...)`. Each input is dispatched according to `input_mode` and submitted as a separate remote job. Args: submit_fn: A callable obtained from `func.run_async`. inputs: Iterable of inputs to fan out over. input_mode: How each input item is passed to *submit_fn*. `"auto"` (default) dispatches dicts as `**kwargs`, lists/tuples as `*args`, and scalars as a single positional argument. max_concurrent: Maximum number of concurrently active jobs. `None` submits all immediately. retries: Number of additional attempts after a job failure. fail_fast: Stop launching new jobs after the first failure. cancel_running_on_fail: Cancel running siblings on failure. name: Human-readable collection name. tags: Arbitrary key-value metadata. project: GCP project. Falls back to KINETIC_PROJECT, then the active profile's project, then GOOGLE_CLOUD_PROJECT. cluster: GKE cluster name. Falls back to KINETIC_CLUSTER, then the active profile's cluster, then the built-in default. Returns: A `BatchHandle` for observing, collecting, and cleaning up the collection. """ if not callable(submit_fn): raise TypeError("submit_fn must be callable") if max_concurrent is not None and max_concurrent < 1: raise ValueError( f"max_concurrent must be a positive integer, got {max_concurrent}" ) if retries < 0: raise ValueError(f"retries must be non-negative, got {retries}") if input_mode not in ("auto", "single", "args", "kwargs"): raise ValueError(f"Unknown input_mode: {input_mode!r}") inputs = list(inputs) if not inputs: raise ValueError("inputs must be non-empty") # Resolve bucket eagerly so the initial manifest can be written # before any jobs are submitted. resolved_project, bucket_name = _resolve_bucket(project, cluster) group_id = f"grp-{uuid.uuid4().hex[:8]}" group_kind = "map" fn_name = getattr(submit_fn, "__name__", str(submit_fn)) manifest = build_initial_manifest( group_id, group_kind, name, tags, len(inputs), fn_name ) # Write the initial manifest (empty children) before any jobs are # submitted so that crash recovery can distinguish "0 of N # submitted" from "collection never created". storage.upload_manifest( bucket_name, group_id, manifest, project=resolved_project ) # Pre-allocate the jobs list with None placeholders. jobs: list[JobHandle | None] = [None] * len(inputs) handle = BatchHandle( group_id=group_id, name=name, tags=tags or {}, jobs=jobs, _bucket_name=bucket_name, _project=resolved_project, ) if max_concurrent is None and len(inputs) > 100: logging.warning( "Submitting %d jobs with max_concurrent=None. " "Consider setting max_concurrent to limit resource usage.", len(inputs), ) if max_concurrent is None and retries == 0: # Simple path: submit all in calling thread. _submission_loop( submit_fn=submit_fn, inputs=inputs, input_mode=input_mode, manifest=manifest, handle=handle, max_concurrent=max_concurrent, retries=retries, fail_fast=fail_fast, cancel_running_on_fail=cancel_running_on_fail, ) else: # Background thread for bounded concurrency or retries. thread = threading.Thread( target=_submission_loop, kwargs={ "submit_fn": submit_fn, "inputs": inputs, "input_mode": input_mode, "manifest": manifest, "handle": handle, "max_concurrent": max_concurrent, "retries": retries, "fail_fast": fail_fast, "cancel_running_on_fail": cancel_running_on_fail, }, daemon=False, ) thread.start() return handle
[docs] def attach_batch( group_id: str, project: str | None = None, cluster: str | None = None, poll_interval: float = _MANIFEST_POLL_INTERVAL, poll_timeout: float | None = None, ) -> BatchHandle: """Reattach to an existing batch collection by *group_id*. Downloads the group manifest from GCS, reconstructs `JobHandle` objects for each child, and returns a fully usable `BatchHandle`. If the manifest has fewer children than `total_expected` (i.e. the original `map()` is still submitting), the returned handle polls the manifest in a background thread until all children appear or *poll_timeout* is reached. Args: group_id: The collection identifier (e.g. `"grp-a1b2c3d4"`). project: GCP project. Falls back to KINETIC_PROJECT, then the active profile's project, then GOOGLE_CLOUD_PROJECT. cluster: GKE cluster name. Falls back to KINETIC_CLUSTER, then the active profile's cluster, then the built-in default. poll_interval: Seconds between manifest polls when the batch is partially submitted. poll_timeout: Maximum seconds to poll for remaining children. `None` means poll indefinitely. Returns: A hydrated `BatchHandle` ready for `results()`, etc. """ resolved_project, bucket_name = _resolve_bucket(project, cluster) manifest = storage.download_manifest( bucket_name, group_id, project=resolved_project ) children = manifest.get("children", []) total_expected = manifest.get("total_expected", len(children)) # Preallocate to total_expected and slot each child by group_index # so that index alignment is preserved even when some handles are # missing or the batch was only partially submitted. jobs: list[JobHandle | None] = [None] * total_expected for child in children: result = _load_child_handle( bucket_name, child, total_expected, resolved_project ) if result is not None: idx, job_handle = result jobs[idx] = job_handle handle = BatchHandle( group_id=manifest["group_id"], name=manifest.get("group_name"), tags=manifest.get("tags", {}), jobs=jobs, _bucket_name=bucket_name, _project=resolved_project, ) loaded = sum(1 for j in jobs if j is not None) if loaded >= total_expected: # All children present — mark complete immediately. handle._submission_complete.set() else: logging.warning( "Batch %s was partially submitted: %d of %d expected jobs. " "Polling manifest for remaining children.", group_id, loaded, total_expected, ) thread = threading.Thread( target=_manifest_poll_loop, kwargs={ "handle": handle, "bucket_name": bucket_name, "group_id": group_id, "project": resolved_project, "total_expected": total_expected, "poll_interval": poll_interval, "timeout": poll_timeout, }, daemon=True, ) thread.start() return handle