# -------------------------------------------------------------
# Bulk Configuration
# -------------------------------------------------------------

# Set up lattice
vector_a = [2.866, 0.0, 0.0]*Angstrom
vector_b = [0.0, 2.866, 0.0]*Angstrom
vector_c = [0.0, 0.0, 31.13975]*Angstrom
lattice = UnitCell(vector_a, vector_b, vector_c)

# Define elements
elements = [Iron, Iron, Iron, Iron, Iron, Iron, Magnesium, Oxygen, Oxygen,
            Magnesium, Magnesium, Oxygen, Oxygen, Magnesium, Magnesium, Oxygen,
            Oxygen, Magnesium, Iron, Iron, Iron, Iron, Iron, Iron]

# Define coordinates
fractional_coordinates = [[ 0.25          ,  0.25          ,  0.023009176374],
                          [ 0.75          ,  0.75          ,  0.069027529123],
                          [ 0.25          ,  0.25          ,  0.115045881871],
                          [ 0.75          ,  0.75          ,  0.16106423462 ],
                          [ 0.25          ,  0.25          ,  0.207082587368],
                          [ 0.75          ,  0.75          ,  0.253100940117],
                          [ 0.25          ,  0.25          ,  0.323750190673],
                          [ 0.75          ,  0.75          ,  0.323750190673],
                          [ 0.25          ,  0.25          ,  0.394250114404],
                          [ 0.75          ,  0.75          ,  0.394250114404],
                          [ 0.25          ,  0.25          ,  0.464750038135],
                          [ 0.75          ,  0.75          ,  0.464750038135],
                          [ 0.25          ,  0.25          ,  0.535249961865],
                          [ 0.75          ,  0.75          ,  0.535249961865],
                          [ 0.25          ,  0.25          ,  0.605749885596],
                          [ 0.75          ,  0.75          ,  0.605749885596],
                          [ 0.25          ,  0.25          ,  0.676249809327],
                          [ 0.75          ,  0.75          ,  0.676249809327],
                          [ 0.25          ,  0.25          ,  0.746899059883],
                          [ 0.75          ,  0.75          ,  0.792917412632],
                          [ 0.25          ,  0.25          ,  0.83893576538 ],
                          [ 0.75          ,  0.75          ,  0.884954118129],
                          [ 0.25          ,  0.25          ,  0.930972470877],
                          [ 0.75          ,  0.75          ,  0.976990823626]]

# Set up configuration
bulk_configuration = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )

# Add tags
bulk_configuration.addTags('Selection 0', [0, 1, 2, 3, 20, 21, 22, 23])

# -------------------------------------------------------------
# Calculator
# -------------------------------------------------------------
#----------------------------------------
# Basis Set
#----------------------------------------
basis_set = [
    GGABasis.Oxygen_DoubleZetaPolarized,
    GGABasis.Magnesium_DoubleZetaPolarized,
    GGABasis.Iron_SingleZetaPolarized,
    ]

#----------------------------------------
# Exchange-Correlation
#----------------------------------------
exchange_correlation = SGGA.PBE

k_point_sampling = MonkhorstPackGrid(
    na=7,
    nb=7,
    nc=2,
    )
numerical_accuracy_parameters = NumericalAccuracyParameters(
    electron_temperature=1200.0*Kelvin,
    k_point_sampling=k_point_sampling,
    )

calculator = LCAOCalculator(
    basis_set=basis_set,
    exchange_correlation=exchange_correlation,
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    )

bulk_configuration.setCalculator(calculator)

# -------------------------------------------------------------
# Initial State
# -------------------------------------------------------------
initial_spin = InitialSpin(scaled_spins=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
bulk_configuration.setCalculator(
    calculator,
    initial_spin=initial_spin,
)
bulk_configuration.update()
nlsave('mgo_relax.nc', bulk_configuration)
nlprint(bulk_configuration)

# -------------------------------------------------------------
# Optimize Geometry
# -------------------------------------------------------------
indices_0 = [0, 1, 2, 3, 20, 21, 22, 23]
constraints = [RigidBody(indices_0)]

bulk_configuration = OptimizeGeometry(
        bulk_configuration,
        max_forces=0.05*eV/Ang,
        max_steps=200,
        max_step_length=0.2*Ang,
        constraints=constraints,
        trajectory_filename=None,
        disable_stress=True,
        optimizer_method=LBFGS(),
        )
nlsave('mgo_relax.nc', bulk_configuration)
nlprint(bulk_configuration)

# -------------------------------------------------------------
# Forces
# -------------------------------------------------------------
forces = Forces(bulk_configuration)
nlsave('mgo_relax.nc', forces)
nlprint(forces)
