cuda

Contents

cuda#

class fkat.pytorch.callbacks.cuda.EmptyCache(schedule: Optional[Schedule] = None)[source]#
maybe_empty_cache(trainer: Trainer, stage: str, batch_idx: Optional[int] = None) None[source]#

Perform empty cache if conditions are met.

Parameters:
  • trainer (L.Trainer) – Lightning Trainer

  • stage (str) – training stage

  • batch_idx (int | None) – Current batch index if available

on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Perform empty cache after prediction batch if needed.

on_predict_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Perform empty cache after predict epoch.

on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Perform empty cache after test batch if needed.

on_test_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Perform empty cache after test epoch.

on_train_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) None[source]#

Perform empty cache after training batch if needed.

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Perform empty cache after training epoch.

on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Perform empty cache after validation batch if needed.

on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Perform empty cache after validation epoch.

class fkat.pytorch.callbacks.cuda.MemoryObserver(flamegraph: bool = True, reset_memory_history: bool = False, snapshot_pickle: bool = False, tensor_cycles: bool = False, schedule: Optional[Schedule] = None, oom: bool = True, **kwargs: Any)[source]#

This callback registers an observer to dump and log the CUDA memory snapshot.

Parameters:
  • oom – (bool): whether to dump memory snapshot on Out-of-Memory (OOM) event. Defaults to True

  • flamegraph (bool) – whether to save memory snapshot in flamegraph format. Defaults to True

  • reset_memory_history (bool) – whether to reset memory history after snapshot. Defaults to False

  • snapshot_pickle (bool) – whether to dump memory snapshot in pickle format. Defaults to False

  • tensor_cycles (bool) – whether to detect and dump graphs with cycles containing tensors in the garbage. Defaults to False.

  • schedule (Optional[Schedule]) – Controls when logging occurs besides OOM event. Defaults to Never

  • **kwargs (Any) – Arbitrary keyword arguments passed as is to memory._record_memory_history.

dump_memory_snapshot(rank: int) None[source]#
maybe_dump_memory_snapshot(trainer: Trainer, stage: Optional[str] = None, batch_idx: Optional[int] = None) None[source]#
on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch begins.

on_test_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch begins.

on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None[source]#

Called when the train batch begins.

on_validation_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

class fkat.pytorch.callbacks.cuda.Nsys(ranks: Optional[Sequence[int]] = None, output_path_prefix: Optional[str] = None, schedule: Optional[Schedule] = None, compress: bool = True, record_shapes: bool = False, **kwargs: Any)[source]#
on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#

Called when any trainer execution is interrupted by an exception.

on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch ends.

on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch ends.

on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int) None[source]#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch ends.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

teardown(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune ends.

class fkat.pytorch.callbacks.cuda.Nvtx[source]#
load_state_dict(state_dict: dict[str, Any]) None[source]#

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict – the callback state returned by state_dict.

on_after_backward(trainer: Trainer, pl_module: LightningModule) None[source]#

Called after loss.backward() and before optimizers are stepped.

on_before_backward(trainer: Trainer, pl_module: LightningModule, loss: Tensor) None[source]#

Called before loss.backward().

on_before_optimizer_step(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None[source]#

Called before optimizer.step().

on_before_zero_grad(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None[source]#

Called before optimizer.zero_grad().

on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#

Called when any trainer execution is interrupted by an exception.

on_load_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None[source]#

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the full checkpoint dictionary that got loaded by the Trainer.

on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch ends.

on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch begins.

on_predict_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when predict ends.

on_predict_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the predict epoch ends.

on_predict_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the predict epoch begins.

on_predict_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the predict begins.

on_sanity_check_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation sanity check ends.

on_sanity_check_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation sanity check starts.

on_save_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) None[source]#

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the checkpoint dictionary that will be saved.

on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch ends.

on_test_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch begins.

on_test_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test ends.

on_test_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test epoch ends.

on_test_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test epoch begins.

on_test_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test begins.

on_train_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None[source]#

Called when the train batch begins.

on_train_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the train ends.

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the train epoch begins.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the train begins.

on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch ends.

on_validation_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch begins.

on_validation_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation loop ends.

on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the val epoch ends.

on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the val epoch begins.

on_validation_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation loop begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

state_dict() dict[str, Any][source]#

Called when saving a checkpoint, implement to generate callback’s state_dict.

Returns:

A dictionary containing callback state.

teardown(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune ends.

class fkat.pytorch.callbacks.cuda.Xid(actions: dict[str, fkat.pytorch.actions.LightningAction], schedule: Schedule)[source]#

A callback to monitor and log Xid errors in a separate process during training.

It utilizes a separate process to monitor these errors, ensuring that the main training process remains unaffected. The monitoring process is started at the beginning of training and terminated either upon an exception in training or at the end of the training/validation stage.

check(trainer: Trainer, stage: str, batch_idx: int) None[source]#
monitor: multiprocessing.Process | None = None#
on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#

Callback method to handle exceptions during training.

If an exception occurs during the training process, this method ensures that the Xid error monitoring process is terminated to prevent resource leakage.

Parameters:
  • trainer (L.Trainer) – The PyTorch Lightning Trainer instance.

  • pl_module (L.LightningModule) – The PyTorch Lightning module being trained.

  • exception (BaseException) – The exception that occurred during training.

Returns:

None.

on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch begins.

on_test_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch begins.

on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None[source]#

Called when the train batch begins.

on_validation_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Initializes the Xid error monitoring process at the start of the training stage.

This method is automatically invoked by the PyTorch Lightning framework. It starts a separate background process dedicated to monitoring Xid errors.

Parameters:
  • trainer (L.Trainer) – The PyTorch Lightning Trainer instance.

  • module (L.LightningModule) – The PyTorch Lightning module being trained.

  • stage (str) – The stage of the training process (e.g., ‘fit’, ‘test’).

Returns:

None.

teardown(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Ensures the Xid error monitoring process is terminated at the end of training.

This method is automatically called by the PyTorch Lightning framework at the end of the training or validation stage to clean up the monitoring process.

Parameters:
  • trainer (L.Trainer) – The PyTorch Lightning Trainer instance.

  • module (L.LightningModule) – The PyTorch Lightning module being trained.

  • stage (str) – The stage of the training process (e.g., ‘fit’, ‘test’).

Returns:

None.