from spinVector import blochStateSpinVector

# -------------------------------------------------------------
# Reading Bi2Se3 slab configuration
# -------------------------------------------------------------
configuration = nlread('Bi2Se3_slab.nc', BulkConfiguration)[1]

# -------------------------------------------------------------
# Preparing pyplot
# -------------------------------------------------------------
import matplotlib.pyplot as plt
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax2 = ax1.twinx()

# -------------------------------------------------------------
# Computing bloch states for electronic band indices 144 and 145.
# Both are surface states just above the Fermi level.
# -------------------------------------------------------------
for band,line_type in zip([144,145],['-','--']):
    # -------------------------------------------------------------
    # Bloch State
    # -------------------------------------------------------------
    bloch_state = BlochState(
        configuration=configuration,
        quantum_number=band,
        k_point=[0, 0.04, 0],
        )

    # -------------------------------------------------------------
    # Evaluating
    # -------------------------------------------------------------
    # Lattice vectors
    lattice = configuration.bravaisLattice()
    vectors = lattice.conventionalVectors().inUnitsOf(Angstrom)
    cell_length = numpy.asarray([vectors[2,2]])

    # Atomic coordinates
    coordinates = configuration.cartesianCoordinates().inUnitsOf(Angstrom)
    c_list = coordinates[:,-1]
    elements = configuration.elements()
    c1 = []
    c2 = []
    for i,element in enumerate(elements):
        if element.name() == 'Selenium':
            c1.append(c_list[i])
        else:
            assert element.name() == 'Bismuth'
            c2.append(c_list[i])

    # Spin vector of Bloch state
    r, theta, phi, c_relative = blochStateSpinVector(bloch_state)

    # -------------------------------------------------------------
    # Plotting
    # -------------------------------------------------------------
    x = cell_length[0]*c_relative
    ax1.plot(x, r,     line_type, color='r', label='r')   
    ax2.plot(x, theta, line_type, color='b', label=r'$\theta$')
    ax2.plot(x, phi,   line_type, color='g', label=r'$\phi$')
    text = "band %i" % band
    q = numpy.argmax(r)
    if band == 144:
        lg1 = ax1.legend(loc='upper left')
        lg2 = ax2.legend(loc='upper right')
        lg1.draw_frame(False)
        lg2.draw_frame(False)
        xy = (x[q]+1, 1500)
    else:
        xy = (x[q]-6.5, 1500)
    ax1.annotate(text, xy=xy)

# -------------------------------------------------------------
# finalizing plot
# -------------------------------------------------------------
ax2.plot([x[0],x[-1]], [0,0], 'k--')
ax2.plot(c1, [20]*len(c1), 'o', color='orange')
ax2.plot(c2, [20]*len(c2), 'o', color='purple')

ax2.set_ylim([0,360])
ax2.set_yticks(numpy.arange(0, 390, 30))
ax1.set_xlabel('C-direction (Angstrom)')
ax1.set_ylabel('r (a.u.)')
ax2.set_ylabel('Angles (Degrees)')
plt.show()
