from QuantumATK import *

def blochStateSpinVector(bloch_state):
    # spin components
    projection_up = bloch_state.axisProjection(spin=Spin.Up)
    projection_dn = bloch_state.axisProjection(spin=Spin.Down)
    c = projection_up[0]
    up = projection_up[1].inUnitsOf(Bohr**(-1.5))
    dn = projection_dn[1].inUnitsOf(Bohr**(-1.5))

    upup = up.conjugate() * up
    updn = up.conjugate() * dn
    dndn = dn.conjugate() * dn

    # spin vector
    m_x = 2*updn.real
    m_y = 2*updn.imag
    m_z = upup.real - dndn.real
    m = numpy.array([m_x, m_y, m_z])

    # calculate the spin vector length
    r_list = (m_x**2 + m_y**2 + m_z**2)**0.5

    # calculate angles theta and phi, but only for non-zero vector lengths
    theta_list = []
    phi_list   = []
    for i in range(len(r_list)):
        r = r_list[i]
        mx = m_x[i]
        my = m_y[i]
        mz = m_z[i]
        if r > 0.0:
            phi = (numpy.arctan2(my, mx) + numpy.pi) % (2*numpy.pi)
            theta = numpy.arctan2(mx*numpy.cos(phi) + my*numpy.sin(phi), mz)
            theta = (theta + 2*numpy.pi) % (2*numpy.pi)
        else:
            phi = 0.0
            theta = 0.0
        theta_list.append(theta)
        phi_list.append(phi)

    # Convert to Degrees
    theta_list = 180.*numpy.asarray(theta_list)/numpy.pi 
    phi_list = 180.*numpy.asarray(phi_list)/numpy.pi

    return r_list, theta_list, phi_list, c
