ShardedDataLoader#
- class fkat.data.ShardedDataLoader(seed: int, shard_sampler: ShardSampler, dataloader_factory: DataLoaderFactory[T_co], num_shard_prefetches: int = 0, num_microbatch_prefetches: int = -1, dp_rank: int = 0, profiler: Optional[Profiler] = None, device: Optional[device] = None, multiprocessing: bool = True)[source]#
Bases:
Iterable[list[T_co]]A
DataLoaderthat processes data in shards, designed for distributed training scenarios.Enables double-buffered micro-batch processing and fetching that overlaps with model forward/backward passes, minimizing dataloading overhead.
- Parameters:
seed (int) – Random seed for reproducibility. Use ${seed} at top level in config.yaml.
shard_sampler (ShardSampler) – Sampler for generating shards.
dataloader_factory (DataLoaderFactory[T_co]) – Factory for creating DataLoaders.
num_shard_prefetches (int, optional) – Number of shards to prefetch. Defaults to 0.
num_microbatch_prefetches (int, optional) – Number of microbatches to prefetch. Defaults to -1.
dp_rank (int, optional) – Rank of the current process. Defaults to 0.
profiler (Profiler, optional) – Profiler for profiling. Defaults to None.
device (Optional[torch.device]) – device to move the microbatches to in the background
multiprocessing (Optional[True]) – whether to instantiate DataLoader in a separate process. Defaults to True to relieve pressure from the training process, use False to debug and profile
- load_state_dict(state_dict: dict[str, Any]) → None[source]#
Loads the state dict into the shard sampler, restoring the data shard indices for each rank.
The state dict should look like {“all_rank_indices”: [torch.tensor(1), torch.tensor(1), torch.tensor(5), torch.tensor(6)]}, where each tensor corresponds to the indices of data shards for specific ranks.
- state_dict() → dict[str, Any][source]#
Returns the shard sampler state dict with adjusted shard indices, accounting for shard prefetches and prefetch backfill in parallel.
Example: If num_shard_prefetches is 3 and the original state dict is {“all_rank_indices”: [torch.tensor(4), torch.tensor(5)]}, it will be updated to {“all_rank_indices”: [torch.tensor(0), torch.tensor(1)]}. This ensures that each rank resumes training from the correct shard index, preventing reprocessing of shards that have already been trained on.