
# TrainingSet with precomputed Energy Forces (and Stress if present and desired) data. This can be
# one or more TrainingSets
training_set = nlread('My_TrainingSet_Data.hdf5', TrainingSet)

# Fetch the calculator from the training set. If not present, it should be retrieved from
# elsewhere/set up analogously to calculate isolated atom energies - unless these are already known
# for all present atoms with the given calculator.
calculator = training_set.calculator()

# Setup MACE parameters for training with Multihead Replay Finetuning
mace_fp = MACEFittingParameters(
    # Name of the model - has to be set
    experiment_name='mace_experiment3',

    # Weights for the different parts of the loss term and type of loss
    energy_weight=1.0,
    forces_weight=100.0,
    stress_weight=1.0,
    loss_function_type=MACEParameterOptions.LOSS_TYPE.UNIVERSAL,

    # Most relevant parameters regarding training setup
    validation_fraction=0.2, # Ratio of training data used for validation
    max_number_of_epochs=20, # Number of epochs to train for - keep low for finetuning
    batch_size=4, # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory errors
    validation_batch_size=4, # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory errors
    random_seed=123, # Can be used for experimenting with the influence on different data splits
    default_dtype=MACEParameterOptions.DTYPE.FLOAT64, # Can be adjusted for shifting tradeoff between model accuracy and speed
    device=Automatic, # Device can be set explicitly but Automatic is recommended to automatically use the GPU if available

    # When using multihead replay finetuning on MP foundation models
    use_multiheads_finetuning=True,
    foundation_model_path='/path/to/model/data/on/run/machine/mace_agnesi_small.model',
    number_of_samples_from_pretrained_head=15000, # Number of samples from the original model training data to use for the multihead replay finetuning
    pretrained_head_train_file='mp', # Use 'mp' for MP foundation type model. Use path to xyz file with training data for other custom models
    mp_data_path='/path/to/model/data/on/run/machine/mp_traj_combined.xyz',
    mp_descriptors_path='/path/to/model/data/on/run/machine/descriptors.npy',
)

# Setup ML model training object
mlfft = MachineLearnedForceFieldTrainer(
    fitting_parameters=mace_fp,
    training_sets=training_set,
    calculator=calculator
)

# Run the training
mlfft.train()


# Load the model for evaluation use in QuantumATK
model_path = 'mace_experiment3.qatkpt'

potentialSet = TremoloXPotentialSet(name='my_mace_potential3')
potential = TorchXPotential(
    dtype='float64',
    device='cuda',
    file=model_path,
)
for symbol in potential.get_symbols(model_path):
    potentialSet.addParticleType(ParticleType(symbol))
potentialSet.addPotential(potential)

calculator = TremoloXCalculator(parameters=potentialSet)

# Define configuration and attach the calculator
# configuration = ...
# configuration.setCalculator(calculator)
# ...
