Source code for kinetic.data.data

"""Data class for declaring data dependencies in remote functions.

Wraps local file/directory paths or GCS URIs. On the remote side, Data
resolves to a plain filesystem path — the user's function only sees paths.
"""

import hashlib
import itertools
import os
import posixpath
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from functools import partial

from absl import logging

# Directories with more files than this threshold are hashed in parallel
# using a thread pool. Below this, sequential hashing avoids pool overhead.
_PARALLEL_HASH_THRESHOLD = 16
_HASH_BATCH_SIZE = 512


def _hash_single_file(fpath: str, relpath: str) -> bytes:
  """SHA-256 of relpath + \\0 + file contents. Returns raw 32-byte digest."""
  h = hashlib.sha256()
  h.update(relpath.encode("utf-8"))
  h.update(b"\0")
  # 256 KB: matches hashlib.file_digest's default buffer size.
  with open(fpath, "rb") as f:
    for chunk in iter(partial(f.read, 2**18), b""):
      h.update(chunk)
  return h.digest()


def _hash_file_batch(batch: list[tuple[str, str]]) -> list[bytes]:
  """Hash a batch of (relpath, fpath) pairs. Returns list of 32-byte digests."""
  return [_hash_single_file(fpath, relpath) for relpath, fpath in batch]


