from spinVector import blochStateSpinVector
from NL.CommonConcepts.Configurations.Utilities import cartesian2fractional, fractional2cartesian
import itertools

# -------------------------------------------------------------
# Reading Bi2Se3 slab configuration
# -------------------------------------------------------------
configuration = nlread('Bi2Se3_slab.nc', BulkConfiguration)[1]
lattice = configuration.bravaisLattice()

# -------------------------------------------------------------
# Create cartesian k-grid around the Gamma point
# -------------------------------------------------------------
# Generate a k-point grid in Cartesian kx,ky around kG.
# Ensure to have a point at G by choosing Nk odd!
Nk = 51
k_max = 0.1
x = numpy.linspace(-k_max, k_max, Nk)
y = numpy.linspace(-k_max, k_max, Nk)
kXkY = numpy.array(list(itertools.product(y,x)))
kXkYkZ = numpy.vstack((kXkY[:,0], kXkY[:,1], numpy.array([0,]*len(kXkY)))).transpose()

# Convert to fractional kx,ky coordinates for evaluation.
kAkB = cartesian2fractional(kXkYkZ*Ang**-1, lattice.reciprocalVectors())

# -------------------------------------------------------------
# Compute eigenenergies in these k-points
# -------------------------------------------------------------
bandstructure = Bandstructure(configuration, kpoints=kAkB)

# -------------------------------------------------------------
# Read the bandstructure eigenenergies and kpoints
# -------------------------------------------------------------
energies = bandstructure.evaluate().inUnitsOf(eV)
kpoints_fractional = bandstructure.kpoints()
tmp = fractional2cartesian(kpoints_fractional, lattice.reciprocalVectors())
kpoints_cartesian = tmp.inUnitsOf(Angstrom**-1)

# -------------------------------------------------------------
# Extract band above the Fermi level and create k-grid
# -------------------------------------------------------------
energies = energies[:,144].reshape(Nk,Nk)
kx = kpoints_cartesian[:,0].reshape(Nk,Nk)
ky = kpoints_cartesian[:,1].reshape(Nk,Nk)

# -------------------------------------------------------------
# Plotting Fermi surfaces as a contour plot 
# -------------------------------------------------------------
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch

plt.figure(figsize=(8,8))
CS = plt.contour(kx, ky, energies, 10)
plt.clabel(CS, CS.levels, inline=1, fontsize=10)
plt.plot(0,0,'ro')
ax = plt.gca()

# -------------------------------------------------------------
# Extract contour line 2 for spin directions at E_F=0.15 eV
# -------------------------------------------------------------
isolines = []
levels   = []
for (collection, level) in zip(CS.collections, CS.levels):
    for path in collection.get_paths():
        v = path.vertices
        isolines.append(numpy.array(v).copy())
        levels.append(level)
fermi_surface = isolines[2]

# -------------------------------------------------------------
# For every 10th point on the Fermi surface:
# - compute the Bloch state
# - compute the corresponding spin vector
# - plot a 2D vector illustrating the spin angle phi
# -------------------------------------------------------------
for p in fermi_surface[::10]:
    # -------------------------------------------------------------
    # Bloch State
    # -------------------------------------------------------------
    kpt = numpy.append(p, 0.0)
    kpoint = cartesian2fractional(kpt*Ang**-1, lattice.reciprocalVectors())
    bloch_state = BlochState(
        configuration=configuration,
        quantum_number=144,
        k_point=kpoint,
        )

    # -------------------------------------------------------------
    # Spin vector
    # -------------------------------------------------------------
    r, theta, phi, c_relative = blochStateSpinVector(bloch_state)
    # use angles at c-axis position where r=max(r)
    i = numpy.argmax(r)
    r = r[i]
    theta = theta[i]
    phi = phi[i]

    # construct vector of length 0.02 with origin at the relevant kpoint
    r = 0.02
    x = r*numpy.cos(phi*numpy.pi/180.)
    y = r*numpy.sin(phi*numpy.pi/180.)

    # -------------------------------------------------------------
    # Plotting
    # -------------------------------------------------------------
    ax.plot(p[0], p[1], 'ko')
    p1 = (p[0], p[1])
    p2 = (p[0]+x, p[1]+y)
    arrow = FancyArrowPatch(posA=p1, posB=p2, mutation_scale=20, lw=2,
                            arrowstyle="-|>", color='k')
    ax.add_patch(arrow)

# -------------------------------------------------------------
# finalizing plot
# -------------------------------------------------------------
plt.axis('equal')
plt.xlabel(r'$k_x$')
plt.ylabel(r'$k_y$')
plt.xlim(-k_max,k_max)
plt.ylim(-k_max,k_max)
plt.show()
