from NanoLanguage import *
import numpy as np
import matplotlib.pyplot as plt
import os

# =============================================================
# User input
# =============================================================

nmc_files = {
    "NMC532": "/remote/quantumatk/atik/Battery-OCV/NMC/OCV-NMC532/ocv_NMC532_best.hdf5",
    "NMC622": "/remote/quantumatk/atik/Battery-OCV/NMC/OCV-NMC622/ocv_NMC622_best.hdf5",
    "NMC721": "/remote/quantumatk/atik/Battery-OCV/NMC/OCV-NMC721/ocv_NMC721_best.hdf5",
    "NMC811": "/remote/quantumatk/atik/Battery-OCV/NMC/OCV-NMC811-2/ocv_NMC811_best.hdf5",
}

colors = {
    "NMC532": "#F80000",  # Red
    "NMC622": "#F77F00",  # Orange
    "NMC721": "#06A77D",  # Green
    "NMC811": "#7e005f",  # Purple
}

markers = {
    "NMC532": "o",        # Circle
    "NMC622": "s",        # Square
    "NMC721": "^",        # Triangle up
    "NMC811": "D",        # Diamond
}

# x-grid used during OCV calculation
# MUST match the x_list used in the OCV workflow
x_list = np.linspace(0.0, 1.0, 21)
x_list = [float(x) for x in x_list]

# =============================================================
# Plot style (light mode, journal-ready)
# =============================================================

plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'DejaVu Serif'],
    'font.size': 10,
    'axes.labelsize': 11,
    'axes.titlesize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 9,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'lines.linewidth': 1.5,
    'axes.linewidth': 1.2,
    'grid.linewidth': 0.5,
    'xtick.major.width': 1.2,
    'ytick.major.width': 1.2,
    'xtick.minor.width': 0.8,
    'ytick.minor.width': 0.8,
})

# =============================================================
# Helper: extract volume and c-axis by object_id
# =============================================================

def extract_structural_data(hdf5_file, x_list):
    x_vals = []
    volumes = []
    c_axis = []

    for x in x_list:
        tag = f"x_{x:.3f}".replace(".", "p")
        obj_id = f"cfg_{tag}"

        try:
            cfg = nlread(hdf5_file, BulkConfiguration, object_id=obj_id)[0]
        except Exception:
            # Skip compositions that were not saved
            continue

        lattice = cfg.bravaisLattice()
        a_vec, b_vec, c_vec = lattice.primitiveVectors()

        a = np.array(a_vec.inUnitsOf(Angstrom))
        b = np.array(b_vec.inUnitsOf(Angstrom)) 
        c = np.array(c_vec.inUnitsOf(Angstrom))

        V = abs(np.dot(a, np.cross(b, c)))
        c_len = np.linalg.norm(c)

        x_vals.append(x)
        volumes.append(V)
        c_axis.append(c_len)

    return np.array(x_vals), np.array(volumes), np.array(c_axis)

# =============================================================
# Plot 1: Relative volume change
# =============================================================

nlprint("\n" + "="*60)
nlprint("Processing NMC volume data...")
nlprint("="*60)

fig, ax = plt.subplots(figsize=(4.5, 3.5))

for name, file in nmc_files.items():
    if not os.path.exists(file):
        nlprint(f"Error: {file} not found")
        raise FileNotFoundError(f"{file} not found")
    
    nlprint(f"\nLoading {name} data from: {file}")
    x, V, _ = extract_structural_data(file, x_list)
    nlprint(f"  -> Extracted {len(x)} data points for {name}")

    # Reference at fully lithiated state (x ≈ 1)
    idx_ref = np.argmax(x)
    V_ref = V[idx_ref]

    dV_rel = (V - V_ref) / V_ref * 100.0

    ax.plot(x, dV_rel, marker=markers[name], markersize=5, label=name, 
            color=colors[name], markerfacecolor=colors[name], markeredgecolor='white', 
            markeredgewidth=0.5, linewidth=1.5)

ax.set_xlabel(r"Lithiation $x$ in Li$_x$NMC")
ax.set_ylabel(r"$\Delta V / V_0$ [%]")
ax.set_title("Relative Volume Change vs Lithiation", pad=15)
ax.minorticks_on()

# Improved legend styling
legend = ax.legend(loc='best', frameon=True, shadow=True, 
                   fancybox=True, framealpha=0.95)
legend.get_frame().set_edgecolor('#2E86AB')
legend.get_frame().set_linewidth(1.2)

ax.set_xlim(0, 1)

plt.tight_layout()
plt.savefig("NMC_volume_vs_lithiation.png", dpi=600, bbox_inches='tight')
plt.close()

nlprint("\n[+] Volume plot saved: NMC_volume_vs_lithiation.png")

# =============================================================
# Plot 2: Relative c-axis change
# =============================================================

nlprint("\n" + "="*60)
nlprint("Processing NMC c-axis data...")
nlprint("="*60)

fig, ax = plt.subplots(figsize=(4.5, 3.5))

for name, file in nmc_files.items():
    nlprint(f"\nLoading {name} c-axis data...")
    x, _, c = extract_structural_data(file, x_list)
    nlprint(f"  -> Extracted {len(c)} c-axis values for {name}")

    idx_ref = np.argmax(x)
    c_ref = c[idx_ref]

    dc_rel = (c - c_ref) / c_ref * 100.0

    ax.plot(x, dc_rel, marker=markers[name], markersize=5, label=name, 
            color=colors[name], markerfacecolor=colors[name], markeredgecolor='white', 
            markeredgewidth=0.5, linewidth=1.5)

ax.set_xlabel(r"Lithiation $x$ in Li$_x$NMC")
ax.set_ylabel(r"$\Delta c / c_0$ [%]")
ax.set_title("Relative c-axis Change vs Lithiation", pad=15)
ax.minorticks_on()

# Improved legend styling
legend = ax.legend(loc='best', frameon=True, shadow=True, 
                   fancybox=True, framealpha=0.95)
legend.get_frame().set_edgecolor('#2E86AB')
legend.get_frame().set_linewidth(1.2)

ax.set_xlim(0, 1)

plt.tight_layout()
plt.savefig("NMC_caxis_vs_lithiation.png", dpi=600, bbox_inches='tight')
plt.close()

nlprint("\n[+] C-axis plot saved: NMC_caxis_vs_lithiation.png")

nlprint("\n" + "="*60)
nlprint("Summary: Successfully generated plots")
nlprint("="*60)
nlprint("  [1] NMC_volume_vs_lithiation.png")
nlprint("  [2] NMC_caxis_vs_lithiation.png")
nlprint("="*60)