# Set minimal log verbosity
setVerbosity(MinimalLog)

results_filename = 'CuH_results.hdf5'


# Set up lattice
vector_a = [14.45984, 0.0, 0.0]*Angstrom
vector_b = [0.0, 14.45984, 0.0]*Angstrom
vector_c = [0.0, 0.0, 14.45984]*Angstrom
lattice = UnitCell(vector_a, vector_b, vector_c)

# Define elements
elements = [Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Copper, Copper, Copper, Copper, Copper, Copper, Copper, Copper,
            Hydrogen]

# Define coordinates
fractional_coordinates = [[ 0.   ,  0.   ,  0.   ],
                            [ 0.   ,  0.   ,  0.25 ],
                            [ 0.   ,  0.   ,  0.5  ],
                            [ 0.   ,  0.   ,  0.75 ],
                            [ 0.   ,  0.25 ,  0.   ],
                            [ 0.   ,  0.25 ,  0.25 ],
                            [ 0.   ,  0.25 ,  0.5  ],
                            [ 0.   ,  0.25 ,  0.75 ],
                            [ 0.   ,  0.5  ,  0.   ],
                            [ 0.   ,  0.5  ,  0.25 ],
                            [ 0.   ,  0.5  ,  0.5  ],
                            [ 0.   ,  0.5  ,  0.75 ],
                            [ 0.   ,  0.75 ,  0.   ],
                            [ 0.   ,  0.75 ,  0.25 ],
                            [ 0.   ,  0.75 ,  0.5  ],
                            [ 0.   ,  0.75 ,  0.75 ],
                            [ 0.25 ,  0.   ,  0.   ],
                            [ 0.25 ,  0.   ,  0.25 ],
                            [ 0.25 ,  0.   ,  0.5  ],
                            [ 0.25 ,  0.   ,  0.75 ],
                            [ 0.25 ,  0.25 ,  0.   ],
                            [ 0.25 ,  0.25 ,  0.25 ],
                            [ 0.25 ,  0.25 ,  0.5  ],
                            [ 0.25 ,  0.25 ,  0.75 ],
                            [ 0.25 ,  0.5  ,  0.   ],
                            [ 0.25 ,  0.5  ,  0.25 ],
                            [ 0.25 ,  0.5  ,  0.5  ],
                            [ 0.25 ,  0.5  ,  0.75 ],
                            [ 0.25 ,  0.75 ,  0.   ],
                            [ 0.25 ,  0.75 ,  0.25 ],
                            [ 0.25 ,  0.75 ,  0.5  ],
                            [ 0.25 ,  0.75 ,  0.75 ],
                            [ 0.5  ,  0.   ,  0.   ],
                            [ 0.5  ,  0.   ,  0.25 ],
                            [ 0.5  ,  0.   ,  0.5  ],
                            [ 0.5  ,  0.   ,  0.75 ],
                            [ 0.5  ,  0.25 ,  0.   ],
                            [ 0.5  ,  0.25 ,  0.25 ],
                            [ 0.5  ,  0.25 ,  0.5  ],
                            [ 0.5  ,  0.25 ,  0.75 ],
                            [ 0.5  ,  0.5  ,  0.   ],
                            [ 0.5  ,  0.5  ,  0.25 ],
                            [ 0.5  ,  0.5  ,  0.5  ],
                            [ 0.5  ,  0.5  ,  0.75 ],
                            [ 0.5  ,  0.75 ,  0.   ],
                            [ 0.5  ,  0.75 ,  0.25 ],
                            [ 0.5  ,  0.75 ,  0.5  ],
                            [ 0.5  ,  0.75 ,  0.75 ],
                            [ 0.75 ,  0.   ,  0.   ],
                            [ 0.75 ,  0.   ,  0.25 ],
                            [ 0.75 ,  0.   ,  0.5  ],
                            [ 0.75 ,  0.   ,  0.75 ],
                            [ 0.75 ,  0.25 ,  0.   ],
                            [ 0.75 ,  0.25 ,  0.25 ],
                            [ 0.75 ,  0.25 ,  0.5  ],
                            [ 0.75 ,  0.25 ,  0.75 ],
                            [ 0.75 ,  0.5  ,  0.   ],
                            [ 0.75 ,  0.5  ,  0.25 ],
                            [ 0.75 ,  0.5  ,  0.5  ],
                            [ 0.75 ,  0.5  ,  0.75 ],
                            [ 0.75 ,  0.75 ,  0.   ],
                            [ 0.75 ,  0.75 ,  0.25 ],
                            [ 0.75 ,  0.75 ,  0.5  ],
                            [ 0.75 ,  0.75 ,  0.75 ],
                            [ 0.125,  0.125,  0.   ],
                            [ 0.125,  0.125,  0.25 ],
                            [ 0.125,  0.125,  0.5  ],
                            [ 0.125,  0.125,  0.75 ],
                            [ 0.125,  0.375,  0.   ],
                            [ 0.125,  0.375,  0.25 ],
                            [ 0.125,  0.375,  0.5  ],
                            [ 0.125,  0.375,  0.75 ],
                            [ 0.125,  0.625,  0.   ],
                            [ 0.125,  0.625,  0.25 ],
                            [ 0.125,  0.625,  0.5  ],
                            [ 0.125,  0.625,  0.75 ],
                            [ 0.125,  0.875,  0.   ],
                            [ 0.125,  0.875,  0.25 ],
                            [ 0.125,  0.875,  0.5  ],
                            [ 0.125,  0.875,  0.75 ],
                            [ 0.375,  0.125,  0.   ],
                            [ 0.375,  0.125,  0.25 ],
                            [ 0.375,  0.125,  0.5  ],
                            [ 0.375,  0.125,  0.75 ],
                            [ 0.375,  0.375,  0.   ],
                            [ 0.375,  0.375,  0.25 ],
                            [ 0.375,  0.375,  0.5  ],
                            [ 0.375,  0.375,  0.75 ],
                            [ 0.375,  0.625,  0.   ],
                            [ 0.375,  0.625,  0.25 ],
                            [ 0.375,  0.625,  0.5  ],
                            [ 0.375,  0.625,  0.75 ],
                            [ 0.375,  0.875,  0.   ],
                            [ 0.375,  0.875,  0.25 ],
                            [ 0.375,  0.875,  0.5  ],
                            [ 0.375,  0.875,  0.75 ],
                            [ 0.625,  0.125,  0.   ],
                            [ 0.625,  0.125,  0.25 ],
                            [ 0.625,  0.125,  0.5  ],
                            [ 0.625,  0.125,  0.75 ],
                            [ 0.625,  0.375,  0.   ],
                            [ 0.625,  0.375,  0.25 ],
                            [ 0.625,  0.375,  0.5  ],
                            [ 0.625,  0.375,  0.75 ],
                            [ 0.625,  0.625,  0.   ],
                            [ 0.625,  0.625,  0.25 ],
                            [ 0.625,  0.625,  0.5  ],
                            [ 0.625,  0.625,  0.75 ],
                            [ 0.625,  0.875,  0.   ],
                            [ 0.625,  0.875,  0.25 ],
                            [ 0.625,  0.875,  0.5  ],
                            [ 0.625,  0.875,  0.75 ],
                            [ 0.875,  0.125,  0.   ],
                            [ 0.875,  0.125,  0.25 ],
                            [ 0.875,  0.125,  0.5  ],
                            [ 0.875,  0.125,  0.75 ],
                            [ 0.875,  0.375,  0.   ],
                            [ 0.875,  0.375,  0.25 ],
                            [ 0.875,  0.375,  0.5  ],
                            [ 0.875,  0.375,  0.75 ],
                            [ 0.875,  0.625,  0.   ],
                            [ 0.875,  0.625,  0.25 ],
                            [ 0.875,  0.625,  0.5  ],
                            [ 0.875,  0.625,  0.75 ],
                            [ 0.875,  0.875,  0.   ],
                            [ 0.875,  0.875,  0.25 ],
                            [ 0.875,  0.875,  0.5  ],
                            [ 0.875,  0.875,  0.75 ],
                            [ 0.125,  0.   ,  0.125],
                            [ 0.125,  0.   ,  0.375],
                            [ 0.125,  0.   ,  0.625],
                            [ 0.125,  0.   ,  0.875],
                            [ 0.125,  0.25 ,  0.125],
                            [ 0.125,  0.25 ,  0.375],
                            [ 0.125,  0.25 ,  0.625],
                            [ 0.125,  0.25 ,  0.875],
                            [ 0.125,  0.5  ,  0.125],
                            [ 0.125,  0.5  ,  0.375],
                            [ 0.125,  0.5  ,  0.625],
                            [ 0.125,  0.5  ,  0.875],
                            [ 0.125,  0.75 ,  0.125],
                            [ 0.125,  0.75 ,  0.375],
                            [ 0.125,  0.75 ,  0.625],
                            [ 0.125,  0.75 ,  0.875],
                            [ 0.375,  0.   ,  0.125],
                            [ 0.375,  0.   ,  0.375],
                            [ 0.375,  0.   ,  0.625],
                            [ 0.375,  0.   ,  0.875],
                            [ 0.375,  0.25 ,  0.125],
                            [ 0.375,  0.25 ,  0.375],
                            [ 0.375,  0.25 ,  0.625],
                            [ 0.375,  0.25 ,  0.875],
                            [ 0.375,  0.5  ,  0.125],
                            [ 0.375,  0.5  ,  0.375],
                            [ 0.375,  0.5  ,  0.625],
                            [ 0.375,  0.5  ,  0.875],
                            [ 0.375,  0.75 ,  0.125],
                            [ 0.375,  0.75 ,  0.375],
                            [ 0.375,  0.75 ,  0.625],
                            [ 0.375,  0.75 ,  0.875],
                            [ 0.625,  0.   ,  0.125],
                            [ 0.625,  0.   ,  0.375],
                            [ 0.625,  0.   ,  0.625],
                            [ 0.625,  0.   ,  0.875],
                            [ 0.625,  0.25 ,  0.125],
                            [ 0.625,  0.25 ,  0.375],
                            [ 0.625,  0.25 ,  0.625],
                            [ 0.625,  0.25 ,  0.875],
                            [ 0.625,  0.5  ,  0.125],
                            [ 0.625,  0.5  ,  0.375],
                            [ 0.625,  0.5  ,  0.625],
                            [ 0.625,  0.5  ,  0.875],
                            [ 0.625,  0.75 ,  0.125],
                            [ 0.625,  0.75 ,  0.375],
                            [ 0.625,  0.75 ,  0.625],
                            [ 0.625,  0.75 ,  0.875],
                            [ 0.875,  0.   ,  0.125],
                            [ 0.875,  0.   ,  0.375],
                            [ 0.875,  0.   ,  0.625],
                            [ 0.875,  0.   ,  0.875],
                            [ 0.875,  0.25 ,  0.125],
                            [ 0.875,  0.25 ,  0.375],
                            [ 0.875,  0.25 ,  0.625],
                            [ 0.875,  0.25 ,  0.875],
                            [ 0.875,  0.5  ,  0.125],
                            [ 0.875,  0.5  ,  0.375],
                            [ 0.875,  0.5  ,  0.625],
                            [ 0.875,  0.5  ,  0.875],
                            [ 0.875,  0.75 ,  0.125],
                            [ 0.875,  0.75 ,  0.375],
                            [ 0.875,  0.75 ,  0.625],
                            [ 0.875,  0.75 ,  0.875],
                            [ 0.   ,  0.125,  0.125],
                            [ 0.   ,  0.125,  0.375],
                            [ 0.   ,  0.125,  0.625],
                            [ 0.   ,  0.125,  0.875],
                            [ 0.   ,  0.375,  0.125],
                            [ 0.   ,  0.375,  0.375],
                            [ 0.   ,  0.375,  0.625],
                            [ 0.   ,  0.375,  0.875],
                            [ 0.   ,  0.625,  0.125],
                            [ 0.   ,  0.625,  0.375],
                            [ 0.   ,  0.625,  0.625],
                            [ 0.   ,  0.625,  0.875],
                            [ 0.   ,  0.875,  0.125],
                            [ 0.   ,  0.875,  0.375],
                            [ 0.   ,  0.875,  0.625],
                            [ 0.   ,  0.875,  0.875],
                            [ 0.25 ,  0.125,  0.125],
                            [ 0.25 ,  0.125,  0.375],
                            [ 0.25 ,  0.125,  0.625],
                            [ 0.25 ,  0.125,  0.875],
                            [ 0.25 ,  0.375,  0.125],
                            [ 0.25 ,  0.375,  0.375],
                            [ 0.25 ,  0.375,  0.625],
                            [ 0.25 ,  0.375,  0.875],
                            [ 0.25 ,  0.625,  0.125],
                            [ 0.25 ,  0.625,  0.375],
                            [ 0.25 ,  0.625,  0.625],
                            [ 0.25 ,  0.625,  0.875],
                            [ 0.25 ,  0.875,  0.125],
                            [ 0.25 ,  0.875,  0.375],
                            [ 0.25 ,  0.875,  0.625],
                            [ 0.25 ,  0.875,  0.875],
                            [ 0.5  ,  0.125,  0.125],
                            [ 0.5  ,  0.125,  0.375],
                            [ 0.5  ,  0.125,  0.625],
                            [ 0.5  ,  0.125,  0.875],
                            [ 0.5  ,  0.375,  0.125],
                            [ 0.5  ,  0.375,  0.375],
                            [ 0.5  ,  0.375,  0.625],
                            [ 0.5  ,  0.375,  0.875],
                            [ 0.5  ,  0.625,  0.125],
                            [ 0.5  ,  0.625,  0.375],
                            [ 0.5  ,  0.625,  0.625],
                            [ 0.5  ,  0.625,  0.875],
                            [ 0.5  ,  0.875,  0.125],
                            [ 0.5  ,  0.875,  0.375],
                            [ 0.5  ,  0.875,  0.625],
                            [ 0.5  ,  0.875,  0.875],
                            [ 0.75 ,  0.125,  0.125],
                            [ 0.75 ,  0.125,  0.375],
                            [ 0.75 ,  0.125,  0.625],
                            [ 0.75 ,  0.125,  0.875],
                            [ 0.75 ,  0.375,  0.125],
                            [ 0.75 ,  0.375,  0.375],
                            [ 0.75 ,  0.375,  0.625],
                            [ 0.75 ,  0.375,  0.875],
                            [ 0.75 ,  0.625,  0.125],
                            [ 0.75 ,  0.625,  0.375],
                            [ 0.75 ,  0.625,  0.625],
                            [ 0.75 ,  0.625,  0.875],
                            [ 0.75 ,  0.875,  0.125],
                            [ 0.75 ,  0.875,  0.375],
                            [ 0.75 ,  0.875,  0.625],
                            [ 0.75 ,  0.875,  0.875],
                            [ 0.375,  0.375,  0.375]]

