ShardedDataLoader#

class fkat.data.ShardedDataLoader(seed: int, shard_sampler: ShardSampler, dataloader_factory: DataLoaderFactory[T_co], num_shard_prefetches: int = 0, num_microbatch_prefetches: int = -1, dp_rank: int = 0, profiler: Optional[Profiler] = None, device: Optional[device] = None, multiprocessing: bool = True)[source]#

Bases: Iterable[list[T_co]]

A DataLoader that processes data in shards, designed for distributed training scenarios.

Enables double-buffered micro-batch processing and fetching that overlaps with model forward/backward passes, minimizing dataloading overhead.

Parameters:
  • seed (int) – Random seed for reproducibility. Use ${seed} at top level in config.yaml.

  • shard_sampler (ShardSampler) – Sampler for generating shards.

  • dataloader_factory (DataLoaderFactory[T_co]) – Factory for creating DataLoaders.

  • num_shard_prefetches (int, optional) – Number of shards to prefetch. Defaults to 0.

  • num_microbatch_prefetches (int, optional) – Number of microbatches to prefetch. Defaults to -1.

  • dp_rank (int, optional) – Rank of the current process. Defaults to 0.

  • profiler (Profiler, optional) – Profiler for profiling. Defaults to None.

  • device (Optional[torch.device]) – device to move the microbatches to in the background

  • multiprocessing (Optional[True]) – whether to instantiate DataLoader in a separate process. Defaults to True to relieve pressure from the training process, use False to debug and profile

load_batch() None[source]#
load_batch_sync() list[+T_co][source]#
load_state_dict(state_dict: dict[str, Any]) None[source]#

Loads the state dict into the shard sampler, restoring the data shard indices for each rank.

The state dict should look like {“all_rank_indices”: [torch.tensor(1), torch.tensor(1), torch.tensor(5), torch.tensor(6)]}, where each tensor corresponds to the indices of data shards for specific ranks.

on_exception(exception: BaseException) None[source]#
prefetch_shards(count: int) None[source]#
set_device(device: torch.device | None) None[source]#
state_dict() dict[str, Any][source]#

Returns the shard sampler state dict with adjusted shard indices, accounting for shard prefetches and prefetch backfill in parallel.

Example: If num_shard_prefetches is 3 and the original state dict is {“all_rank_indices”: [torch.tensor(4), torch.tensor(5)]}, it will be updated to {“all_rank_indices”: [torch.tensor(0), torch.tensor(1)]}. This ensures that each rank resumes training from the correct shard index, preventing reprocessing of shards that have already been trained on.

teardown(*args: Any) None[source]#