from QuantumATK import *
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import math

# Read the complex bandstructure object from the NC file
cbs = nlread('si_100_cbs.nc', ComplexBandstructure)[-1]
energies = cbs.energies().inUnitsOf(eV)
k_real, k_complex = cbs.evaluate()
L = cbs.layerSeparation()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# First plot the real bands
kvr = numpy.array([])
e = numpy.array([])

# Loop over energies, and pick those where we have solutions
for (j, energy) in enumerate(energies):
    k = k_real[j]*L/math.pi
    if len(k)>0:
        e = numpy.append(e,[energy,]*len(k))
        kvr = numpy.append(kvr,k)
        
# Plot the bands with red
ax.scatter([0.0,]*len(kvr), kvr, e, c='r', marker='o', linewidths=0, s=10)

# Next plot the complex bands
kvr = []
kvi = []
e = []

# Again loop over energies and pick solutions
for (j, energy) in enumerate(energies):
    if len(k_complex[j])>0:
        for x in numpy.array(k_complex[j]*L/math.pi):
            kr = numpy.abs(x.real)
            ki = -numpy.abs(x.imag)
            # Discard rapidly decaying modes which clutter the plot
            if ki>-0.3:
                e += [energy]
                kvr += [kr]
                kvi += [ki]

# Plot the complex bands with blue
ax.scatter(kvi, kvr, e, c='b', marker='o', linewidths=0, s=10)

# Put on labels
ax.set_xlabel('$\kappa$ (1/Ang)')
ax.set_ylabel('$kL/\pi$')
ax.set_zlabel('Energy / eV')

plt.show()
