# Set minimal log verbosity
setVerbosity(MinimalLog)

from NanoLanguage import *
import numpy as np
import random
import matplotlib.pyplot as plt
from math import sqrt
from NL.CommonConcepts.Configurations.Utilities import fractionalToCartesian

# --- User settings ----------------------------------------------------------
from configuration import FePO4, LiFePO4, Li_metal4

output_hdf5 = 'ocv_LiFePO4_all.hdf5'

# Choose repetition of primitive unit cell
repeat_cells = (2, 4, 1)  # (na, nb, nc)

# x_list will be generated automatically after determining Li_sites_total
n_samples_per_x = 150   # sample different Li/vacancy orderings (>=1). 
                      # This will produce: 1 clustered + 1 uniform + (n_samples_per_x-2) randoms per x.

# Optimization settings (consistent across all runs)
max_forces = 0.01 * eV/Angstrom
max_steps = 1000

# --- Choose calculator here ------------------------------------------------
# Options: 'DFT' (use DFTcalculator functions), 'MATTERSIM'.
CALCULATOR = 'MATTERSIM'   # <-- change to 'MATTERSIM' or 'DFT' if you have DFTcalculator

# --- Calculator factory -----------------------------------------------------
def get_calculator(charge=None):
    if CALCULATOR.upper() == 'DFT':
        try:
            from DFTcalculator import calculator_hse06_medium as dft_factory
        except Exception as e:
            raise RuntimeError("DFT selected but cannot import calculator_hse06_medium from DFTcalculator: " + str(e))
        return dft_factory(charge=charge)

    elif CALCULATOR.upper() == 'MATTERSIM':
        try:
            potentialSet = TorchX_MatterSim_v1_0_0_5M(dtype='float32', enforceLTX=False)
            return TremoloXCalculator(parameters=potentialSet)
        except Exception as e:
            raise RuntimeError("MatterSim selected but required binding not available: " + str(e))
    else:
        raise RuntimeError(f"Unknown CALCULATOR '{CALCULATOR}'. Valid: 'DFT','MATTERSIM'")

# --- Basic helpers ----------------------------------------------------------
def repeat_configuration(cfg, rep):
    na, nb, nc = rep
    return cfg.repeat(na, nb, nc)

def get_Li_sites_from_LiFePO4(cfg_full):
    # returns indices of Li atoms in the template configuration
    return [i for i, a in enumerate(cfg_full.atomicNumbers()) if a == Lithium.atomicNumber()]

def add_atoms_once(base_cfg, new_elements, new_coords):
    """
    Add multiple atoms to base configuration in a single operation.
    Call once with all atoms, not in a loop. Returns new BulkConfiguration.
    """
    elems = list(base_cfg.elements())
    coords = list(base_cfg.cartesianCoordinates())
    for e, c in zip(new_elements, new_coords):
        elems.append(e)
        coords.append(c)
    return BulkConfiguration(
        bravais_lattice = base_cfg.bravaisLattice(),
        elements = elems,
        cartesian_coordinates = coords
    )

# --- Sampling helper functions (clustered, uniform, random) -----------------
def _get_frac_positions(cfg):
    """Return list of fractional coordinates for the configuration (ordered)."""
    return list(cfg.fractionalCoordinates())

def _get_distance_matrix_cart(frac_positions, host_cfg):
    """
    Compute minimum-image pairwise distances (Cartesian) for fractional coords
    mapped into the host cell.
    frac_positions: list of 3-tuples (fractional coords) in same ordering as template.
    """
    frac = np.array(frac_positions, dtype=float)
    N = len(frac)
    # fractional pair differences with PBC wrap
    frac_diff = frac.reshape(N,1,3) - frac.reshape(1,N,3)
    frac_diff -= np.round(frac_diff)   # wrap into [-0.5,0.5]
    # convert fractional differences to cartesian using host primitive vectors
    cell = host_cfg.primitiveVectors()
    a = np.asarray(cell[0])  # these are PhysicalQuantity objects; numpy will convert
    b = np.asarray(cell[1])
    c = np.asarray(cell[2])
    # broadcast linear combination
    cart_diff = frac_diff[:,:,0:1]*a + frac_diff[:,:,1:2]*b + frac_diff[:,:,2:3]*c
    dist = np.linalg.norm(cart_diff, axis=2)
    return dist

