import torch

from QATK.MLDFT import *

model_dir = '/densitymodel/cutoff_4_model_3_128/'

# Model parameters
model = DeepDFTModel(
    model_dir=model_dir,
    model_type=DeepDFTModelType.PaiNN,
    cutoff=4.0,
    num_interactions=3,
    node_size=128,
    num_components=1
)

# Define dataset directories containing HDF5 files with training data
dataset_dirs=[
    "/directory/with/hdf5/files/for/training1/",
    "/directory/with/hdf5/files/for/training2/",
]

# Create the dataset
dataset = GridValuesDataset(
    dataset_dirs=dataset_dirs,
    probe_count_train=1000,
    probe_count_val=5000,
    validation_ratio=0.05,
    seed=123456,
)

criterion = torch.nn.L1Loss()

# Setup the trainer
trainer = GridValuesModelTrainer(
    target_mae=1e-5,
    grid_values_model=model,
    grid_values_dataset=dataset,
    max_steps=int(1e6),
    batch_size=2,
    learning_rate_base=None,
    criterion=criterion,
    validation_interval=5000,
    gpu_acceleration=Automatic,
)

# Start the training process
trainer.train()
