import pylab
# Set output verbosity to SilentLog
setVerbosity(SilentLog)

#+----------------------------------------------------------------------------+
#| Script to calculate the formation energy of a charged vacancy defect,      |
#| given the .hdf5 files of the defect calculations and a bulk calculation in   |
#| the same supercell.                                                        |
#| We apply the FNV correctional scheme, which is described in:               |
#| C. Freysoldt et al., PRL 102, 016402 (2009)                                |
#| doi: 10.1103/PhysRevLett.102.016402                                        |
#+----------------------------------------------------------------------------+

#+----------------------------------------------------------------------------+
#                          USER SPECIFIED PARAMETERS                          |
#+----------------------------------------------------------------------------+

# Filename of bulk calculations with the element removed at the vacancy
bulk_defect_element_filename = 'bulk-As.hdf5'

# Filename of bulk calculation in the same unit cell as the defects
bulk_filename = 'big-bulk-GaAs.hdf5'

# List of the different charge states and corresponding filenames, in the 
# same order
list_of_charges = [+1, 0, -1, -2, -3]
list_of_filenames = ['vac-+1-GaAs.hdf5', 'vac-0-GaAs.hdf5', 'vac--1-GaAs.hdf5', 
                     'vac--2-GaAs.hdf5','vac--3-GaAs.hdf5']

# Element at (or removed from) the defect as NL variable
element = Arsenic

# Static dielectric constant, as calculated with the same computational model.
# See the tutorial "Optical Properties of Silicon" for the method.
epsilon = 12.9

# Define the width of the gaussian used as a model charge distribution
# Default is 3 bohr, minimum is 0.5 bohr
width = 3.0

#+----------------------------------------------------------------------------+
#                         END OF USER INPUT SECTION                           |
#+----------------------------------------------------------------------------+

print('+------------------------------------------------------------------------------+')
print('|                                                                              |')
print('|  Calculation of formation energy of charge defects using the FNV correction  |')
print('|            See: C. Freysoldt et al., PRL 102, 016402 (2009)                  |')
print('|                                                                              |')
print('+------------------------------------------------------------------------------+')

# Get string of the element name
element_name = element.name()
str_element = str(element_name)

U = 2*Hartree * numpy.sqrt((1/width**2)/numpy.pi)

# Function to extract the valence band maximum from a Density of States
# of the pristine bulk structure
def get_vbm(bulk_filename):
    # Find the valence band maximum through the Density of States.
    dos = nlread(bulk_filename, DensityOfStates)[-1]
    data = dos.tetrahedronSpectrum()
    energies = dos.energies()
    valence = numpy.where(energies.inUnitsOf(eV)<0.0)
    for i in range(len(valence[0])):                                      
        if data[len(valence[0])-i].inUnitsOf(eV**-1)>0:
            e_valence_max = energies[len(valence[0])-i]
            break

    # We also need the Fermi level, as the above is relative to that
    fermi_level = dos.fermiLevel()

    return e_valence_max + fermi_level

