# -*- coding: utf-8 -*-
setVerbosity(MinimalLog)

# -------------------------------------------------------------
# Bulk Configuration
# -------------------------------------------------------------
# Set up lattice
vector_a = [9.497426971725758, 0.0, 0.0]*Angstrom
vector_b = [0.0, 9.497426971725758, 0.0]*Angstrom
vector_c = [0.0, 0.0, 9.497426971725758]*Angstrom
lattice = UnitCell(vector_a, vector_b, vector_c)

# Define elements
elements = [Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen, Oxygen,
            Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium,
            Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium,
            Hafnium, Hafnium, Hafnium, Hafnium, Hafnium, Hafnium]

# Define coordinates
fractional_coordinates = [[ 0.076286602313,  0.484331164618,  0.381316169166],
                          [ 0.766917781374,  0.494937841286,  0.291874574266],
                          [ 0.731969980361,  0.943836312627,  0.57296061439 ],
                          [ 0.554574027258,  0.461332853561,  0.886717927207],
                          [ 0.250228263221,  0.890222307852,  0.388755307743],
                          [ 0.929803223539,  0.676225151653,  0.795242624538],
                          [ 0.39718592907 ,  0.158335635766,  0.388556756402],
                          [ 0.11645891105 ,  0.71352269763 ,  0.628339753261],
                          [ 0.464081550737,  0.146998191259,  0.949948088773],
                          [ 0.593809676382,  0.723661829469,  0.282815566712],
                          [ 0.956144927607,  0.410082162955,  0.63118390115 ],
                          [ 0.630261587996,  0.068949857119,  0.312897385008],
                          [ 0.676614862447,  0.916136708317,  0.080737525515],
                          [ 0.283280934016,  0.082466086602,  0.147414791521],
                          [ 0.156534915042,  0.264983388403,  0.44504783418 ],
                          [ 0.323709076563,  0.944373529376,  0.767758731255],
                          [ 0.086974924396,  0.245291615829,  0.13994601509 ],
                          [ 0.382753861034,  0.801331012658,  0.120736888779],
                          [ 0.325331129964,  0.599980937095,  0.400551045963],
                          [ 0.710664562389,  0.543883143178,  0.640949170099],
                          [ 0.768632684085,  0.291833856211,  0.982008566834],
                          [ 0.935781100858,  0.449313456644,  0.91777328168 ],
                          [ 0.470152665574,  0.892151512961,  0.514373913142],
                          [ 0.896162402168,  0.246679452251,  0.280748840967],
                          [ 0.056111635294,  0.968087964348,  0.543755879799],
                          [ 0.394873138661,  0.413695452858,  0.165606209471],
                          [ 0.953268976235,  0.076126206638,  0.75940728911 ],
                          [ 0.629651636858,  0.634630228571,  0.05858775273 ],
                          [ 0.589855514666,  0.756388866923,  0.772396168952],
                          [ 0.012660970567,  0.697385940088,  0.191689851382],
                          [ 0.869214445599,  0.734481879712,  0.431281932903],
                          [ 0.297163631424,  0.496592752664,  0.698151011982],
                          [ 0.464653722521,  0.600528668785,  0.541204813482],
                          [ 0.734874860709,  0.260514207603,  0.638705119014],
                          [ 0.912475699997,  0.936927768911,  0.225749701655],
                          [ 0.41838929944 ,  0.18217514434 ,  0.693574297042],
                          [ 0.580921531053,  0.352165986484,  0.471443636145],
                          [ 0.190302700851,  0.486222762835,  0.020574472234],
                          [ 0.161258411077,  0.268084679045,  0.84395447297 ],
                          [ 0.328992105963,  0.706695342299,  0.870286697912],
                          [ 0.1214543959  ,  0.707765332143,  0.395966412374],
                          [ 0.324898608394,  0.402583945263,  0.38949737155 ],
                          [ 0.466227728418,  0.908619926909,  0.935331908209],
                          [ 0.302880914006,  0.24554211747 ,  0.01438921806 ],
                          [ 0.580163847321,  0.116820565433,  0.537604874893],
                          [ 0.886135894107,  0.268791895681,  0.783965762093],
                          [ 0.344730617223,  0.766690970039,  0.665638788583],
                          [ 0.645121097015,  0.750009044287,  0.53235612777 ],
                          [ 0.988059904634,  0.436680835838,  0.160021696006],
                          [ 0.109670678906,  0.533225387478,  0.786010316405],
                          [ 0.749014847635,  0.116815557743,  0.153472283775],
                          [ 0.53286379991 ,  0.405031432457,  0.674071349234],
                          [ 0.437918433144,  0.923188119777,  0.277003037718],
                          [ 0.810521688743,  0.424399343153,  0.479289133013],
                          [ 0.800449713604,  0.735269463106,  0.219879631298],
                          [ 0.093005373938,  0.069416769287,  0.30541397964 ],
                          [ 0.204687273036,  0.142966757615,  0.642380807038],
                          [ 0.403413430916,  0.591194801698,  0.041298434967],
                          [ 0.924039151622,  0.8794440263  ,  0.681366359624],
                          [ 0.732190278544,  0.574240691674,  0.870477684012]]


