# Load the training sets and calculator used in all model trainings

# TrainingSet with precomputed Energy, Forces, and Stress (if present and desired) data. This can be
# one or more TrainingSets from one or more files that all are added to a single list.
training_sets = nlread('c-am-TiSi-TrainingSets.hdf5', TrainingSet)

# Define the calculator used for calculating the training st(s). This can be done either by saving
# and loading the calculator or by fully defining the calculator in scripting. Here, we use a saved
# calculator definition for convenience. The calculator is required in order to calculate the
# the isolated atom energies for all elements present in the training set(s) as described by the
# particular calculator.
calculator = nlread('c-am-TiSi-TrainingSets.hdf5', LCAOCalculator)[0]


# NOTE: This file can be run with or without training the last 2 models where finetuning is
# performed. In order to run with finetuning, download the model and data files referenced in the
# tutorial and place them either in the local folder (and ensure to upload them with the job tool -
# slow because of file size) or on the cluster.


# If running on a remote cluster, a path to a shared folder for large MACE files can be defined.
# Otherwise, the files can be stored locally but the transfer times from the QuantumATK Job Tool
# to the local machine can be long.
cluster_path = '/absolute/path/to/MACE/files/on/cluster/'
# cluster_path = ''  # Set to empty string if submitting the script directly on the cluster with files in the same directory.

# -------------------------------------------------------------------------------------------------
# Model 6: Training with Naive Finetuning
# -------------------------------------------------------------------------------------------------
# Set up MACE parameters for training with Naive Finetuning
mace_fp = MACEFittingParameters(
    # Name of the training experiment - should be set uniquely for each training.
    experiment_name='naive_finetuning',

    # Loss function and loss function weights to focus specific parts of the loss function

    # Weight for the energy term in the loss function. Higher weight means more focus on the
    # energy term.
    energy_weight=1.0,

    # Weight for the forces term in the loss function. Higher weight means more focus on the
    # forces term. By default the forces weight is higher than the other quantities.
    forces_weight=100.0,

    # Weight for the stress term in the loss function. Higher weight means more focus on the
    # stress term.
    stress_weight=1.0,

    # Loss function type to use. Options are contained within the MACEParameterOptions.LOSS_TYPE
    # class. Other member classes of MACEParameterOptions exist for settings with defined options.
    loss_function_type=MACEParameterOptions.LOSS_TYPE.UNIVERSAL,

    # Other relevant parameters regarding training setup

    # Ratio of training data used for validation data.
    validation_fraction=0.2,

    # Number of epochs to train for - keep low for finetuning.
    max_number_of_epochs=100,

    # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory
    # errors.
    batch_size=4,

    # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory
    # errors.
    validation_batch_size=4,

    # Seed for the train-validation data split and initial model parameters values.
    random_seed=123,

    # Whether to compute stress or not. If set to False, the stress loss term is not included in
    # the loss function.
    compute_stress=True,

    # Whether to use distributed training or not. If set to True and available, the training is
    # done on multiple GPUs.
    distributed_training=True,

    # Whether to keep checkpoints or not during training. If set to True, the checkpoints are
    # kept in the checkpoints directory. This is useful for debugging and for resuming training
    # from a checkpoint but quickly uses up a lot of disk space.
    keep_checkpoints=False,

    # Parameters related to Naive Finetuning

    # Path to a foundation MACE model to finetune - here a recent universal MP MACE model is
    # utilized. For efficient data transfers, the model can be stored in a separate directory on
    # the cluster (potentially accessible for all users).
    foundation_model_path=cluster_path+'mace-mp-0b3-medium.model',
)

# Set up ML model training object
mlfft = MachineLearnedForceFieldTrainer(
    # MACEFittingParameters object with all the parameters for the Naive Finetuning.
    fitting_parameters=mace_fp,

    # List of TrainingSet objects to use for the training.
    training_sets=training_sets,

    # Calculator object used to calculate the training data, which will be used to determine
    # isolated atom energies.
    calculator=calculator,

    # Random seed for the train-test split. - Note that the training data
    # resulting from this split will be split further into an effective train set and a validation
    # set.
    random_seed=12345678,

    # 1.0 yields only a training set. For lower values, the given ratio of training and test data
    # is produced. - Note that the training data resulting from ths split will be split further
    # into an effective train set and a validation set.
    train_test_split=1.0,
)

# Run the training
mlfft.train()