profiling#

class fkat.pytorch.callbacks.profiling.Flops(schedule: Optional[Schedule] = None, *args: Any, **kwargs: Any)[source]#

A PyTorch Lightning callback that measures and logs floating-point operations (FLOPs) and Model FLOP Utilization (MFU) during training, validation, testing, and prediction.

This callback helps to monitor the computational efficiency of models by measuring: - Total machine FLOPs available - Per-batch FLOPs used by the model - Model FLOP Utilization (MFU), i.e., how efficiently the model uses the available compute - Batch throughput (batches per second)

It supports two methods for estimating FLOPs: 1. Tracing-based estimation via a Trace context manager. 2. Formula-based estimation using a predefined GPTModel FLOP calculator.

Metrics are logged periodically (or once) to the experiment logger (e.g., MLflow) and include:

  • mfu: Model FLOP Utilization (traced)

  • actual_batches_per_sec: Measured throughput

  • max_batches_per_sec: Theoretical max throughput

  • batch_flops: FLOPs used in the current batch

  • batch_flops_from_formula: FLOPs estimated via formula (if available)

  • mfu_from_formula: MFU based on formula-based estimation

  • tracked_operations: Number of FLOPs tracked during tracing

  • untracked_operations: Number of operations not accounted for by the tracer

Parameters:

schedule (Optional[Schedule]) – Controls when logging occurs during training. Defaults to Every 5 batch. - FLOPs are always calculated at least once at the beginning.

Example

>>> trainer = L.Trainer(callbacks=[Flops(log_every_n_batches=10)])
on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch begins.

on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch begins.

on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch begins.

on_test_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch begins.

on_train_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) None[source]#

Called at the end of each training batch.

This method finalizes FLOP tracing and logs performance metrics if applicable.

If conditions are met, it: - Ends FLOP tracing and calculates batch-level FLOPs. - Aggregates tracked and untracked operations across devices. - Computes Model FLOP Utilization (MFU) based on actual vs. theoretical throughput. - Optionally estimates FLOPs using a formula-based approach (GPTModel). - Logs performance metrics (e.g., MFU, throughput, FLOPs) to the experiment logger.

Logging is only performed on the global rank 0 process and is skipped during sanity checks.

Parameters:
  • trainer (Trainer) – The PyTorch Lightning trainer instance.

  • pl_module (LightningModule) – The model being trained.

  • outputs (STEP_OUTPUT) – The outputs from the training step.

  • batch (Any) – The current batch of data.

  • *args – Additional positional arguments.

  • **kwargs – Additional keyword arguments.

on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None[source]#

Called at the beginning of each training batch.

This method initiates FLOP tracing and throughput timing for the current batch, depending on the logging frequency.

If conditions are met, it: - Begins FLOP tracing using a Trace context manager. - Records the start time to later compute batch throughput.

Tracing and logging are skipped during sanity checks.

Parameters:
  • trainer (Trainer) – The PyTorch Lightning trainer instance.

  • pl_module (LightningModule) – The model being trained.

  • batch (Any) – The current batch of data.

  • batch_idx (int) – Index of the current batch.

  • *args – Additional positional arguments.

  • **kwargs – Additional keyword arguments.

on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch begins.

on_validation_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

class fkat.pytorch.callbacks.profiling.Memray(ranks: Optional[Sequence[int]] = None, flamegraph: bool = False, output_path_prefix: Optional[str] = None, schedule: Optional[Schedule] = None, compress: bool = False, **kwargs: Any)[source]#
on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#

Called when any trainer execution is interrupted by an exception.

on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch ends.

on_predict_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the predict begins.

on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch ends.

on_test_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test begins.

on_train_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) None[source]#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the train begins.

on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch ends.

on_validation_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation loop begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

teardown(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune ends.

class fkat.pytorch.callbacks.profiling.PyTorch(ranks: Optional[Sequence[int]] = None, output_path_prefix: Optional[str] = None, schedule: Optional[Schedule] = None, compress: bool = True, **kwargs: Any)[source]#
on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#

Called when any trainer execution is interrupted by an exception.

on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch ends.

on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch ends.

on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int) None[source]#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch ends.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

teardown(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune ends.

class fkat.pytorch.callbacks.profiling.VizTracer(ranks: Optional[Sequence[int]] = None, output_path_prefix: Optional[str] = None, schedule: Optional[Schedule] = None, compress: bool = False, patch: bool = False, **kwargs: Any)[source]#
on_exception(trainer: Trainer, pl_module: LightningModule, exception: BaseException) None[source]#

Called when any trainer execution is interrupted by an exception.

on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch ends.

on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch ends.

on_train_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int) None[source]#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch ends.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

teardown(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune ends.