# Set up configuration
bulk_configuration = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates,
)

# -------------------------------------------------------------
# Calculator: DFT Reference calculator.
# -------------------------------------------------------------
k_point_sampling = KpointDensity(
    density_a=7.0*Angstrom,
    )
numerical_accuracy_parameters = NumericalAccuracyParameters(
    density_mesh_cutoff=125.0*Hartree,
    k_point_sampling=k_point_sampling,
    )

iteration_control_parameters = IterationControlParameters(
    tolerance=5e-05,
    damping_factor=0.3,
    number_of_history_steps=12,
    )

reference_calculator = LCAOCalculator(
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    iteration_control_parameters=iteration_control_parameters,
    )

# Load the training data. This can be one or several TrainingSet, MomentTensorPotentialTraining or
# Trajectory objects, which contain energy, forces, stress calculated with the reference
# calculator.
initial_training_data = [nlread('HfO2_crystal_training.hdf5', TrainingSet)[0]]

# Use the predefined small MTP basis.
mtp_basis = PredefinedBasisSmall

# Optimize the non-linear coefficients on the energy only.
nl_parameters = NonLinearCoefficientsParameters(
    perform_optimization=True,
    energy_only=True,
)

fitting_parameters = MomentTensorPotentialFittingParameters(
    basis_size=mtp_basis,
    outer_cutoff_radii=4.5*Ang,
    mtp_filename='HfO2_active_learning.mtp',
    non_linear_coefficients_parameters=nl_parameters,
    use_element_specific_coefficients=True,
)

active_learning = ActiveLearningSimulation(
    fitting_parameters=fitting_parameters,
    initial_training_data=initial_training_data,
    mtp_study_filename='HfO2_mtp_study',
    mtp_study_object_id='HfO2',
    reference_calculator=reference_calculator,
    candidate_threshold=1.0,
    retrain_threshold=3.0,
    check_interval=20,
    max_forces_check=10.0*eV/Ang,
    use_stress=True,
    candidate_trajectory_filename='HfO2_am_active_learning_candidates.hdf5',
    restart_simulation=True,
    extrapolation_selection_parameters=ExtrapolationSelectionParameters(
        extrapolation_grade_algorithm=QueryByCommitteeForces,
        descriptor_cutoff=0.1,
    ),
)

# Set up a high-temperature MD at 3000 K.
initial_velocity = MaxwellBoltzmannDistribution(
    temperature=3000.0*Kelvin,
    remove_center_of_mass_momentum=True,
    random_seed=None,
    enforce_temperature=True,
)

method = Langevin(
    time_step=1*femtoSecond,
    reservoir_temperature=3000*Kelvin,
    friction=0.01*femtoSecond**-1,
    initial_velocity=initial_velocity,
)

constraints = [FixCenterOfMass()]

# Run the MD simulation through the active learning object.
md_trajectory = active_learning.runMolecularDynamics(
    bulk_configuration,
    constraints=constraints,
    trajectory_filename='HfO2_am_active_learning_3000K.hdf5',
    steps=100000,
    log_interval=100,
    method=method,
    domain_decomposition_pattern=[1, 1, 1],
)

# Extract the additional training data that has been added during active learning as a TrainingSet
# object, and save it to a file.
additional_training_data = active_learning.additionalTrainingSet()
nlsave('HfO2_active_learning_additional_training_data.hdf5', additional_training_data)

# Extract a table with the initial and additional training data that has been added
# during active learning. This table can be used as input to a final MTP fit.
training_set_table = active_learning.trainingSetTable()

# Test different basis sizes.
fitting_parameters_list = []
for mtp_basis in [PredefinedBasisSmall, 400, 800]:
    # Optimize the non-linear coefficients on the energy only.
    nl_parameters = NonLinearCoefficientsParameters(
        perform_optimization=True,
        energy_only=True,
    )

    fitting_parameters = MomentTensorPotentialFittingParameters(
        basis_size=mtp_basis,
        outer_cutoff_radii=4.5*Ang,
        mtp_filename=f'HfO2_active_learning_{mtp_basis}.mtp',  # RJL: Not technically active learning. Maybe rename.
        non_linear_coefficients_parameters=nl_parameters,
        use_element_specific_coefficients=True,
    )

    fitting_parameters_list.append(fitting_parameters)

# Perform an MTP training using the list of fitting parameters.
mtp_training = MomentTensorPotentialTraining(
    filename='Final_MTP_training_basis_size_scan.hdf5',
    object_id='mtp_training',
    training_sets=training_set_table,
    calculator=reference_calculator,
    calculate_stress=True,
    fitting_parameters_list=fitting_parameters_list,
    train_test_split=0.9,
    random_seed=13345,
    log_filename_prefix='mtp_basis_size_scan',
)
mtp_training.update()
nlprint(mtp_training)