def _maxmin_greedy_selection(frac_positions, host_cfg, k):
    """
    Farthest-point / max-min greedy selection among given fractional positions.
    frac_positions: list (reduced) - only positions corresponding to Li candidate sites.
    returns indices (0..len(frac_positions)-1) chosen.
    """
    N = len(frac_positions)
    if k >= N:
        return list(range(N))
    dist = _get_distance_matrix_cart(frac_positions, host_cfg)
    avgd = dist.mean(axis=1)
    chosen = [int(np.argmax(avgd))]   # start at most remote on average
    while len(chosen) < k:
        remaining = [i for i in range(N) if i not in chosen]
        # compute distance from each remaining to the chosen set (min)
        min_to_chosen = np.min(dist[np.ix_(remaining, chosen)], axis=1)
        pick = remaining[int(np.argmax(min_to_chosen))]
        chosen.append(pick)
    return sorted(chosen)

def _cluster_greedy_selection(frac_positions, host_cfg, k):
    """
    Greedy clustering selection: pick a seed at smallest avg dist (most central),
    then iteratively add nearest neighbours to the cluster.
    returns indices in reduced frac_positions list.
    """
    N = len(frac_positions)
    if k >= N:
        return list(range(N))
    dist = _get_distance_matrix_cart(frac_positions, host_cfg)
    avgd = dist.mean(axis=1)
    seed = int(np.argmin(avgd))
    chosen = [seed]
    while len(chosen) < k:
        remaining = [i for i in range(N) if i not in chosen]
        min_to_cluster = np.min(dist[np.ix_(remaining, chosen)], axis=1)
        pick = remaining[int(np.argmin(min_to_cluster))]
        chosen.append(pick)
    return sorted(chosen)

def build_configuration_for_x_with_modes(base_FePO4_rep, base_LiFePO4_rep, x, seed=None,
                                         include_clustered=True, include_uniform=True, n_random=1):
    """
    Build multiple candidate configurations for a given x.
    Returns list of tuples (cfg, nLi, label).
    label ∈ {'clustered','uniform','random_i','empty','full'}.
    """
    Li_indices = get_Li_sites_from_LiFePO4(base_LiFePO4_rep)
    total_sites = len(Li_indices)
    if total_sites == 0:
        raise ValueError("No Li sites found; check LiFePO4 builder or repetition.")
    nLi = int(round(x * total_sites))

    # endpoints
    if nLi == 0:
        return [(base_FePO4_rep.copy(), 0, 'empty')]
    if nLi == total_sites:
        return [(base_LiFePO4_rep.copy(), nLi, 'full')]

    # fractional coordinates for ALL atoms in LiFePO4 template (we will index by Li_indices)
    frac_all = _get_frac_positions(base_LiFePO4_rep)
    host_cfg = base_FePO4_rep

    chosen_configurations = []

    # Prepare reduced list of fractional coords for Li candidate sites only
    frac_li_sites = [frac_all[i] for i in Li_indices]

    # Clustered sample
    if include_clustered:
        cluster_idx_list = _cluster_greedy_selection(frac_li_sites, host_cfg, nLi)
        chosen_global = [Li_indices[i] for i in cluster_idx_list]
        # collect coords once
        new_coords = []
        for idx in chosen_global:
            frac = frac_all[idx]
            pos_host = fractionalToCartesian(frac, host_cfg.primitiveVectors())
            new_coords.append(pos_host)
        cfg = add_atoms_once(base_FePO4_rep, [Lithium]*len(new_coords), new_coords)
        chosen_configurations.append((cfg, nLi, 'clustered'))

    # Uniform (max-min) sample
    if include_uniform:
        uniform_idx_list = _maxmin_greedy_selection(frac_li_sites, host_cfg, nLi)
        chosen_global = [Li_indices[i] for i in uniform_idx_list]
        new_coords = []
        for idx in chosen_global:
            frac = frac_all[idx]
            pos_host = fractionalToCartesian(frac, host_cfg.primitiveVectors())
            new_coords.append(pos_host)
        cfg = add_atoms_once(base_FePO4_rep, [Lithium]*len(new_coords), new_coords)
        chosen_configurations.append((cfg, nLi, 'uniform'))

    # Random samples
    rng = random.Random(seed)
    for r in range(n_random):
        rand_choice = sorted(rng.sample(Li_indices, nLi))
        new_coords = []
        for idx in rand_choice:
            frac = frac_all[idx]
            pos_host = fractionalToCartesian(frac, host_cfg.primitiveVectors())
            new_coords.append(pos_host)
        cfg = add_atoms_once(base_FePO4_rep, [Lithium]*len(new_coords), new_coords)
        chosen_configurations.append((cfg, nLi, f'random_{r}'))

    return chosen_configurations