# Function calculate the periodic correction term based on the FNV scheme
def calc_per_corr(bulk_filename, element, str_element, charge, width):
    # Put in some checks to make sure width and cutoff are reasonable
    if width < 0.5:
        width = 0.5

    cutoff = 1*(3.0/width)**2*Hartree

    if cutoff < 1*Hartree:
        cutoff = 1*Hartree

    # Read the bulk configuration
    bulk_configuration = nlread(bulk_filename, BulkConfiguration)[-1]

    # Setup a model configuration for the defect
    lattice = bulk_configuration.bravaisLattice()

    model_configuration = BulkConfiguration(
        bravais_lattice=lattice,
        elements = [element],
        fractional_coordinates = [[ 0.5,  0.5,  0.5]]
        )

    # Find the periodic long range potential from the defect
    elementBasis = getattr(HoffmannHuckelParameters, str_element + '_Basis')(onsite_hartree_shift=[ U , U ])
 
    poisson_solver = MultigridSolver(
        boundary_conditions=[[PeriodicBoundaryCondition,PeriodicBoundaryCondition],
                             [PeriodicBoundaryCondition,PeriodicBoundaryCondition],
                             [PeriodicBoundaryCondition,PeriodicBoundaryCondition]]
        )
    
    calculator = HuckelCalculator(
        basis_set=[elementBasis],
        numerical_accuracy_parameters=NumericalAccuracyParameters(density_mesh_cutoff=cutoff),
        charge=charge,
        poisson_solver=poisson_solver
        )
    model_configuration.setCalculator(calculator)
    
    # Quantities computed with periodic boundary conditions are denoted by tilde
    Vtilde_q_lr = ElectrostaticDifferencePotential(model_configuration)
    qtilded = ElectronDifferenceDensity(model_configuration)

    # Find "true" long range potential from the defect
    poisson_solver = MultigridSolver(
        boundary_conditions=[[MultipoleBoundaryCondition,MultipoleBoundaryCondition],
                             [MultipoleBoundaryCondition,MultipoleBoundaryCondition],
                             [MultipoleBoundaryCondition,MultipoleBoundaryCondition]]
        )

    calculator = HuckelCalculator(
        basis_set=[elementBasis],
        numerical_accuracy_parameters=NumericalAccuracyParameters(density_mesh_cutoff=cutoff),
        charge=charge,
        poisson_solver=poisson_solver
        )
    model_configuration.setCalculator(calculator)

    # Quantities without tilde are exact, within the model system
    V_q_lr = ElectrostaticDifferencePotential(model_configuration)
    qd = ElectronDifferenceDensity(model_configuration)

    ve=qd.volumeElement()
    volume = numpy.dot(numpy.cross(ve[0],ve[1]),ve[2]) 

    # Calculating the electrostatic interaction in the two cases
    sum_periodic = -0.5*(qtilded*Vtilde_q_lr)[:,:,:].sum()*volume
    sum_isolated = -0.5*(qd*V_q_lr)[:,:,:].sum()*volume
    # The result is the difference between the two interactions, corrected for
    # the screening of the material, as described by the static dielectric 
    # constant
    result = -(sum_periodic-sum_isolated).inUnitsOf(Volt)/epsilon*eV

    return [charge, result]


# Function to calculate the band offset resulting from introduction of the 
# defect
def calc_band_corr(bulk_filename, defect_filename, element, str_element):
    # Read the bulk electrostatic potential
    bulk_configuration = nlread(bulk_filename, BulkConfiguration)[-1]
    V_bulk = ElectrostaticDifferencePotential(bulk_configuration)

    cutoff = bulk_configuration.calculator().numericalAccuracyParameters().densityMeshCutoff()

    # Read the electrostatic potential from the defect configuration
    vac_configuration = nlread(defect_filename, BulkConfiguration)[-1]
    V_defect_q = ElectrostaticDifferencePotential(vac_configuration)
    vac_index = vac_configuration.ghostAtoms()[0]
    vac_pos = vac_configuration.fractionalCoordinates()[vac_index]
    # Create the difference potential
    V_qb = V_defect_q-V_bulk

    # Setup model configuration for calculating the long range potential
    lattice=bulk_configuration.bravaisLattice()
    model_configuration = BulkConfiguration(
        bravais_lattice=lattice,
        elements=[element],
        fractional_coordinates=[[ vac_pos[0],  vac_pos[1],  vac_pos[2]]]
        )
    elementBasis = getattr(HoffmannHuckelParameters, str_element + '_Basis')(onsite_hartree_shift=[ U , U ])

    calculator = HuckelCalculator(
        basis_set=[elementBasis],
        numerical_accuracy_parameters=NumericalAccuracyParameters(density_mesh_cutoff=cutoff),
        charge=charge,
        )
    model_configuration.setCalculator(calculator)

    # The long range electrostatic potential from the model defect configuration
    V_q_lr = 1./epsilon*ElectrostaticDifferencePotential(model_configuration)

    # Finding the short range potential as the difference between the long range
    # model potential, and the difference potential from the bulk and neutral 
    # defect configurations 
    V_qb_sr = V_qb-V_q_lr

    # Find value of the short range potential half a unitcell from the defect, to get the overall
    # shift due to introduction of the defect. It is averaged over the four grid
    # points around the relevant point (five in total) to smooth out local irregularities
    V_qb_sr_mean_z = V_qb_sr.axisProjection('average', 'c')

    V_qb_mean_z = V_qb.axisProjection('average', 'c')
    V_q_lr_mean_z = V_q_lr.axisProjection('average', 'c')

    # We plot the potentials (maybe mostly relevant for debugging)
    fig_band = pylab.figure(0)
    pylab.plot(range(len(V_qb_sr_mean_z[1])),V_qb_sr_mean_z[1],color='k',linewidth=2.0, label='V$_{q/b}$$^{sr}$')
    pylab.plot(range(len(V_qb_mean_z[1])),V_qb_mean_z[1],color='r',linewidth=2.0, label='V$_{q/b}$')
    pylab.plot(range(len(V_q_lr_mean_z[1])),V_q_lr_mean_z[1],color='g',linewidth=2.0, label='V$_{q}$$^{lr}$')
    pylab.legend()
    pylab.savefig('potentials' + str(charge)+'.png', format='png')
    pylab.clf()

    n0=V_qb_sr_mean_z[1].shape[0]
    V_mean = 0
    for i in range(5):
        n = int(n0/2.0 + n0*vac_pos[2] + i-2)
        if n < n0:
            V_mean += V_qb_sr_mean_z[1][n]
        elif n >= n0:
            V_mean += V_qb_sr_mean_z[1][n-n0]
    V_mean /= 5.0
    deltaV=V_qb_sr_mean_z[1].sum()/n-V_mean

    return -charge*V_mean.convertTo(Volt)/Volt*eV

