# %% Manganosite

# Set up lattice
lattice = FaceCenteredCubic(4.4448*Angstrom)

# Define elements
elements = [Manganese, Oxygen]

# Define coordinates
fractional_coordinates = [[ 0. ,  0. ,  0. ],
                          [ 0.5,  0.5,  0.5]]

# Set up configuration
manganosite = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )

manganosite_name = "manganosite"


# %% Set LCAOCalculator

# %% LCAOCalculator

k_point_sampling = KpointDensity(
    density_a=4.0 * Angstrom, density_b=4.0 * Angstrom, density_c=4.0 * Angstrom
)

numerical_accuracy_parameters = NumericalAccuracyParameters(
    k_point_sampling=k_point_sampling
)

calculator = LCAOCalculator(
    exchange_correlation=SGGA.PBE,
    basis_set=BasisGGAPseudoDojo.Medium,
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    checkpoint_handler=NoCheckpointHandler,
)


# %% Set Calculator

manganosite.setCalculator(calculator)

# Do an initial density of states calculation in order to compare results with DFT+U
manganosite.update()

density_of_states = DensityOfStates(configuration=manganosite)
nlsave('MnO_LSCC_U.hdf5', density_of_states, object_id='MnO DOS standard DFT')


# Calculate LSCC U values for Mn 3d shell

shell_indices = {Manganese: [3]} # 3d orbital has index 3 in the PseudoDojo Medium basis set

lscc_u_per_atom = calculateLocalScreenedCoulombCorrectionHubbardU(
    configuration=manganosite,
    shell_indices=shell_indices
)

for atom_index, u_values_per_shell in lscc_u_per_atom:
    nlprint('{:d}: U for 3d shell: {:.3f} eV'.format(atom_index, u_values_per_shell[0].inUnitsOf(eV)))

# Create a new configuration where the atoms with U values are tagged and create basis sets for
# these tags which sets the U value to the ones calculated above.
configuration_with_tags, basis_set_with_u = createLocalDFTUBasisSet(
    configuration=manganosite,
    basis_set=calculator.basisSet(),
    shell_indices=shell_indices,
    atom_u_values=lscc_u_per_atom
)

# Create the XC functional with +U using the 'Dual' projection method.
xc_with_dftu = ExchangeCorrelation(
    exchange=PerdewBurkeErnzerhofExchange,
    correlation=PerdewBurkeErnzerhofCorrelation,
    hubbard_term=Dual,
    number_of_spins=2
)

# Create a copy of the original calculator but change xc to DFT+U and use the basis set
# with U values.
calculator_with_u = calculator(
    exchange_correlation=xc_with_dftu,
    basis_set=basis_set_with_u)

configuration_with_tags.setCalculator(calculator_with_u)

# Perform the DFT+U calculation
configuration_with_tags.update()

# Calculate density of states
density_of_states_with_u = DensityOfStates(configuration=configuration_with_tags)
nlsave('MnO_LSCC_U.hdf5', density_of_states_with_u, object_id='MnO DOS DFT+U')