# --- Energy / relaxation wrapper -------------------------------------------
def optimize_and_get_energy(cfg, calc, traj_name):
    cfg.setCalculator(calc)
    cfg.update()
    try:
        optimized_cfg = OptimizeGeometry(configuration=cfg,
                                         max_forces=max_forces,
                                         max_steps=max_steps,
                                         optimize_cell=True)
                                         # trajectory_filename=traj_name)
    except Exception as e:
        nlprint("OptimizeGeometry failed:", e)
        optimized_cfg = cfg
    E = TotalEnergy(optimized_cfg).evaluate().inUnitsOf(eV)
    return float(E), optimized_cfg

# --- Prepare base configurations -------------------------------------------
nlprint(f"Repeating unit cells with: {repeat_cells}")
base_FePO4_rep = repeat_configuration(FePO4(), repeat_cells)
base_LiFePO4_rep = repeat_configuration(LiFePO4(), repeat_cells)
base_Li_bulk = Li_metal4()

# Sanity: Li site count
Li_sites_total = len(get_Li_sites_from_LiFePO4(base_LiFePO4_rep))
nlprint(f"Total Li sites in repeated LiFePO4: {Li_sites_total}")
if Li_sites_total == 0:
    raise RuntimeError("No Li sites detected after repetition; aborting.")

# Generate x_list to guarantee integer Li counts
x_list = [n / Li_sites_total for n in range(Li_sites_total + 1)]
# Optional: subsample to reduce workload (uncomment and adjust k as needed)
k = 2
x_list = [x_list[i] for i in range(0, len(x_list), k)]
nlprint(f"Generated x_list with {len(x_list)} points: {[f'{x:.3f}' for x in x_list]}")

# Show mapping from requested x to actual integer nLi and actual x_used
for x in x_list:
    nLi = int(round(x * Li_sites_total))
    x_used = float(nLi) / float(Li_sites_total)
    if not np.isclose(x * Li_sites_total, nLi):
        nlprint(f"Warning: requested x={x:.3f} -> rounded nLi={nLi} -> actual x_used={x_used:.3f}")
    else:
        nlprint(f"x={x:.3f} ok -> nLi={nLi}")

# Build and relax Li metal bulk using the same calculator type
nlprint(f"Using calculator: {CALCULATOR}")
calc = get_calculator()
base_Li_bulk.setCalculator(calc)
base_Li_bulk.update()
try:
    opt_Li_bulk = OptimizeGeometry(configuration=base_Li_bulk,
                                   max_forces=max_forces,
                                   max_steps=max_steps,
                                   optimize_cell=True)
    E_Li_bulk = float(TotalEnergy(opt_Li_bulk).evaluate().inUnitsOf(eV))
except Exception as e:
    nlprint(f"Li bulk optimize failed; Error: {e}")
    raise

nlprint(f"E(Li bulk supercell) [eV]: {E_Li_bulk}")
n_Li_atoms_bulk = len(opt_Li_bulk.atomicNumbers())
E_Li_per_atom = E_Li_bulk / n_Li_atoms_bulk
nlprint(f"E(Li) per atom [eV]: {E_Li_per_atom}")

