from QuantumATK import *
from NL.ComputerScienceUtilities.Functions import numpyToComplexVector
from NL.CommonConcepts.PoissonSolvers.PoissonSolverTools import vectorBoundaryConditions
import scipy

def vectorToGrid(vector, configuration, kpoint=(0.0, 0.0, 0.0), mesh_cutoff=75*Hartree):
    """ Utility function to put a basis vector, v_i, onto the 3-d grid
        through the sum_i v_i phi_i(r) """
    # Create cpp objects
    vector_cpp = numpyToComplexVector(vector)
    kpoint_cpp = NLEngine.Cartesian3D(*kpoint)

    # Get some internal QuantumATK objects
    calculator = configuration.calculator()
    density_matrix_calculator = calculator._densityMatrixCalculator()
    builder = calculator._builder(configuration)
    # Make the grid descriptor
    poisson_solver = calculator().poissonSolver()
    # Extract the boundary_conditions for the electrostatic calculator.
    boundary_conditions = poisson_solver.boundaryConditions()
    # Calculate the lattice distance of the grid.
    delta = math.pi/(mesh_cutoff.inUnitsOf(Units.Ry))**0.5

    # Setup the grid descriptor.
    grid_descriptor = NLEngine.GridDescriptor(delta,
                                              density_matrix_calculator.configuration().unitCell(),
                                              vectorBoundaryConditions(boundary_conditions),
                                              True)

    matrix_calculator = builder.createGridTool(
        density_matrix_calculator.orbitalMap(),
        density_matrix_calculator.neighbourlist(),
        grid_descriptor,
    )

    # Now generate the grid as superposition of basis functions
    grid = NLEngine.superPosition(grid_descriptor,
                              density_matrix_calculator.configuration(),
                              density_matrix_calculator.neighbourlist(),
                              vector_cpp,
                              density_matrix_calculator.orbitalMap(),
                              matrix_calculator.basisSet(),
                              kpoint_cpp)



    # Return the data on the GridValues object.
    return GridValues(grid, Units.Bohr**-(3.0/2.0))

def scatteringStates(device_configuration, energy, kpoint=(0.0, 0.0, 0.0)):
    """ Utility function to calculate the scattering states according to  PRB 76, 115117 (2007) """

    # Calculate the Retarded Green Function
    Gr = calculateRetardedGreenFunction(device_configuration, energy, kpoint)
    Gr = Gr.inUnitsOf(eV**-1)

    # Calculate SelfEnergies
    Sigma_L = calculateSelfEnergy(device_configuration, energy, kpoint,
                                  contribution=Left)
    Sigma_L = Sigma_L.inUnitsOf(eV)
    Sigma_R = calculateSelfEnergy(device_configuration, energy, kpoint,
                                  contribution=Right)
    Sigma_R = Sigma_R.inUnitsOf(eV)
    # Calculate Gamma's of left and right electrodes
    G_L = 1j*(Sigma_L-numpy.conj(Sigma_L.transpose()))
    G_R = 1j*(Sigma_R-numpy.conj(Sigma_R.transpose()))

    # Increase shape of Gamma from electrode size to full device size
    Gamma_L = 0*Gr
    n = G_L.shape[0]
    Gamma_L[:n,:n] = G_L

    Gamma_R = 0*Gr
    n = G_R.shape[0]
    Gamma_R[-n:,-n:] = G_R

    # Get H and S
    H, S = calculateHamiltonianAndOverlap(device_configuration, kpoint)
    H = H.inUnitsOf(eV)

    # Calculate left Green function, Eq. (13)
    AL = numpy.dot(Gr, numpy.dot(Gamma_L,numpy.conj(Gr.transpose())))

    # Calculate S^(1/2) and S^(-1/2) matrices
    S_sqr2 = scipy.linalg.sqrtm(S)
    S_invsqr2 = numpy.linalg.inv(S_sqr2)

    # Orthogonal left Green function
    ABar_L = numpy.dot(S_sqr2, numpy.dot(AL, S_sqr2))

    # Eq. (27)
    lamb, U = numpy.linalg.eig(ABar_L)
    # Eq. (28)
    Utilde = numpy.dot(U, numpy.diagflat(numpy.sqrt(1./(2.*numpy.pi)*lamb)))
    # Orthogonal gamma right
    GammaBar_R = numpy.dot(S_invsqr2, numpy.dot(Gamma_R, S_invsqr2))

    # Eq. (29)
    TM = 2.*numpy.pi*numpy.dot(numpy.conj(Utilde.transpose()), numpy.dot(GammaBar_R, Utilde))
    # Eq. (30)
    T, c = numpy.linalg.eig(TM)
    # Eq. (31)
    Pd = numpy.dot(S_invsqr2, numpy.dot(Utilde, c))
    return T, Pd


def averageFermiLevel(device_configuration):
    """Extract the average fermi level from the configuration """
    # Get chemical potentials
    chemical_potential_class = ChemicalPotential(device_configuration)
    chemical_potentials = chemical_potential_class.evaluate()
    # Calculate average Fermi level
    average_fermi_level = 0.5*numpy.sum(chemical_potentials)
    return average_fermi_level