# Set minimal log verbosity
setVerbosity(MinimalLog)

import itertools

# %% Molecule List

# %% Configuration

# Define elements
elements = [Carbon, Aluminium, Carbon, Chlorine, Hydrogen, Hydrogen, Hydrogen,
            Hydrogen, Hydrogen, Hydrogen]

# Define coordinates
cartesian_coordinates = [[ 1.883346486692,  0.322087190044, -0.762750081762],
                         [-0.072021416491,  0.043271691306, -0.391067110364],
                         [-1.298609410433,  1.621688052023, -0.172101706866],
                         [-0.829614644369, -1.817046311479, -0.400578804191],
                         [ 2.508894110005, -0.517810274836, -0.412298399057],
                         [ 2.032144537596,  0.395798557588, -1.859671034502],
                         [ 2.272107143242,  1.256933785186, -0.321416552383],
                         [-2.173813999523,  1.398726404326,  0.465429716889],
                         [-0.781900542475,  2.509970600411,  0.234482785274],
                         [-1.694311829116,  1.914636707872, -1.165449354861]]*Angstrom

# Set up configuration
configuration_0 = MoleculeConfiguration(
    elements=elements,
    cartesian_coordinates=cartesian_coordinates
    )

# Add tags
configuration_0.addTags('H_C', [4, 5, 6, 7, 8, 9])

configuration_name_0 = "Al_(CH3)2Cl"

configuration_table = Table()
configuration_table.addInstanceColumn(key='configurations', types=MoleculeConfiguration)
configuration_table.append(configuration_0)

configuration_names_table = Table()
configuration_names_table.addStringColumn(key='configuration_names')
configuration_names_table.append(configuration_name_0)


# %% optimizeMoleculeList

if 'optimize_molecule_list' not in locals():
    optimize_molecule_list = Table('ald_bde.hdf5', object_id='table')
    optimize_molecule_list.addInstanceColumn(
        key='configuration', types=MoleculeConfiguration
    )
    optimize_molecule_list.setMetatext('optimizeMoleculeList')


# %% Table Iteration

# Table Iteration(preparation)
for row_index in range(configuration_table.numberOfRows()):
    configuration = configuration_table[row_index, ['configurations']]

    # %% Set ForceFieldCalculator

    # %% ForceFieldCalculator

    potentialSet = TorchX_MACE_MP_0_L0_2023(dtype='float32', enforceLTX=False)
    calculator = TremoloXCalculator(parameters=potentialSet)

    # %% Set Calculator

    configuration.setCalculator(calculator)

    nlsave(
        'ald_bde.hdf5',
        configuration,
        object_id=f'configuration_Set_Calculator_row_index_{row_index}',
    )

    # %% OptimizeGeometry

    restart_strategy = RestartFromTrajectory(
        trajectory_filename='ald_bde.hdf5',
        object_id=f'optimized_configuration_optimize_trajectory_row_index_{row_index}',
    )

    optimized_configuration = OptimizeGeometry(
        configuration=configuration,
        trajectory_filename='ald_bde.hdf5',
        trajectory_object_id=f'optimized_configuration_optimize_trajectory_row_index_{row_index}',
        restart_strategy=restart_strategy,
    )

    nlsave(
        'ald_bde.hdf5',
        optimized_configuration,
        object_id=f'optimized_configuration_optgeom_row_index_{row_index}',
    )

    # %% Append Row to Table

    optimize_molecule_list.append(optimized_configuration)


# %% LCAOCalculator

# ----------------------------------------
# Exchange-Correlation
# ----------------------------------------
exchange_correlation = HybridSGGA.B3LYP5

k_point_sampling = KpointDensity(
    density_a=4.0 * Angstrom, density_b=4.0 * Angstrom, density_c=4.0 * Angstrom
)

numerical_accuracy_parameters = NumericalAccuracyParameters(
    k_point_sampling=k_point_sampling
)

calculator = LCAOCalculator(
    exchange_correlation=exchange_correlation,
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    checkpoint_handler=NoCheckpointHandler,
)


# %% Input Table To Collect Configurations