# --- Loop over compositions ------------------------------------------------
energies = {}   # energies[x] = list of (E_supercell, nLi, label)
for x in x_list:
    energies[x] = []
    # Endpoints (x=0 and x=1) handled simply (one relaxation)
    if np.isclose(x, 0.0):
        nlprint("x=0 endpoint: using FePO4 template")
        cfg0 = base_FePO4_rep.copy()
        E, optcfg = optimize_and_get_energy(cfg0, get_calculator(), "opt_x0_endpoint.hdf5")
        energies[x].append((E, 0, 'endpoint'))
        nlsave(output_hdf5, TotalEnergy(optcfg), object_id=f"Etot_x0_endpoint")
        nlsave(output_hdf5, optcfg, object_id=f"optcfg_x0_endpoint")
        continue
    if np.isclose(x, 1.0):
        nlprint("x=1 endpoint: using LiFePO4 template")
        cfg1 = base_LiFePO4_rep.copy()
        E, optcfg = optimize_and_get_energy(cfg1, get_calculator(), "opt_x1_endpoint.hdf5")
        total_li_sites = Li_sites_total
        energies[x].append((E, total_li_sites, 'endpoint'))
        nlsave(output_hdf5, TotalEnergy(optcfg), object_id=f"Etot_x1_endpoint")
        nlsave(output_hdf5, optcfg, object_id=f"optcfg_x1_endpoint")
        continue

    # For intermediate x, build candidate set: clustered + uniform + randoms
    nlprint(f"Building candidates for x={x:.3f}")
    n_random = max(0, n_samples_per_x - 2)
    candidates = build_configuration_for_x_with_modes(base_FePO4_rep, base_LiFePO4_rep, x,
                                                     seed=42, include_clustered=True,
                                                     include_uniform=True, n_random=n_random)
    nlprint(f" -> {len(candidates)} candidates generated for x={x:.3f}")

    # Relax each candidate and store energies with labels
    for icand, (cfg_cand, nLi, label) in enumerate(candidates):
        calc_local = get_calculator()
        x_str = f"{x:.3f}".replace('.','p')
        traj = f"opt_x{x_str}_{label}_s{icand}.hdf5"
        nlprint(f"Optimizing x={x:.3f} candidate {icand} label={label} nLi={nLi}")
        E, optcfg = optimize_and_get_energy(cfg_cand, calc_local, traj)
        energies[x].append((E, nLi, label))
        # Save
        objid_e = f"Etot_x{x_str}_{label}_s{icand}"
        objid_cfg = f"optcfg_x{x_str}_{label}_s{icand}"
        nlsave(output_hdf5, TotalEnergy(optcfg), object_id=objid_e)
        nlsave(output_hdf5, optcfg, object_id=objid_cfg)

# --- Postprocessing and OCV computation -----------------------------------
E_x = {}
nLi_x = {}
for x in x_list:
    sample_list = energies[x]
    if len(sample_list) == 0:
        raise RuntimeError(f"No samples computed for x={x:.3f}")
    # pick minimum-energy sample
    best = min(sample_list, key=lambda t: t[0])
    E_x[x] = best[0]
    nLi_x[x] = best[1]
    nlprint(f"Chosen E(x={x:.3f}) = {E_x[x]} eV   nLi = {nLi_x[x]}   tag={best[2]}")

# --- Compute voltages ------------------------------------------------------
voltages = []
x_mid = []
for i in range(len(x_list)-1):
    x_i = x_list[i]; x_j = x_list[i+1]
    E_i = E_x[x_i]; E_j = E_x[x_j]
    nLi_i = nLi_x[x_i]; nLi_j = nLi_x[x_j]
    delta_n = nLi_j - nLi_i
    if delta_n <= 0:
        raise ValueError("x_list must be strictly increasing in terms of Li count (nLi).")
    V = - (E_j - E_i - delta_n * E_Li_per_atom) / float(delta_n)
    voltages.append(V)
    x_mid.append(0.5*(x_i+x_j))
    nlprint(f"V (x {x_i:.3f} -> {x_j:.3f}) = {V:.3f} V")

# --- Plotting --------------------------------------
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'DejaVu Serif'],
    'font.size': 10,
    'axes.labelsize': 11,
    'axes.titlesize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 9,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'lines.linewidth': 1.5,
    'axes.linewidth': 1.2,
    'grid.linewidth': 0.5,
})
fig, ax = plt.subplots(figsize=(4.5, 3.5))
if len(voltages) > 0:
    ax.step([0] + x_mid + [1], [voltages[0]] + voltages + [voltages[-1]],
            where='mid', linewidth=2.5, color='#2E86AB', label='Average OCV', zorder=3)
ax.scatter(x_mid, voltages, s=80, facecolors='#E63946',
           edgecolors='#A4031F', linewidths=1.5, alpha=0.9, label='Calculated points', zorder=4)
ax.set_xlabel(r'Lithiation $x$ (Li$_x$FePO$_4$)')
ax.set_ylabel('OCV [V]')
ax.set_title(r'OCV vs $x$ for Li$_x$FePO$_4$', pad=15)
ax.minorticks_on()
legend = ax.legend(loc='best', frameon=True, shadow=True, fancybox=True, framealpha=0.95)
legend.get_frame().set_edgecolor('#2E86AB'); legend.get_frame().set_linewidth(1.2)
ax.set_xlim(0,1)
plt.tight_layout()
plt.savefig('ocv_vs_x.png', dpi=600, bbox_inches='tight')

with open('ocv_vs_x.txt','w') as f:
    f.write("# x_mid   V (V)\n")
    for xm, V in zip(x_mid, voltages):
        f.write(f"{xm:.3f}   {V:.3f}\n")