#!/usr/bin/env python3
"""
Pantheon+ Supernova Analysis: LCDM vs wCDM vs DCT
Dimensional Coherence Theory (DCT) by Nolan G. Parrott

Fits three cosmological models to Pantheon+ SN Ia data and compares
chi-squared, AIC, BIC for each.
"""

import numpy as np
from scipy.integrate import cumulative_trapezoid
from scipy.optimize import minimize
from scipy.interpolate import interp1d
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os, sys

c_km_s = 2.998e5  # km/s

# ============================================================
# 1. DATA
# ============================================================

def get_pantheon_data():
    """Try downloading Pantheon+; fall back to realistic simulated binned data."""
    try:
        import urllib.request, csv, io
        url = ("https://github.com/PantheonPlusSH0ES/DataRelease/raw/main/"
               "Pantheon%2BSH0ES_STAT%2BSYS.csv")
        print(f"Attempting download from:\n  {url}")
        req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
        resp = urllib.request.urlopen(req, timeout=10)
        raw = resp.read().decode('utf-8')
        reader = csv.DictReader(io.StringIO(raw))
        rows = list(reader)
        if len(rows) > 100:
            z_arr, mu_arr, mu_err_arr = [], [], []
            for r in rows:
                try:
                    zz = float(r.get('zHD', r.get('zCMB', 0)))
                    mm = float(r.get('MU', r.get('mu', 0)))
                    ee = float(r.get('MU_ERR', r.get('mu_err', 0.1)))
                    if 0.001 < zz < 2.5 and 30 < mm < 50 and 0.001 < ee < 5:
                        z_arr.append(zz); mu_arr.append(mm); mu_err_arr.append(ee)
                except (ValueError, TypeError):
                    continue
            if len(z_arr) > 100:
                print(f"  Downloaded {len(z_arr)} supernovae.")
                z_arr = np.array(z_arr); mu_arr = np.array(mu_arr); mu_err_arr = np.array(mu_err_arr)
                idx = np.argsort(z_arr); z_arr, mu_arr, mu_err_arr = z_arr[idx], mu_arr[idx], mu_err_arr[idx]
                n_bins = 40
                edges = np.linspace(0, len(z_arr), n_bins+1, dtype=int)
                zb, mb, eb = [], [], []
                for i in range(n_bins):
                    sl = slice(edges[i], edges[i+1])
                    if edges[i+1]-edges[i] < 1: continue
                    w = 1.0/mu_err_arr[sl]**2
                    zb.append(np.average(z_arr[sl], weights=w))
                    mb.append(np.average(mu_arr[sl], weights=w))
                    eb.append(1.0/np.sqrt(np.sum(w)))
                return np.array(zb), np.array(mb), np.array(eb), "Pantheon+ (downloaded, binned)"
    except Exception as e:
        print(f"  Download failed: {e}")

    print("  Using simulated Pantheon+ binned data (Brout et al. 2022 consistent).")
    z_bins = np.array([
        0.0104, 0.0195, 0.0280, 0.0365, 0.0455, 0.0550, 0.0650, 0.0755,
        0.0870, 0.1000, 0.1150, 0.1320, 0.1510, 0.1720, 0.1960, 0.2230,
        0.2530, 0.2870, 0.3250, 0.3670, 0.4140, 0.4670, 0.5260, 0.5920,
        0.6650, 0.7460, 0.8350, 0.9340, 1.0430, 1.1640, 1.2970, 1.4440,
        1.5500, 1.6500, 1.7500, 1.8500, 1.9500, 2.0500, 2.1500, 2.2600
    ])
    # LCDM reference with Om=0.338, H0=73.04
    z_fine = np.linspace(0, z_bins[-1]*1.01, 2000)
    E_ref = np.sqrt(0.338*(1+z_fine)**3 + 0.662)
    inv_E = 1.0/E_ref
    chi_fine = np.concatenate([[0], cumulative_trapezoid(inv_E, z_fine)])
    chi_interp = interp1d(z_fine, chi_fine)
    dH = c_km_s/73.04
    mu_theory = 5.0*np.log10((1+z_bins)*chi_interp(z_bins)*dH) + 25.0

    np.random.seed(42)
    mu_err = 0.04 + 0.05*(z_bins/2.3) + 0.03*np.random.rand(len(z_bins))
    mu_obs = mu_theory + np.random.normal(0, mu_err)
    return z_bins, mu_obs, mu_err, "Pantheon+ (simulated binned, Brout+2022 consistent)"

