import pylab
import sys
import matplotlib.pyplot as plt
import matplotlib.patches as ptc
from matplotlib import colors, ticker, cm


#-------------
# Introduction
#-------------
'''
In this script you need to define the .hdf5 file containing the Hartree 
difference potential from the HarteeDifferencePotential analysis as input file 
and set the name of the output file, along with providing the user defined parameters.
'''

#------------------------
# User defined parameters
#------------------------

# File containing the Hartree difference potentials
fname1 = "IVcurve_results.hdf5"

# Output filename
fname2 = 'hdp_avr'

# z-distance between atomic layers in the left and right electrode material in Ang (Determines the width
# of the Gaussian kernels used to calculate the average potential)
gauss_left = 2.26
gauss_right = 1.35

# z-coordinate of the interface position in Ang calculated as the midpoint between 
# the last atom of the left material and the first atom in the right material
zcoord_int = 27.17


#-----------
# Load files
#-----------

# Load the file containing the Hartee difference potential
iv = nlread(fname1, IVCharacteristics)[0]

#----------------------------------------------------------
# Calculate the averaged electrostatic difference potential
#----------------------------------------------------------

# Get voltages with units
x_list = iv.drainSourceVoltages()  # Already has units
g_list = iv.gateSourceVoltages()   # Already has units

# Now get potentials
for voltage in x_list:
    HDP_avr = iv.results(g_list[0], voltage, result_types=[HartreeDifferencePotential])[0]
 
    # Average the potential over the xy-plane
    c, pot_z = HDP_avr.axisProjection("average","c")
    pot_z = pot_z.inUnitsOf(eV)

    # c is the fractional coordinates of all grid points, these are converted to Cartesian coordinates for plotting
    c = numpy.array(c)
    z = c*HDP_avr.primitiveVectors()[2][2].inUnitsOf(Ang)
    dz = HDP_avr.volumeElement()[2][2].inUnitsOf(Ang) # The grid spacing in the z-direction

    # For each grid point, average over the z-distance between atoms (in a homogenerous material, 
    # this would be the same distance for all points, say the lattice constant)
    av_z = numpy.array([0.0]*len(pot_z))
    ix = numpy.where(z < zcoord_int)[0] # Left material
    av_z[ix] = gauss_left
    ix = numpy.where(z >= zcoord_int)[0] # Right material
    av_z[ix] = gauss_right

    # Number of grid points used to sample the Gaussians, based on the average averaging distance 
    a = numpy.average(av_z)
    n = 2*int(a/dz+1)
    
    # Gaussian kernels for performing average, one for each grid point
    index = numpy.arange(-n,n+1)
    kernels = [numpy.exp(-(index*dz/a)**2) for a in av_z]
    weights = [sum(kernel) for kernel in kernels] # Create weights for nomalization

    # Perform the averaging (n points in the grid are excluded at both ends of the grid)
    av_pot = numpy.array([ sum(pot_z[i-n:i+n+1]*kernels[i])/weights[i] for i in range(n, len(pot_z)-n)])
    r_pot = av_pot[len(pot_z)-2*n-1]
    zmax = z[n:len(pot_z)-n][-1]
    zmin = z[n:len(pot_z)-n][0]

#-------------------------
# Print the data to a file
#-------------------------
 
    nz = len(z[n:len(pot_z)-n])
    voltage_str = f"{float(voltage.inUnitsOf(Volt)):.2f}V"
    with open(f"hdp_avr_{voltage_str}.dat", "w") as out:
        for i in range(nz):
            out.write(f"{z[i+n]} {av_pot[i]-r_pot}\n")

#------------
# Plot figure
#------------

    plt.figure(figsize=(5,2))
    ax = pylab.subplot(111)
    
    emin = min(av_pot-r_pot)-0.5
    emax = max(av_pot-r_pot)+0.5

    ax.plot(z[n:len(pot_z)-n], av_pot-r_pot,color='blue',linewidth=0.75)
    
    ax.set_ylabel(r'$\langle\Delta$ $V_H\rangle$ (eV)', fontsize=10)
    ax.set_xlabel(r'Cell Length Z ($\AA$)', fontsize=10)
    ax.set_ylim(emin,emax)
    ax.set_xlim(zmin,zmax)

    ax.spines['top'].set_linewidth(0.5)
    ax.spines['bottom'].set_linewidth(0.5)
    ax.spines['right'].set_linewidth(0.5)
    ax.spines['left'].set_linewidth(0.5)
    ax.xaxis.set_tick_params(width=0.25,zorder=20)
    ax.yaxis.set_tick_params(width=0.25,zorder=20)
    ax.tick_params(direction='in', pad=7.5)

    plt.tight_layout()
    plt.savefig(f"hdp_avr_{voltage_str}.png",dpi=300)