def input_table_to_collect_configurations(config, name):

    # Extract matching columns from tables
    configuration = config.column('configuration')
    name = name.column('configuration_names')

    table = Table()
    table.addInstanceColumn(key='configuration', types=MoleculeConfiguration)
    table.addStringColumn(key='name')

    for i, _ in enumerate(configuration):
        table.append(configuration[i], name[i])

    # Get the configuration in table
    # conf_0 is the full molecule
    # conf_1 is the molecule with a ligand removed
    # conf_2 is the removed ligand
    if 'reaction_table' not in locals():
        reaction_table = Table()
        reaction_table.addInstanceColumn(key='conf_0', types=MoleculeConfiguration)
        reaction_table.addInstanceColumn(key='conf_1', types=MoleculeConfiguration)
        reaction_table.addInstanceColumn(key='conf_2', types=MoleculeConfiguration)
    return table, reaction_table


table, reaction_table = input_table_to_collect_configurations(
    optimize_molecule_list, configuration_names_table
)

nlsave('ald_bde.hdf5', table)

nlsave('ald_bde.hdf5', reaction_table)


# %% Table Iteration

# Table Iteration(preparation)
for row_index in range(table.numberOfRows()):
    configuration = table[row_index, ['configuration']]

    # %% generate_fragments

    def generate_fragments(reaction_table):

        # Extract matching columns from tables
        conf_0 = reaction_table.column('conf_0')
        conf_1 = reaction_table.column('conf_1')
        conf_2 = reaction_table.column('conf_2')

        # Defined variables.
        central_atom = '[Al, Mo]'

        # Script.
        from generateFragments import GenerateFragments

        config_original_mol = configuration
        reaction_table = GenerateFragments(
            reaction_table, config_original_mol, central_atom
        )
        return reaction_table

    reaction_table_1 = generate_fragments(reaction_table)

    nlsave(
        'ald_bde.hdf5', reaction_table_1, object_id='reaction_table_1_generate_fragments'
    )


# %% Array Table Iteration

