# %% Cobalt disilicide

# Set up lattice
lattice = FaceCenteredCubic(10.73*Angstrom)

# Define elements
elements = [Cobalt, Cobalt, Cobalt, Cobalt, Cobalt, Cobalt, Cobalt, Cobalt,
            Silicon, Silicon, Silicon, Silicon, Silicon, Silicon, Silicon,
            Silicon, Silicon, Silicon, Silicon, Silicon, Silicon, Silicon,
            Silicon, Silicon]

# Define coordinates
fractional_coordinates = [[ 0.   ,  0.   ,  0.   ],
                          [ 0.   ,  0.   ,  0.5  ],
                          [-0.   ,  0.5  ,  0.   ],
                          [-0.   ,  0.5  ,  0.5  ],
                          [ 0.5  , -0.   ,  0.   ],
                          [ 0.5  , -0.   ,  0.5  ],
                          [ 0.5  ,  0.5  ,  0.   ],
                          [ 0.5  ,  0.5  ,  0.5  ],
                          [ 0.125,  0.125,  0.125],
                          [ 0.125,  0.125,  0.625],
                          [ 0.125,  0.625,  0.125],
                          [ 0.125,  0.625,  0.625],
                          [ 0.625,  0.125,  0.125],
                          [ 0.625,  0.125,  0.625],
                          [ 0.625,  0.625,  0.125],
                          [ 0.625,  0.625,  0.625],
                          [ 0.375,  0.375,  0.375],
                          [ 0.375,  0.375,  0.875],
                          [ 0.375,  0.875,  0.375],
                          [ 0.375,  0.875,  0.875],
                          [ 0.875,  0.375,  0.375],
                          [ 0.875,  0.375,  0.875],
                          [ 0.875,  0.875,  0.375],
                          [ 0.875,  0.875,  0.875]]

# Set up configuration
cosi2 = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )

# Set tags for fixed atoms
cosi2.addTags('fixed', [0, 4, 12, 16])


# Load the pretrained MTP containing committee potentials and set up a calculator
potentialSet = QuantumATK_MTP_CoSi_2022_12()
calculator = TremoloXCalculator(parameters=potentialSet)


# Set Calculator
cosi2.setCalculator(calculator)
nlsave('CoSi2_MD_ErrorPrediction.hdf5', cosi2)


# The MolecularDynamicsErrorPredictionHook takes an MTPErrorPredictionParameters object as input.
mtp_error_prediction_parameters = MTPErrorPredictionParameters(
    predict_energy_error=True,
    predict_forces_error=True,
    check_interval=1,
    write_atomic_error_estimates=True,
    stop_simulation_energy_threshold=None,
    stop_simulation_force_threshold=None
)

# MolecularDynamicsErrorPredictionHook
molecular_dynamics_error_prediction_hook = MolecularDynamicsErrorPredictionHook(
    mtp_error_prediction_parameters=mtp_error_prediction_parameters
)

# HookFunctions
pre_step_hooks = []
post_step_hooks = []
measurement_hooks = [
    molecular_dynamics_error_prediction_hook,
]
hook_functions = HookFunctions(
    pre_step_hooks=pre_step_hooks,
    post_step_hooks=post_step_hooks,
    measurement_hooks=measurement_hooks,
)
nlsave('CoSi2_MD_ErrorPrediction.hdf5', hook_functions)

# Constraints
fix_atom_indices_0 = cosi2.indicesFromTags(['fixed'])
constraints = [FixStrain(True, True, True), FixAtomConstraints(fix_atom_indices_0)]

# Set up MolecularDynamics with the error prediction hook
md_trajectory = MolecularDynamics(
    configuration=cosi2,
    constraints=constraints,
    log_interval=1,
    trajectory_filename='CoSi2_MD_ErrorPrediction.hdf5',
    hook_functions=hook_functions,
    trajectory_object_id='md',
)
