MachineLearnedForceFieldTrainer¶
- class MachineLearnedForceFieldTrainer(fitting_parameters, training_sets=None, calculator=None, train_test_split=None, random_seed=None)¶
Class for training a machine learned force field.
- Parameters:
fitting_parameters (
BaseMLFFFittingParameters
) – The parameters for the training.training_sets (
TrainingSet
|Table
| sequence of [TrainingSet
] | None) – The list of training sets to use for training. Default:None
calculator (Calculator |
None
) – The calculator to use for calculating the constant terms if applicable for the model. If None, the calculator of the training set is used. Default:None
train_test_split (float) – The fraction of the training set to use for training. The rest is used for testing. Must be a float between 0 and 1. If set to 1, the entire training set is used for training. Default:
0.9
random_seed (int) – The random seed used for splitting the data into training and testing data. Default: Generated automatically.
- train()¶
Train the machine learned force field.
Usage Examples¶
In order to train a MACE model using the MachineLearnedForceFieldTrainer
, the general approach
is to set up a fitting parameters object for the model type to train, to load in the training data,
and to load/configure an appropriate calculator for calculating isolated atom energies if
required. Additionally, non-default train_test_split
and random_seed
values for controlling
the way the input data is split can be created. These objects and values are passed to the
MachineLearnedForceFieldTrainer
class which will start the training process by simply calling the
train()
method.
The example below shows how a MACE model is trained using the
MachineLearnedForceFieldTrainer
class.
# TrainingSet with precomputed Energy Forces (and Stress if present and desired) data. This can be
# one or more TrainingSets
training_set = nlread('My_TrainingSet_Data.hdf5', TrainingSet)
# Fetch the calculator from the training set. If not present, it should be retrieved from
# elsewhere/set up analogously to calculate isolated atom energies
calculator = training_set.calculator()
# Setup model specific fitting parameters object for the training
mace_fp = MACEFittingParameters(
# Name of the model - has to be set
experiment_name='mace_experiment',
# Other parameters can be set as desired
# number_of_channels=32,
# ...
)
# Setup ML model training object
mlfft = MachineLearnedForceFieldTrainer(
fitting_parameters=mace_fp,
training_sets=training_set,
calculator=calculator,
# Optional parameters can be set as desired
train_test_split=0.8,
random_seed=1234,
)
# Run the training
mlfft.train()
This example script is available for download:
mlfft_example.py
Notes¶
The
MachineLearnedForceFieldTrainer
class is used to train machine-learned force fields (MLFFs) in QuantumATK. The intention of it is to offer a model agnostic trainer object that can train various kinds of models, where the type of thefitting_parameter
object will determine how the training is conducted. It can train MLFF models that have FittingParameters objects that fulfil certain setup requirements. This does not yet include the MomentTensorPotentialFittingParameters.The
MachineLearnedForceFieldTrainer
class is designed to be used with only the TrainingSet class for training data input. That means that training data has to be converted into that format either by using the conversion/”export as” utility in the Data View in the GUI or by directly converting/wrapping other data storage types into TrainingSet objects.The
MachineLearnedForceFieldTrainer
class is similarly designed to only be used for the actual training step of a MLFF model. Gathering training data is not part of the class and has to be done separately before starting the training. Since the training process can take different training various training paths when multiple types of models are going to be supported, the trainer object also does not support a general model evaluation step. The evaluation of the trained model therefore has to be done separately after training is concluded and a MLFF model file has been saved.