# Load the training sets and calculator used in all model trainings

# TrainingSet with precomputed Energy, Forces, and Stress (if present and desired) data. This can be
# one or more TrainingSets from one or more files that all are added to a single list.
training_sets = nlread('c-am-TiSi-TrainingSets.hdf5', TrainingSet)

# Define the calculator used for calculating the training st(s). This can be done either by saving
# and loading the calculator or by fully defining the calculator in scripting. Here, we use a saved
# calculator definition for convenience. The calculator is required in order to calculate the
# the isolated atom energies for all elements present in the training set(s) as described by the
# particular calculator.
calculator = nlread('c-am-TiSi-TrainingSets.hdf5', LCAOCalculator)[0]


# -------------------------------------------------------------------------------------------------
# Model 1: Training from scratch with 8 channels and 5 Ang cutoff
# -------------------------------------------------------------------------------------------------

# Set up MACE parameters for training from scratch
mace_fp = MACEFittingParameters(
    # Name of the training experiment - should be set uniquely for each training.
    experiment_name='scratch_8_5Ang',

    # Most important parameters for the size, speed, and accuracy of the model:

    # Maximum L value - controls the degree equivariance in the model messages - 0 corresponds to
    # 'small', 1 to 'medium', and 2 to 'large' MACE models.
    max_l_equivariance=0,

    # Number of channels in the model. Higher number of channels (more model parameter weights)
    # means more accurate model but also slower training and inference.
    number_of_channels=8,

    # Distance cutoff for the model. Higher distance cutoff means more accurate model but also
    # slower training and inference.
    distance_cutoff=5.0*Ang,

    # Loss function and loss function weights to focus specific parts of the loss function

    # Weight for the energy term in the loss function. Higher weight means more focus on
    # the energy term.
    energy_weight=1.0,

    # Weight for the forces term in the loss function. Higher weight means more focus on the
    # forces term. By default the forces weight is higher than the other quantities.
    forces_weight=100.0,

    # Weight for the stress term in the loss function. Higher weight means more focus on the
    # stress term.
    stress_weight=5.0,

    # Loss function type to use. Options are contained within the MACEParameterOptions.LOSS_TYPE
    # class. Other member classes of MACEParameterOptions exist for settings with defined options.
    loss_function_type=MACEParameterOptions.LOSS_TYPE.UNIVERSAL,

    # Other relevant parameters regarding training setup

    # Ratio of training data used for validation data.
    validation_fraction=0.2,

    # Number of epochs to train for.
    max_number_of_epochs=1600,

    # Number of epochs without any improvement that will cause model training to finalize.
    patience=50,

    # Generally, higher means quicker training, but too high batch_sizes can run into GPU memory
    # limitations. Higher batch_size also means fewer gradient updates which might require higher
    # number of epochs.
    batch_size=4,

    # Generally, higher means quicker training, but too high batch_sizes can run into GPU memory
    # limitations.
    validation_batch_size=4,

    # Seed for the train-validation data split and initial model parameters values.
    random_seed=123,

    # Whether to compute stress or not. If set to False, the stress loss term is not included in
    # the loss function.
    compute_stress=True,

    # Whether to use distributed training or not. If set to True and available, the training is
    # done on multiple GPUs.
    distributed_training=True,

    # Whether to keep checkpoints or not during training. If set to True, the checkpoints are
    # kept in the checkpoints directory. This is useful for debugging and for resuming training
    # from a checkpoint but quickly uses up a lot of disk space.
    keep_checkpoints=False,
)


# Set up the ML FF trainer object handling the training flow
mlfft = MachineLearnedForceFieldTrainer(
    # MACEFittingParameters object with all the parameters for the training from scratch.
    fitting_parameters=mace_fp,

    # List of TrainingSet objects to use for the training.
    training_sets=training_sets,

    # Calculator object used to calculate the training data, which will be used to determine
    # isolated atom energies.
    calculator=calculator,

    # Random seed for the train-test split. - Note that the training
    # data resulting from this split will be split further into an effective train set and a
    # validation set.
    random_seed=12345678,

    # 1.0 yields only a training set. For lower values, the given ratio of training and test
    # data is produced. - Note that the training data resulting from ths split will be split
    # further into an effective train set and a validation set.
    train_test_split=1.0,
)

# Run the training
mlfft.train()