from QuantumATK import *
import pylab
import sys

filename = "au_vacuumgap.hdf5"

potentials = nlread(filename, ElectrostaticDifferencePotential)
densities = nlread(filename, ElectronDifferenceDensity)
configuration_list = nlread(filename, DeviceConfiguration, read_state=False)

# Make a fingerprint/bias lookup table
bias_dict = {}
for configuration in configuration_list:
    fingerprint = configuration.calculator()._fingerPrint()
    voltages = configuration.calculator().electrodeVoltages()
    bias_dict[fingerprint] = abs(voltages[1].inUnitsOf(Volt)-voltages[0].inUnitsOf(Volt))
    
# Get the bias used for each object
voltage_list_e = [bias_dict[i._fingerPrint()] for i in potentials]
voltage_list_d = [bias_dict[i._fingerPrint()] for i in densities]

# Sort all lists by bias!
sort_order_e = numpy.argsort(voltage_list_e)
sort_order_d = numpy.argsort(voltage_list_d)
biases = numpy.take(voltage_list_d,sort_order_d,axis=0)
potentials = numpy.take(potentials,sort_order_e,axis=0)
densities = numpy.take(densities,sort_order_d,axis=0)

# Compute the induced potential at 1 V bias
induced_potential = potentials[-1]-potentials[0]
induced_density = densities[-1]-densities[0]

# Average potential along z
v_z = numpy.apply_over_axes(numpy.mean,induced_potential[:,:,:],[0,1]).flatten()
v_z *= induced_potential[:,:,:].unit()

# Integrated density in x,y
shape = induced_density.shape()
dX, dY, dZ = induced_density.volumeElement().inUnitsOf(Ang)
dAYZ = numpy.linalg.norm(numpy.cross(dY,dZ))
n_z = numpy.array([induced_density[:,:,i].sum() * dAYZ for i in range(shape[2])])
n_z *= induced_density[:,:,:].unit()*Ang**2

# Get the z-values
dz = numpy.linalg.norm(dZ)
z = dz*numpy.arange(shape[2])*Ang

# Find the electric field in the gap, i.e. around z = 20-24 Angstrom
index = range(int(20/dz),int(24/dz)+1)
zv = z[index].inUnitsOf(Ang)
vv = v_z[index].inUnitsOf(eV)

# Fit a polynomial
M = numpy.array([ zv, numpy.ones(len(vv))])
# linearly generated sequence
w = numpy.linalg.lstsq(M.T,vv)[0]
E = w[0] *Volt/Ang

print('Electric field = ', E)
# Calculate distance between image planes from E
d = 1*Volt/E
print('Image plane distance = ', d)

f, axarr = pylab.subplots(2, sharex=True)
axarr[0].set_title('Induced Density at 1 Volt')
axarr[1].set_title('Induced electrostatic potential at 1 Volt')
axarr[1].set_xlabel('z (Angstrom)')
axarr[0].plot(z.inUnitsOf(Ang), n_z.inUnitsOf(Ang**-1))
axarr[0].set_ylabel('n_z (1/Angstrom)')
axarr[1].plot(z.inUnitsOf(Ang), v_z.inUnitsOf(eV))
axarr[1].set_ylabel('v_z (eV)')

# Plot atom positions
coords = configuration_list[0].cartesianCoordinates().inUnitsOf(Ang)
for i in range(len(coords)):
    axarr[0].plot(coords[i][2],0,'bo',markersize=8)
    axarr[1].plot(coords[i][2],0,'bo',markersize=8)
 
pylab.savefig("%s_induced.png" % filename.replace(".hdf5",""), dpi=120)
pylab.show()