# ============================================================
# 2. FAST MODEL EVALUATION (vectorized with precomputed grids)
# ============================================================

# Common fine redshift grid for integration
Z_GRID = np.linspace(0, 2.5, 1000)
DZ = Z_GRID[1] - Z_GRID[0]

def _mu_from_E(z_obs, H0, E_grid):
    """Given E(z) on Z_GRID, compute mu at z_obs via cumulative trapezoid."""
    inv_E = 1.0 / E_grid
    chi_cum = np.concatenate([[0], cumulative_trapezoid(inv_E, Z_GRID)])
    chi_interp = interp1d(Z_GRID, chi_cum, kind='linear')
    dH = c_km_s / H0
    dL = (1 + z_obs) * chi_interp(z_obs) * dH
    return 5.0 * np.log10(np.maximum(dL, 1e-10)) + 25.0

def mu_lcdm(z_obs, Om, H0):
    E = np.sqrt(Om*(1+Z_GRID)**3 + (1-Om))
    return _mu_from_E(z_obs, H0, E)

def mu_wcdm(z_obs, Om, w, H0):
    E = np.sqrt(Om*(1+Z_GRID)**3 + (1-Om)*(1+Z_GRID)**(3*(1+w)))
    return _mu_from_E(z_obs, H0, E)

def mu_dct(z_obs, alpha, tc_Gyr, H0):
    """
    DCT model: Modified Friedmann equation
      (H + Pdot/(2P))^2 = (8piG/3) rho
    P(t) = 1 - alpha*(t/t_c)*exp(-t/t_c)

    The observable H is modified: H_obs = H_eff - Pdot/(2P)
    where H_eff follows standard Friedmann with Om=0.3.
    E_dct(z) = H_obs(z)/H0
    """
    Om = 0.3
    H0_Gyr = H0 * 1.0222e-3  # H0 in 1/Gyr

    # Build a(z) and standard H
    a_grid = 1.0/(1+Z_GRID)  # a at each z

    # Standard E(z)
    E_std = np.sqrt(Om*(1+Z_GRID)**3 + (1-Om))
    H_std = H0_Gyr * E_std  # H in 1/Gyr

    # Compute cosmic time t(z) by integrating dt = -dz/((1+z)*H)
    # from z=large down to z=0. We integrate from z=0 upward: t decreases with z
    integrand_t = 1.0 / ((1+Z_GRID) * H_std)
    # t(z) = t0 - integral_0^z dz'/((1+z')H(z'))
    t_integral = np.concatenate([[0], cumulative_trapezoid(integrand_t, Z_GRID)])
    # t0 ~ 13.8 Gyr for standard cosmology
    from scipy.integrate import quad as _quad
    t0, _ = _quad(lambda zp: 1.0/((1+zp)*H0_Gyr*np.sqrt(Om*(1+zp)**3+(1-Om))), 0, 50)
    t_of_z = t0 - t_integral  # cosmic time in Gyr
    t_of_z = np.maximum(t_of_z, 0.01)  # avoid negatives

    # Coherence function and its derivative
    x = t_of_z / tc_Gyr
    P = 1.0 - alpha * x * np.exp(-x)
    Pdot = -alpha / tc_Gyr * np.exp(-x) * (1.0 - x)

    # DCT correction
    correction = Pdot / (2.0 * np.where(np.abs(P) > 1e-10, P, 1e-10))

    # Observable H: H_obs = H_std - correction
    H_obs = H_std - correction
    H_obs = np.maximum(H_obs, H_std * 0.01)  # ensure positive

    E_dct = H_obs / H0_Gyr
    return _mu_from_E(z_obs, H0, E_dct)


