# %% Defect Formation

# %% Calculators

# %% Pre-Optimization ForceFieldCalculator - SiC MTP

potentialSet = QuantumATK_MTP_SiC_Defects_2023()
calculator = TremoloXCalculator(parameters=potentialSet)


# %% Bandgap LCAOCalculator - HSE06

#----------------------------------------
# Exchange-Correlation
#----------------------------------------
exchange_correlation = HybridGGA.HSE06

k_point_sampling = KpointDensity(
    density_a=4.0*Angstrom,
    density_b=4.0*Angstrom,
    density_c=4.0*Angstrom
)

numerical_accuracy_parameters = NumericalAccuracyParameters(
    k_point_sampling=k_point_sampling
)

bandgap_lcaocalculator__hse06 = LCAOCalculator(
    exchange_correlation=exchange_correlation,
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    checkpoint_handler=NoCheckpointHandler
)


# %% Phonon ForceFieldCalculator - SiC MTP

potentialSet = QuantumATK_MTP_SiC_Defects_2023()
calculator_1 = TremoloXCalculator(parameters=potentialSet)


# %% Reference LCAOCalculator - Spin Polarized PBE

#----------------------------------------
# Exchange-Correlation
#----------------------------------------
exchange_correlation = SGGA.PBE

k_point_sampling = KpointDensity(
    density_a=4.0*Angstrom,
    density_b=4.0*Angstrom,
    density_c=4.0*Angstrom
)

numerical_accuracy_parameters = NumericalAccuracyParameters(
    k_point_sampling=k_point_sampling
)

reference_lcaocalculator__spin_polarized_pbe = LCAOCalculator(
    exchange_correlation=exchange_correlation,
    numerical_accuracy_parameters=numerical_accuracy_parameters,
    checkpoint_handler=NoCheckpointHandler
)


# %% Pristine Material

# %% SiC-4H

# Set up lattice
lattice = Hexagonal(3.08051*Angstrom, 10.0848*Angstrom)

# Define elements
elements = [Carbon, Carbon, Silicon, Silicon, Carbon, Carbon, Silicon, Silicon]

# Define coordinates
fractional_coordinates = [[ 0.25          ,  0.25          ,  0.719485      ],
                          [ 0.25          ,  0.25          ,  0.219485      ],
                          [ 0.25          ,  0.25          ,  0.531645      ],
                          [ 0.25          ,  0.25          ,  0.031645      ],
                          [ 0.916666666667,  0.583333333333,  0.968355      ],
                          [ 0.583333333333,  0.916666666667,  0.468355      ],
                          [ 0.916666666667,  0.583333333333,  0.781465      ],
                          [ 0.583333333333,  0.916666666667,  0.281465      ]]

# Set up configuration
sic4h = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )
sic4h_name = "sic4h"


# %% Set Calculator

sic4h.setCalculator(reference_lcaocalculator__spin_polarized_pbe)

sic4h.update()

nlsave('SiC_Defects_Example_results.hdf5', sic4h)


# %% OptimizeGeometry

restart_strategy = RestartFromTrajectory(
    trajectory_filename='SiC_Defects_Example_results.hdf5',
    object_id='optimize_trajectory'
)

optimized_configuration = OptimizeGeometry(
    configuration=sic4h,
    constraints=[
        BravaisLatticeConstraint()
    ],
    trajectory_filename='SiC_Defects_Example_results.hdf5',
    trajectory_object_id='optimize_trajectory',
    optimize_cell=True,
    restart_strategy=restart_strategy
)

nlsave('SiC_Defects_Example_results.hdf5', optimized_configuration)


# %% OpticalSpectrum

kpoints = KpointDensity(
    density_a=2.0*Angstrom,
    density_b=2.0*Angstrom,
    density_c=2.0*Angstrom
)

optical_spectrum = OpticalSpectrum(
    configuration=optimized_configuration,
    kpoints=kpoints,
    bands_below_fermi_level=100,
    bands_above_fermi_level=100
)
nlsave('SiC_Defects_Example_results.hdf5', optical_spectrum)


# %% ElasticConstants

elastic_constants = ElasticConstants(
    configuration=optimized_configuration
)
nlsave('SiC_Defects_Example_results.hdf5', elastic_constants)


# %% IsotropicFiniteSizeCorrectionParameters

