import itertools

# %% List_of_SMIs_Config

# %% Configuration

# Define elements
elements = [Hydrogen, Carbon, Hydrogen, Hydrogen, Carbon, Oxygen, Oxygen,
            Hydrogen]

# Define coordinates
cartesian_coordinates = [[ 0.748765, -0.127722, -1.46492 ],
                         [-0.209816, -0.194167, -0.933683],
                         [-0.836558,  0.642786, -1.268776],
                         [-0.701067, -1.128721, -1.232691],
                         [-0.001295, -0.142614,  0.547721],
                         [-0.296097, -0.967368,  1.393966],
                         [ 0.600553,  0.98165 ,  1.006038],
                         [ 0.695512,  0.936156,  1.952344]]*Angstrom

# Set up configuration
configuration_0 = MoleculeConfiguration(
    elements=elements,
    cartesian_coordinates=cartesian_coordinates
    )

# Add bonds
bonds = [[0, 1],
         [1, 2],
         [1, 3],
         [1, 4],
         [4, 5],
         [4, 6],
         [6, 7]]
configuration_0.setBonds(bonds)

configuration_name_0 = "Acetic acid"

configuration_table = Table()
configuration_table.addInstanceColumn(key='configurations', types=MoleculeConfiguration)
configuration_table.append(configuration_0)

configuration_names_table = Table()
configuration_names_table.addStringColumn(key='configuration_names')
configuration_names_table.append(configuration_name_0)


# %% List_of_cluster_model

# %% Configuration

# Define elements
elements = [Aluminium, Nitrogen, Hydrogen, Nitrogen, Hydrogen, Nitrogen,
            Hydrogen, Hydrogen, Hydrogen, Hydrogen]

# Define coordinates
cartesian_coordinates = [[  9.865768213117,  12.383377293118,  11.748123418484],
                         [ 10.669928205505,  13.645245174255,  10.058539260268],
                         [ 11.634771653761,  13.378964418898,   9.709224345954],
                         [ 10.604212310744,  10.469469434034,  10.809099556215],
                         [ 11.654520764871,  10.322933671824,  10.861739632734],
                         [ 10.622061370086,  13.790527070005,  13.296477200545],
                         [ 11.433508645293,  14.437456397798,  13.164253165495],
                         [ 10.069368922141,  14.458709581093,   9.746148012834],
                         [  9.916597400124,   9.891134033957,  10.25249843707 ],
                         [  9.923287172717,  13.530771082121,  14.034338129733]]*Angstrom

# Set up configuration
configuration_0 = MoleculeConfiguration(
    elements=elements,
    cartesian_coordinates=cartesian_coordinates
    )

# Add tags
configuration_0.addTags('H_N', [7, 8, 9])

# Add bonds
bonds = [[0, 1],
         [0, 3],
         [0, 5],
         [1, 2],
         [1, 7],
         [3, 4],
         [3, 8],
         [5, 6],
         [5, 9]]
configuration_0.setBonds(bonds)

configuration_name_0 = "AlNH2"

configuration_table_1 = Table()
configuration_table_1.addInstanceColumn(key='configurations', types=MoleculeConfiguration)
configuration_table_1.append(configuration_0)

configuration_names_table_1 = Table()
configuration_names_table_1.addStringColumn(key='configuration_names')
configuration_names_table_1.append(configuration_name_0)


# %% ForceFieldCalculator

potentialSet = TorchX_MACE_MP_0_L0_2023(dtype='float32', enforceLTX=False)
calculator = TremoloXCalculator(parameters=potentialSet)


# %% generate_complex


def generate_complex(smis, surface_models, name_surface_models, name_smis):

    # Extract matching columns from tables
    config_smis = smis.column('configurations')
    surf_config = surface_models.column('configurations')
    name_surf = name_surface_models.column('configuration_names')
    name_smi = name_smis.column('configuration_names')

    # Defined variables.
    central_atoms = '[Al, Si]'

    # Script.
    from generateAcidDerivativeComplex import generateAcidDerivativeComplex

    table_of_complex_config = generateAcidDerivativeComplex(
        config_smis, surf_config, name_surf, name_smi, central_atoms
    )

    return table_of_complex_config


table_of_complex_config = generate_complex(
    configuration_table,
    configuration_table_1,
    configuration_names_table_1,
    configuration_names_table,
)

nlsave('AdsorptionEnergies_AcidDerivatives.hdf5', table_of_complex_config)


# %% Array Table Iteration

