
# TrainingSet with precomputed Energy Forces (and Stress if present and desired) data. This can be
# one or more TrainingSets
training_set = nlread('My_TrainingSet_Data.hdf5', TrainingSet)

# Fetch the calculator from the training set. If not present, it should be retrieved from
# elsewhere/set up analogously to calculate isolated atom energies - unless these are already known
# for all present atoms with the given calculator.
calculator = training_set.calculator()

# Setup MACE parameters for training with Naive Finetuning
mace_fp = MACEFittingParameters(
    # Name of the model - has to be set
    experiment_name='mace_experiment3',

    # Weights for the different parts of the loss term and type pf loss
    energy_weight=1.0,
    forces_weight=100.0,
    stress_weight=1.0,
    loss_function_type=MACEParameterOptions.LOSS_TYPE.UNIVERSAL,

    # Most relevant parameters regarding training setup
    validation_fraction=0.2, # Ratio of training data used for validation
    max_number_of_epochs=30, # Number of epochs to train for - keep low for finetuning
    batch_size=4, # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory errors
    validation_batch_size=4, # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory errors
    random_seed=123, # Can be used for experimenting with the influence on different data splits
    default_dtype=MACEParameterOptions.DTYPE.FLOAT64, # Can be adjusted for shifting tradeoff between model accuracy and training speed
    device=Automatic, # Device can be set explicitly but Automatic is recommended to automatically use the GPU if available

    # When using naive finetuning on a custom model (or on a MP foundation model)
    foundation_model_path='/path/to/model/data/on/run/machine/my_previous_mace_model.model',
)

# Setup ML model training object
mlfft = MachineLearnedForceFieldTrainer(
    fitting_parameters=mace_fp,
    training_sets=training_set,
    calculator=calculator
)

# Run the training
mlfft.train()


# Load the model for evaluation use in QuantumATK
model_path = 'mace_experiment2.qatkpt'

potentialSet = TremoloXPotentialSet(name='my_mace_potential2')
potential = TorchXPotential(
    dtype='float64',
    device='cuda',
    file=model_path,
)
for symbol in potential.get_symbols(model_path):
    potentialSet.addParticleType(ParticleType(symbol))
potentialSet.addPotential(potential)

calculator = TremoloXCalculator(parameters=potentialSet)

# Define configuration and attach the calculator
# configuration = ...
# configuration.setCalculator(calculator)
# ...
