Source code for fkat
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import os
import sys
import asyncio
from collections.abc import Callable
import hydra
import omegaconf as oc
import lightning as L
import torch
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import record
from fkat.utils import config, pdb
from fkat.utils.logging import rank0_logger
from fkat.utils.config import SingletonResolver
log = rank0_logger(__name__)
[docs]def run_main(main: Callable[[], None]) -> None:
patch_args()
@record
async def async_main() -> None:
try:
main()
except Exception as e:
import traceback
traceback.print_tb(e.__traceback__)
raise e
asyncio.run(async_main()) # type: ignore[arg-type]
[docs]def patch_args() -> None:
"""
In case we need to pass wildcard arguments (e.g. overrides) as expected by Hydra,
but the runtime only allows named arguments we pass them using a bogus "--overrides" flag.
This function will take care of removing this flag by the time we call Hydra.
"""
overrides_pos = -1
for i, a in enumerate(sys.argv):
if a == "--overrides":
overrides_pos = i
break
if overrides_pos >= 0:
overrides = sys.argv[overrides_pos + 1] if overrides_pos + 1 < len(sys.argv) else ""
# skipping overrides when constructing new args, there could be more args ahead
sys.argv = sys.argv[:overrides_pos] + (
sys.argv[overrides_pos + 2 :] if overrides_pos + 2 < len(sys.argv) else []
)
if overrides:
# adding overrides at the end
sys.argv.extend(overrides.split(" "))
[docs]def setup(
cfg: oc.DictConfig | None = None,
print_config: bool = False,
multiprocessing: str = "spawn",
seed: int | None = None,
post_mortem: bool = False,
determinism: bool = False,
resolvers: dict[str, "oc.Resolver"] | None = None,
) -> SingletonResolver:
"""Setup the training environment.
Args:
cfg (oc.OmegaConf | None): Full configuration
print_config (bool): Whether to print configuration to output. Defaults to ``False``
multiprocessing (str): Multiprocessing mode. Defaults to ``spawn``
seed (int | None): Random number generator seed to start off when set
post_mortem (bool): Whether to start pdb debugger when an uncaught exception encoutered. Defaults to ``False``
determinism (bool): Whether to enforce deterministric algorithms. Defaults to ``False```
resolvers (dict[str, oc.Resover] | None): Custom resolvers to register for configuration processing
Returns:
:class:`SingletonResolver` object that holds initialized data, trainer, model, etc.
"""
if print_config:
log.info(config.to_str(cfg))
mp.set_start_method(multiprocessing, force=True)
if seed:
L.seed_everything(seed)
if post_mortem:
pdb.post_mortem()
if determinism: # Enable deterministic algorithms globally
assert seed is not None, "seed has to be set for deterministic runs"
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
deterministic_env_vars = {
"CUBLAS_WORKSPACE_CONFIG": [":16:8", ":4096:8"],
"NCCL_ALGO": ["^NVLS"],
"NVTE_ALLOW_NONDETERMINISTIC_ALGO": ["0"],
}
for var, vals in deterministic_env_vars.items():
if (val := os.environ.get(var, vals[-1])) not in vals:
raise ValueError(f"{var} has to be set to one of {vals} for deterministic runs, got: {val}")
os.environ[var] = val
for rn, fn in (resolvers or {}).items():
oc.OmegaConf.register_new_resolver(rn, fn, replace=True)
s = config.register_singleton_resolver()
return s
[docs]def initialize(cfg: oc.DictConfig) -> SingletonResolver:
"""Initialize data, model and trainer with supplied configurations.
Args:
cfg (oc.DictConfig): Configurations supplied by user through yaml file.
Returns:
:class:`SingletonResolver` object that holds initialized data, trainer, model, etc.
"""
# 0. setup the training environment
s = setup(cfg, **(hydra.utils.instantiate(cfg["setup"]) if "setup" in cfg else {}))
# 1. instantiate `trainer`
s.trainer = hydra.utils.instantiate(cfg.trainer)
# 2. instantiate optional `data`
s.data = hydra.utils.instantiate(cfg.get("data"))
# 3. instantiate `model` after `trainer`
s.model = hydra.utils.instantiate(cfg.model)
# 4. obtain optional `ckpt_path`, `return_predictions` after `model`
s.ckpt_path = hydra.utils.call(cfg.get("ckpt_path"))
s.return_predictions = hydra.utils.call(cfg.get("return_predictions"))
# 5. save and upload the config
config.save(cfg, s.trainer)
# 6. run tuners
s.tuners = hydra.utils.instantiate(cfg.get("tuners"))
return s