from QATK.Analysis import *
from QATK.Calculators.DFT import *
from QATK.Core import *

# %% CoSi2

# Set up lattice
lattice = FaceCenteredCubic(5.356*Angstrom)

# Define elements
elements = [Cobalt, Silicon, Silicon]

# Define coordinates
fractional_coordinates = [[ 0.  ,  0.  ,  0.  ],
                          [ 0.25,  0.25,  0.25],
                          [ 0.75,  0.75,  0.75]]

# Set up configuration
cosi2 = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )


# %% Set LCAOCalculator

# %% LCAOCalculator

# ----------------------------------------
# Exchange-Correlation
# ----------------------------------------
exchange_correlation = SGGA.PBE

# ----------------------------------------
# Basis Set
# ----------------------------------------
basis_set = [
    BasisGGAPseudoDojo.Silicon_High,
    BasisGGAPseudoDojo.Cobalt_High,
]

k_point_sampling = KpointDensity(
    density_a=4.0 * Angstrom, density_b=4.0 * Angstrom, density_c=4.0 * Angstrom
)

numerical_accuracy_parameters = NumericalAccuracyParameters(
    k_point_sampling=k_point_sampling
)

model = PretrainedGridValuesModels.density_MP_PBE_PD_High_Y2026

grid_values_predictor = DensityPredictor(
    model=model,
)

algorithm_parameters = AlgorithmParameters(grid_values_predictor=grid_values_predictor)

calculator = LCAOCalculator(
    basis_set=basis_set,
    exchange_correlation=exchange_correlation,
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    checkpoint_handler=NoCheckpointHandler,
    algorithm_parameters=algorithm_parameters,
)


# %% Set Calculator

cosi2.setCalculator(calculator)

cosi2.update()

nlsave('CoSi2_results.hdf5', cosi2)


# %% TotalEnergy

total_energy_1 = TotalEnergy(configuration=cosi2)
nlsave('CoSi2_results.hdf5', total_energy_1)
