# Load the training sets and calculator used in all model trainings

# TrainingSet with precomputed Energy, Forces, and Stress (if present and desired) data. This can be
# one or more TrainingSets from one or more files that all are added to a single list.
training_sets = nlread('c-am-TiSi-TrainingSets.hdf5', TrainingSet)

# Define the calculator used for calculating the training st(s). This can be done either by saving
# and loading the calculator or by fully defining the calculator in scripting. Here, we use a saved
# calculator definition for convenience. The calculator is required in order to calculate the
# the isolated atom energies for all elements present in the training set(s) as described by the
# particular calculator.
calculator = nlread('c-am-TiSi-TrainingSets.hdf5', LCAOCalculator)[0]


# NOTE: This file can be run with or without training the last 2 models where finetuning is
# performed. In order to run with finetuning, download the model and data files referenced in the
# tutorial and place them either in the local folder (and ensure to upload them with the job tool -
# slow because of file size) or on the cluster.
# TODO: To train finetuning models, ensure the above step has been done and set this variable to
# True.
do_finetuning = True


# If running on a remote cluster, a path to a shared folder for large MACE files can be defined.
# Otherwise, the files can be stored locally but the transfer times from the QuantumATK Job Tool
# to the local machine can be long.
cluster_path = '/absolute/path/to/MACE/files/on/cluster/'
# cluster_path = ''  # Set to empty string if submitting the sript directly on the cluster with files in the same directory.



# -------------------------------------------------------------------------------------------------
# Model 1: Training from scratch with 8 channels and 5 Ang cutoff
# -------------------------------------------------------------------------------------------------

# Set up MACE parameters for training from scratch
mace_fp = MACEFittingParameters(
    # Name of the training experiment - should be set uniquely for each training.
    experiment_name='scratch_8_5Ang',

    # Most important parameters for the size, speed, and accuracy of the model:

    # Maximum L value - controls the degree equivariance in the model messages - 0 corresponds to
    # 'small', 1 to 'medium', and 2 to 'large' MACE models.
    max_l_equivariance=0,

    # Number of channels in the model. Higher number of channels (more model parameter weights)
    # means more accurate model but also slower training and inference.
    number_of_channels=8,

    # Distance cutoff for the model. Higher distance cutoff means more accurate model but also
    # slower training and inference.
    distance_cutoff=5.0*Ang,

    # Loss function and loss function weights to focus specific parts of the loss function

    # Weight for the energy term in the loss function. Higher weight means more focus on
    # the energy term.
    energy_weight=1.0,

    # Weight for the forces term in the loss function. Higher weight means more focus on the
    # forces term. By default the forces weight is higher than the other quantities.
    forces_weight=100.0,

    # Weight for the stress term in the loss function. Higher weight means more focus on the
    # stress term.
    stress_weight=5.0,

    # Loss function type to use. Options are contained within the MACEParameterOptions.LOSS_TYPE
    # class. Other member classes of MACEParameterOptions exist for settings with defined options.
    loss_function_type=MACEParameterOptions.LOSS_TYPE.UNIVERSAL,

    # Other relevant parameters regarding training setup

    # Ratio of training data used for validation data.
    validation_fraction=0.2,

    # Number of epochs to train for.
    max_number_of_epochs=1600,

    # Number of epochs without any improvement that will cause model training to finalize.
    patience=50,

    # Generally, higher means quicker training, but too high batch_sizes can run into GPU memory
    # limitations. Higher batch_size also means fewer gradient updates which might require higher
    # number of epochs.
    batch_size=4,

    # Generally, higher means quicker training, but too high batch_sizes can run into GPU memory
    # limitations.
    validation_batch_size=4,

    # Seed for the train-validation data split and initial model parameters values.
    random_seed=123,

    # Whether to compute stress or not. If set to False, the stress loss term is not included in
    # the loss function.
    compute_stress=True,

    # Whether to use distributed training or not. If set to True and available, the training is
    # done on multiple GPUs.
    distributed_training=True,

    # Whether to keep checkpoints or not during training. If set to True, the checkpoints are
    # kept in the checkpoints directory. This is useful for debugging and for resuming training
    # from a checkpoint but quickly uses up a lot of disk space.
    keep_checkpoints=False,
)


