# Set up lattice
lattice = Hexagonal(4.916*Angstrom, 5.4054*Angstrom)

# Define elements
elements = [Silicon, Silicon, Silicon, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
           Oxygen]

# Define coordinates
fractional_coordinates = [[ 0.4697        ,  0.            ,  0.            ],
                         [ 0.            ,  0.4697        ,  0.666666666667],
                         [ 0.5303        ,  0.5303        ,  0.333333333333],
                         [ 0.4135        ,  0.2669        ,  0.1191        ],
                         [ 0.2669        ,  0.4135        ,  0.547567      ],
                         [ 0.7331        ,  0.1466        ,  0.785767      ],
                         [ 0.5865        ,  0.8534        ,  0.214233      ],
                         [ 0.8534        ,  0.5865        ,  0.452433      ],
                         [ 0.1466        ,  0.7331        ,  0.8809        ]]

# Set up configuration
reference_configuration = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
)

# Define calculator for E/F/S data calculations.
calculator = LCAOCalculator()

# In this specific example, the default displacement protocol for crystals is used
training_sets = crystalTrainingRandomDisplacements(
    reference_configuration,
    supercell_repetitions_list=[(2, 2, 2), (3, 3, 3)],
    sample_size_per_stage=10,
)

# Set up training data generation.
force_field_training_set_generator = ForceFieldTrainingSetGenerator(
    filename='study_training_data.hdf5',
    object_id='fftsg',
    training_sets=training_sets,
    calculator=calculator,
    calculate_stress=True,
)
force_field_training_set_generator.update()

# Retrieve TrainingSet labeled by DFT calculations - here using the
# ForceFieldTrainingSetGenerator API directly.
generated_trainingset = force_field_training_set_generator.generatedTrainingSet()

# This TrainingSet can then be used for training machine-learned force fields.
# For example, an MTP can be trained with this TrainingSet by:

# Set up non-linear coefficients with optimization.
non_linear_coefficients_parameters = NonLinearCoefficientsParameters(
    perform_optimization=True,
)

# Set up parameters to use in the MTP fitting.
fitting_parameters = MomentTensorPotentialFittingParameters(
    basis_size=PredefinedBasisSmall,
    outer_cutoff_radii=3.0*Angstrom,
    mtp_filename='mtp_study.mtp',
    non_linear_coefficients_parameters=non_linear_coefficients_parameters,
)

# Set up and run the MTP training.
machine_learned_force_field_trainer = MachineLearnedForceFieldTrainer(
    fitting_parameters=fitting_parameters,
    training_sets=generated_trainingset,
    calculator=calculator,
    train_test_split=0.8,
)
machine_learned_force_field_trainer.train()