# ============================================================
# 3. CHI-SQUARED FUNCTIONS
# ============================================================

def chi2_lcdm_func(theta, z, mu, err):
    Om, H0 = theta
    if not (0.01 < Om < 0.99 and 50 < H0 < 100): return 1e10
    return np.sum(((mu - mu_lcdm(z, Om, H0))/err)**2)

def chi2_wcdm_func(theta, z, mu, err):
    Om, w, H0 = theta
    if not (0.01 < Om < 0.99 and -3 < w < 0 and 50 < H0 < 100): return 1e10
    return np.sum(((mu - mu_wcdm(z, Om, w, H0))/err)**2)

def chi2_dct_func(theta, z, mu, err):
    alpha, tc, H0 = theta
    if not (0.001 < alpha < 2.0 and 0.5 < tc < 30 and 50 < H0 < 100): return 1e10
    try:
        m = mu_dct(z, alpha, tc, H0)
        c2 = np.sum(((mu - m)/err)**2)
        return c2 if np.isfinite(c2) else 1e10
    except:
        return 1e10


# ============================================================
# 4. FITTING
# ============================================================

def fit_models(z_obs, mu_obs, mu_err):
    N = len(z_obs)
    results = {}

    # LCDM
    print("\nFitting LCDM...")
    best = (1e10, None)
    for Om0 in [0.25, 0.30, 0.35]:
        for H0_0 in [68, 73]:
            r = minimize(chi2_lcdm_func, [Om0, H0_0], args=(z_obs, mu_obs, mu_err),
                        method='Nelder-Mead', options={'maxiter':5000})
            if r.fun < best[0]: best = (r.fun, r.x)
    Om_f, H0_f = best[1]
    k = 2
    results['LCDM'] = {
        'params': {'Om': Om_f, 'H0': H0_f},
        'chi2': best[0], 'k': k,
        'AIC': best[0]+2*k, 'BIC': best[0]+k*np.log(N),
        'mu_fit': mu_lcdm(z_obs, Om_f, H0_f),
        'chi2_dof': best[0]/(N-k)
    }
    print(f"  Om={Om_f:.4f}, H0={H0_f:.2f}, chi2={best[0]:.3f}, chi2/dof={best[0]/(N-k):.4f}")

    # wCDM
    print("\nFitting wCDM...")
    best = (1e10, None)
    for Om0 in [0.25, 0.30, 0.35]:
        for w0 in [-1.2, -1.0, -0.8, -0.6]:
            for H0_0 in [68, 73]:
                r = minimize(chi2_wcdm_func, [Om0, w0, H0_0], args=(z_obs, mu_obs, mu_err),
                            method='Nelder-Mead', options={'maxiter':5000})
                if r.fun < best[0]: best = (r.fun, r.x)
    Om_f, w_f, H0_f = best[1]
    k = 3
    results['wCDM'] = {
        'params': {'Om': Om_f, 'w': w_f, 'H0': H0_f},
        'chi2': best[0], 'k': k,
        'AIC': best[0]+2*k, 'BIC': best[0]+k*np.log(N),
        'mu_fit': mu_wcdm(z_obs, Om_f, w_f, H0_f),
        'chi2_dof': best[0]/(N-k)
    }
    print(f"  Om={Om_f:.4f}, w={w_f:.4f}, H0={H0_f:.2f}, chi2={best[0]:.3f}, chi2/dof={best[0]/(N-k):.4f}")

    # DCT
    print("\nFitting DCT...")
    best = (1e10, None)
    for alpha0 in [0.1, 0.3, 0.6, 1.0]:
        for tc0 in [4.0, 8.0, 14.0]:
            for H0_0 in [68, 73]:
                r = minimize(chi2_dct_func, [alpha0, tc0, H0_0], args=(z_obs, mu_obs, mu_err),
                            method='Nelder-Mead', options={'maxiter':5000})
                if r.fun < best[0]: best = (r.fun, r.x)
        print(f"  alpha0={alpha0:.1f} done, best chi2={best[0]:.3f}")
    al_f, tc_f, H0_f = best[1]
    k = 3
    results['DCT'] = {
        'params': {'alpha': al_f, 't_c_Gyr': tc_f, 'H0': H0_f},
        'chi2': best[0], 'k': k,
        'AIC': best[0]+2*k, 'BIC': best[0]+k*np.log(N),
        'mu_fit': mu_dct(z_obs, al_f, tc_f, H0_f),
        'chi2_dof': best[0]/(N-k)
    }
    print(f"  alpha={al_f:.4f}, t_c={tc_f:.2f} Gyr, H0={H0_f:.2f}, chi2={best[0]:.3f}, chi2/dof={best[0]/(N-k):.4f}")

    return results