# Set up configuration
copper = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
)

copper.addTags('hydrogen', [256])
copper.addTags('constrain', [100])

# %% ForceFieldCalculator

potentialSet = TorchX_M3GNet_MP_DIRECT_PES_2021(dtype='float32')
calculator = TremoloXCalculator(parameters=potentialSet)

# %% Set Calculator

copper.setCalculator(calculator)

# Constrain one atom to fix the lattice in place.
fix_atom_index = copper.indicesFromTags(['constrain'])
constraints = [FixAtomConstraints(fix_atom_index), BravaisLatticeConstraint()]

restart_strategy = RestartFromTrajectory(
    trajectory_filename='copper_1_results.hdf5', object_id='optimize_trajectory'
)

# Optimize the initial geometry
optimized_configuration = OptimizeGeometry(
    configuration=copper,
    max_forces=0.01 * eV / Angstrom,
    constraints=constraints,
    trajectory_filename='copper_1_results.hdf5',
    trajectory_object_id='optimize_trajectory',
    restart_strategy=restart_strategy,
)

nlsave('copper_results.hdf5', optimized_configuration, object_id='optgeom')


# %% TotalEnergy

total_energy = TotalEnergy(configuration=optimized_configuration)
nlsave('copper_1_results.hdf5', total_energy)