# Set up the ML FF trainer object handling the training flow
mlfft = MachineLearnedForceFieldTrainer(
    # MACEFittingParameters object with all the parameters for the training from scratch.
    fitting_parameters=mace_fp,

    # List of TrainingSet objects to use for the training.
    training_sets=training_sets,

    # Calculator object used to calculate the training data, which will be used to determine
    # isolated atom energies.
    calculator=calculator,

    # Random seed for the train-test split. - Note that the training
    # data resulting from this split will be split further into an effective train set and a
    # validation set.
    random_seed=12345678,

    # 1.0 yields only a training set. For lower values, the given ratio of training and test
    # data is produced. - Note that the training data resulting from ths split will be split
    # further into an effective train set and a validation set.
    train_test_split=1.0,
)

# Run the training
mlfft.train()


# -------------------------------------------------------------------------------------------------
# Model 2: Training from scratch with 16 channels and 5 Ang cutoff
# -------------------------------------------------------------------------------------------------
# Set up MACE parameters for training from scratch
# The parameters object is callable and can initiate a new object with the same parameters
# as the previous one apart from the parameters set for the new object.
mace_fp = mace_fp(experiment_name='scratch_16_5Ang', number_of_channels=16)

# Set up ML FF Trainer object
mlfft = MachineLearnedForceFieldTrainer(
    # MACEFittingParameters object with all the parameters for the training from scratch.
    fitting_parameters=mace_fp,

    # List of TrainingSet objects to use for the training.
    training_sets=training_sets,

    # Calculator object used to calculate the training data, which will be used to determine
    # isolated atom energies.
    calculator=calculator,

    # Random seed for the train-test split. - Note that the training
    # data resulting from this split will be split further into an effective train set and a
    # validation set.
    random_seed=12345678,

    # 1.0 yields only a training set. For lower values, the given ratio of training and test
    # data is produced. - Note that the training data resulting from ths split will be split
    # further into an effective train set and a validation set.
    train_test_split=1.0,
)

# Run the training
mlfft.train()


# -------------------------------------------------------------------------------------------------
# Model 3: Training from scratch with 32 channels and 5 Ang cutoff
# -------------------------------------------------------------------------------------------------
# Set up MACE parameters for training from scratch
# The parameters object is callable and can initiate a new object with the same parameters
# as the previous one apart from the parameters set for the new object.
mace_fp = mace_fp(experiment_name='scratch_32_5Ang', number_of_channels=32)

# Set up ML FF Trainer object
mlfft = MachineLearnedForceFieldTrainer(
    # MACEFittingParameters object with all the parameters for the training from scratch.
    fitting_parameters=mace_fp,

    # List of TrainingSet objects to use for the training.
    training_sets=training_sets,

    # Calculator object used to calculate the training data, which will be used to determine
    # isolated atom energies.
    calculator=calculator,

    # Random seed for the train-test split. - Note that the training
    # data resulting from this split will be split further into an effective train set and a
    # validation set.
    random_seed=12345678,

    # 1.0 yields only a training set. For lower values, the given ratio of training and test
    # data is produced. - Note that the training data resulting from ths split will be split
    # further into an effective train set and a validation set.
    train_test_split=1.0,
)

# Run the training
mlfft.train()


# -------------------------------------------------------------------------------------------------
# Model 4: Training from scratch with 64 channels and 5 Ang cutoff
# -------------------------------------------------------------------------------------------------
# Set up MACE parameters for training from scratch
mace_fp = mace_fp(experiment_name='scratch_64_5Ang', number_of_channels=64)

# Set up ML FF Trainer object
mlfft = MachineLearnedForceFieldTrainer(
    # MACEFittingParameters object with all the parameters for the training from scratch.
    fitting_parameters=mace_fp,

    # List of TrainingSet objects to use for the training.
    training_sets=training_sets,

    # Calculator object used to calculate the training data, which will be used to determine
    # isolated atom energies.
    calculator=calculator,

    # Random seed for the train-test split. - Note that the training
    # data resulting from this split will be split further into an effective train set and a
    # validation set.
    random_seed=12345678,

    # 1.0 yields only a training set. For lower values, the given ratio of training and test
    # data is produced. - Note that the training data resulting from ths split will be split
    # further into an effective train set and a validation set.
    train_test_split=1.0,
)

