import json

# Potentials to be benchmarked
potentials = {
    'scratch8': 'scratch_8_5Ang.qatkpt',
    'scratch16': 'scratch_16_5Ang.qatkpt',
    'scratch32': 'scratch_32_5Ang.qatkpt',
    'scratch64': 'scratch_64_5Ang.qatkpt',
    'scratch128': 'scratch_128_5Ang.qatkpt',
    'naive_finetuning': 'naive_finetuning.qatkpt',
    'multihead_finetuning': 'multihead_finetuning.qatkpt',
    'foundation': '',
    'mtp': 'MTP_TiSi_basis400_final.mtp'
}

# =================================================================================================
# ACCURACY BENCHMARKING
# =================================================================================================

# Dictionary to store benchmark results for each model
accuracy_benchmark_results = {}

# Function to compute Mean Squared Error
def compute_rmse(predicted, reference):
    return numpy.sqrt(numpy.mean(numpy.square(predicted - reference)))

# Load the MD trajectories with sample data for the benchmarking
trajectories = nlread('md_example_results.hdf5', MDTrajectory)
save_path = 'all_results_as_numpy.hdf5'

for potential_number, (potential_name, qatkpt_path) in enumerate(potentials.items()):
    nlprint(f"Processing potential {potential_number + 1}/{len(potentials)}: {potential_name}")
    # Set up a calculator with the given potential
    if potential_name == 'mtp':
        # MTP case
        potentialSet = TremoloXPotentialSet(name='Moment Tensor Potential')
        potentialSet.addParticleType(
            ParticleType(
                symbol='Si',
                mass=28.0855 * atomicMassUnit,
                charge=None,
                sigma=None,
                sigma14=None,
                epsilon=None,
                epsilon14=None,
                atomicNumber=14,
                tags=[],
            )
        )
        potentialSet.addParticleType(
            ParticleType(
                symbol='Ti',
                mass=47.867 * atomicMassUnit,
                charge=None,
                sigma=None,
                sigma14=None,
                epsilon=None,
                epsilon14=None,
                atomicNumber=22,
                tags=[],
            )
        )
        _potential = MTPPotential(
            file=qatkpt_path,
            suppress_intercept=False,
            group_name='',
        )
        potentialSet.addPotential(_potential)
    elif potential_name == 'foundation':
        # Foundation MACE model case
        potentialSet = TorchX_MACE_MP_0b3_medium(dtype='float32', enforceLTX=False)
    else:
        # Trained/finetuned MACE model case
        model_path = qatkpt_path

        potentialSet = TremoloXPotentialSet(name='my_mace_potential')
        potential = TorchXPotential(
            dtype='float32',
            device='cuda',
            file=model_path,
        )
        for symbol in potential.get_symbols(model_path):
            potentialSet.addParticleType(ParticleType(symbol))
        potentialSet.addPotential(potential)

    calculator = TremoloXCalculator(parameters=potentialSet)

    # Initialize accumulators for quantities
    energies_list = []
    ref_energies_list = []
    forces_list = []
    ref_forces_list = []


    for md_number, md in enumerate(trajectories):
        nlprint(f"Processing MD trajectory {md_number + 1}/{len(trajectories)}")
        num_configurations = md.length()
        # Compute energy, forces, and stress for each configuration
        for i in range(num_configurations):
            config = md.image(i)
            num_atoms = len(config)

            # Set the calculator
            config.setCalculator(calculator)

            # Calculate properties
            energy = TotalEnergy(config).evaluate() / num_atoms  # Energy per atom
            forces = Forces(config).evaluate().flatten()
            energies_list.append(energy)
            forces_list.extend(forces)

            # Retrieve reference values
            ref_energy = md.potentialEnergies(i) / num_atoms  # Reference energy per atom
            ref_forces = md.forces(i).flatten()
            ref_energies_list.append(ref_energy)
            ref_forces_list.extend(ref_forces)

    energies_list = numpy.array(energies_list)
    ref_energies_list = numpy.array(ref_energies_list)
    nlsave(save_path, energies_list, object_id=f'energies_eV_per_atom_{potential_name}')
    nlsave(save_path, ref_energies_list, object_id=f'reference_energies_eV_per_atom_{potential_name}')
    forces_list = numpy.array(forces_list)
    ref_forces_list = numpy.array(ref_forces_list)
    nlsave(save_path, forces_list, object_id=f'forces_eVperAng_{potential_name}')
    nlsave(save_path, ref_forces_list, object_id=f'reference_forces_eVperAng_{potential_name}')


    # Compute RMSE for energy, forces, and stress
    energy_rmse = compute_rmse(energies_list, ref_energies_list)
    force_rmse = compute_rmse(forces_list, ref_forces_list)

    # Store results
    accuracy_benchmark_results[potential_name] = {
        "energy_rmse_per_atom": energy_rmse,
        "force_rmse": force_rmse,
    }