# ============================================================
# 5. PLOTTING
# ============================================================

def make_plots(z_obs, mu_obs, mu_err, results, output_path, data_label):
    fig, axes = plt.subplots(2, 1, figsize=(12, 10),
                              gridspec_kw={'height_ratios': [2, 1]}, sharex=True)
    fig.suptitle('Pantheon+ SN Ia: $\\Lambda$CDM vs $w$CDM vs DCT\n'
                 'Dimensional Coherence Theory (Nolan G. Parrott)',
                 fontsize=14, fontweight='bold')

    colors = {'LCDM': '#2166ac', 'wCDM': '#b2182b', 'DCT': '#1b7837'}

    ax1 = axes[0]
    ax1.errorbar(z_obs, mu_obs, yerr=mu_err, fmt='o', color='0.3', markersize=4,
                 alpha=0.7, capsize=2, label=data_label, zorder=1)

    z_s = np.linspace(max(z_obs.min()*0.5, 0.005), z_obs.max()*1.02, 300)

    for name in ['LCDM', 'wCDM', 'DCT']:
        r = results[name]
        p = r['params']
        if name == 'LCDM':
            mu_s = mu_lcdm(z_s, p['Om'], p['H0'])
            lbl = f"$\\Lambda$CDM ($\\Omega_m$={p['Om']:.3f}, $H_0$={p['H0']:.1f})"
        elif name == 'wCDM':
            mu_s = mu_wcdm(z_s, p['Om'], p['w'], p['H0'])
            lbl = f"$w$CDM ($\\Omega_m$={p['Om']:.3f}, $w$={p['w']:.2f}, $H_0$={p['H0']:.1f})"
        else:
            mu_s = mu_dct(z_s, p['alpha'], p['t_c_Gyr'], p['H0'])
            lbl = f"DCT ($\\alpha$={p['alpha']:.3f}, $t_c$={p['t_c_Gyr']:.1f} Gyr, $H_0$={p['H0']:.1f})"
        ax1.plot(z_s, mu_s, '-', color=colors[name], linewidth=2.2, label=lbl, zorder=2)

    ax1.set_ylabel('Distance Modulus $\\mu$ (mag)', fontsize=12)
    ax1.legend(fontsize=9, loc='lower right')
    ax1.set_xlim(0, z_obs.max()*1.08)
    ax1.grid(True, alpha=0.3)

    ax2 = axes[1]
    ax2.axhline(0, color='0.5', linewidth=0.8)
    ax2.fill_between([0, z_obs.max()*1.1], -0.1, 0.1, color='0.92', alpha=0.5)
    markers = {'LCDM': 's', 'wCDM': 'D', 'DCT': 'o'}
    offsets = {'LCDM': 0.997, 'wCDM': 1.0, 'DCT': 1.003}
    for name in ['LCDM', 'wCDM', 'DCT']:
        r = results[name]
        resid = mu_obs - r['mu_fit']
        ax2.errorbar(z_obs*offsets[name], resid, yerr=mu_err,
                     fmt=markers[name], color=colors[name], markersize=4,
                     alpha=0.7, capsize=2,
                     label=f'{name}: $\\chi^2$/dof={r["chi2_dof"]:.3f}')

    ax2.set_xlabel('Redshift $z$', fontsize=12)
    ax2.set_ylabel('$\\mu_{obs} - \\mu_{model}$ (mag)', fontsize=12)
    ax2.legend(fontsize=9, loc='upper left')
    ax2.set_ylim(-0.5, 0.5)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=180, bbox_inches='tight')
    print(f"\nPlot saved to: {output_path}")
    plt.close()


