import matplotlib.pyplot as plt

# -------------------------------------------------------------
# Reference energies
# -------------------------------------------------------------
E_Li = nlread('Lithium.hdf5', TotalEnergy)[-1].evaluate().inUnitsOf(eV)
E_Li /= len(nlread('Lithium.hdf5', BulkConfiguration)[-1])
E_S  = nlread('a-Sulphur.hdf5', TotalEnergy)[-1].evaluate().inUnitsOf(eV)
E_S  /= len(nlread('a-Sulphur.hdf5', BulkConfiguration)[-1])

# -------------------------------------------------------------
# OCV
# -------------------------------------------------------------
OCV = []
energies = nlread('Li0.4S_relax.hdf5', TotalEnergy)
bulks    = nlread('Li0.4S_relax.hdf5', BulkConfiguration)
assert len(energies) == len(bulks)
for energy,bulk in zip(energies,bulks):
    e = energy.evaluate().inUnitsOf(eV)
    symbols = numpy.array(bulk.symbols())
    a = len(numpy.where(symbols=='S')[0])
    b = len(numpy.where(symbols=='Li')[0])
    if b == 0:
        composition = 0.0
    elif a == 0:
        composition = 1.0
    else:
        composition = 1.0*b/a
    # -------------------------------------------------------------
    # Compute OCV
    # -------------------------------------------------------------
    ocv = -(e - b*E_Li - a*E_S)/b
    print(ocv, composition)
    OCV.append(ocv)

# -------------------------------------------------------------
# Mean OCV
# -------------------------------------------------------------
mean = numpy.mean(OCV)
x = range(1,len(OCV)+1)

# -------------------------------------------------------------
# Plotting
# -------------------------------------------------------------
fig = plt.figure(figsize=(6,3))
ax = fig.add_subplot(111)
ax.plot(x, OCV, '-o', color='b', label='x=0.40')
ax.axhline(mean, ls='--', lw=2, color='k')
ax.set_xlim((0.8,6.2))
ax.set_ylim((2.21,2.23))
ax.legend(loc='upper right')
ax.set_xlabel('MD image')
ax.set_ylabel('Open-circuit voltage (Volt)')
plt.tight_layout()
plt.savefig('ocv.png')
plt.show()