# Save results to a JSON file
output_file = "accuracy_benchmark_results.json"
with open(output_file, "w") as f:
    json.dump(accuracy_benchmark_results, f, indent=4)

nlprint(f"Accuracy benchmark results saved to {output_file}")

# =================================================================================================
# PERFORMANCE BENCHMARKING
# =================================================================================================
# Set up the sample system - TiSi2 amorphous with 4992 atoms
# We start with a sample amorphous structure of TiSi2 with 78 atoms and then repeat it to create a
# larger system with 4992 atoms.

# %% TiSi2_amorphous_78atoms

# Set up lattice
vector_a = [10.625693742234722, 0.0, 0.0]*Angstrom
vector_b = [0.0, 10.625693742234722, 0.0]*Angstrom
vector_c = [0.0, 0.0, 10.625693742234722]*Angstrom
lattice = UnitCell(vector_a, vector_b, vector_c)

# Define elements
elements = [Titanium, Titanium, Titanium, Titanium, Titanium, Titanium,
            Titanium, Titanium, Titanium, Titanium, Titanium, Titanium,
            Titanium, Titanium, Titanium, Titanium, Titanium, Titanium,
            Titanium, Titanium, Titanium, Titanium, Titanium, Titanium,
            Titanium, Titanium, Titanium, Titanium, Titanium, Titanium,
            Titanium, Titanium, Titanium, Titanium, Titanium, Titanium,
            Titanium, Titanium, Titanium, Silicon, Silicon, Silicon, Silicon,
            Silicon, Silicon, Silicon, Silicon, Silicon, Silicon, Silicon,
            Silicon, Silicon, Silicon, Silicon, Silicon, Silicon, Silicon,
            Silicon, Silicon, Silicon, Silicon, Silicon, Silicon, Silicon,
            Silicon, Silicon, Silicon, Silicon, Silicon, Silicon, Silicon,
            Silicon, Silicon, Silicon, Silicon, Silicon, Silicon, Silicon]

