import random

setVerbosity(MinimalLog)

############################################################
# 0. USER PARAMETERS
############################################################

# ---- Target NMC composition (fractions of TM) ----
# Example NMC811: Ni=0.8, Mn=0.1, Co=0.1
# Example NMC622: Ni=0.6, Mn=0.2, Co=0.2
NI_FRAC = 0.6
MN_FRAC = 0.2
CO_FRAC = 0.2

# Sanity check
if not numpy.isclose(NI_FRAC + MN_FRAC + CO_FRAC, 1.0, atol=1e-6):
    raise ValueError("Ni+Mn+Co fractions must sum to 1.0")

# ---- Supercell repetition (from R-3m LiCoO2 primitive cell) ----
REPEAT_A = 5
REPEAT_B = 4
REPEAT_C = 1   # layered along c

# ---- SQS settings ----
SQS_GENERATIONS     = 150
SQS_POPULATION_SIZE = 100
SQS_PROMOTE         = 10
SQS_HEREDITY_PROB   = 0.5
SQS_PERMUTE_PROB    = 0.5

CLUSTER_MAX_DIAMETERS = {
    2: 5.9*Angstrom,
    3: 5.3*Angstrom,
    4: 2.9*Angstrom,
}

# ---- Zn → Mn/Co splitting method ----
# "random" : random splitting
# "fps"    : farthest-point style uniform distribution of Mn
SPLIT_METHOD = "random"

# ---- Geometry optimization settings ----
def get_calculator():
    # Edit this to preferred calculator (PBE+U, HSE, etc.)
    potentialSet = TorchX_MatterSim_v1_0_0_5M(dtype='float32', enforceLTX=False)
    return TremoloXCalculator(parameters=potentialSet)

DO_RELAX      = True
MAX_FORCES    = 0.01 * eV/Angstrom
MAX_STEPS     = 1000
OPTIMIZE_CELL = True   # usually False for bulk NMC SQS


############################################################
# 1. Define primitive R-3m LiCoO2 cell
############################################################

lattice = Hexagonal(2.81125698*Angstrom, 13.90945643*Angstrom)

elements = [
    Lithium, Lithium, Lithium,
    Cobalt,  Cobalt,  Cobalt,
    Oxygen,  Oxygen,  Oxygen, Oxygen, Oxygen, Oxygen
]

fractional_coordinates = [
    [0.0       , 0.0       , 0.0       ],
    [0.66666667, 0.33333333, 0.33333333],
    [0.33333333, 0.66666667, 0.66666667],
    [0.0       , 0.0       , 0.5       ],
    [0.66666667, 0.33333333, 0.83333333],
    [0.33333333, 0.66666667, 0.16666667],
    [0.0       , 0.0       , 0.2400068 ],
    [0.0       , 0.0       , 0.7599932 ],
    [0.66666667, 0.33333333, 0.57334013],
    [0.66666667, 0.33333333, 0.09332653],
    [0.33333333, 0.66666667, 0.90667347],
    [0.33333333, 0.66666667, 0.42665987],
]

bulk = BulkConfiguration(
    bravais_lattice=lattice,
    elements=elements,
    fractional_coordinates=fractional_coordinates
)

nlsave("LiCoO2_template.hdf5", bulk)


############################################################
# 2. Build supercell
############################################################

supercell = bulk.repeat(REPEAT_A, REPEAT_B, REPEAT_C)

coords = supercell.fractionalCoordinates()
elems  = supercell.elements()
lat    = supercell.bravaisLattice()

# Count TM sites (initially Co)
n_tm = sum(1 for e in elems if e.symbol() == "Co")
nlprint(f"Number of TM sites in supercell: {n_tm}")


############################################################
# 3. Set up AlloyConfiguration: TM = Ni vs Zn (pseudo-element)
############################################################

# X = Mn + Co together
X_FRAC = MN_FRAC + CO_FRAC
NI_OCC = NI_FRAC
X_OCC  = X_FRAC

nlprint(f"Binary SQS composition on TM sites: Ni={NI_OCC:.3f}, X(Zn)={X_OCC:.3f}")

sites = []

for elem in elems:
    symbol = elem.symbol()
    if symbol == "Li":
        sites.append(AlloySite(Lithium=1.0))
    elif symbol == "O":
        sites.append(AlloySite(Oxygen=1.0))
    elif symbol == "Co":
        # TM site: binary alloy Ni/X represented as Ni/Zn
        sites.append(AlloySite(Nickel=NI_OCC, Zinc=X_OCC))
    else:
        raise ValueError("Unexpected element in template: " + symbol)

