from QuantumATK import *
from NL.CommonConcepts.Configurations.ConfigurationCopy import configurationCopy
from NL.CommonConcepts.Configurations.FindSymmetryOperations import findDirectSpaceSymmetryOperations, translateCenterOfMass
from NL.CommonConcepts.Configurations.Utilities import wrap
import scipy

def createInversionTransformationMap(input_configuration):
    """
    Function for generating the transformation map of atoms under inversion symmetry

         @param bulk_configuration    : The configuration for which to calculate the transformation
         @type                        : BulkConfiguration

         @return                      : Mapping between atoms under inversion, and the cell vector of the cell where the displaced atom ends up
    """
    # Make sure all atoms are wrapped
    bulk_configuration = configurationCopy(input_configuration)
    wrap(bulk_configuration)
    # Extract the fractional coordinates
    original_coordinates = bulk_configuration.fractionalCoordinates()
    # Get list of lattice symmetries
    symmetry_operations = findDirectSpaceSymmetryOperations(bulk_configuration)
    # Define the inversion symmetry
    inversion_symmetry = -1*numpy.eye(3)
    # Match the inversion symmetry with the list of lattice symmetries
    inversion_match = [numpy.linalg.norm(x[0]-inversion_symmetry) for x in symmetry_operations]
    # Check if the inversion symmetry is included
    if numpy.min(inversion_match) > 0.0:
        print("Inversion symmetry not found")
        exit(0)
    # We found it, so get the associated translation
    symmetry, translation = symmetry_operations[numpy.argmin(inversion_match)]
    # Generate transformed coordinates under inversion and translation
    inverted_coordinates = numpy.round(-1.*original_coordinates+translation, decimals=8)
    # Get the cell of each inverted coordinate
    cell_vectors = numpy.floor(inverted_coordinates)
    # Find the index of the inverted coordinates in the original list
    index_inverted_coordinates = [int(numpy.argmin([numpy.linalg.norm(z)
                    for z in original_coordinates -y+cell_vectors]))
                        for y,cell in zip(inverted_coordinates, cell_vectors)]
    # Return the result
    return index_inverted_coordinates, cell_vectors

def inversionOfEigenstate(c, overlap, atom_map, orbital_map, cell_vectors, kpoint):
    """
    Function for transforming an eigenstate with the inversion symmetry

         @param c                : The eigenstate to be transformed
         @param overlap          : The overlap function
         @param atom_map         : Map of how atoms transform under inversion
         @param orbital_map      : Map of the layout of orbitals on the atoms
         @param cell_vectors     : The cell vector of each transformed atom
         @param kpoint           : The kpoint of the eigenstate

         @return                 : The transformed eigenstate
    """
    # Initialize variables
    c_inv = 0.0*c
    n_centers = orbital_map.numberOfCenters()
    i_shell = 0
    i_orb = 0
    num_spins =  c.shape[0]/orbital_map.numberOfOrbitals()

    # Loop over each atom
    for i in range(n_centers):
        i_transformed = atom_map[i]
        i_orb_transformed = num_spins*orbital_map.firstOrbitalOnCenter(i_transformed)
        n_shells = orbital_map.numberOfShellsOnCenter(i)
        # Get the phase for translating the atom back into the unit cell
        kpoint_phase = numpy.exp(2.*numpy.pi*numpy.dot(-1*cell_vectors[i], kpoint)*1j)
        # Loop over the shell on the atom
        for j in range(n_shells):
            n_orb_shell = orbital_map.numberOfOrbitalsInShell(i_shell)
            l = n_orb_shell/2
            # The phase aquired under the inversion
            phase = (-1)**l*kpoint_phase
            # Loop over the orbitals in the shell
            for k in range(n_orb_shell):
                # Loop over the spins including
                for l in range(num_spins):
                    # Calculate the transformation of the eigenstate
                    c_inv[i_orb_transformed] = phase*c[i_orb]
                    i_orb += 1
                    i_orb_transformed += 1
            i_shell += 1
    # Return the coefficients of the transformed eigenstate
    return c_inv

def topologicalInvariant3D(bulk_configuration):
    """
    Function for calculating the topological invariant of a 3D inversion symmetric configuration

         @param bulk_configuration    : The configuration for which to calculate the topological invariant
         @type                        : BulkConfiguration

         @return                      : The topological invariant as [v0,v1,v2,v3 9
    """
    # Initialize
    band_sym_index = numpy.ones(8)
    # Loop over all the k-points needed for the topological invariant
    for ik,kpoint in enumerate(numpy.array([[0,0,0],[0.5,0,0],[0.0,0.5,0],[0.5,0.5,0],
                                            [0,0,0.5],[0.5,0,0.5],[0.0,0.5,0.5],[0.5,0.5,0.5]])):
        # Get Hamiltonian and Overlap in this kpoint
        H, S = calculateHamiltonianAndOverlap(bulk_configuration, kpoint=kpoint)
        H = H.inUnitsOf(eV)
        # Calculate the eigenfunctions
        w, v = scipy.linalg.eigh(H, S)
        # Get the density matrix calculator
        dmc = bulk_configuration.calculator()._densityMatrixCalculator()
        # Get the orbital map to know the orbitals on each center
        orbital_map = dmc.orbitalMap()
        # Get the transformation of the atoms under inversion symmetry
        atom_map, cell_vectors = createInversionTransformationMap(bulk_configuration)
        # Now get the number of electrons
        number_of_electrons = dmc.fermiDistribution().numberOfElectrons()
        # Now get the number of occupied states
        number_of_occupied_states = int(number_of_electrons/dmc.fermiDistribution().degeneracy())
        # Loop over the occupied states to get the topological invariant
        sym_indices = numpy.zeros(number_of_occupied_states)
        for i in range(number_of_occupied_states):
            c = v[:,i]
            c_inv= inversionOfEigenstate(c, S, atom_map, orbital_map, cell_vectors, kpoint)
            # Calculate the expectation value of the inversion symmetry
            sym_index = numpy.dot(numpy.conj(c_inv), numpy.dot(S,c))
            # Remove numerical inaccuracies
            sym_indices[i] =  numpy.rint(numpy.real(sym_index))
        # Calculate the Pfaffian
        band_sym_index[ik] = numpy.real(numpy.product(numpy.sqrt(sym_indices*(1+0j))))
    # Extract the topological invariant
    topological_invariant = numpy.ones(4)
    # Main invariant
    topological_invariant[0] = numpy.product(band_sym_index)
    # x plane
    topological_invariant[1] = numpy.product(band_sym_index[[1,3,5,7]])
    # y plane
    topological_invariant[2] = numpy.product(band_sym_index[[2,3,6,7]])
    # z plane
    topological_invariant[3] = numpy.product(band_sym_index[[4,5,6,7]])
    # Extract v = 0,1, where (-1)**v = x -> v = (1-x)/2,
    topological_invariant = numpy.array((1-topological_invariant)*0.5, dtype=int)

    return topological_invariant
