###############################################################
#  OCV for Li_x NMC (622, 811, etc.) using Li-removal on SQS supercell
###############################################################
setVerbosity(MinimalLog)

from NanoLanguage import *
import random
import matplotlib.pyplot as plt

###############################################################
# 0. USER INPUT
###############################################################

# ---- Choose target cathode composition ----
# E.g., "NMC811", "NMC622", "NMC532"
NMC_TAG = "NMC721"

# ---- Input optimized SQS supercell ----
NMC_SQS_FILE = f"{NMC_TAG}_SQS_final_optimized.hdf5"

# ---- Output ----
output_hdf5 = f"ocv_{NMC_TAG}_all.hdf5"
output_best_hdf5 = f"ocv_{NMC_TAG}_best.hdf5"
output_plot = f"ocv_{NMC_TAG}.png"

# ---- Lithiation grid (modifiable) ----
x_target = numpy.linspace(0.0, 1.0, 21)  # 17 points (0, 0.0625, ..., 1)
x_target = [float(x) for x in x_target]

# ---- Sampling settings ----
n_samples_per_x = 150  # 1 cluster + 1 uniform + 78 random

# ---- Geometry relaxation settings ----
max_forces = 0.01 * eV/Angstrom
max_steps  = 1000
OPTIMIZE_CELL = True   # Change to False if you want to keep SQS fixed


###############################################################
# 1. Load SQS structure & identify Li sublattice
###############################################################

sqs_struct = nlread(NMC_SQS_FILE, BulkConfiguration)[-1]

elems = sqs_struct.elements()
Li_indices = [i for i, e in enumerate(elems) if e.symbol() == "Li"]

n_total_Li = len(Li_indices)
nlprint(f"{NMC_TAG}: Total Li sites in supercell = {n_total_Li}")

# Convert x_target -> integer Li counts
nLi_list = sorted(set(int(round(x * n_total_Li)) for x in x_target))
nLi_list = [n for n in nLi_list if 0 <= n <= n_total_Li]
x_list = [n / n_total_Li for n in nLi_list]

nlprint(f"nLi_list: {nLi_list}")
nlprint(f"Effective x_list: {x_list}")

# Store Li coordinates
cart_all = numpy.array(sqs_struct.cartesianCoordinates().inUnitsOf(Angstrom))
Li_coords = cart_all[Li_indices]


###############################################################
# 2. Vacancy pattern sampling (FPS, clustered, random)
###############################################################

def fps_selection(coords, k):
    """Farthest-point sampling for uniform vacancy distribution."""
    N = coords.shape[0]
    if k >= N:
        return list(range(N))
    cm = coords.mean(axis=0)
    d0 = numpy.linalg.norm(coords - cm, axis=1)
    chosen = [int(numpy.argmax(d0))]
    while len(chosen) < k:
        remaining = [i for i in range(N) if i not in chosen]
        dmat = coords[remaining][:,None,:] - coords[chosen][None,:,:]
        dist = numpy.linalg.norm(dmat, axis=2)
        min_to_chosen = dist.min(axis=1)
        next_idx = remaining[int(numpy.argmax(min_to_chosen))]
        chosen.append(next_idx)
    return sorted(chosen)

def cluster_selection(coords, k):
    """Clustered vacancies: remove Li from one region."""
    N = coords.shape[0]
    if k >= N:
        return list(range(N))
    cm = coords.mean(axis=0)
    d0 = numpy.linalg.norm(coords - cm, axis=1)
    chosen = [int(numpy.argmin(d0))]
    while len(chosen) < k:
        remaining = [i for i in range(N) if i not in chosen]
        dmat = coords[remaining][:,None,:] - coords[chosen][None,:,:]
        dist = numpy.linalg.norm(dmat, axis=2)
        nearest = dist.min(axis=1)
        next_idx = remaining[int(numpy.argmin(nearest))]
        chosen.append(next_idx)
    return sorted(chosen)

def random_selection(N, k, seed=0):
    rng = random.Random(seed)
    return sorted(rng.sample(range(N), k))


###############################################################
# 3. Remove Li atoms
###############################################################

def remove_atoms(cfg, remove_list):
    elems = list(cfg.elements())
    coords = list(cfg.cartesianCoordinates())
    new_elems = [e for i,e in enumerate(elems) if i not in remove_list]
    new_coords = [c for i,c in enumerate(coords) if i not in remove_list]
    return BulkConfiguration(
        cfg.bravaisLattice(),
        new_elems,
        cartesian_coordinates=new_coords
    )


###############################################################
# 4. Build Li_x configurations
###############################################################

def build_configs_for_nLi(sqs_ref, nLi, Li_indices, Li_coords,
                          n_random=10, seed=0):
    
    n_total = len(Li_indices)
    nVac = n_total - nLi
    configs = []

    # Endpoints
    if nVac == 0:
        configs.append((sqs_ref.copy(), nLi, "full"))
        return configs
    if nLi == 0:
        cfg0 = remove_atoms(sqs_ref, Li_indices)
        configs.append((cfg0, 0, "empty"))
        return configs

    # Clustered
    vac_loc = cluster_selection(Li_coords, nVac)
    vac_glob = [Li_indices[i] for i in vac_loc]
    configs.append((remove_atoms(sqs_ref, vac_glob), nLi, "vac_cluster"))

    # Uniform (FPS)
    vac_loc = fps_selection(Li_coords, nVac)
    vac_glob = [Li_indices[i] for i in vac_loc]
    configs.append((remove_atoms(sqs_ref, vac_glob), nLi, "vac_uniform"))

    # Randoms
    for r in range(n_random):
        vac_loc = random_selection(n_total, nVac, seed+r)
        vac_glob = [Li_indices[i] for i in vac_loc]
        configs.append((remove_atoms(sqs_ref, vac_glob), nLi, f"vac_random_{r}"))

    return configs


