import pylab
import sys
import matplotlib.pyplot as plt
import matplotlib.patches as ptc
from matplotlib import colors, ticker, cm


#-------------
# Introduction
#-------------
'''
This script needs to be run with the .hdf5 file containing the Hartree 
difference potential from the HarteeDifferencePotential analysis and the projected 
local density of states from the ProjectedLocalDensityOfStates analysis.

Example of run:
atkpython pldos_hdp.py device.hdf5
'''

#------------------------
# User defined parameters
#------------------------

# Name of the left and right electrode material
leftname = 'Ag' 
rightname = 'Si' 

# z-distance between atomic layers in the left and right electrode material in Ang (Determines the width
# of the Gaussian kernels used to calculate the average potential)
gauss_left = 2.275
gauss_right = 5.406

# z-coordinate of the interface position in Ang calculated as the midpoint between 
# the last atom of the left material and the first atom in the right material
zcoord_int = 27.2734

# Lenght of the left material shown in the plot in Ang
lenleft = 10

#-----------
# Load files
#-----------

fname = sys.argv[1]

# Load the PLDOS
pldos_data = nlread(fname, ProjectedLocalDensityOfStates)[0]

E = pldos_data.energies() # Energy range
emax = float(max(E)) # maximal energy of PLDOS
emin = float(min(E)) # minimal energy of PLDOS

Z = pldos_data.zSlicing() # z-grid
Z, E = numpy.meshgrid(numpy.array(Z),numpy.array(E))

pldos = numpy.array(pldos_data.evaluate()) # PLDOS

# Load the electrostatic difference potential
pot = nlread(fname, HartreeDifferencePotential)[0]

#-----------------------------------------------------
# Find the band gap region and conduction band minimum
#-----------------------------------------------------

# Threshold under which the DOS is considered zero (Band gap region)
thr = 0.01

# Find the energy range of the band gap and the conduction band minimum
zeros = numpy.where(pldos[:,-1]<thr)
for i in zeros:
    gap_energies = pldos[i,-1]
CB_edge = gap_energies[-1]

#----------------------------------------------------------
# Calculate the averaged electrostatic difference potential
#----------------------------------------------------------

# Average the potential over the xy-plane
c, pot_z = pot.axisProjection("average","c")
pot_z = pot_z.inUnitsOf(eV)


# c is the fractional coordinates of all grid points, these are converted to Cartesian coordinates for plotting
c = numpy.array(c)
z = c*pot.primitiveVectors()[2][2].inUnitsOf(Ang)
dz = pot.volumeElement()[2][2].inUnitsOf(Ang) # The grid spacing in the z-direction

# For each grid point, average over the z-distance between atoms (in a homogenerous material, 
# this would be the same distance for all points, say the lattice constant)
av_z = numpy.array([0.0]*len(pot_z))
ix = numpy.where(z < zcoord_int)[0] # Left material
av_z[ix] = gauss_left
ix = numpy.where(z >= zcoord_int)[0] # Right material
av_z[ix] = gauss_right

# Number of grid points used to sample the Gaussians, based on the average averaging distance 
a = numpy.average(av_z)
n = 2*int(a/dz+1)

# Gaussian kernels for performing average, one for each grid point
index = numpy.arange(-n,n+1)
kernels = [numpy.exp(-(index*dz/a)**2) for a in av_z]
weights = [sum(kernel) for kernel in kernels] # Create weights for nomalization

# Perform the averaging (n points in the grid are excluded at both ends of the grid)
av_pot = numpy.array([ sum(pot_z[i-n:i+n+1]*kernels[i])/weights[i] for i in range(n, len(pot_z)-n)])
r_pot = av_pot[len(pot_z)-2*n-1]
zmax = z[n:len(pot_z)-n][-1]
zmin = z[n:len(pot_z)-n][0]

# Do an estimation of the Schottky barrier
schottky = max(av_pot-r_pot)

print("+------------------------------------------------------------------------------+")
print()
print("The conduction band minimum in the right electrode:")
print()
print("CB min =", CB_edge, "eV")
print()
print("Estimated Schottky barrier as the difference between the chemical potential of ")
print("the left electrode and the maximum of the averaged Hartee difference ")
print("potential:")
print()
print("Schottky barrier =",schottky*1000,"meV")
print()
print("+------------------------------------------------------------------------------+")

#------------
# Plot figure
#------------

plt.figure(figsize=(6,2))
ax = pylab.subplot(111)

# Contourplot of the PLDOS
levels = numpy.linspace(0,numpy.amax(pldos)/10.,100)
ax.contourf(Z,E,pldos, levels=levels, locator=ticker.LogLocator(), cmap=cm.copper)
ax.plot([0,zmax],[0,0],color='white',linestyle=':',linewidth=0.5)

ax.set_ylabel(r'$E$-$\mathrm{\mu_{L}}$ (eV)', fontsize=10)
ax.set_xlabel(r'Cell Length Z ($\AA$)', fontsize=10)
ax.set_ylim(emin,emax)
ax.set_xlim(zcoord_int-lenleft,zmax)

ax.spines['top'].set_linewidth(0.5)
ax.spines['bottom'].set_linewidth(0.5)
ax.spines['right'].set_linewidth(0.5)
ax.spines['left'].set_linewidth(0.5)
ax.xaxis.set_tick_params(width=0.25,zorder=20)
ax.yaxis.set_tick_params(width=0.25,zorder=20)
ax.tick_params(direction='in', pad=7.5)

# Plot the average Hartree difference potential
ax2 = ax.twinx()
ax.plot(z[n:len(pot_z)-n], av_pot-r_pot+CB_edge,color='deepskyblue',linewidth=0.75)
ax.plot([zcoord_int,zcoord_int],[emin,emax],color='black',linestyle='-',linewidth=0.5)

ax2.set_ylabel(r'$E$-$\mathrm{\chi}$-$\mathrm{\mu_{L}}$ (eV)', fontsize=10)
ax2.yaxis.label.set_color('deepskyblue')
ax2.tick_params(axis='y', labelcolor='deepskyblue',pad=7.5)
ax2.set_ylim(emin,emax)
ax2.set_xlim(zcoord_int-lenleft,zmax)

ax2.spines['top'].set_linewidth(0.5)
ax2.spines['bottom'].set_linewidth(0.5)
ax2.spines['right'].set_linewidth(0.5)
ax2.spines['left'].set_linewidth(0.5)
ax2.xaxis.set_tick_params(width=0.25,zorder=20)
ax2.yaxis.set_tick_params(width=0.25,zorder=20)
ax2.tick_params(direction='in', pad=7.5)

# Plot the illustation of the interface
# Right square
ax2.add_patch(
    ptc.Rectangle(
        (zcoord_int, emax-0.4), # (x,y)
        zmax-zcoord_int,        # width
        0.4,                    # height
        facecolor=(0.941176,0.784314,0.627451), #RBG color
        linewidth=0.5,
        zorder=3 # Draw as 3rd layer
    )
)
# Left square
ax2.add_patch(
    ptc.Rectangle(
        (zcoord_int-lenleft, emax-0.4),
        lenleft,                   
        0.4,                  
        facecolor=(0.752941,0.752941,0.752941),
        linewidth=0.5,
        zorder=3
    )
)

t = plt.text(zcoord_int-(lenleft/2+1), emax-0.3, leftname, fontsize = 6, color = 'black',zorder=4)
t = plt.text(zcoord_int+(lenleft/2-1), emax-0.3, rightname, fontsize = 6, color = 'black',zorder=4)

plt.tight_layout()
plt.savefig("ddos_hdp.png",dpi=300)