isotropic_finite_size_correction_parameters = IsotropicFiniteSizeCorrectionParameters(
    optimize_width=True,
    optimize_position=False
)
nlsave('SiC_Defects_Example_results.hdf5', isotropic_finite_size_correction_parameters)


# %% PristineConfiguration

pristine_configuration = PristineConfiguration(
    configuration=optimized_configuration,
    reference_calculator=reference_lcaocalculator__spin_polarized_pbe,
    band_gap_calculator=bandgap_lcaocalculator__hse06,
    phonon_calculator=calculator_1,
    supercell_repetitions=(3, 3, 1),
    finite_size_correction_parameters=isotropic_finite_size_correction_parameters,
    include_vibrations=True
)

pristine_configuration.update()
nlsave('SiC_Defects_Example_results.hdf5', pristine_configuration)


# %% Si Reference Material
# %% Silicon

# Set up lattice
lattice = SimpleCubic(5.4306*Angstrom)

# Define elements
elements = [Silicon, Silicon, Silicon, Silicon, Silicon, Silicon, Silicon,
            Silicon]

# Define coordinates
fractional_coordinates = [[ 0.  ,  0.  ,  0.  ],
                          [ 0.25,  0.25,  0.25],
                          [ 0.5 ,  0.5 ,  0.  ],
                          [ 0.75,  0.75,  0.25],
                          [ 0.5 ,  0.  ,  0.5 ],
                          [ 0.75,  0.25,  0.75],
                          [ 0.  ,  0.5 ,  0.5 ],
                          [ 0.25,  0.75,  0.75]]

# Set up configuration
silicon = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )
silicon_name = "silicon"


# %% Si Reference Calculator

silicon.setCalculator(reference_lcaocalculator__spin_polarized_pbe)

nlsave('SiC_Defects_Example_results.hdf5', silicon)


# %% Si Reference OptimizeGeometry

restart_strategy = RestartFromTrajectory(
    trajectory_filename='SiC_Defects_Example_results.hdf5',
    object_id='optimize_trajectory_1'
)

optimized_configuration_1 = OptimizeGeometry(
    configuration=silicon,
    constraints=[
        BravaisLatticeConstraint()
    ],
    trajectory_filename='SiC_Defects_Example_results.hdf5',
    trajectory_object_id='optimize_trajectory_1',
    optimize_cell=True,
    restart_strategy=restart_strategy
)

nlsave('SiC_Defects_Example_results.hdf5', optimized_configuration_1)


# %% C Reference Material
# %% Diamond

# Set up lattice
lattice = SimpleCubic(3.56679*Angstrom)

# Define elements
elements = [Carbon, Carbon, Carbon, Carbon, Carbon, Carbon, Carbon, Carbon]

# Define coordinates
fractional_coordinates = [[ 0.  ,  0.  ,  0.  ],
                          [ 0.25,  0.25,  0.25],
                          [ 0.5 ,  0.5 ,  0.  ],
                          [ 0.75,  0.75,  0.25],
                          [ 0.5 ,  0.  ,  0.5 ],
                          [ 0.75,  0.25,  0.75],
                          [ 0.  ,  0.5 ,  0.5 ],
                          [ 0.25,  0.75,  0.75]]

# Set up configuration
diamond = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
    )
diamond_name = "diamond"


# %% C Reference Calculator

diamond.setCalculator(reference_lcaocalculator__spin_polarized_pbe)

nlsave('SiC_Defects_Example_results.hdf5', diamond)


# %% C Reference OptimizeGeometry

restart_strategy = RestartFromTrajectory(
    trajectory_filename='SiC_Defects_Example_results.hdf5',
    object_id='optimize_trajectory_2'
)

optimized_configuration_2 = OptimizeGeometry(
    configuration=diamond,
    constraints=[
        BravaisLatticeConstraint()
    ],
    trajectory_filename='SiC_Defects_Example_results.hdf5',
    trajectory_object_id='optimize_trajectory_2',
    optimize_cell=True,
    restart_strategy=restart_strategy
)

nlsave('SiC_Defects_Example_results.hdf5', optimized_configuration_2)


# %% Chemical Potential
# %% ChemicalPotentialTable

if 'chemical_potential_table' not in locals():
    columns = [
        StringColumn('element', 'Element'),
        InstanceColumn('potential', 'Potential', instance_types=(BaseChemicalPotential,))
    ]
    chemical_potential_table = Table(columns=columns, data=[])
    chemical_potential_table.setMetatext('ChemicalPotentialTable')

    nlsave('SiC_Defects_Example_results.hdf5', chemical_potential_table, object_id='table')


