Source code for fkat.pytorch.callbacks.debugging.optimizer

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from typing_extensions import override
import datetime as dt

import torch
import fsspec
import lightning as L

from fkat.pytorch.schedule import Schedule, Never


[docs]class OptimizerSnapshot(L.Callback): """ Callback that saves optimizer state at specified intervals during training. This callback allows you to capture the state of optimizers at specific points during training, which can be useful for debugging, analysis, or resuming training from specific optimization states. Args: output_path_prefix (str): Output path prefix for generated optimizer snapshots. schedule (Optional[Schedule]): Schedule at which to take a snapshot of optimizers. Defaults to ``Never`` """ def __init__( self, output_path_prefix: str, schedule: Schedule | None = None, ) -> None: self.output_path_prefix = output_path_prefix self.schedule = schedule or Never()
[docs] @override def on_train_batch_start( self, trainer: "L.Trainer", pl_module: "L.LightningModule", batch: Any, batch_idx: int ) -> None: if self.schedule.check(trainer=trainer, stage="train", batch_idx=batch_idx, step=trainer.global_step): timestamp = dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ") for i, opt in enumerate(trainer.optimizers): path = f"{self.output_path_prefix}rank{trainer.global_rank}_opt{i}_{timestamp}.pt" with fsspec.open(path, "wb", makedirs=True) as f: torch.save(opt, f)