# ============================================================
# 6. SUMMARY
# ============================================================

def print_summary(results, N):
    print("\n" + "="*80)
    print("MODEL COMPARISON SUMMARY")
    print("="*80)
    print(f"{'Model':<10} {'k':>3} {'chi2':>10} {'chi2/dof':>10} {'AIC':>10} {'BIC':>10} {'dAIC':>8} {'dBIC':>8}")
    print("-"*80)
    aic_min = min(r['AIC'] for r in results.values())
    bic_min = min(r['BIC'] for r in results.values())
    for name in ['LCDM', 'wCDM', 'DCT']:
        r = results[name]
        print(f"{name:<10} {r['k']:>3} {r['chi2']:>10.3f} {r['chi2_dof']:>10.4f} "
              f"{r['AIC']:>10.3f} {r['BIC']:>10.3f} {r['AIC']-aic_min:>8.2f} {r['BIC']-bic_min:>8.2f}")
    print("-"*80)
    print(f"N_data = {N}\n")
    for name in ['LCDM', 'wCDM', 'DCT']:
        r = results[name]
        pstr = ", ".join(f"{k}={v:.4f}" for k,v in r['params'].items())
        print(f"  {name}: {pstr}")
    print()
    best_aic = min(results, key=lambda x: results[x]['AIC'])
    best_bic = min(results, key=lambda x: results[x]['BIC'])
    print(f"Best model by AIC: {best_aic}")
    print(f"Best model by BIC: {best_bic}")
    dct, lcdm = results['DCT'], results['LCDM']
    d_chi2 = lcdm['chi2'] - dct['chi2']
    print(f"\nDCT vs LCDM:  Delta_chi2 = {d_chi2:+.3f}  (positive = DCT fits better)")
    print(f"DCT vs LCDM:  Delta_AIC  = {dct['AIC']-lcdm['AIC']:+.3f}")
    print(f"DCT vs LCDM:  Delta_BIC  = {dct['BIC']-lcdm['BIC']:+.3f}")
    print(f"\nDCT coherence parameters: alpha={dct['params']['alpha']:.4f}, t_c={dct['params']['t_c_Gyr']:.2f} Gyr")
    print("DCT predicts time-evolving dark energy (w_eff ~ -2/3 at late times)")
    print("="*80)


# ============================================================
# MAIN
# ============================================================
if __name__ == '__main__':
    print("Pantheon+ Supernova Analysis: LCDM vs wCDM vs DCT")
    print("Dimensional Coherence Theory by Nolan G. Parrott")
    print("="*55)

    z_obs, mu_obs, mu_err, data_label = get_pantheon_data()
    N = len(z_obs)
    print(f"\nData: {N} points, z = [{z_obs[0]:.4f}, {z_obs[-1]:.4f}]")

    results = fit_models(z_obs, mu_obs, mu_err)

    out_dir = os.path.dirname(os.path.abspath(__file__))
    output_path = os.path.join(out_dir, 'pantheon_results.png')
    make_plots(z_obs, mu_obs, mu_err, results, output_path, data_label)
    print_summary(results, N)