# %% Silicon CalculatedChemicalPotential

silicon_calculated_chemical_potential = CalculatedChemicalPotential(
    element=Silicon,
    calculator=reference_lcaocalculator__spin_polarized_pbe,
    phonon_calculator=calculator_1,
    configuration=optimized_configuration_1,
    include_vibrations=True
)

silicon_calculated_chemical_potential.update()
nlsave('SiC_Defects_Example_results.hdf5', silicon_calculated_chemical_potential)

element_name = "Silicon"


# %% Append Potential to Chemical Potential Table

row_data = []
row_data.append(element_name)
row_data.append(silicon_calculated_chemical_potential)
chemical_potential_table.appendRow(row_data=row_data)

nlsave('SiC_Defects_Example_results.hdf5', chemical_potential_table, object_id='table')


# %% Carbon CalculatedChemicalPotential

carbon_calculated_chemical_potential = CalculatedChemicalPotential(
    element=Carbon,
    calculator=reference_lcaocalculator__spin_polarized_pbe,
    phonon_calculator=calculator_1,
    configuration=optimized_configuration_2,
    include_vibrations=True
)

carbon_calculated_chemical_potential.update()
nlsave('SiC_Defects_Example_results.hdf5', carbon_calculated_chemical_potential)

element_name_1 = "Carbon"


# %% Append Potential to Chemical Potential Table

row_data = []
row_data.append(element_name_1)
row_data.append(carbon_calculated_chemical_potential)
chemical_potential_table.appendRow(row_data=row_data)

nlsave('SiC_Defects_Example_results.hdf5', chemical_potential_table, object_id='table')


# %% Chemical Potential Table Final

def chemical_potential_table_final(chemical_potential_table):

    return chemical_potential_table

chemical_potential_table = chemical_potential_table_final(chemical_potential_table)

nlsave('SiC_Defects_Example_results.hdf5', chemical_potential_table)


# %% Defects

# %% Defects

defects_defects = []
defects_defect_generators = {}
defects_defect_generator_map = {}

defect_parameters = DefectsParameters(
    defect_type=1
)
defect_generator = defect_parameters.defectGenerator(sic4h)
defects_from_generator = []
point_defect = Vacancy(
    site_index=0,
    unit_cell_index=(0, 0, 0))
named_point_defect = NamedPointDefect(
    point_defect=point_defect,
    index=0,
    name='Vacancy C 000 000',
)
defects_from_generator.append(named_point_defect)
point_defect = Vacancy(
    site_index=2,
    unit_cell_index=(0, 0, 0))
named_point_defect = NamedPointDefect(
    point_defect=point_defect,
    index=1,
    name='Vacancy Si 001 002',
)
defects_from_generator.append(named_point_defect)
point_defect = Vacancy(
    site_index=5,
    unit_cell_index=(0, 0, 0))
named_point_defect = NamedPointDefect(
    point_defect=point_defect,
    index=2,
    name='Vacancy C 002 005',
)
defects_from_generator.append(named_point_defect)
point_defect = Vacancy(
    site_index=7,
    unit_cell_index=(0, 0, 0))
named_point_defect = NamedPointDefect(
    point_defect=point_defect,
    index=3,
    name='Vacancy Si 003 007',
)
defects_from_generator.append(named_point_defect)
kept_defects_from_generator = [defect.pointDefect() for defect in defects_from_generator]
defect_generator = defect_generator.filterByPointDefect(kept_defects_from_generator)

defects_defects += defects_from_generator
defects_defect_generators['c1cd9f51d23442339c9c3d6f7895535e'] = defect_generator
for unique_defect in defects_from_generator:
    defects_defect_generator_map[unique_defect.name()] = 'c1cd9f51d23442339c9c3d6f7895535e'

defect_parameters = DefectsParameters(
    defect_type=0,
    defect_element=Carbon
)
defect_generator = defect_parameters.defectGenerator(sic4h)
defects_from_generator = []
point_defect = Substitutional(
    element=Carbon,
    site_index=2,
    unit_cell_index=(0, 0, 0))
named_point_defect = NamedPointDefect(
    point_defect=point_defect,
    index=1,
    name='Substitutional C^Si 001 002',
)
defects_from_generator.append(named_point_defect)
point_defect = Substitutional(
    element=Carbon,
    site_index=7,
    unit_cell_index=(0, 0, 0))