# Run the training
mlfft.train()


# -------------------------------------------------------------------------------------------------
# Model 5: Training from scratch with 128 channels and 5 Ang cutoff
# -------------------------------------------------------------------------------------------------
# Set up MACE parameters for training from scratch
mace_fp = mace_fp(experiment_name='scratch_128_5Ang', number_of_channels=128)

# Set up ML FF Trainer object
mlfft = MachineLearnedForceFieldTrainer(
    # MACEFittingParameters object with all the parameters for the training from scratch.
    fitting_parameters=mace_fp,

    # List of TrainingSet objects to use for the training.
    training_sets=training_sets,

    # Calculator object used to calculate the training data, which will be used to determine
    # isolated atom energies.
    calculator=calculator,

    # Random seed for the train-test split. - Note that the training
    # data resulting from this split will be split further into an effective train set and a
    # validation set.
    random_seed=12345678,

    # 1.0 yields only a training set. For lower values, the given ratio of training and test
    # data is produced. - Note that the training data resulting from ths split will be split
    # further into an effective train set and a validation set.
    train_test_split=1.0,
)

# Run the training
mlfft.train()



# -------------------------------------------------------------------------------------------------
# Model 6: Training with Naive Finetuning
# -------------------------------------------------------------------------------------------------
if do_finetuning:
    # Set up MACE parameters for training with Naive Finetuning
    mace_fp = MACEFittingParameters(
        # Name of the training experiment - should be set uniquely for each training.
        experiment_name='naive_finetuning',

        # Loss function and loss function weights to focus specific parts of the loss function

        # Weight for the energy term in the loss function. Higher weight means more focus on the
        # energy term.
        energy_weight=1.0,

        # Weight for the forces term in the loss function. Higher weight means more focus on the
        # forces term. By default the forces weight is higher than the other quantities.
        forces_weight=100.0,

        # Weight for the stress term in the loss function. Higher weight means more focus on the
        # stress term.
        stress_weight=1.0,

        # Loss function type to use. Options are contained within the MACEParameterOptions.LOSS_TYPE
        # class. Other member classes of MACEParameterOptions exist for settings with defined options.
        loss_function_type=MACEParameterOptions.LOSS_TYPE.UNIVERSAL,

        # Other relevant parameters regarding training setup

        # Ratio of training data used for validation data.
        validation_fraction=0.2,

        # Number of epochs to train for - keep low for finetuning.
        max_number_of_epochs=100,

        # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory
        # errors.
        batch_size=4,

        # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory
        # errors.
        validation_batch_size=4,

        # Seed for the train-validation data split and initial model parameters values.
        random_seed=123,

        # Whether to compute stress or not. If set to False, the stress loss term is not included in
        # the loss function.
        compute_stress=True,

        # Whether to use distributed training or not. If set to True and available, the training is
        # done on multiple GPUs.
        distributed_training=True,

        # Whether to keep checkpoints or not during training. If set to True, the checkpoints are
        # kept in the checkpoints directory. This is useful for debugging and for resuming training
        # from a checkpoint but quickly uses up a lot of disk space.
        keep_checkpoints=False,

        # Parameters related to Naive Finetuning

        # Path to a foundation MACE model to finetune - here a recent universal MP MACE model is
        # utilized. For efficient data transfers, the model can be stored in a separate directory on
        # the cluster (potentially accessible for all users).
        foundation_model_path=cluster_path+'mace-mp-0b3-medium.model',
    )

    # Set up ML FF Trainer object
    mlfft = MachineLearnedForceFieldTrainer(
        # MACEFittingParameters object with all the parameters for the Naive Finetuning.
        fitting_parameters=mace_fp,

        # List of TrainingSet objects to use for the training.
        training_sets=training_sets,

        # Calculator object used to calculate the training data, which will be used to determine
        # isolated atom energies.
        calculator=calculator,

        # Random seed for the train-test split. - Note that the training data
        # resulting from this split will be split further into an effective train set and a validation
        # set.
        random_seed=12345678,

        # 1.0 yields only a training set. For lower values, the given ratio of training and test data
        # is produced. - Note that the training data resulting from ths split will be split further
        # into an effective train set and a validation set.
        train_test_split=1.0,
    )

    # Run the training
    mlfft.train()


