#-------------------------------------------------------------
# Calculator
# -------------------------------------------------------------
k_point_sampling = KpointDensity(
    density_a=4.0*Angstrom,
    )
numerical_accuracy_parameters = NumericalAccuracyParameters(
    density_mesh_cutoff=125.0*Hartree,
    k_point_sampling=k_point_sampling,
    )

calculator = LCAOCalculator(
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    )

#-----------------------------------
# compile training data
#-----------------------------------
active_learning_data=nlread('active_learning_candidates.hdf5',Trajectory)[-1]
active_learning_dataset=TrainingSet(active_learning_data,recalculate_training_data=False) 
displaced_crystal_data=nlread('reference-data.hdf5',MomentTensorPotentialTraining)[-1]
displaced_crystal_dataset=TrainingSet(displaced_crystal_data, recalculate_training_data=False)

#---------------------------------------------
# Set up non-linear coefficients with optimization.
#https://docs.quantumatk.com/manual/Types/NonLinearCoefficientsParameters/NonLinearCoefficientsParameters.html
#---------------------------------------------
fitting_parameters=[]
import numpy
rng=numpy.random.RandomState(12345)

for i in range(0,200):
    rand=rng.randint(0,1000000)
    non_linear_coefficients_parameters = NonLinearCoefficientsParameters(
        #initial_coefficients=Random, #For quick run uncomment 35, 36 and comment 37
        #perform_optimization=False,
        perform_optimization=True,
        random_seed=rand,
        max_steps=1,
        energy_only=True,
    )

    #---------------------------------------------
    #use default MTP fitting parameters
    #https://docs.quantumatk.com/manual/Types/MomentTensorPotentialFittingParameters/MomentTensorPotentialFittingParameters.html
    #---------------------------------------------

    fitting_parameter = MomentTensorPotentialFittingParameters(
        basis_size=300,
        inner_cutoff_radii=1.0 * Angstrom,
        tapering_cutoff_radii=1.1 * Angstrom,
        outer_cutoff_radii= 5.0 * Angstrom,
        mtp_filename='mtp_{}.mtp'.format(i+1),
        forces_cap= 100.0 * eV / Angstrom,
        non_linear_coefficients_parameters=non_linear_coefficients_parameters,
    )
    fitting_parameters.append(fitting_parameter)

#---------------------------------------------
#set up MTP training
#for multiple training_sets, use [training_set_1, training_set_2]
#---------------------------------------------
moment_tensor_potential_training = MomentTensorPotentialTraining(
    filename="MTP_training.hdf5",
    object_id='training',
    training_sets=[active_learning_dataset,displaced_crystal_dataset],
    train_test_split=0.8,
    calculator=calculator,
    fitting_parameters_list=fitting_parameters,
    calculate_stress=True,
    random_seed=5000,
    log_filename_prefix='mtp_'
)

#---------------------------------------------
#fit MTP
#---------------------------------------------
moment_tensor_potential_training.update()
nlprint(moment_tensor_potential_training)