# Get the energy at the valence band maximum from the bulk calculation
E_V = get_vbm(bulk_filename)
print('+------------------------------------------------------------------------------+')
print('| Absolute energy of valence band maximum: %.2f eV' %(E_V.inUnitsOf(eV)))

# Go through the list of charges and corresponding files, calculating the
# corrections for each and the full corrected formation energy
for charge,filename in zip(list_of_charges,list_of_filenames):
    print('+------------------------------------------------------------------------------+')
    print('| Defect charge state: %i' %(charge))

    defect = nlread(filename, TotalEnergy)[-1]
    E_defect = defect.evaluate().inUnitsOf(eV)

    big_bulk = nlread(bulk_filename, TotalEnergy)[-1]
    E_big_bulk = big_bulk.evaluate().inUnitsOf(eV)

    defect_element_configuration = nlread(bulk_defect_element_filename, BulkConfiguration)[-1]
    defect_bulk_energy = nlread(bulk_defect_element_filename, TotalEnergy)[-1]
    e2 = defect_bulk_energy.evaluate().inUnitsOf(eV)/defect_element_configuration.numberOfAtoms()

    # Calculate the term depending only on the total energies
    E_f_total_energies = E_defect - E_big_bulk + e2

    # print '| Formation energy from total energies:', E_f_total_energies + charge*E_V.inUnitsOf(eV), 'eV'
    print('| Formation energy from total energies: %.2f eV' %(E_f_total_energies + charge*E_V.inUnitsOf(eV)))

    # Calculate the correction due to periodic images of charges
    E_corr_periodic = calc_per_corr(bulk_filename, element, str_element, charge, width)[1]
    print('| Correction for periodic interaction: %.2f eV' %(E_corr_periodic.inUnitsOf(eV)))

    # Calculate the correction due to band-shifting introduced by the defect
    E_corr_bands = calc_band_corr(bulk_filename, filename, element, str_element)
    print('| Correction for band shifts due to defect: %.2f eV' %(E_corr_bands.inUnitsOf(eV)))

    # Calculate the total formation energy
    E_formation = E_f_total_energies + charge*E_V.inUnitsOf(eV) + E_corr_periodic.inUnitsOf(eV) + E_corr_bands.inUnitsOf(eV)
    print('| Fully corrected formation energy: %.2f eV' %(E_formation))

    # Make a plot with the results
    x = numpy.linspace(0, 1.5, 10)
    pylab.figure(1)
    if charge==0:
        pylab.plot([min(x),max(x)],[E_formation,E_formation], linewidth=2.0, linestyle='-',
                 label='V$_{As}$$^{%d}$'%charge)
    elif charge==1:
        y= E_formation + charge*(x)
        pylab.plot(x, y, linewidth=2.0, linestyle='-', label='V$_{As}$$^{+%d}$'%charge)
    else:
        y= E_formation + charge*(x)
        pylab.plot(x, y, linewidth=2.0, linestyle='-', label='V$_{As}$$^{%d}$'%charge)

pylab.title("As vacancies")
pylab.xlim([min(x),max(x)])
pylab.ylim([1,5])
pylab.grid(True)
pylab.xlabel("$\mu_e$ (eV)")
pylab.ylabel("Formation energy (eV)")
pylab.legend()
pylab.savefig('vacancy-As-mue.png',format='png')
print('+------------------------------------------------------------------------------+')