# -*- coding: utf-8 -*-
setVerbosity(MinimalLog)

# -------------------------------------------------------------
# Bulk Configuration
# -------------------------------------------------------------

# Set up lattice
lattice = FaceCenteredCubic(5.4306*Angstrom)

# Define elements
elements = [Silicon, Silicon]

# Define coordinates
fractional_coordinates = [[ 0.  ,  0.  ,  0.  ],
                          [ 0.25,  0.25,  0.25]]

# Set up configuration
bulk_configuration = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )

# -------------------------------------------------------------
# Calculator
# -------------------------------------------------------------
k_point_sampling = KpointDensity(
    density_a=7.0*Angstrom,
    )
numerical_accuracy_parameters = NumericalAccuracyParameters(
    density_mesh_cutoff=30.0*Hartree,
    k_point_sampling=k_point_sampling,
    occupation_method=FermiDirac(25.0*meV),
    )

iteration_control_parameters = IterationControlParameters(
    tolerance=5e-05,
    )

calculator = LCAOCalculator(
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    iteration_control_parameters=iteration_control_parameters,
    )

# Generate a list of initial guesses for non-linear coefficients.
fitting_parameters_list = scanOverNonLinearCoefficients(
    number_of_initial_guesses=30,
    basis_size=PredefinedBasisSmall,
    mtp_filename_suffix='MTP_fit.mtp',
    random_seed=42,
    perform_optimization=False,

)

# Define RandomDisplacementsParameters used to generate the training sets.
training_sets = RandomDisplacementsParameters(
    reference_configurations=bulk_configuration,
    supercell_repetitions_list=[(1, 1, 1), (2, 2, 2)],
    sample_size=10,
    atomic_rattling_amplitudes=0.15*Angstrom,
    cell_rattling_amplitudes=0.07,
)

# Generate the displaced structures and calculate DFT training data.
mtp_training = MomentTensorPotentialTraining(
    filename='Silicon_crystal_mtp_training_testtest.hdf5',
    object_id='mtp_training',
    training_sets=training_sets,
    calculator=calculator,
    calculate_stress=True,
    fitting_parameters_list=fitting_parameters_list,
)
mtp_training.update()

# Determine the best fit and extract its parameters.
best_fit_index = mtp_training.rankFits(
    data_tags=None,
    weights=[[1, 1, 1], [1, 1, 1]],
    statistical_measure=R2Score
)[0][0]

best_fitting_parameters = mtp_training.fittingParametersList()[best_fit_index]
