# Set up lattice
lattice = FaceCenteredCubic(3.61496*Angstrom)

# Define elements
elements = [Copper]

# Define coordinates
fractional_coordinates = [[ 0.,  0.,  0.]]

# Set up configuration
bulk_configuration_copper = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )

# Set up lattice
lattice = BodyCenteredCubic(2.8665*Angstrom)

# Define elements
elements = [Iron]

# Define coordinates
fractional_coordinates = [[ 0.,  0.,  0.]]

# Set up configuration
bulk_configuration_iron = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )

# Define calculator for pre-optimization calculations.
potentialSet = EAM_CuFeNi_2009()
calculator_ff = TremoloXCalculator(parameters=potentialSet)

optimize_geometry_parameters = OptimizeGeometryParameters(
    max_forces=0.1*eV/Ang,
    max_steps=1000,
    max_step_length=0.2*Ang,
    constraints=[FixStrain(x=True, y=True, z=False)],
    trajectory_interval=1,
    optimizer_method=FIRE(),
    enable_optimization_stop_file=True,
    restart_strategy=NoRestart,
    trajectory_filename='interface-optimization-trajectory.hdf5',
)

reference_configurations = [bulk_configuration_copper, bulk_configuration_iron]

training_set = CrystalInterfaceTrainingParameters(
         reference_configurations,
         optimize=calculator_ff,
         optimize_geometry_parameters=optimize_geometry_parameters,
         rattle=True,
         sample_size=2,
         atomic_rattling_amplitudes=[0.3] * Angstrom,
         buffer_zone=0.3 * Angstrom,
         surface_termination_0=Copper,
         surface_termination_1=Iron,
         plane_0=[[1, 1, 1], [0, 1, 1], [0, 0, 1]],
         plane_1=[[0, 1, 1]],
         vacuum=10 * Angstrom,
         max_number_of_atoms=150,
         thickness_max=9*Angstrom,
         shortest_surface_lattice_vector=5 * Angstrom,
         longest_surface_lattice_vector=12 * Angstrom,
         minimum_surface_lattice_vector_angle=60 * Degrees,
         strain_max=0.06,
)

# -------------------------------------------------------------
# Calculator
# -------------------------------------------------------------
k_point_sampling = KpointDensity(
    density_a=5.0*Angstrom,
    )
numerical_accuracy_parameters = NumericalAccuracyParameters(
    density_mesh_cutoff=100.0*Hartree,
    k_point_sampling=k_point_sampling,
    occupation_method=MethfesselPaxton(0.2*eV, 1),
    )

iteration_control_parameters = IterationControlParameters(
    tolerance=5e-05,
    damping_factor=0.3,
    number_of_history_steps=12,
    max_steps=200,
    non_convergence_behavior=StopCalculation(),
    )

algorithm_parameters = AlgorithmParameters(
    scf_restart_step_length=0.3*Angstrom,
    )

calculator = LCAOCalculator(
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    iteration_control_parameters=iteration_control_parameters,
    algorithm_parameters=algorithm_parameters,
    )

# Set up non-linear coefficients with optimization.
non_linear_coefficients_parameters = NonLinearCoefficientsParameters(
   perform_optimization=True,
   energy_only=False,
)

# Set up parameters to use in the MTP fitting.
fitting_parameters = MomentTensorPotentialFittingParameters(
   basis_size=1000,
   outer_cutoff_radii=4.5*Angstrom,
   mtp_filename='mtp_Cu-Fe_interface.mtp',
   non_linear_coefficients_parameters=non_linear_coefficients_parameters,
)

# Set up MTP training.
moment_tensor_potential_training = MomentTensorPotentialTraining(
    filename='mtp_study',
    object_id='training',
    training_sets=training_set,
    calculator=calculator,
    calculate_stress=True,
    fitting_parameters_list=fitting_parameters,
    train_test_split=0.8,
    random_seed=13345,
    number_of_processes_per_task=8,
    log_filename_prefix='fit_mtp_Cu-Fe_interface',
)
moment_tensor_potential_training.update()
nlprint(moment_tensor_potential_training)

# The MTP calculator can now be extracted from the MomentTensorPotentialTraining object.
mtp_calculator = moment_tensor_potential_training.momentTensorPotentialCalculators()[0]