[docs] class Data: """A reference to data that should be available on the remote pod. Wraps a local file/directory path or a GCS URI. When passed as a function argument or used in the `volumes` decorator parameter, Data is resolved to a plain filesystem path on the remote side. The user's function code never needs to know about Data — it just receives paths. By default, data is downloaded into the container before execution. Pass `fuse=True` to lazily mount data from GCS via the GCS FUSE CSI driver instead — useful for large datasets where only a subset of files are read at runtime. Args: path: Local file/directory path (absolute or relative) or GCS URI (`gs://bucket/prefix`). fuse: If `True`, mount the data via GCS FUSE instead of downloading it. The data is read on demand — only files that are actually opened are fetched from cloud storage. Requires the GCS FUSE CSI driver addon on the GKE cluster (`kinetic up` enables it by default). .. note:: For GCS URIs, a trailing slash indicates a directory (prefix). `Data("gs://my-bucket/dataset/")` is treated as a directory, while `Data("gs://my-bucket/dataset")` is treated as a single object. If you intend to reference a GCS directory, always include the trailing slash. Examples:: # Local directory Data("./my_dataset/") # Local file Data("./config.json") # GCS directory — trailing slash required Data("gs://my-bucket/datasets/imagenet/") # GCS single object Data("gs://my-bucket/datasets/weights.h5") # FUSE-mounted directory (lazy loading) Data("./large_dataset/", fuse=True) # FUSE-mounted GCS data Data("gs://my-bucket/datasets/imagenet/", fuse=True) # Hugging Face Dataset (downloads on pod) Data("hf://imdb?split=train") # Hugging Face Dataset with remote code execution allowed Data("hf://custom/repo", hf_trust_remote_code=True) """
[docs] def __init__( self, path: str, fuse: bool = False, hf_trust_remote_code: bool = False ): if not path: raise ValueError("Data path must not be empty") self._raw_path = path self._fuse = fuse self._hf_trust_remote_code = hf_trust_remote_code if self.is_hf and self._fuse: raise ValueError( "fuse=True is not supported for Hugging Face datasets (hf:// URIs)" ) if self.is_gcs or self.is_hf: self._resolved_path = path if self.is_gcs: _warn_if_missing_trailing_slash(path) else: self._resolved_path = os.path.abspath(os.path.expanduser(path)) if not os.path.exists(self._resolved_path): raise FileNotFoundError( f"Data path does not exist: {path} " f"(resolved to {self._resolved_path})" )
@property def path(self) -> str: return self._resolved_path @property def fuse(self) -> bool: return self._fuse @property def hf_trust_remote_code(self) -> bool: return self._hf_trust_remote_code @property def is_gcs(self) -> bool: return self._raw_path.startswith("gs://") @property def is_hf(self) -> bool: return self._raw_path.startswith("hf://") @property def is_dir(self) -> bool: if self.is_gcs: return self._raw_path.endswith("/") if self.is_hf: return True return os.path.isdir(self._resolved_path)
[docs] def content_hash(self) -> str: """SHA-256 hash of all file contents in deterministic order. Uses two-level hashing for parallelism: each file is hashed independently (SHA-256 of relpath + contents), then per-file digests are combined in sorted-walk (DFS) order into a final hash. Includes a type prefix ("dir:" or "file:") to prevent collisions between a single file and a directory containing only that file. Symlinked directories are not recursed into (followlinks=False) to prevent infinite recursion from circular symlinks. Symlinked files are read and their resolved contents are hashed, so the hash reflects the actual data visible at runtime. """ if self.is_gcs or self.is_hf: raise ValueError( f"Cannot compute content hash for cloud URI: {self.path}" ) if os.path.isdir(self._resolved_path): return self._content_hash_dir() return self._content_hash_file()
def _content_hash_file(self) -> str: h = hashlib.sha256() h.update(b"file:") h.update( _hash_single_file( self._resolved_path, os.path.basename(self._resolved_path), ) ) return h.hexdigest() def _content_hash_dir(self) -> str: resolved = self._resolved_path # Walk in sorted order for determinism. Sorting dirs in-place # controls os.walk's traversal order; sorting files within each # directory yields a deterministic DFS order without materializing # the full file list — critical for datasets with millions of files. def _iter_files(): for root, dirs, files in os.walk(resolved, followlinks=False): dirs.sort() for fname in sorted(files): fpath = os.path.join(root, fname) relpath = os.path.relpath(fpath, resolved) yield (relpath, fpath) file_iter = _iter_files() first_batch = list( itertools.islice(file_iter, _PARALLEL_HASH_THRESHOLD + 1) ) h = hashlib.sha256() h.update(b"dir:") if len(first_batch) <= _PARALLEL_HASH_THRESHOLD: # Small directory — hash sequentially, no pool overhead. for digest in _hash_file_batch(first_batch): h.update(digest) else: # Large directory — stream batches to a thread pool. # Futures are kept in a bounded deque so at most a few batches # worth of file tuples reside in memory at any time. max_workers = min(32, (os.cpu_count() or 4) + 4) with ThreadPoolExecutor(max_workers=max_workers) as pool: pending = deque() batch = [] for item in itertools.chain(first_batch, file_iter): batch.append(item) if len(batch) >= _HASH_BATCH_SIZE: pending.append(pool.submit(_hash_file_batch, batch)) batch = [] # Drain oldest completed futures to bound memory. while len(pending) > max_workers * 2: for digest in pending.popleft().result(): h.update(digest) if batch: pending.append(pool.submit(_hash_file_batch, batch)) for future in pending: for digest in future.result(): h.update(digest) return h.hexdigest() def __repr__(self): if self._fuse: return f"Data({self._raw_path!r}, fuse=True)" return f"Data({self._raw_path!r})"
def _warn_if_missing_trailing_slash(path: str) -> None: """Log a warning if a GCS path looks like a directory but has no trailing slash.""" if path.endswith("/"): return gcs_path = path.split("//", 1)[1] # strip gs:// last_segment = posixpath.basename(gcs_path) if last_segment and "." not in last_segment: logging.warning( "GCS path %r does not end with '/' but the last segment " "(%r) has no file extension. If this is a directory " "(prefix), add a trailing slash: %r", path, last_segment, path + "/", ) def make_data_ref( uri: str, is_dir: bool, mount_path: str | None = None, fuse: bool = False, hf_trust_remote_code: bool = False, ) -> dict[str, object]: """Create a serializable data reference dict. These dicts replace Data objects in the payload before serialization. The remote runner identifies them by the __data_ref__ key. """ return { "__data_ref__": True, "uri": uri, "is_dir": is_dir, "mount_path": mount_path, "fuse": fuse, "hf_trust_remote_code": hf_trust_remote_code, } def is_data_ref(obj: object) -> bool: """Check if an object is a serialized data reference.""" return isinstance(obj, dict) and obj.get("__data_ref__") is True def parse_gcs_uri(gcs_uri: str) -> tuple[str, str]: """Parse a GCS URI into (bucket_name, prefix). Args: gcs_uri: A URI like `gs://my-bucket/some/prefix/`. Returns: Tuple of `(bucket_name, prefix)` where prefix has no leading or trailing slashes. For `gs://my-bucket/some/prefix/`, returns `("my-bucket", "some/prefix")`. For `gs://my-bucket`, returns `("my-bucket", "")`. """ stripped = gcs_uri[len("gs://") :] if gcs_uri.startswith("gs://") else gcs_uri parts = stripped.split("/", 1) bucket = parts[0] prefix = parts[1].strip("/") if len(parts) > 1 else "" return bucket, prefix