dp#
- class fkat.pytorch.callbacks.monitoring.dp.DistDpGroup(dp_size: int)[source]#
Calculates DP group info based on dp_size using distributed rank.
- class fkat.pytorch.callbacks.monitoring.dp.DpGroupStrategy(*args, **kwargs)[source]#
Protocol for getting DP group info for the current rank.
- class fkat.pytorch.callbacks.monitoring.dp.DpSyncMonitor(dp_group: DpGroupStrategy, schedule: Optional[Schedule] = None)[source]#
Monitors time for each DP group to reach synchronization point. Measures from batch start to before optimizer step to identify slow/fast groups.
- on_before_optimizer_step(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None[source]#
End timing when ready for sync (before optimizer step) and log if needed.