alloy_config = AlloyConfiguration(
    bravais_lattice=lat,
    sites=sites,
    fractional_coordinates=coords
)


############################################################
# 4. Run SQS for Ni–Zn binary alloy
############################################################

sqs = EvolutionarySQS(
    alloy_config,
    CLUSTER_MAX_DIAMETERS,
    number_of_generations=SQS_GENERATIONS,
    population_size=SQS_POPULATION_SIZE,
    number_to_promote=SQS_PROMOTE,
    heredity_probability=SQS_HEREDITY_PROB,
    permutation_probability=SQS_PERMUTE_PROB,
)

best_binary_sqs = sqs.bestStructure()
nlsave("NMC_binary_Ni_Zn_SQS.hdf5", best_binary_sqs)


############################################################
# 5. Split Zn → Mn and Co with correct ratio
############################################################

# Find all Zn atoms
zn_indices = list(indicesFromExpression(best_binary_sqs, 'e=Zn'))
n_zn = len(zn_indices)
nlprint(f"Number of Zn (X) sites after SQS: {n_zn}")

if n_zn == 0:
    raise RuntimeError("No Zn sites found in SQS result; something is wrong.")

# Desired Mn and Co fractions within X:
# Mn/X = MN_FRAC / (MN_FRAC + CO_FRAC), similarly for Co
if X_FRAC <= 0.0:
    raise RuntimeError("X_FRAC is zero; Ni fraction is 1. This is not NMC.")

p_mn = MN_FRAC / X_FRAC
p_co = CO_FRAC / X_FRAC

# Integer counts
n_mn_target = int(round(p_mn * n_zn))
n_co_target = n_zn - n_mn_target

nlprint(f"Splitting Zn: nMn={n_mn_target}, nCo={n_co_target} (from pMn={p_mn:.3f}, pCo={p_co:.3f})")

if SPLIT_METHOD.lower() == "random":
    # Random assignment of Mn / Co over Zn
    random.shuffle(zn_indices)
    mn_indices = zn_indices[:n_mn_target]
    co_indices = zn_indices[n_mn_target:]

elif SPLIT_METHOD.lower() == "fps":
    # Farthest-point style: spread Mn as uniformly as possible
    cart = numpy.array(best_binary_sqs.cartesianCoordinates().inUnitsOf(Angstrom))
    zn_coords = cart[zn_indices]

    # Seed: Zn farthest from center
    center = zn_coords.mean(axis=0)
    dist_center = numpy.linalg.norm(zn_coords - center, axis=1)
    seed_local = int(numpy.argmax(dist_center))
    chosen_local = [seed_local]

    while len(chosen_local) < n_mn_target:
        current_mn_coords = zn_coords[chosen_local]
        # distances from all Zn to chosen Mn set
        dists = numpy.linalg.norm(
            zn_coords[:, None, :] - current_mn_coords[None, :, :],
            axis=2
        )
        nearest = dists.min(axis=1)
        # mask already chosen
        nearest[chosen_local] = -1.0
        next_local = int(numpy.argmax(nearest))
        chosen_local.append(next_local)

    mn_indices = [zn_indices[i] for i in chosen_local]
    co_indices = [idx for idx in zn_indices if idx not in mn_indices]

else:
    raise ValueError("Unknown SPLIT_METHOD. Use 'random' or 'fps'.")


# Apply replacements Zn → Mn and Co
final_nmc = setElement(best_binary_sqs, mn_indices, Manganese)
final_nmc = setElement(final_nmc, co_indices, Cobalt)

# Save raw SQS NMCxxy
comp_tag = f"{int(NI_FRAC*10)}{int(MN_FRAC*10)}{int(CO_FRAC*10)}"  # e.g. 622, 811
nlsave(f"NMC{comp_tag}_SQS_final.hdf5", final_nmc)


############################################################
# 6. Optional geometry optimization
############################################################

if DO_RELAX:
    calc = get_calculator()
    final_nmc.setCalculator(calc)
    final_nmc.update()

    opt_nmc = OptimizeGeometry(
        configuration=final_nmc,
        max_forces=MAX_FORCES,
        max_steps=MAX_STEPS,
        optimize_cell=OPTIMIZE_CELL
    )

    nlsave(f"NMC{comp_tag}_SQS_final_optimized.hdf5", opt_nmc)
else:
    nlprint("Skipping relaxation as DO_RELAX = False.")