from QuantumATK import *
import pylab
# Custom analyzer for calculating the band gap of a bandstructure

#helper function to find minima
def fitEnergyMinimum (i_min, energies, k_points, nfit=3):
    """
    Function for fitting the energy minimum located around i_min.

    @param i_min   : approximate position of the energy minimum.
    @param energies: list of energies.
    @param k_points: list of k_points which correspond to the energies.
    @param nfit    : order of the polynomium.
    @return d2_e, e_min, k_min : second derivative of energy,
                                 minimum energy, minimum k_point.
    """
    #list of energies
    n = len(energies)
    efit = numpy.array([energies[(n+i_min-nfit/2+i)%n] for i in range(nfit)])
    kfit = numpy.array([i-nfit/2 for i in range(nfit)])
    #special cases
    if i_min == 0: #assume bandstructure symmetric around zero
        for i in range(nfit/2):
            efit[i] = energies[nfit/2-i]

    if i_min == n-1: #assume bandstructure symmetric around end point
        for i in range(nfit/2+1, nfit):
            efit[i] = energies[n-1+nfit/2-i]

    #make fit
    p = numpy.polyfit(kfit,efit,2)
    i_fit_min = -p[1]/2./p[0]
    pf = numpy.poly1d(p)
    e_min = pf(i_fit_min)
    i0 = int(i_fit_min+i_min+n)
    w = i_fit_min+i_min+n-i0
    k_min = (1-w)*k_points[i0%n]+w*k_points[(i0+1)%n]

    return p[0], e_min, k_min

def analyseBandstructure(bandstructure, spin):
    """
    Function for analysing a band structure and calculating bandgaps.

    @param bandstructure :  The bandstructure to analyze
    @param spin          :  Which spin to select from the bandstructure.
    @return  e_val, e_con, e_gap :  maximum valence band energy,
                                    minimum conduction band energy,
                                    and direct band gap.
    """
    energies = bandstructure.evaluate(spin=spin).inUnitsOf(eV)

    #some placeholder variable to help finding the extrema
    e_valence_max = -1.e10
    e_conduction_min = 1.e10
    e_gap_min = 1.e10
    i_valence_max = 0
    i_conduction_min = 0
    i_gap_min = 0
    n_valence_max = 0
    n_conduction_min = 0
    n_gap_min = 0

    # Locate extrema
    for i in range(energies.shape[0]):
        # find first state below Fermi level
        n = 0
        while n < energies.shape[1] and energies[i][n] < 0.0:
            n += 1

        # find maximum of valence band
        if (energies[i][n-1] > e_valence_max):
            e_valence_max = energies[i][n-1]
            i_valence_max = i
            n_valence_max = n-1
        # find minimum of conduction band
        if (energies[i][n] < e_conduction_min):
            e_conduction_min=energies[i][n]
            i_conduction_min=i
            n_conduction_min=n
        # find minimum band gap
        if (energies[i][n]-energies[i][n-1] < e_gap_min):
            e_gap_min = energies[i][n]-energies[i][n-1]
            i_gap_min = i
            n_gap_min = n-1

    # Print out results
    a_val, e_val, k_val = fitEnergyMinimum(i_valence_max,
                                           energies[:,n_valence_max],
                                           bandstructure.kpoints())
    print('Valence band maximum    %7.4f eV at [%6.4f, %6.4f,%6.4f]   ' \
                         %(e_val, k_val[0], k_val[1], k_val[2]))

    a_con, e_con, k_con = fitEnergyMinimum(i_conduction_min,
                                           energies[:,n_conduction_min],
                                           bandstructure.kpoints())
    print('Conduction band minimum %7.4f eV at [%6.4f, %6.4f,%6.4f]   ' \
                         %(e_con, k_con[0], k_con[1], k_con[2]))

    print('Fundamental band gap    %7.4f eV ' % (e_con-e_val))

    a_gap, e_gap, k_gap = fitEnergyMinimum(i_gap_min,
                           energies[:,n_gap_min+1]- energies[:,n_gap_min],
                           bandstructure.kpoints())

    print('Direct band gap         %7.4f eV at [%6.4f, %6.4f,%6.4f]   ' \
                        %(e_gap, k_gap[0], k_gap[1], k_gap[2]))
    return e_val, e_con, e_gap

def analyzer(filename, **args):
    """
    Find band gaps of band structures in netcdf file.
    """

    if filename == None:
        return

    #read in the bandstructure you would like to analyze
    bandstructure_list = nlread(filename, Bandstructure)
    if len(bandstructure_list) == 0 :
        print('No Bandstructures in file ', filename)
        return

    for s in [Spin.All]:
        b_list = []
        n = 0
        for b in bandstructure_list:
            print('Analyzing bandstructure number ', n)
            b_list = b_list + [analyseBandstructure(b,s)]
            print()
            print()
            n += 1

        x = numpy.arange(len(b_list))
        e_val =  numpy.array([b[0] for b in b_list])
        e_con =  numpy.array([b[1] for b in b_list])
        e_indirect = e_con-e_val
        e_direct = numpy.array([b[2] for b in b_list])


analyzer("si_100_nanowire.hdf5")
