# ARRAY prepare 11 finalize

import itertools

# Set minimal log verbosity
setVerbosity(MinimalLog)

results_filename = 'array_workflow_results.hdf5'

# Extract array part (if present).
part = os.environ.get('PART')
common_folder = os.getcwd()

# Common setup.
if part in ('prepare', None):
    # Set up lattice
    lattice = FaceCenteredCubic(5.4306*Angstrom)

    # Define elements
    elements = [Silicon, Silicon]

    # Define coordinates
    fractional_coordinates = [[ 0.  ,  0.  ,  0.  ],
                              [ 0.25,  0.25,  0.25]]

    # Set up configuration
    silicon_alpha = BulkConfiguration(
        bravais_lattice=lattice,
        elements=elements,
        fractional_coordinates=fractional_coordinates
        )

    silicon_alpha_name = "silicon_alpha"

    nlsave(results_filename, silicon_alpha, object_id='silicon_alpha')

# Parallelize.
if part in ('main', None):
    if part == 'main':
        # Extract array specifics.
        folder = os.environ['ARRAY_FOLDER']
        array_index = int(os.environ['ARRAY_INDEX'])
        array_size = int(os.environ['ARRAY_SIZE'])
    else:
        folder = '.'
        array_index = None
        array_size = None

    # Move to the array folder.
    os.chdir(folder)

    silicon_alpha = nlread(
        os.path.join(common_folder, results_filename), object_id='silicon_alpha'
    )[0]

    # Array Iteration (over density mesh cut-off)(preparation)
    ############################################################
    #          Array Iteration (density_mesh_cutoff)           #
    ############################################################

    # Create iterator.
    def densityMeshCutoffsIterator():
        for i in range(11):
            v = (10.0 + i * 10.0) * Hartree
            yield v

    density_mesh_cutoffs = densityMeshCutoffsIterator()

    def arraySplit(generator, length, array_size, array_index):
        """
        Determine the start and end indices to distribute an generator over
        an array of calculations.

        :param length:          The length of the generator.
        :type  length:          int

        :param array_size:      The array length.
        :type  array_size:      int

        :param array_index:     The array index.
        :type  array_index:     int

        :returns:               An generator over the split indices.
        :rtype:                 iterator
        """
        # Fallback if no array parameters were set.
        if None in (array_size, array_index):
            return generator

        size = length // array_size
        extra = length % array_size

        start = array_index * size + min(array_index, extra)
        stop = (array_index + 1) * size + min(array_index + 1, extra)

        return itertools.islice(generator, start, stop)

    for density_mesh_cutoff in arraySplit(
        density_mesh_cutoffs, 11, array_size, array_index
    ):

        # %% Set LCAOCalculator

        # %% LCAOCalculator

        k_point_sampling = KpointDensity(
            density_a=4.0 * Angstrom, density_b=4.0 * Angstrom, density_c=4.0 * Angstrom
        )

        numerical_accuracy_parameters = NumericalAccuracyParameters(
            density_mesh_cutoff=density_mesh_cutoff, k_point_sampling=k_point_sampling
        )

        calculator = LCAOCalculator(
            numerical_accuracy_parameters=numerical_accuracy_parameters,
            checkpoint_handler=NoCheckpointHandler,
        )

        # %% Set Calculator

        silicon_alpha.setCalculator(calculator)

        silicon_alpha.update()

        # %% Calculate TotalEnergy

        calculate_total_energy = TotalEnergy(configuration=silicon_alpha)

        # %% Extract Total Energy

        def extract_total_energy(total_energy):
            evaluate = total_energy.evaluate()
            return evaluate

        evaluate = extract_total_energy(calculate_total_energy)

        # %% Array Join (collect cut-off and energy)

        if 'array_join_collect_cutoff_and_energy' not in locals():
            array_join_collect_cutoff_and_energy = Table(
                results_filename, object_id='table'
            )
            array_join_collect_cutoff_and_energy.addQuantityColumn(
                key='density_mesh_cutoff', unit=Hartree
            )
            array_join_collect_cutoff_and_energy.addQuantityColumn(
                key='total_energy', unit=eV
            )
            array_join_collect_cutoff_and_energy.setMetatext(
                'Array Join (collect cut-off and energy)'
            )

        array_join_collect_cutoff_and_energy.append(density_mesh_cutoff, evaluate)

# Collect all results into common table.
if part in ('collect', 'finalize'):
    # Extract array specifics.
    array_folder_prefix = os.environ['ARRAY_FOLDER_PREFIX']
    array_size = int(os.environ['ARRAY_SIZE'])

    # Collect all tables into one.
    if part == 'finalize':
        array_join_collect_cutoff_and_energy = Table(results_filename)
    else:
        array_join_collect_cutoff_and_energy = Table('partial_' + results_filename)

    for array_index in range(array_size):
        path = os.path.join(array_folder_prefix + str(array_index), results_filename)

        try:
            # Read the partial table.
            partial_table = nlread(path, object_id='table')[0]

            # Extend main table.
            array_join_collect_cutoff_and_energy.extend(partial_table)
        except Exception:
            print(f'No file found for index {array_index} at {path}.')

# Post-process results.
if part in ('collect', 'finalize', None):

    # Create PlotModel.
    plot_model = Plot.PlotModel(x_unit=Hartree, y_unit=meV)

    plot_model.framing().setTitle('Plot_Total_Energies')

    # Add line
    Total_energy = Plot.Line(
        array_join_collect_cutoff_and_energy[:, 'density_mesh_cutoff'],
        array_join_collect_cutoff_and_energy[:, 'total_energy'],
    )
    Total_energy.setLabel('Total energy')
    Total_energy.setColor('#b82832')
    Total_energy.setLineStyle('-')
    Total_energy.setMarkerStyle('')
    plot_model.addItem(Total_energy)

    # Auto-adjust axis limits.
    plot_model.setLimits()

    # Save plot to file.
    Plot.save(plot_model, results_filename)
