
import sys
import networkx

from NL.CommonConcepts.Configurations.Utilities import findMoleculesInBulkConfiguration
from NL.IO.NLSaveUtilities import nlread


def configurationPeriodicity(conf):
    """
    Determine the periodicity of the molecular structure.

    :param conf:    The configuration being tested.
    :type  conf:    BulkConfiguration

    :returns:   The number of dimensions that the molecular structure is periodic in.
    :rtype:     int
    """
    # Find molecules in the bulk structure
    data = findMoleculesInBulkConfiguration(conf)

    # If there is only 1 periodic structure
    translations = 0
    if len(data['periodic']) == 1:
        # Create a new configuration with just that periodic substructure
        periodic_conf = conf._createSubstructure(data['periodic'][0])

        anchor_atom = 0
        translated_atom = len(periodic_conf)
        for a, b, c in [(2, 1, 1), (1, 2, 1), (1, 1, 2)]:
            # Repeat the cell in one direction
            config_supercell = periodic_conf.repeat(a, b, c)
            config_supercell.findBonds()
            supercell_graph = config_supercell._bondGraph()

            # Test to see if an atom is connected to its translated copy in that direction
            try:
                networkx.shortest_path(
                    supercell_graph,
                    source=anchor_atom,
                    target=translated_atom
                )
                translations += 1
            except networkx.exception.NetworkXNoPath:
                pass
    return translations


def numberOfMolecules(trajectory):
    """
    Calculate the number of molecules in each image of a trajectory.

    :param trajectory:  The trajectory containing the system.
    :type  trajectory:  Trajectory

    :returns:   The number of molecules in each image.
    :rtype:     list of int
    """
    number_of_molecules = []
    for i in range(len(trajectory)):
        configuration = trajectory.image(i)
        graph = configuration._bondGraph()
        molecule_count = len(list(networkx.connected_component_subgraphs(graph)))
        number_of_molecules.append(molecule_count)
    return number_of_molecules


def findTargetCrossover(trajectory, target):
    """
    Find the frame that has the last value of the given target.

    :param trajectory:  The trajectory of the cross-linking reaction.
    :type  trajectory:  Trajectory

    :param target:      Value being searched for.
    :type  target:      int

    :returns:
        Tuple of the bounds that are either size of the target transition and values calculated
        to find those bounds.
    :rtype:
        tuple
    """
    calculated_values = {}
    bounds = [0, len(trajectory)-1]
    while (bounds[1]-bounds[0]) > 1:
        index = (bounds[1] + bounds[0]) // 2
        conf = trajectory.image(index)
        if index not in calculated_values:
            value = configurationPeriodicity(conf)
            calculated_values[index] = value

        if value <= target:
            bounds[0] = index
        else:
            bounds[1] = index

    return bounds, calculated_values


if __name__ == '__main__':
    # Open the trajectory file
    trajectory = nlread(sys.argv[1])[-1]

    # Allocate space for the calculated periodicity
    periodicity = [None] * len(trajectory)

    # Find the point the trajectory transitions from periodicity 0 to 1
    print('Calculating lower bound...')
    bounds, values = findTargetCrossover(trajectory, 0)
    for i, value in values.items():
        periodicity[i] = value
    lower_bound = bounds[1]

    # Find the point the trajectory transitions from periodicity 2 to 3
    print('Calculating upper bound...')
    bounds, values = findTargetCrossover(trajectory, 2)
    for i, value in values.items():
        periodicity[i] = value
    upper_bound = bounds[1]

    # Set the values beyond those two points
    for i in range(lower_bound):
        periodicity[i] = 0
    for i in range(upper_bound, len(trajectory)):
        periodicity[i] = 3

    # Calculating the values between the end points
    print('Calculating intermediate values...')
    for i in range(len(periodicity)):
        if periodicity[i] is None:
            conf = trajectory.image(i)
            periodicity[i] = configurationPeriodicity(conf)

    # Calculating the number of molecules in each image
    print('Calculating number of molecules...')
    number_of_molecules = numberOfMolecules(trajectory)
    reaction = [ trajectory._getQuantity('Reaction_Complete', i) for i in range(len(trajectory))]

    # Plot the resulting data
    line_a = Plot.Line(reaction, periodicity)
    line_a.setLabel('Periodicity')
    line_a.setColor('red')
    line_a.setLineWidth(1)

    line_b = Plot.Line(reaction, number_of_molecules)
    line_b.setLabel('Number of molecules')
    line_b.setColor('blue')
    line_b.setLineWidth(1)

    model_a = Plot.PlotModel()
    model_a.framing().setTitle('Thermoset Gel Point')
    model_a.xAxis().setLabel('Reaction completion')
    model_a.yAxis().setLabel('Periodicity')
    model_a.legend().setVisible(True)
    model_a.legend().setLocation('center left')
    model_a.addItem(line_a)
    model_a.setLimits()

    model_b = Plot.PlotModel()
    # model_b.framing().setTitle('Periodicity')
    model_b.xAxis().setLabel('')
    model_b.yAxis().setLabel('Number of molecules')
    model_b.yAxis().setMirrored(True)
    model_b.legend().setVisible(True)
    model_b.legend().setLocation('center right')
    model_b.addItem(line_b)
    model_b.setLimits()

    layout = Plot.OverlayLayout()
    layout.setMode(Plot.LAYOUT_MODES.SHARE_X)

    # Add main model.
    frame_a = Plot.PlotFrame(use_frame_model=False)
    frame_a.addModel(model_a)
    layout.addItem(frame_a)

    frame_b = Plot.PlotFrame(use_frame_model=False)
    frame_b.addModel(model_b)
    layout.addItem(frame_b)

    # Show the plot.
    Plot.show(layout)

    # Save plot.
    Plot.save(layout, 'gel_point_periodicity.png')