# %% AdaptiveKineticMonteCarlo

saddle_search_parameters = SaddleSearchParameters(max_neb_images=5)

if os.path.isfile('akmc_markov_chain.nc'):
    markov_chain = nlread('akmc_markov_chain.nc')[-1]
else:
    markov_chain = MarkovChain(
        configuration=optimized_configuration,
        configuration_energy=total_energy.evaluate(),
    )

if os.path.isfile('akmc_kmc.nc'):
    kmc = nlread('akmc_kmc.nc')[-1]
else:
    kmc = None


# Set up the direction generator for the saddle search.
direction_generator = HypersphereDirection(
    order=Furthest,
    magnitude=(0.5*Angstrom, 1.0*Angstrom),
)

# Set up the LanczosSaddleSearch object. We are interested in the hydrogen atom, so we set the
# index_selector to the index of the hydrogen atom.
hydrogen_index = copper.indicesFromTags(['hydrogen'])
saddle_search = LanczosSaddleSearch(
    initial_direction_generator=direction_generator,
    constraints=optimized_configuration.indicesFromTags(['constrain']),
    index_selector=hydrogen_index,
)

akmc = AdaptiveKineticMonteCarlo(
    markov_chain=markov_chain,
    kmc_temperature=300.0 * Kelvin,
    calculator=optimized_configuration.calculator(),
    kmc=kmc,
    saddle_search_method=saddle_search
)

akmc.run(max_searches=10, max_kmc_steps=1000)
