cuda#
- class fkat.pytorch.callbacks.cuda.EmptyCache(schedule: Optional[Schedule] = None)[source]#
- maybe_empty_cache(trainer: Trainer, stage: str, batch_idx: Optional[int] = None) None[source]#
Perform empty cache if conditions are met.
- Parameters:
trainer (L.Trainer) – Lightning Trainer
stage (str) – training stage
batch_idx (int | None) – Current batch index if available
- on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Perform empty cache after prediction batch if needed.
- on_predict_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Perform empty cache after predict epoch.
- on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Perform empty cache after test batch if needed.
- on_test_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Perform empty cache after test epoch.
- on_train_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) None[source]#
Perform empty cache after training batch if needed.
- on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Perform empty cache after training epoch.
- class fkat.pytorch.callbacks.cuda.MemoryObserver(flamegraph: bool = True, reset_memory_history: bool = False, snapshot_pickle: bool = False, tensor_cycles: bool = False, schedule: Optional[Schedule] = None, oom: bool = True, **kwargs: Any)[source]#
This callback registers an observer to dump and log the CUDA memory snapshot.
- Parameters:
oom – (bool): whether to dump memory snapshot on Out-of-Memory (OOM) event. Defaults to
Trueflamegraph (bool) – whether to save memory snapshot in flamegraph format. Defaults to
Truereset_memory_history (bool) – whether to reset memory history after snapshot. Defaults to
Falsesnapshot_pickle (bool) – whether to dump memory snapshot in pickle format. Defaults to
Falsetensor_cycles (bool) – whether to detect and dump graphs with cycles containing tensors in the garbage. Defaults to
False.schedule (Optional[Schedule]) – Controls when logging occurs besides OOM event. Defaults to
Never**kwargs (Any) – Arbitrary keyword arguments passed as is to
memory._record_memory_history.
- maybe_dump_memory_snapshot(trainer: Trainer, stage: Optional[str] = None, batch_idx: Optional[int] = None) None[source]#
- on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the predict batch begins.
- on_test_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the test batch begins.
- on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None[source]#
Called when the train batch begins.
- class fkat.pytorch.callbacks.cuda.Nsys(ranks: Optional[Sequence[int]] = None, output_path_prefix: Optional[str] = None, schedule: Optional[Schedule] = None, compress: bool = True, record_shapes: bool = False, **kwargs: Any)[source]#
- on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#
Called when any trainer execution is interrupted by an exception.
- on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the predict batch ends.
- on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the test batch ends.
- on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int) None[source]#
Called when the train batch ends.
Note
The value
outputs["loss"]here will be the normalized value w.r.taccumulate_grad_batchesof the loss returned fromtraining_step.
- on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the validation batch ends.
- class fkat.pytorch.callbacks.cuda.Nvtx[source]#
- load_state_dict(state_dict: dict[str, Any]) None[source]#
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict.- Parameters:
state_dict – the callback state returned by
state_dict.
- on_after_backward(trainer: Trainer, pl_module: LightningModule) None[source]#
Called after
loss.backward()and before optimizers are stepped.
- on_before_backward(trainer: Trainer, pl_module: LightningModule, loss: Tensor) None[source]#
Called before
loss.backward().
- on_before_optimizer_step(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None[source]#
Called before
optimizer.step().
- on_before_zero_grad(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None[source]#
Called before
optimizer.zero_grad().
- on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#
Called when any trainer execution is interrupted by an exception.
- on_load_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None[source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainerinstance.pl_module – the current
LightningModuleinstance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.
- on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the predict batch ends.
- on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the predict batch begins.
- on_predict_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when predict ends.
- on_predict_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the predict epoch ends.
- on_predict_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the predict epoch begins.
- on_predict_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the predict begins.
- on_sanity_check_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the validation sanity check ends.
- on_sanity_check_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the validation sanity check starts.
- on_save_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None[source]#
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer – the current
Trainerinstance.pl_module – the current
LightningModuleinstance.checkpoint – the checkpoint dictionary that will be saved.
- on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the test batch ends.
- on_test_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the test batch begins.
- on_test_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the test epoch ends.
- on_test_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the test epoch begins.
- on_test_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the test begins.
- on_train_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the train batch ends.
Note
The value
outputs["loss"]here will be the normalized value w.r.taccumulate_grad_batchesof the loss returned fromtraining_step.
- on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None[source]#
Called when the train batch begins.
- on_train_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the train ends.
- on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
lightning.pytorch.core.LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the train epoch begins.
- on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the train begins.
- on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the validation batch ends.
- on_validation_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the validation batch begins.
- on_validation_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the validation loop ends.
- on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the val epoch ends.
- on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the val epoch begins.
- on_validation_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the validation loop begins.
- setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#
Called when fit, validate, test, predict, or tune begins.
- class fkat.pytorch.callbacks.cuda.Xid(actions: dict[str, fkat.pytorch.actions.LightningAction], schedule: Schedule)[source]#
A callback to monitor and log Xid errors in a separate process during training.
It utilizes a separate process to monitor these errors, ensuring that the main training process remains unaffected. The monitoring process is started at the beginning of training and terminated either upon an exception in training or at the end of the training/validation stage.
- monitor: multiprocessing.Process | None = None#
- on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#
Callback method to handle exceptions during training.
If an exception occurs during the training process, this method ensures that the Xid error monitoring process is terminated to prevent resource leakage.
- Parameters:
trainer (L.Trainer) – The PyTorch Lightning Trainer instance.
pl_module (L.LightningModule) – The PyTorch Lightning module being trained.
exception (BaseException) – The exception that occurred during training.
- Returns:
None.
- on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the predict batch begins.
- on_test_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the test batch begins.
- on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None[source]#
Called when the train batch begins.
- on_validation_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#
Called when the validation batch begins.
- setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#
Initializes the Xid error monitoring process at the start of the training stage.
This method is automatically invoked by the PyTorch Lightning framework. It starts a separate background process dedicated to monitoring Xid errors.
- Parameters:
trainer (L.Trainer) – The PyTorch Lightning Trainer instance.
module (L.LightningModule) – The PyTorch Lightning module being trained.
stage (str) – The stage of the training process (e.g., ‘fit’, ‘test’).
- Returns:
None.
- teardown(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#
Ensures the Xid error monitoring process is terminated at the end of training.
This method is automatically called by the PyTorch Lightning framework at the end of the training or validation stage to clean up the monitoring process.
- Parameters:
trainer (L.Trainer) – The PyTorch Lightning Trainer instance.
module (L.LightningModule) – The PyTorch Lightning module being trained.
stage (str) – The stage of the training process (e.g., ‘fit’, ‘test’).
- Returns:
None.