# Array Table Iteration(preparation)
for row_index in range(reaction_table.numberOfRows()):
    conf_0, conf_1, conf_2 = reaction_table[row_index, ['conf_0', 'conf_1', 'conf_2']]

    # %% Create table for calculations

    def create_table_for_calculations(conf_0, conf_1, conf_2):
        # This custom script supports atkpython syntax
        # and can perform almost any procedure.

        configuration_table = Table()
        configuration_table.addInstanceColumn(
            key='configuration', types=MoleculeConfiguration
        )
        # configuration_table.append(conf_0)
        configuration_table.append(conf_1)
        configuration_table.append(conf_2)

        energy = Table()
        energy.addInstanceColumn(key='energy', types=TotalEnergy)
        return configuration_table, energy

    configuration_table_1, energy = create_table_for_calculations(conf_0, conf_1, conf_2)

    nlsave(
        'ald_bde.hdf5',
        configuration_table_1,
        object_id='configuration_table_1_Create_table_for_calculations',
    )

    nlsave('ald_bde.hdf5', energy, object_id='energy_Create_table_for_calculations')

    # %% Set Calculator

    conf_0.setCalculator(calculator)

    conf_0.update()

    nlsave(
        'ald_bde.hdf5', conf_0, object_id=f'conf_0_Set_Calculator_row_index_{row_index}'
    )

    # %% OptimizeGeometry

    restart_strategy = RestartFromTrajectory(
        trajectory_filename='ald_bde.hdf5',
        object_id=f'optimized_configuration_optimize_trajectory_row_index_{row_index}_1',
    )

    optimized_configuration = OptimizeGeometry(
        configuration=conf_0,
        trajectory_filename='ald_bde.hdf5',
        trajectory_object_id=f'optimized_configuration_optimize_trajectory_row_index_{row_index}_1',
        restart_strategy=restart_strategy,
    )

    nlsave(
        'ald_bde.hdf5',
        optimized_configuration,
        object_id=f'optimized_configuration_optgeom_row_index_{row_index}_1',
    )

    # %% TotalEnergy

    total_energy = TotalEnergy(configuration=optimized_configuration)
    nlsave(
        'ald_bde.hdf5',
        total_energy,
        object_id=f'total_energy_TotalEnergy_row_index_{row_index}',
    )

    # %% Append Energy of Molecule

    def append_energy_of_molecule(energy, total_energy):

        # Extract matching columns from tables
        energy = energy.column('energy')

        # This custom script supports atkpython syntax
        # and can perform almost any procedure.
        energy.append(total_energy)
        return

    append_energy_of_molecule(energy, total_energy)

    # %% Table Iteration

    # Table Iteration(preparation)
    for row_index_1 in range(configuration_table_1.numberOfRows()):
        configuration = configuration_table_1[row_index_1, ['configuration']]

        # %% Set Calculator

        configuration.setCalculator(calculator)

        configuration.update()

        nlsave(
            'ald_bde.hdf5',
            configuration,
            object_id=f'configuration_Set_Calculator_row_index_1_{row_index_1}_row_index_{row_index}',
        )

        # %% OptimizeGeometry

        restart_strategy = RestartFromTrajectory(
            trajectory_filename='ald_bde.hdf5',
            object_id=f'optimized_configuration_1_optimize_trajectory_row_index_1_{row_index_1}_row_index_{row_index}',
        )

        optimized_configuration_1 = OptimizeGeometry(
            configuration=configuration,
            trajectory_filename='ald_bde.hdf5',
            trajectory_object_id=f'optimized_configuration_1_optimize_trajectory_row_index_1_{row_index_1}_row_index_{row_index}',
            restart_strategy=restart_strategy,
        )

        nlsave(
            'ald_bde.hdf5',
            optimized_configuration_1,
            object_id=f'optimized_configuration_1_optgeom_row_index_1_{row_index_1}_row_index_{row_index}',
        )

        # %% TotalEnergy

        total_energy_1 = TotalEnergy(configuration=optimized_configuration_1)
        nlsave(
            'ald_bde.hdf5',
            total_energy_1,
            object_id=f'total_energy_1_TotalEnergy_row_index_1_{row_index_1}_row_index_{row_index}',
        )

        # %% Save energies to Table

        def save_energies_to_table(energy, energy_total):
            # This custom script supports atkpython syntax
            # and can perform almost any procedure.
            energy.append(energy_total)
            return energy

        energy_1 = save_energies_to_table(energy, total_energy_1)

        nlsave('ald_bde.hdf5', energy_1, object_id='energy_1_Save_energies_to_Table')

    # %% BondDissociationEnergy

    def bond_dissociation_energy(energies):

        # Extract matching columns from tables
        energy = energies.column('energy')

        # This custom script supports atkpython syntax
        # and can perform almost any procedure.
        original_mol = conf_0.hillFormula()
        ligand = conf_2.hillFormula()

        Energy_all = energies.column(0)

        Binding_energy = (
            Energy_all[2].evaluate() + Energy_all[1].evaluate() - Energy_all[0].evaluate()
        )
        return Binding_energy, original_mol, ligand

    binding__energy, original_mol, ligand = bond_dissociation_energy(energy)

    nlsave(
        'ald_bde.hdf5',
        binding__energy,
        object_id=f'binding__energy_BondDissociationEnergy_row_index_{row_index}_1',
    )

    nlsave(
        'ald_bde.hdf5',
        original_mol,
        object_id=f'original_mol_BondDissociationEnergy_row_index_{row_index}_1',
    )

    nlsave(
        'ald_bde.hdf5',
        ligand,
        object_id=f'ligand_BondDissociationEnergy_row_index_{row_index}_1',
    )

    # %% CollectResults

    if 'collect_results' not in locals():
        collect_results = Table('ald_bde.hdf5', object_id='table_1')
        collect_results.addStringColumn(key='molecule_name')
        collect_results.addInstanceColumn(key='molecule', types=MoleculeConfiguration)
        collect_results.addStringColumn(key='ligand_name')
        collect_results.addInstanceColumn(key='ligand', types=MoleculeConfiguration)
        collect_results.addQuantityColumn(key='bde', unit=eV)
        collect_results.setMetatext('CollectResults')

    collect_results.append(original_mol, conf_0, ligand, conf_2, binding__energy)


# %% SaveResults


def save_results(table):

    # Extract matching columns from tables
    molecule_name = table.column('molecule_name')
    molecule = table.column('molecule')
    ligand_name = table.column('ligand_name')
    ligand = table.column('ligand')
    bde = table.column('bde')

    # This custom script supports atkpython syntax
    # and can perform almost any procedure.
    nlsave('BDE.hdf5', table)
    return


save_results(collect_results)