named_point_defect = NamedPointDefect(
    point_defect=point_defect,
    index=3,
    name='Substitutional C^Si 003 007',
)
defects_from_generator.append(named_point_defect)
kept_defects_from_generator = [defect.pointDefect() for defect in defects_from_generator]
defect_generator = defect_generator.filterByPointDefect(kept_defects_from_generator)

defects_defects += defects_from_generator
defects_defect_generators['e7cfdd8dfe242a993be436d02ec2072'] = defect_generator
for unique_defect in defects_from_generator:
    defects_defect_generator_map[unique_defect.name()] = 'e7cfdd8dfe242a993be436d02ec2072'

defect_parameters = DefectsParameters(
    defect_type=0,
    defect_element=Silicon
)
defect_generator = defect_parameters.defectGenerator(sic4h)
defects_from_generator = []
point_defect = Substitutional(
    element=Silicon,
    site_index=0,
    unit_cell_index=(0, 0, 0))
named_point_defect = NamedPointDefect(
    point_defect=point_defect,
    index=0,
    name='Substitutional Si^C 000 000',
)
defects_from_generator.append(named_point_defect)
point_defect = Substitutional(
    element=Silicon,
    site_index=5,
    unit_cell_index=(0, 0, 0))
named_point_defect = NamedPointDefect(
    point_defect=point_defect,
    index=2,
    name='Substitutional Si^C 002 005',
)
defects_from_generator.append(named_point_defect)
kept_defects_from_generator = [defect.pointDefect() for defect in defects_from_generator]
defect_generator = defect_generator.filterByPointDefect(kept_defects_from_generator)

defects_defects += defects_from_generator
defects_defect_generators['c068bc284565a8effa2fd290a57d'] = defect_generator
for unique_defect in defects_from_generator:
    defects_defect_generator_map[unique_defect.name()] = 'c068bc284565a8effa2fd290a57d'

defects = Defects(
    defects=defects_defects,
    defect_generators=defects_defect_generators,
    defect_generator_map=defects_defect_generator_map
)
nlsave('SiC_Defects_Example_results.hdf5', defects)


# %% Create defects table

def create_defects_table(defects):
    defects_table = defects.totalDefects()

    return defects_table

defects_table = create_defects_table(defects)

nlsave('SiC_Defects_Example_results.hdf5', defects_table)


# %% TableIteration

# TableIteration(preparation)
for row_index in range(defects_table.numberOfRows()):
    row_data = defects_table.row(row_index)
    defect = row_data[0]

    # %% Pre-Relaxation Parameters

    optimize_geometry_parameters = OptimizeGeometryParameters(
        constraints=[
            BravaisLatticeConstraint()
        ],
        optimize_cell=False
    )

    # %% Relaxation Parameters

    optimize_geometry_parameters_1 = OptimizeGeometryParameters(
        constraints=[
            BravaisLatticeConstraint()
        ],
        optimize_cell=False
    )

    # %% ChargedPointDefectConfiguration

    charged_point_defect_configuration = ChargedPointDefectConfiguration(
        pristine_configuration=pristine_configuration,
        point_defect=defect,
        atomic_chemical_potentials=chemical_potential_table,
        charge_states=(-2, -1, 0, 1, 2),
        pre_relaxation_calculator=calculator,
        pre_relaxation_parameters=optimize_geometry_parameters,
        relaxation_parameters=optimize_geometry_parameters_1
    )

    charged_point_defect_configuration.update()
    local_object_id = f'charged_point_defect_configuration_ChargedPointDefectConfiguration_row_index_{row_index}'
    nlsave('SiC_Defects_Example_results.hdf5', charged_point_defect_configuration, object_id=local_object_id)

    # %% Add charged point defects to a table

    if 'add_charged_point_defects_to_a_table' not in locals():
        columns = [
            InstanceColumn('cpdc', 'Cpdc', instance_types=(ChargedPointDefectConfiguration,))
        ]
        add_charged_point_defects_to_a_table = Table(columns=columns, data=[])
        add_charged_point_defects_to_a_table.setMetatext('Add charged point defects to a table')
    row_data = []
    row_data.append(charged_point_defect_configuration)
    add_charged_point_defects_to_a_table.appendRow(row_data=row_data)

    nlsave('SiC_Defects_Example_results.hdf5', add_charged_point_defects_to_a_table, object_id='table')