###############################################################
# 5. Relaxation + energy evaluation
###############################################################

def get_calculator():
    pot = TorchX_MatterSim_v1_0_0_5M(dtype='float32', enforceLTX=False)
    return TremoloXCalculator(parameters=pot)

def relax_and_energy(cfg):
    calc = get_calculator()
    cfg.setCalculator(calc)
    cfg.update()

    opt = OptimizeGeometry(cfg,
                           max_forces=max_forces,
                           max_steps=max_steps,
                           optimize_cell=OPTIMIZE_CELL)

    E = float(TotalEnergy(opt).evaluate().inUnitsOf(eV))
    return E, opt


###############################################################
# 6. Compute Li metal reference (needed for voltage)
###############################################################

Li_bulk = nlread("Li_metal4.hdf5", BulkConfiguration)[-1]
Li_bulk.setCalculator(get_calculator())
Li_bulk.update()

opt_Li = OptimizeGeometry(Li_bulk,
                          max_forces=max_forces,
                          max_steps=max_steps,
                          optimize_cell=True)

E_Li_atom = float(TotalEnergy(opt_Li).evaluate().inUnitsOf(eV)) / len(opt_Li.elements())
nlprint(f"Li metal reference: E_Li_atom = {E_Li_atom} eV")


###############################################################
# 7. Loop over all x values
###############################################################

E_x = []
nLi_x = []
best_labels = []

for x, nLi in zip(x_list, nLi_list):

    nlprint(f"\n=== x = {x:.3f}, nLi = {nLi} ===")

    configs = build_configs_for_nLi(
        sqs_struct, nLi,
        Li_indices, Li_coords,
        n_random = max(0, n_samples_per_x - 2)
    )

    best_E = 1e99
    best_cfg = None
    best_label = None

    for icand, (cfg, this_nLi, label) in enumerate(configs):
        nlprint(f"Optimizing x={x:.3f} candidate {icand}: {label}")

        E, optcfg = relax_and_energy(cfg)

        # Save all data
        tag = f"x_{x:.3f}".replace('.','p') + f"_{label}_s{icand}"
        nlsave(output_hdf5, optcfg, object_id=f"cfg_{tag}")
        nlsave(output_hdf5, TotalEnergy(optcfg), object_id=f"E_{tag}")

        if E < best_E:
            best_E = E
            best_cfg = optcfg
            best_label = label

    E_x.append(best_E)
    nLi_x.append(nLi)
    best_labels.append(best_label)

    tag = f"x_{x:.3f}".replace('.','p')
    nlsave(output_best_hdf5, best_cfg, object_id=f"cfg_{tag}")
    nlsave(output_best_hdf5, TotalEnergy(best_cfg), object_id=f"E_{tag}")

# Print all chosen configurations at the end
nlprint("\n" + "+" + "-" * 78 + "+")
nlprint("Summary of chosen ground state configurations:")
nlprint("+" + "-" * 78 + "+")
for x, E, nLi, label in zip(x_list, E_x, nLi_x, best_labels):
    nlprint(f"Chosen E(x={x:.3f}) = {E} eV   nLi = {nLi}   tag={label}")


###############################################################
# 8. Compute voltages
###############################################################

voltages = []
x_mid = []

for i in range(len(x_list)-1):
    x1, x2 = x_list[i], x_list[i+1]
    E1, E2 = E_x[i], E_x[i+1]
    n1, n2 = nLi_x[i], nLi_x[i+1]

    dn = n2 - n1
    V = - (E2 - E1 - dn * E_Li_atom) / dn

    voltages.append(V)
    x_mid.append(0.5*(x1+x2))

    nlprint(f"V({x1:.3f}->{x2:.3f}) = {V:.3f} V")


###############################################################
# 9. Plot OCV curve
###############################################################

plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'DejaVu Serif'],
    'font.size': 10,
    'axes.labelsize': 11,
    'axes.titlesize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 9,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'lines.linewidth': 1.5,
    'axes.linewidth': 1.2,
    'grid.linewidth': 0.5,
})

fig, ax = plt.subplots(figsize=(4.5, 3.5))

if len(voltages) > 0:
    ax.step([x_list[0]] + x_mid + [x_list[-1]], 
            [voltages[0]] + voltages + [voltages[-1]],
            where='mid', linewidth=2.5, color='#2E86AB', label='Average OCV', zorder=3)

ax.scatter(x_mid, voltages, s=70, facecolors='#E63946',
           edgecolors='#A4031F', linewidths=1.5, alpha=0.9, zorder=4)

ax.set_xlabel(r'Lithiation $x$ (Li$_x$' + NMC_TAG + ')')
ax.set_ylabel('OCV [V]')
ax.set_title(r'OCV vs $x$ for Li$_x$' + NMC_TAG, pad=15)
ax.minorticks_on()

legend = ax.legend(loc='best', frameon=True, shadow=True, fancybox=True, framealpha=0.95)
legend.get_frame().set_edgecolor('#2E86AB')
legend.get_frame().set_linewidth(1.2)

ax.set_xlim(0, 1)

plt.tight_layout()
plt.savefig(output_plot, dpi=600, bbox_inches='tight')
plt.close()

