logging#

class fkat.pytorch.callbacks.logging.Heartbeat(schedule: Optional[Schedule] = None, last_check_in_time_tag: str = 'last_check_in_time', last_check_in_step_tag: str = 'last_check_in_step')[source]#

Publishes tags indicating the time and step of the last heartbeat with the provided schedule.

on_predict_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch begins.

on_test_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch begins.

on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None[source]#

Called when the train batch begins.

on_validation_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

class fkat.pytorch.callbacks.logging.Throughput(dp_ranks: Optional[int] = None, schedule: Optional[Schedule] = None)[source]#
on_before_zero_grad(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None[source]#

Report metrics for individual steps during training.

on_predict_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the predict batch ends.

on_predict_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when predict ends.

on_predict_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the predict epoch ends.

on_predict_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the predict epoch begins.

on_test_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the test batch ends.

on_test_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test ends.

on_test_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test epoch ends.

on_test_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the test epoch begins.

on_train_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the train ends.

on_train_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_train_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the train epoch begins.

on_validation_batch_end(trainer: L.Trainer, pl_module: L.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) None[source]#

Called when the validation batch ends.

on_validation_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation loop ends.

on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the val epoch ends.

on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the val epoch begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

class fkat.pytorch.callbacks.logging.ValidationMetrics(output_path: Optional[str] = None)[source]#

Saves validation metrics after each validation epoch.

This callback persists validation metrics in JSON format, creating both a versioned file (with epoch and step) and a “latest” file for easy access.

Parameters:

output_path (str | None) – Directory path where metrics will be saved. Supports any fsspec-compatible filesystem (local, s3://, gcs://, etc.). If None, logs to MLflow artifacts. Defaults to None.

Example

>>> # MLflow artifacts (default)
>>> callback = ValidationMetrics()
>>> # Local storage
>>> callback = ValidationMetrics(output_path="/tmp/metrics")
>>> # S3 storage
>>> callback = ValidationMetrics(output_path="s3://my-bucket/metrics")
>>> trainer = L.Trainer(callbacks=[callback])
on_validation_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the validation loop ends.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.