dp#

class fkat.pytorch.callbacks.monitoring.dp.DistDpGroup(dp_size: int)[source]#

Calculates DP group info based on dp_size using distributed rank.

dp_group_info() tuple[int, int][source]#

Return (group_id, rank_in_group) for the current rank.

class fkat.pytorch.callbacks.monitoring.dp.DpGroupStrategy(*args, **kwargs)[source]#

Protocol for getting DP group info for the current rank.

dp_group_info() tuple[int, int][source]#

Return (group_id, rank_in_group) for the current rank.

class fkat.pytorch.callbacks.monitoring.dp.DpSyncMonitor(dp_group: DpGroupStrategy, schedule: Optional[Schedule] = None)[source]#

Monitors time for each DP group to reach synchronization point. Measures from batch start to before optimizer step to identify slow/fast groups.

on_before_optimizer_step(trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer) None[source]#

End timing when ready for sync (before optimizer step) and log if needed.

on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) None[source]#

Start timing when batch processing begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Called when fit, validate, test, predict, or tune begins.

class fkat.pytorch.callbacks.monitoring.dp.EnvDpGroup(dp_size: int)[source]#

Calculates DP group info based on dp_size using environment variables.

dp_group_info() tuple[int, int][source]#

Return (group_id, rank_in_group) for the current rank.

class fkat.pytorch.callbacks.monitoring.dp.MegatronDpGroup(*args, **kwargs)[source]#

Gets DP group info from Megatron parallel_state.

dp_group_info() tuple[int, int][source]#

Return (group_id, rank_in_group) for the current rank.