from NL.TrainMLModel.MachineLearnedForceFieldTrainer.MACE.MACEFittingParameters import MACEParameterOptions

# TrainingSet with precomputed Energy Forces (and Stress if present and desired) data. This can be
# one or more TrainingSets
training_set = nlread('My_TrainingSet_Data.hdf5', TrainingSet)

# Fetch the calculator from the training set. If not present, it should be retrieved from
# elsewhere/set up analogously to calculate isolated atom energies - unless these are already known
# for all present atoms with the given calculator.
calculator = training_set.calculator()

# Setup MACE parameters for training from scratch
mace_fp = MACEFittingParameters(
    # Name of the model - has to be set
    experiment_name='mace_experiment1',

    # Most important model size parameters (affects accuracy and speed)
    max_l_equivariance=0,
    number_of_channels=64,
    distance_cutoff=4.0*Ang,

    # Weights for the different parts of the loss term
    energy_weight=1.0,
    forces_weight=100.0,
    stress_weight=1.0,
    loss_function_type=MACEParameterOptions.LOSS_TYPE.UNIVERSAL,

    # Most relevant other parameters regarding training setup
    validation_fraction=0.2, # Ratio of training data used for validation
    max_number_of_epochs=200, # Number of epochs to train for
    patience=50, # Number of epochs without any improvement that will cause model training to finalize
    batch_size=4, # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory errors
    validation_batch_size=4, # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory errors
    random_seed=123, # Can be used for experimenting with the influence on different data splits
    default_dtype=MACEParameterOptions.DTYPE.FLOAT64, # Can be adjusted for shifting tradeoff between model accuracy and training speed
    device=Automatic, # Device can be set explicitly but Automatic is recommended to automatically use the GPU if available

    # Include stress if desired and present in the training data
    compute_stress=True,
)

# Setup ML model training object
mlfft = MachineLearnedForceFieldTrainer(
    fitting_parameters=mace_fp,
    training_sets=training_set,
    calculator=calculator
)

# Run the training
mlfft.train()


# Load the model for evaluation use in QuantumATK
model_path = 'mace_experiment1.qatkpt'

potentialSet = TremoloXPotentialSet(name='my_mace_potential1')
potential = TorchXPotential(
    dtype='float64',
    device='cuda',
    file=model_path,
)
for symbol in potential.get_symbols(model_path):
    potentialSet.addParticleType(ParticleType(symbol))
potentialSet.addPotential(potential)

calculator = TremoloXCalculator(parameters=potentialSet)

# Define configuration and attach the calculator
# configuration = ...
# configuration.setCalculator(calculator)
# ...