# Define coordinates
fractional_coordinates = [[ 0.090888840804,  0.430034440455,  0.083228527215],
                          [ 0.67230516415 ,  0.373854387657,  0.788498485252],
                          [ 0.30997368786 ,  0.941228089022,  0.63698886758 ],
                          [ 0.246133375882,  0.518241346326,  0.852444995443],
                          [ 0.058313349914,  0.139190353416,  0.732430554387],
                          [ 0.665380199606,  0.469443178644,  0.340258572067],
                          [ 0.871114957757,  0.372440279275,  0.566818199753],
                          [ 0.418145139008,  0.378543042779,  0.464820092452],
                          [ 0.372880199213,  0.15090861513 ,  0.090397984221],
                          [ 0.017167969015,  0.802401056665,  0.487185093555],
                          [ 0.025243434376,  0.621790602418,  0.931108848416],
                          [ 0.641590108368,  0.599759764188,  1.004190421896],
                          [ 0.518461427134,  0.877340312153,  0.007304749057],
                          [ 0.242825527866,  0.244322987974,  0.858962930657],
                          [ 0.128733275756,  1.028164225797,  0.510405915499],
                          [ 0.558027627618,  0.902265403418,  0.721728752948],
                          [ 0.838925101714,  0.060575194116,  0.152651948225],
                          [ 0.319243329388,  0.88924291312 ,  0.165059989167],
                          [ 1.078983404727,  0.886047646202,  0.818286165921],
                          [ 0.139629146876,  0.336911701001,  0.513043668324],
                          [ 0.955547875987,  0.22290020898 , -0.020851029771],
                          [ 0.335545217578,  0.649905355858,  0.063945905245],
                          [ 0.664245153037,  1.011415471854,  0.483814304373],
                          [ 0.751078848308,  0.771263692868,  0.561260337535],
                          [ 0.667623575118,  0.077034737307,  0.879728370082],
                          [ 0.410368042827,  0.685650950965,  0.318730581711],
                          [ 0.853566881307, -0.02188436702 ,  0.682338300628],
                          [ 0.380355438966,  0.713457560796,  0.653020323648],
                          [ 0.198987149742,  0.007210278085,  0.97399801997 ],
                          [ 0.734863997697,  0.199250566918,  0.335685523462],
                          [ 0.556681829862,  0.922097186236,  0.263468410058],
                          [ 0.023184892766,  0.218136216403,  0.323308232641],
                          [ 0.271681707565,  0.399756220001,  0.261525775641],
                          [ 0.064631528515,  0.975093259206,  0.257301666748],
                          [ 0.847925445366,  0.592555684468,  0.189891374758],
                          [ 0.110906190573,  0.649052955591,  0.238962571239],
                          [ 0.658927001028,  0.339504887395,  0.068132753176],
                          [ 0.411318475639,  0.411584850968,  0.974728530978],
                          [ 0.754864392876,  0.837957068087,  0.876121768854],
                          [ 0.845340451385,  0.214826201923,  0.759134150448],
                          [ 0.314571363056,  0.817467622628,  0.866784410945],
                          [ 0.648349565395,  0.330637562838,  0.551987079984],
                          [ 0.135824249838,  0.786563501672,  1.036811152795],
                          [ 0.485612059367,  0.288072963235,  0.250438071807],
                          [ 0.653254043091,  0.143034461021,  0.666477830695],
                          [ 0.260005963379,  0.167068868886,  0.342826425332],
                          [ 0.533720649375,  0.806626875133,  0.481482832891],
                          [ 0.921436540325,  0.154776547086,  0.544797616448],
                          [ 0.375786088313,  0.966871518345,  0.381227017387],
                          [ 0.492548036647,  0.513109331134,  0.165497659057],
                          [ 0.210833206136,  0.847890348758,  0.371140266205],
                          [ 1.031509463724,  0.445478512888,  0.322496337355],
                          [ 0.866659180905,  0.322313902832,  0.162265340345],
                          [ 0.508397230865,  0.175132791239,  0.457038934988],
                          [ 0.866848756365,  0.607412025401,  0.421885776959],
                          [ 0.920391935231,  0.741618667121,  0.713659716472],
                          [ 0.04699105457 ,  0.381060108897,  0.831896918711],
                          [ 0.322922086377,  0.181127596272,  0.631696422013],
                          [ 0.424087842944,  0.359875745175,  0.712947362232],
                          [ 0.634007965536,  0.65546715367 ,  0.764772798973],
                          [ 0.737087380689,  0.845185942351,  0.141722950563],
                          [ 0.498555228499,  0.224762892372,  0.879696433448],
                          [ 0.620987502613,  0.681678020615,  0.216978345687],
                          [ 0.618404394453,  0.107150515242,  0.118142407228],
                          [ 0.451133735816,  0.633604859239,  0.875353694979],
                          [ 0.406874001044,  1.041340886496,  0.864654886489],
                          [ 0.930316471308, -0.017928355973,  0.953811503169],
                          [ 0.84131565048 ,  0.444669608252,  0.967674588967],
                          [ 0.720588085491,  0.542132842703,  0.603305632464],
                          [ 0.50911494439 ,  0.592104331618,  0.507766357059],
                          [ 0.04901467146 ,  0.544739940968,  0.588470826724],
                          [ 0.232605120876,  0.581543563153,  0.43413067519 ],
                          [ 0.883329167398,  0.985059856324,  0.398205854245],
                          [ 0.157488018159,  0.703506715035,  0.700639002345],
                          [ 0.87974629708 ,  0.527717747216,  0.7658975975  ],
                          [ 0.300172518909,  0.49697251298 ,  0.620378634371],
                          [ 0.159740445781,  0.202629423467,  0.123940860961],
                          [ 0.941832914819,  0.806641467836,  0.172560004557]]

