data#

This module provides specialized data loading components for efficient handling and processing of datasets.

DictDataLoader#

class fkat.data.DictDataLoader(dataloaders: dict[str, collections.abc.Iterable[dict[str, Any]]], strategy: SamplerStrategy, key: str = 'dataset')[source]#

A LightningDataModule that manages multiple DataLoaders for different stages.

DataModule#

class fkat.data.DataModule(dataloaders: dict[str, dict[str, Any] | collections.abc.Callable[[], collections.abc.Iterable[Any]]], profiler: Optional[Profiler] = None)[source]#

A LightningDataModule that manages multiple DataLoaders for different stages.

Parameters:
  • dataloaders (dict[str, dict[str, Any] | Callable[[], Iterable[Any]]]) – Dataloaders for different stages.

  • profiler (Profiler | None) – Profiler instance for worker initialization.

load_state_dict(state_dict: dict[str, Any]) None[source]#

Called when loading a checkpoint to reload the DataModule’s state for a ShardedDataLoader.

This method iterates over each stage’s dataloader, loads its state from the provided state_dict, and sets the RNG states. If a dataloader does not implement the RestoreStates protocol, it sets its load_state_dict attribute to vanilla_dataloader_load_state_dict to allow loading its state.

Parameters:

state_dict (Dict[str, Any]) – A dictionary containing the dataloader states and RNG states.

on_exception(exception: BaseException) None[source]#

Called when the trainer execution is interrupted by an exception.

on_load_checkpoint(checkpoint: dict[str, Any]) None[source]#

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters:

checkpoint – Loaded checkpoint

Example:

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

on_save_checkpoint(checkpoint: dict[str, Any]) None[source]#

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:

checkpoint – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Example:

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

predict_dataloader() collections.abc.Iterable[Any] | None[source]#

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

prepare_data() None[source]#

Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.

Warning

DO NOT set state to the model (use setup instead) since this is NOT called on every device

Example:

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In a distributed environment, prepare_data can be called in two ways (using prepare_data_per_node)

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.

  2. Once in total. Only called on GLOBAL_RANK=0.

Example:

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True


# call on GLOBAL_RANK=0 (great for shared file systems)
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = False

This is called before requesting the dataloaders:

model.prepare_data()
initialize_distributed()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
setup(stage: str) None[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
state_dict() dict[str, Any][source]#

Called when saving a checkpoint, implement to generate and save datamodule state for ShardedDataLoader.

This method iterates over each stage’s dataloader to retrieve its state using the state_dict() method, and saves it along with the RNG states. If a dataloader does not implement the PersistStates protocol, it sets its state_dict attribute to vanilla_dataloader_state_dict to allow saving its state.

Returns:

A dictionary containing the dataloader states and RNG states.

Return type:

dict[str, Any]

teardown(stage: str | None) None[source]#

Called at the end of fit (train + validate), validate, test, or predict.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

test_dataloader() collections.abc.Iterable[Any] | None[source]#

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

train_dataloader() collections.abc.Iterable[Any] | None[source]#

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader() collections.abc.Iterable[Any] | None[source]#

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

class fkat.data.PersistStates(*args, **kwargs)[source]#
class fkat.data.RestoreStates(*args, **kwargs)[source]#

ShmDataLoader#

class fkat.data.ShmDataLoader(seed: int, dataloader_factory: DataLoaderFactory[T_co], num_microbatch_prefetches: int = -1, dp_rank: int = 0, profiler: Optional[Profiler] = None, device: Optional[device] = None, multiprocessing: bool = True)[source]#

A DataLoader that uses shared memory to efficiently manage and prefetch data batches.

Enables double-buffered micro-batch processing and fetching that overlaps with model forward/backward passes, minimizing dataloading overhead.

Parameters:
  • seed (int) – Random seed for reproducibility. Use ${seed} at top level in config.yaml.

  • dataloader_factory (DataLoaderFactory[T_co]) – Factory for creating DataLoaders.

  • num_microbatch_prefetches (int, optional) – Number of microbatches to prefetch. Defaults to -1.

  • dp_rank (int, optional) – Rank of the current process. Defaults to 0.

  • profiler (Optional[Profiler], optional) – Profiler for profiling. Defaults to None.

  • device (Optional[torch.device]) – device to move the microbatches to in the background

  • multiprocessing (Optional[True]) – whether to instantiate DataLoader in a separate process. Defaults to True to relieve pressure from the training process, use False to debug and profile

ShardedDataLoader#

class fkat.data.ShardedDataLoader(seed: int, shard_sampler: ShardSampler, dataloader_factory: DataLoaderFactory[T_co], num_shard_prefetches: int = 0, num_microbatch_prefetches: int = -1, dp_rank: int = 0, profiler: Optional[Profiler] = None, device: Optional[device] = None, multiprocessing: bool = True)[source]#

A DataLoader that processes data in shards, designed for distributed training scenarios.

Enables double-buffered micro-batch processing and fetching that overlaps with model forward/backward passes, minimizing dataloading overhead.

Parameters:
  • seed (int) – Random seed for reproducibility. Use ${seed} at top level in config.yaml.

  • shard_sampler (ShardSampler) – Sampler for generating shards.

  • dataloader_factory (DataLoaderFactory[T_co]) – Factory for creating DataLoaders.

  • num_shard_prefetches (int, optional) – Number of shards to prefetch. Defaults to 0.

  • num_microbatch_prefetches (int, optional) – Number of microbatches to prefetch. Defaults to -1.

  • dp_rank (int, optional) – Rank of the current process. Defaults to 0.

  • profiler (Profiler, optional) – Profiler for profiling. Defaults to None.

  • device (Optional[torch.device]) – device to move the microbatches to in the background

  • multiprocessing (Optional[True]) – whether to instantiate DataLoader in a separate process. Defaults to True to relieve pressure from the training process, use False to debug and profile

load_state_dict(state_dict: dict[str, Any]) None[source]#

Loads the state dict into the shard sampler, restoring the data shard indices for each rank.

The state dict should look like {“all_rank_indices”: [torch.tensor(1), torch.tensor(1), torch.tensor(5), torch.tensor(6)]}, where each tensor corresponds to the indices of data shards for specific ranks.

state_dict() dict[str, Any][source]#

Returns the shard sampler state dict with adjusted shard indices, accounting for shard prefetches and prefetch backfill in parallel.

Example: If num_shard_prefetches is 3 and the original state dict is {“all_rank_indices”: [torch.tensor(4), torch.tensor(5)]}, it will be updated to {“all_rank_indices”: [torch.tensor(0), torch.tensor(1)]}. This ensures that each rank resumes training from the correct shard index, preventing reprocessing of shards that have already been trained on.