Source code for fkat.predict
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#!/usr/bin/env python
"""
The ``fkat.predict`` entrypoint processes the provided config,
instantiates the ``trainer``, ``model`` and ``data`` sections and calls ``trainer.predict()``.
"""
import hydra
import lightning as L
from omegaconf import DictConfig
from fkat import initialize, run_main
[docs]@hydra.main(version_base="1.3")
def main(cfg: DictConfig) -> None:
s = initialize(cfg)
kwargs = {
"ckpt_path": s.ckpt_path,
"return_predictions": s.return_predictions,
}
if isinstance(s.data, L.LightningDataModule):
kwargs["datamodule"] = s.data
else:
kwargs["predict_dataloader"] = s.data.predict_dataloader() if s.data else None
s.trainer.predict(s.model, **kwargs)
if __name__ == "__main__":
run_main(main)