# Define velocities
velocities = [[-3.732156537470e-03,  1.559806818504e-03,  2.995320796778e-05],
              [ 3.840256380006e-03,  8.015103092250e-04, -2.848272080262e-03],
              [ 6.539906715783e-03,  2.979700306886e-03, -5.581067674912e-03],
              [-6.393021776520e-03, -3.589430624039e-03,  7.194844753844e-03],
              [-4.676434979139e-03,  3.545686697892e-03, -4.270551200281e-03],
              [ 3.356836497720e-03, -4.100606363423e-03,  3.395279678865e-03],
              [ 3.659387993901e-03,  8.245739219022e-03,  9.125053680288e-04],
              [-3.394917153306e-03, -3.090302697046e-03,  4.903877416279e-03],
              [-3.073048017902e-03,  1.366825421173e-03, -2.623611091697e-03],
              [ 1.731334402728e-03,  1.088294600374e-02, -4.758986945212e-03],
              [ 8.472661314250e-03,  3.093559101681e-03, -1.368636878053e-04],
              [-4.232579657881e-03,  5.220928774858e-03, -2.155962471377e-03],
              [-1.218618740194e-03, -6.590298421791e-03, -6.589828739022e-03],
              [ 1.806646866215e-03, -5.034282010868e-04,  3.947971659211e-03],
              [-5.901977953972e-03,  8.297764608319e-04, -5.428078315952e-03],
              [-5.897255378135e-05, -2.230182802157e-03,  3.860508505672e-03],
              [-7.166259560678e-03, -2.635628189030e-03, -6.406133056724e-04],
              [ 1.048046954352e-03, -1.990397819551e-03, -1.629709467377e-03],
              [ 2.122314284875e-03, -5.717007541258e-03,  8.871193165893e-03],
              [ 6.786709297449e-04,  6.838107287636e-03,  1.762991459730e-03],
              [ 1.896602698622e-03, -4.549462990767e-03,  5.866125204046e-03],
              [ 8.730121727352e-03, -6.158255540773e-04,  3.512219016097e-03],
              [ 3.821024552355e-03,  2.342198230897e-03, -6.054167892639e-03],
              [-5.703215774058e-03, -4.343703783013e-03,  1.610421380183e-03],
              [ 2.588321614260e-03, -3.400002102204e-03, -1.666642000617e-03],
              [ 2.112753073508e-03,  8.185301128007e-04,  1.682696849508e-03],
              [ 6.083872124356e-03,  1.590702309997e-04, -3.257346939895e-03],
              [-3.922543612478e-03,  1.729643598785e-03,  4.047250869730e-03],
              [ 8.726096112664e-03,  2.207878499846e-04,  1.279209530106e-03],
              [-8.486578257291e-03, -1.855143199812e-03, -2.574640686845e-03],
              [ 7.915111370658e-04,  5.081535831340e-03, -6.231777293418e-03],
              [-9.050734936297e-04,  8.973165905430e-04,  1.215224842934e-03],
              [ 5.978880124031e-03,  4.348243369616e-03,  5.588136540590e-03],
              [ 1.039482329777e-03,  8.516844253215e-03,  7.519366015766e-04],
              [-1.779626681460e-03, -7.724756080253e-03, -5.258420624349e-03],
              [-1.895391044277e-03, -2.313033350892e-03,  8.275377622994e-03],
              [ 3.217439731923e-03, -4.059581848320e-03, -3.847947872497e-03],
              [-8.057576349364e-03, -2.680909201223e-03,  8.948310375193e-04],
              [-1.723680098311e-03,  6.631668164850e-03, -9.331598919055e-04],
              [ 7.364910054293e-03, -3.806056544239e-03, -1.529798338851e-03],
              [ 3.761558266294e-03, -3.767261197745e-03, -4.718839070139e-03],
              [ 4.396145162663e-03, -3.714033346177e-03, -8.473891872715e-03],
              [-1.500772148163e-03,  3.440134417696e-03,  5.507550572042e-03],
              [-1.207838920122e-02,  1.153062049432e-02, -7.794494911865e-03],
              [ 6.792113343414e-03, -5.364491857507e-04, -3.019925570692e-03],
              [ 7.578591815833e-03, -8.211986877744e-04,  5.025520794735e-03],
              [ 6.868090609309e-04,  7.417198116302e-03,  1.112742573700e-02],
              [-5.612925879605e-03,  1.821097363230e-03, -2.999625056864e-03],
              [-1.502178087734e-02,  4.686034817086e-03, -6.423523120584e-03],
              [ 1.406774587936e-03, -4.824351459258e-03,  4.134484034139e-03],
              [ 9.318467211320e-03, -5.667046798358e-03,  7.208598735239e-03],
              [-3.802147798916e-04, -9.724783631318e-03, -7.707803359569e-03],
              [-1.022420804996e-02,  3.335350395112e-04,  1.358013315679e-03],
              [-2.905047385829e-03, -3.441672535928e-03, -6.322199751788e-03],
              [ 1.480051396942e-02, -1.907080431649e-03, -2.468231425103e-03],
              [-5.251949946696e-03,  1.188109025674e-02, -1.563452130744e-04],
              [-1.521967270566e-03, -1.232909132362e-03,  2.786740478024e-03],
              [ 4.692381493442e-03,  1.275309939253e-03,  3.878565308462e-03],
              [-3.902586102424e-03, -1.129679598034e-02, -2.400017131730e-03],
              [-4.837797720465e-03,  1.957091114959e-03, -4.148966694309e-03],
              [ 1.643111014778e-03, -1.542482588443e-03,  7.338262387887e-03],
              [ 2.606938429246e-03,  3.204549872196e-03, -3.802130477733e-03],
              [-4.151824634565e-03, -1.282261707032e-03,  1.206806692822e-02],
              [ 2.480450268915e-03, -6.119089506453e-03, -7.069634069089e-03],
              [-5.207309947566e-03, -9.972417836544e-03,  7.148571163622e-04],
              [ 5.052342127593e-04, -3.852650420651e-03,  2.223572472686e-03],
              [-6.371818775696e-03,  4.490552018681e-03,  3.536029104237e-03],
              [ 4.349604555820e-03,  1.241234110991e-03,  9.010933154649e-03],
              [ 8.519716368350e-03, -5.062687157752e-03, -3.637243784343e-03],
              [ 1.800779285415e-03,  6.251220796035e-03, -5.258259915983e-03],
              [-1.185964128440e-02,  1.151035351088e-02, -8.286994542520e-03],
              [ 3.294561320875e-03, -2.524023107902e-03,  8.080461393645e-03],
              [-4.310602567791e-03, -7.656080632369e-03, -4.748165472041e-03],
              [-4.308868760950e-03,  8.618505970200e-03, -6.131020631209e-03],
              [ 3.927748318372e-03, -8.222397470259e-03,  2.233542520072e-03],
              [-4.760222008774e-04, -1.038214821176e-02,  2.296396551773e-03],
              [-1.187129486079e-03, -1.173739746458e-04,  1.828270657218e-04],
              [ 9.217234231557e-04,  3.484178711462e-03,  2.448158464259e-03]]*Angstrom/fs

