training_sets = []

# Load another ForceFieldTrainingSetGenerator (or old MomentTensorPotentialTraining) object with precalculated data.
# For the configurations in this dataset, we keep the energy, forces, and stress data, if available.
mtp_training_data_input = nlread('mtp_training_data.hdf5', ForceFieldTrainingSetGenerator)[0]
training_sets.append(
    TrainingSet(mtp_training_data_input, recalculate_training_data=False)
)

# Load another Trajectory object with data another calculator.
# For the configurations in this dataset, we recalculate the training data with the new calculator.
trajectory_training_data_input = nlread('trajectory_training_data.hdf5', Trajectory)[0]
training_sets.append(
    TrainingSet(trajectory_training_data_input, recalculate_training_data=True)
)

# Set up training data generation.
force_field_training_set_generator = ForceFieldTrainingSetGenerator(
    filename='study_training_data.hdf5',
    object_id='fftsg',
    training_sets=training_sets,
    calculator=LCAOCalculator(),
    calculate_stress=True,
)
force_field_training_set_generator.update()

# Retrieve TrainingSet now labeled by compatible LCAO calculations.
generated_trainingset = force_field_training_set_generator.generatedTrainingSet()

# This TrainingSet can then be used for training machine-learned force fields.
# For example, an MTP can be trained with this TrainingSet by:

# Set up non-linear coefficients with optimization.
non_linear_coefficients_parameters = NonLinearCoefficientsParameters(
    perform_optimization=True,
)

# Set up parameters to use in the MTP fitting.
fitting_parameters = MomentTensorPotentialFittingParameters(
    basis_size=PredefinedBasisSmall,
    outer_cutoff_radii=3.0*Angstrom,
    mtp_filename='mtp_study.mtp',
    non_linear_coefficients_parameters=non_linear_coefficients_parameters,
)

# Set up and run the MTP training.
machine_learned_force_field_trainer = MachineLearnedForceFieldTrainer(
    fitting_parameters=fitting_parameters,
    training_sets=generated_trainingset,
    calculator=calculator,
    train_test_split=0.8,
)
machine_learned_force_field_trainer.train()
