#!/usr/bin/env python3
"""
Confirm the partition-function fix found in session50_deep_gap_audit.py:
the lattice partition function that gives S(beta=0.966) = ln(31)
includes the SU(2) Casimir factor C_j inside the Boltzmann weight.

  Z(beta) = Sum_{j>0} d_j^2  *  exp(-beta * C_j * mu_j)
  S(beta) = beta * <C_j * mu_j> + ln Z

with mu_j = 1 - lambda_j / z, j over the 600-cell spectrum (excluding the
zero mode lambda = 12).
"""
import math

phi = (1 + math.sqrt(5)) / 2
sqrt5 = math.sqrt(5)
z = 12

# Eigenvalue, multiplicity (d^2), Casimir C_j = j(j+1) for j = (d-1)/2
# Spectrum from quaternionic 2I construction (verified Tier 2)
# Excludes lambda=12 zero mode.
spectrum_with_casimir = [
    (3 + 3*sqrt5,  4,   3/4),     # j=1/2  (d=2)
    (2 + 2*sqrt5,  9,   2.0),     # j=1    (d=3)
    (3.0,          16,  15/4),    # j=3/2  (d=4)
    (0.0,          25,  6.0),     # j=2    (d=5)
    (-2.0,         36,  35/4),    # j=5/2  (d=6)
    (2 - 2*sqrt5,  9,   2.0),     # j=1
    (-3.0,         16,  15/4),    # j=3/2
    (3 - 3*sqrt5,  4,   3/4),     # j=1/2
]

def thermo(beta):
    Z = 0.0
    avg = 0.0
    for lam, mult, Cj in spectrum_with_casimir:
        mu = 1 - lam/z
        # Boltzmann factor includes C_j * mu_j
        w = mult * math.exp(-beta * Cj * mu)
        Z += w
        avg += w * Cj * mu
    avg /= Z
    S = beta * avg + math.log(Z)
    sum_p2 = 0.0
    for lam, mult, Cj in spectrum_with_casimir:
        mu = 1 - lam/z
        w = mult * math.exp(-beta * Cj * mu)
        p = w/Z
        sum_p2 += (p**2)/mult        # p_j^2 normalized per state
    N_eff = 1.0 / sum_p2 if sum_p2 > 0 else float('inf')
    # Alternative N_eff using bands (multiplicity weighted)
    return Z, avg, S, N_eff

target = math.log(31)

print("Lattice partition function with Casimir factor:")
print(f"  Target: S = ln(31) = {target:.4f}")
print()
print(f"  {'beta':>8s} {'Z':>10s} {'<E>':>8s} {'S':>10s} {'N_eff':>10s}")
for b in [0.0, 0.5, 0.8, 0.9, 0.95, 0.96, 0.966, 0.97, 1.0, 1.5, 2.0]:
    Z, avg, S, Neff = thermo(b)
    flag = " <-- ln(31)" if abs(S - target) < 0.005 else ""
    flag2 = " <-- paper β*=0.966" if abs(b-0.966) < 1e-6 else ""
    print(f"  {b:8.3f} {Z:10.4f} {avg:8.4f} {S:10.4f} {Neff:10.4f}{flag}{flag2}")

# Binary search
lo, hi = 0.001, 100.0
for _ in range(80):
    mid = (lo + hi)/2
    _,_, S,_ = thermo(mid)
    if S > target:
        lo = mid
    else:
        hi = mid
print()
print(f"Solved beta where S = ln(31): {(lo+hi)/2:.4f}")
print(f"Paper claims:                 0.966")
print(f"Match: {'YES' if abs((lo+hi)/2 - 0.966) < 0.005 else 'NO'}")