# Set up configuration
tisi2_amorphous_78atoms = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates,
    velocities=velocities
    )

tisi2_amorphous_4992atoms = tisi2_amorphous_78atoms.repeat(4,4,4)

# Dictionary to store benchmark results for each model
performance_benchmark_results = {}

for potential_number, (potential_name, qatkpt_path) in enumerate(potentials.items()):
    nlprint(f"Processing potential {potential_number + 1}/{len(potentials)}: {potential_name}")
    # Set up a calculator with the given potential
    if potential_name == 'mtp':
        # MTP case
        potentialSet = TremoloXPotentialSet(name='Moment Tensor Potential')
        potentialSet.addParticleType(
            ParticleType(
                symbol='Si',
                mass=28.0855 * atomicMassUnit,
                charge=None,
                sigma=None,
                sigma14=None,
                epsilon=None,
                epsilon14=None,
                atomicNumber=14,
                tags=[],
            )
        )
        potentialSet.addParticleType(
            ParticleType(
                symbol='Ti',
                mass=47.867 * atomicMassUnit,
                charge=None,
                sigma=None,
                sigma14=None,
                epsilon=None,
                epsilon14=None,
                atomicNumber=22,
                tags=[],
            )
        )
        _potential = MTPPotential(
            file=qatkpt_path,
            suppress_intercept=False,
            group_name='',
        )
        potentialSet.addPotential(_potential)
    elif potential_name == 'foundation':
        # Foundation MACE model case
        potentialSet = TorchX_MACE_MP_0b3_medium(dtype='float32', enforceLTX=False)
    else:
        # Trained/finetuned MACE model case
        model_path = qatkpt_path

        potentialSet = TremoloXPotentialSet(name='my_mace_potential')
        potential = TorchXPotential(
            dtype='float32',
            device='cuda',
            file=model_path,
        )
        for symbol in potential.get_symbols(model_path):
            potentialSet.addParticleType(ParticleType(symbol))
        potentialSet.addPotential(potential)

    calculator = TremoloXCalculator(parameters=potentialSet)

    # Attach calculator to the configuration
    configuration = tisi2_amorphous_4992atoms.copy()
    configuration.setCalculator(calculator)

    num_atoms = len(configuration)
    num_steps = 1000

    # Run 1000 steps of MD

    method = NVTBerendsen()

    # Start time
    start_time = time.time()

    md_trajectory = MolecularDynamics(
        configuration=configuration,
        trajectory_filename=None,
        steps=num_steps,
        method=method,
    )
    last_image = md_trajectory.lastImage()

    # End time
    end_time = time.time()
    time_per_atom_per_step = (end_time - start_time) / (num_atoms * num_steps)

    # Store results
    performance_benchmark_results[potential_name] = {
        "avg_time_per_atom_per_step": time_per_atom_per_step
    }

# Save results to a JSON file
output_file = "performance_benchmark_results.json"
with open(output_file, "w") as f:
    json.dump(performance_benchmark_results, f, indent=4)

nlprint(f"Performance benchmark results saved to {output_file}")
