import matplotlib.pyplot as plt

compositions = [0.05,0.1,0.2,0.3,0.4,0.6,0.8,1.0,1.2,1.4,1.6,1.8,2.0]

# -------------------------------------------------------------
# Reference energies
# -------------------------------------------------------------
E_Li = nlread(path+'Lithium.hdf5', TotalEnergy)[-1].evaluate().inUnitsOf(eV)
E_Li /= len(nlread(path+'Lithium.hdf5', BulkConfiguration)[-1])
E_S  = nlread(path+'a-Sulphur.hdf5', TotalEnergy)[-1].evaluate().inUnitsOf(eV)
E_S  /= len(nlread(path+'a-Sulphur.hdf5', BulkConfiguration)[-1])

# -------------------------------------------------------------
# OCV
# -------------------------------------------------------------
profile = []
for composition in compositions:
    hdf5 = 'x%.2f.hdf5' % composition
    OCV = []
    energies = nlread(path+hdf5, TotalEnergy)
    bulks    = nlread(path+hdf5, BulkConfiguration)[2:]
    assert len(energies) == len(bulks)
    for energy,bulk in zip(energies,bulks):
        e = energy.evaluate().inUnitsOf(eV)
        symbols = numpy.array(bulk.symbols())
        a = len(numpy.where(symbols=='S')[0])
        b = len(numpy.where(symbols=='Li')[0])
        if b == 0:
            composition = 0.0
        elif a == 0:
            composition = 1.0
        else:
            composition = 1.0*b/a
        # -------------------------------------------------------------
        # Compute OCV
        # -------------------------------------------------------------
        ocv = -(e - b*E_Li - a*E_S)/b
        OCV.append(ocv)
    # -------------------------------------------------------------
    # Mean OCV
    # -------------------------------------------------------------
    mean = numpy.mean(OCV)
    profile.append(mean)
    print("x=%.1f %.2f Volt" % (composition, mean))

# -------------------------------------------------------------
# Plotting
# -------------------------------------------------------------
fig = plt.figure(figsize=(7,4))
ax = fig.add_subplot(111)
ax.plot(compositions, profile, '-o', lw=1.5)
ax.arrow(0.7,1.7,0.5,0.0,lw=2,head_width=0.05,head_length=0.05,fc='k',ec='k')
ax.text(0.82, 1.75, 'Discharge')
ax.set_xlim((0,2.05))
ax.set_ylim((1.6,2.801))
ax.set_xticks(compositions)
ax.set_xticklabels(compositions, rotation=90)
ax.set_xlabel('x')
ax.set_ylabel('Open-circuit voltage (Volt)')
plt.tight_layout()
plt.savefig('ocv_profile.png')
plt.show()