# wrap_uma.py
from fairchem.core import FAIRChemCalculator
from fairchem.core.units.mlip_unit import load_predict_unit

def wrap_uma(ckpt="uma-s-1.pt", task_name="omat", device="cuda"):
    predictor = load_predict_unit(ckpt, device=device)
    calc = FAIRChemCalculator(predictor, task_name=task_name)
    return calc