# -------------------------------------------------------------------------------------------------
# Model 7: Training with Multihead Finetuning
# -------------------------------------------------------------------------------------------------
if do_finetuning:
    mace_fp = MACEFittingParameters(
        # Name of the training experiment - should be set uniquely for each training.
        experiment_name='multihead_finetuning',

        # Loss function and loss function weights to focus specific parts of the loss function

        # Weight for the energy term in the loss function. Higher weight means more focus on the
        # energy term.
        energy_weight=1.0,                                              # Weight for the energy term in the loss function. Higher weight means more focus on the energy term.

        # Weight for the forces term in the loss function. Higher weight means more focus on the
        # forces term. By default the forces weight is higher than the other quantities.
        forces_weight=100.0,

        # Weight for the stress term in the loss function. Higher weight means more focus on the
        # stress term.
        stress_weight=1.0,

        # Loss function type to use. Options are contained within the MACEParameterOptions.LOSS_TYPE
        # class. Other member classes of MACEParameterOptions exist for settings with defined options.
        loss_function_type=MACEParameterOptions.LOSS_TYPE.UNIVERSAL,

        # Other relevant parameters regarding training setup

        # Ratio of training data used for validation data.
        validation_fraction=0.2,

        # Number of epochs to train for - keep low for finetuning.
        max_number_of_epochs=100,

        # Number of epochs without any improvement that will cause model training to finalize.
        patience=50,

        # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory
        # errors.
        batch_size=4,

        # Generally, higher means quicker training, but too high batch_sizes can cause GPU memory
        # errors.
        validation_batch_size=4,

        # Seed for the train-validation data split and initial model parameters values.
        random_seed=123,

        # Whether to compute stress or not. If set to False, the stress loss term is not included in
        # the loss function.
        compute_stress=True,

        # Whether to use distributed training or not. If set to True and available, the training is
        # done on multiple GPUs.
        distributed_training=True,

        # Whether to keep checkpoints or not during training. If set to True, the checkpoints are
        # kept in the checkpoints directory. This is useful for debugging and for resuming training
        # from a checkpoint but quickly uses up a lot of disk space.
        keep_checkpoints=False,

        # Parameters related to Multihead Finetuning

        # Path to a foundation MACE model to finetune - here a recent universal MP MACE model is
        # utilized. For efficient data transfers, the model can be stored in a separate directory
        # on the cluster (potentially accessible for all users).
        foundation_model_path=cluster_path+'mace-mp-0b3-medium.model',

        # Whether to use multihead finetuning or not.
        use_multiheads_finetuning=True,

        # Use 'mp' if finetuning a universal MP model. Otherwise give the path to the xyz file
        # containing the original training data.
        pretrained_head_train_file='mp',

        # If using the 'mp' keyword above, provide the path to the original training data for the
        # universal MP MACE model. For efficient data transfers, the data can be stored in a
        # separate directory on the cluster (potentially accessible for all users).
        mp_data_path=cluster_path+'mp_traj_combined.xyz',

        # If using the 'mp' keyword above, provide the path to the original descriptors file for
        # the universal MP MACE model. For efficient data transfers, the model can be stored in a
        # separate directory on the cluster (potentially accessible for all users).
        mp_descriptors_path=cluster_path+'descriptors.npy',
    )

    # Set up ML FF Trainer object
    mlfft = MachineLearnedForceFieldTrainer(
        # MACEFittingParameters object with all the parameters for the Multihead Finetuning.
        fitting_parameters=mace_fp,

        # List of TrainingSet objects to use for the training.
        training_sets=training_sets,

        # Calculator object used to calculate the training data, which will be used to determine
        # isolated atom energies.
        calculator=calculator,

        # Random seed for the train-test split. - Note that the training data
        # resulting from this split will be split further into an effective train set and a validation
        # set.
        random_seed=12345678,

        # 1.0 yields only a training set. For lower values, the given ratio of training and test data
        # is produced. - Note that the training data resulting from ths split will be split further
        # into an effective train set and a validation set.
        train_test_split=1.0,
    )

    # Run the training
    mlfft.train()