data#
This module provides specialized data loading components for efficient handling and processing of datasets.
DictDataLoader#
- class fkat.data.DictDataLoader(dataloaders: dict[str, collections.abc.Iterable[dict[str, Any]]], strategy: SamplerStrategy, key: str = 'dataset')[source]#
A
LightningDataModulethat manages multipleDataLoaders for different stages.
DataModule#
- class fkat.data.DataModule(dataloaders: dict[str, dict[str, Any] | collections.abc.Callable[[], collections.abc.Iterable[Any]]], profiler: Optional[Profiler] = None)[source]#
A
LightningDataModulethat manages multipleDataLoaders for different stages.- Parameters:
dataloaders (dict[str, dict[str, Any] | Callable[[], Iterable[Any]]]) – Dataloaders for different stages.
profiler (Profiler | None) – Profiler instance for worker initialization.
- load_state_dict(state_dict: dict[str, Any]) None[source]#
Called when loading a checkpoint to reload the DataModule’s state for a ShardedDataLoader.
This method iterates over each stage’s dataloader, loads its state from the provided state_dict, and sets the RNG states. If a dataloader does not implement the RestoreStates protocol, it sets its load_state_dict attribute to vanilla_dataloader_load_state_dict to allow loading its state.
- Parameters:
state_dict (Dict[str, Any]) – A dictionary containing the dataloader states and RNG states.
- on_exception(exception: BaseException) None[source]#
Called when the trainer execution is interrupted by an exception.
- on_load_checkpoint(checkpoint: dict[str, Any]) None[source]#
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()this is your chance to restore this.- Parameters:
checkpoint – Loaded checkpoint
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
- on_save_checkpoint(checkpoint: dict[str, Any]) None[source]#
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
checkpoint – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- predict_dataloader() collections.abc.Iterable[Any] | None[source]#
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data().predict()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoaderor a sequence of them specifying prediction samples.
- prepare_data() None[source]#
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.
Warning
DO NOT set state to the model (use
setupinstead) since this is NOT called on every deviceExample:
def prepare_data(self): # good download_data() tokenize() etc() # bad self.split = data_split self.some_state = some_other_state()
In a distributed environment,
prepare_datacan be called in two ways (using prepare_data_per_node)Once per node. This is the default and is only called on LOCAL_RANK=0.
Once in total. Only called on GLOBAL_RANK=0.
Example:
# DEFAULT # called once per node on LOCAL_RANK=0 of that node class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = True # call on GLOBAL_RANK=0 (great for shared file systems) class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = False
This is called before requesting the dataloaders:
model.prepare_data() initialize_distributed() model.setup(stage) model.train_dataloader() model.val_dataloader() model.test_dataloader() model.predict_dataloader()
- setup(stage: str) None[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- state_dict() dict[str, Any][source]#
Called when saving a checkpoint, implement to generate and save datamodule state for ShardedDataLoader.
This method iterates over each stage’s dataloader to retrieve its state using the state_dict() method, and saves it along with the RNG states. If a dataloader does not implement the PersistStates protocol, it sets its state_dict attribute to vanilla_dataloader_state_dict to allow saving its state.
- Returns:
A dictionary containing the dataloader states and RNG states.
- Return type:
dict[str, Any]
- teardown(stage: str | None) None[source]#
Called at the end of fit (train + validate), validate, test, or predict.
- Parameters:
stage – either
'fit','validate','test', or'predict'
- test_dataloader() collections.abc.Iterable[Any] | None[source]#
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.
- train_dataloader() collections.abc.Iterable[Any] | None[source]#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader() collections.abc.Iterable[Any] | None[source]#
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
ShmDataLoader#
- class fkat.data.ShmDataLoader(seed: int, dataloader_factory: DataLoaderFactory[T_co], num_microbatch_prefetches: int = -1, dp_rank: int = 0, profiler: Optional[Profiler] = None, device: Optional[device] = None, multiprocessing: bool = True)[source]#
A
DataLoaderthat uses shared memory to efficiently manage and prefetch data batches.Enables double-buffered micro-batch processing and fetching that overlaps with model forward/backward passes, minimizing dataloading overhead.
- Parameters:
seed (int) – Random seed for reproducibility. Use ${seed} at top level in config.yaml.
dataloader_factory (DataLoaderFactory[T_co]) – Factory for creating DataLoaders.
num_microbatch_prefetches (int, optional) – Number of microbatches to prefetch. Defaults to -1.
dp_rank (int, optional) – Rank of the current process. Defaults to 0.
profiler (Optional[Profiler], optional) – Profiler for profiling. Defaults to None.
device (Optional[torch.device]) – device to move the microbatches to in the background
multiprocessing (Optional[True]) – whether to instantiate DataLoader in a separate process. Defaults to True to relieve pressure from the training process, use False to debug and profile