# Array Table Iteration(preparation)
for row_index in range(table_of_complex_config.numberOfRows()):
    surf_cluster, name_surf, smi_mol, name_smi, complex_config, products = (
        table_of_complex_config[
            row_index,
            [
                'surf_cluster',
                'name_surf',
                'config_mol',
                'name_smi',
                'complex_config',
                'product_1',
            ],
        ]
    )

    # %% Create table for calculations

    def create_table_for_calculations(surf_cluster, smi_mol, complex_config, products):
        # This custom script supports atkpython syntax
        # and can perform almost any procedure.
        configuration_table = Table()
        configuration_table.addInstanceColumn(
            key='configuration', types=BulkConfiguration
        )
        configuration_table.append(surf_cluster)
        configuration_table.append(smi_mol)
        configuration_table.append(complex_config)
        configuration_table.append(products)

        energy = Table()
        energy.addInstanceColumn(key='energy', types=TotalEnergy)
        return configuration_table, energy

    configuration_table_2, energy = create_table_for_calculations(
        surf_cluster, smi_mol, complex_config, products
    )

    nlsave(
        'AdsorptionEnergies_AcidDerivatives.hdf5',
        configuration_table_2,
        object_id='configuration_table_2_Create_table_for_calculations',
    )

    nlsave(
        'AdsorptionEnergies_AcidDerivatives.hdf5',
        energy,
        object_id='energy_Create_table_for_calculations',
    )

    # %% Table Iteration

    # Table Iteration(preparation)
    for row_index_1 in range(configuration_table_2.numberOfRows()):
        configuration = configuration_table_2[row_index_1, ['configuration']]

        # %% Set Calculator

        configuration.setCalculator(calculator)

        configuration.update()

        nlsave(
            'AdsorptionEnergies_AcidDerivatives.hdf5',
            configuration,
            object_id=f'configuration_Set_Calculator_row_index_1_{row_index_1}_row_index_{row_index}',
        )

        # %% OptimizeGeometry

        restart_strategy = RestartFromTrajectory(
            trajectory_filename='AdsorptionEnergies_AcidDerivatives.hdf5',
            object_id=f'optimized_configuration_optimize_trajectory_row_index_1_{row_index_1}_row_index_{row_index}',
        )

        optimized_configuration = OptimizeGeometry(
            configuration=configuration,
            max_forces=0.05 * eV / Angstrom,
            constraints=[BravaisLatticeConstraint()],
            trajectory_filename='AdsorptionEnergies_AcidDerivatives.hdf5',
            trajectory_object_id=f'optimized_configuration_optimize_trajectory_row_index_1_{row_index_1}_row_index_{row_index}',
            optimizer_method=FIRE(),
            restart_strategy=restart_strategy,
        )

        nlsave(
            'AdsorptionEnergies_AcidDerivatives.hdf5',
            optimized_configuration,
            object_id=f'optimized_configuration_optgeom_row_index_1_{row_index_1}_row_index_{row_index}',
        )

        # %% TotalEnergy

        total_energy = TotalEnergy(configuration=optimized_configuration)
        nlsave(
            'AdsorptionEnergies_AcidDerivatives.hdf5',
            total_energy,
            object_id=f'total_energy_TotalEnergy_row_index_1_{row_index_1}_row_index_{row_index}',
        )

        # %% Save energies to Table

        def save_energies_to_table(energy, energy_total):
            # This custom script supports atkpython syntax
            # and can perform almost any procedure.
            energy.append(energy_total)
            return energy

        energy = save_energies_to_table(energy, total_energy)

        nlsave(
            'AdsorptionEnergies_AcidDerivatives.hdf5',
            energy,
            object_id='energy_Save_energies_to_Table',
        )

    # %% AdsorptionEnergy

    def adsorption_energy(energies):

        # Extract matching columns from tables
        energy = energies.column('energy')

        # Calculate bond dissociation energy
        # original_mol = surf_cluster.hillFormula()
        # ligand = smi_mol.hillFormula()

        Energy_all = energies.column(0)

        Binding_energy = (
            Energy_all[3].evaluate()
            + Energy_all[2].evaluate()
            - Energy_all[1].evaluate()
            - Energy_all[0].evaluate()
        )

        return Binding_energy

    binding__energy = adsorption_energy(energy)

    nlsave(
        'AdsorptionEnergies_AcidDerivatives.hdf5',
        binding__energy,
        object_id=f'binding__energy_AdsorptionEnergy_row_index_{row_index}_1',
    )

    # %% CollectResults

    if 'collect_results' not in locals():
        collect_results = Table(
            'AdsorptionEnergies_AcidDerivatives.hdf5', object_id='table'
        )
        collect_results.addInstanceColumn(key='model', types=BulkConfiguration)
        collect_results.addStringColumn(key='name_surf')
        collect_results.addInstanceColumn(key='adsorbate', types=BulkConfiguration)
        collect_results.addStringColumn(key='name_smi')
        collect_results.addQuantityColumn(key='adsorption_energy', unit=eV)
        collect_results.setMetatext('CollectResults')

    collect_results.append(surf_cluster, name_surf, smi_mol, name_smi, binding__energy)


# %% SaveResults


def save_results(table):

    # Extract matching columns from tables
    model = table.column('model')
    name_surf = table.column('name_surf')
    adsorbate = table.column('adsorbate')
    name_smi = table.column('name_smi')
    adsorption_energy = table.column('adsorption_energy')

    nlsave('E_ADS.hdf5', table)
    return


save_results(collect_results)
