#-------------------------------------------------------------
# Calculator
# -------------------------------------------------------------
k_point_sampling = KpointDensity(
    density_a=4.0*Angstrom,
    )
numerical_accuracy_parameters = NumericalAccuracyParameters(
    density_mesh_cutoff=125.0*Hartree,
    k_point_sampling=k_point_sampling,
    )

calculator = LCAOCalculator(
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    )

#-----------------------------------
# compile training data
#-----------------------------------
active_learning_data=nlread('active_learning_candidates.hdf5',Trajectory)[-1]
active_learning_dataset=TrainingSet(active_learning_data,recalculate_training_data=False) 
displaced_crystal_data=nlread('reference-data.hdf5',MomentTensorPotentialTraining)[-1]
displaced_crystal_dataset=TrainingSet(displaced_crystal_data, recalculate_training_data=False)

#---------------------------------------------
# Set up non-linear coefficients with optimization.
#https://docs.quantumatk.com/manual/Types/NonLinearCoefficientsParameters/NonLinearCoefficientsParameters.html
#---------------------------------------------
non_linear_coefficients_parameters = NonLinearCoefficientsParameters(
        perform_optimization=True,
        random_seed=865512,
        energy_only=True,
        regularization=1.0e-2,        
)

#---------------------------------------------
#use default MTP fitting parameters
#https://docs.quantumatk.com/manual/Types/MomentTensorPotentialFittingParameters/MomentTensorPotentialFittingParameters.html
#---------------------------------------------

fitting_parameters = MomentTensorPotentialFittingParameters(
    basis_size=300,
    inner_cutoff_radii=1.0 * Angstrom,
    tapering_cutoff_radii=1.1 * Angstrom,
    outer_cutoff_radii= 5.0 * Angstrom,
    mtp_filename='mtp_potential.mtp',
    forces_cap= 100.0 * eV / Angstrom,
    non_linear_coefficients_parameters=non_linear_coefficients_parameters,
)

#---------------------------------------------
#set up MTP training
#for multiple training_sets, use [training_set_1, training_set_2]
#---------------------------------------------
moment_tensor_potential_training = MomentTensorPotentialTraining(
    filename="MTP_training.hdf5",
    object_id='training',
    training_sets=[active_learning_dataset,displaced_crystal_dataset],
    train_test_split=0.8,
    calculator=calculator,
    fitting_parameters_list=fitting_parameters,
    calculate_stress=True,
    random_seed=5000,
    log_filename_prefix='mtp_'
)

#---------------------------------------------
#fit MTP
#---------------------------------------------
moment_tensor_potential_training.update()
nlprint(moment_tensor_potential_training)
