#!/usr/bin/env python3
"""
compare_snr_stats.py

Compare mean C/N0 (SNR) values from two groups of GNSS HDF5 files, broken
down by constellation (SVType) and frequency band.

Each group is specified as one or more directories; every *_gnss.h5 file
found (recursively) is included in that group's aggregate statistics.

Expects HDF5 files with structure:
    snr_raw/<antenna>/<band>/<constellation>/<sv_prn>/<ant_idx>/
        dataset fields: Week, Time, elev, az, CNo, ...

Usage:
    python compare_snr_stats.py \
        --dirs-a dir1/ dir2/ --label-a "Group A" \
        --dirs-b dir3/       --label-b "Group B" \
        --min-elev 10

Example:
    python compare_snr_stats.py \
        --dirs-a 20260216_north_pillars_with_klondike_radios_off_cell_on/ \
        --label-a "cell ON" \
        --dirs-b 20260217_north_pillars_with_klondike_radios_off_cell_off/ \
        --label-b "cell OFF" \
        --min-elev 10
"""

import argparse
import os
import h5py
import numpy as np
import pandas as pd


def find_h5_files(directories):
    """Recursively find all *_gnss.h5 files in the given directories."""
    h5_files = []
    for d in directories:
        for root, _dirs, files in os.walk(d):
            for fname in sorted(files):
                if fname.endswith("_gnss.h5"):
                    h5_files.append(os.path.join(root, fname))
    return h5_files


def compute_mean_snr(filepath, label, antenna="0"):
    """Extract mean CNo per band/constellation from snr_raw/<antenna>/."""
    results = []
    with h5py.File(filepath, "r") as f:
        if f"snr_raw/{antenna}" not in f:
            return results
        base = f[f"snr_raw/{antenna}"]
        for band_name in sorted(base.keys()):
            band_grp = base[band_name]
            for constellation in sorted(band_grp.keys()):
                const_grp = band_grp[constellation]
                all_cno = []
                all_elev = []
                n_svs = 0
                for sv_prn in const_grp.keys():
                    sv_grp = const_grp[sv_prn]
                    for ant_idx in sv_grp.keys():
                        ds = sv_grp[ant_idx]
                        cno = ds["CNo"][:]
                        elev = ds["elev"][:]
                        valid = cno > 0
                        all_cno.append(cno[valid])
                        all_elev.append(elev[valid])
                    n_svs += 1

                if all_cno:
                    combined_cno = np.concatenate(all_cno)
                    combined_elev = np.concatenate(all_elev)
                    if len(combined_cno) > 0:
                        results.append({
                            "Dataset": label,
                            "Band": band_name,
                            "Constellation": constellation,
                            "NumSVs": n_svs,
                            "NumObs": len(combined_cno),
                            "MeanCNo_all": round(float(np.mean(combined_cno)), 2),
                            "StdCNo_all": round(float(np.std(combined_cno)), 2),
                            "cno": combined_cno,
                            "elev": combined_elev,
                        })
    return results


def aggregate_results(all_results, label):
    """Merge per-file results into one set of band/constellation aggregates."""
    buckets = {}
    for r in all_results:
        key = (r["Band"], r["Constellation"])
        if key not in buckets:
            buckets[key] = {"cno": [], "elev": [], "n_svs": 0}
        buckets[key]["cno"].append(r["cno"])
        buckets[key]["elev"].append(r["elev"])
        buckets[key]["n_svs"] = max(buckets[key]["n_svs"], r["NumSVs"])

    merged = []
    for (band, const), v in sorted(buckets.items()):
        combined_cno = np.concatenate(v["cno"])
        combined_elev = np.concatenate(v["elev"])
        merged.append({
            "Dataset": label,
            "Band": band,
            "Constellation": const,
            "NumSVs": v["n_svs"],
            "NumObs": len(combined_cno),
            "MeanCNo_all": round(float(np.mean(combined_cno)), 2),
            "StdCNo_all": round(float(np.std(combined_cno)), 2),
            "cno": combined_cno,
            "elev": combined_elev,
        })
    return merged


def add_elev_filtered_stats(results, min_elev):
    """Add mean/count stats for observations above a minimum elevation."""
    for r in results:
        mask = r["elev"] >= min_elev
        r[f"MeanCNo_elev{min_elev}"] = (
            round(float(np.mean(r["cno"][mask])), 2) if mask.sum() > 0 else np.nan
        )
        r[f"NumObs_elev{min_elev}"] = int(mask.sum())


def print_table(df, value_col, title, label1, label2):
    """Print a formatted comparison table for a given metric."""
    pivot = df.pivot_table(
        index=["Band", "Constellation"],
        columns="Dataset",
        values=value_col,
        aggfunc="first",
    )

    print("=" * 100)
    print(title)
    print(f"  {label1}")
    print(f"  {label2}")
    print("=" * 100)
    print()

    header = f"{'Band':<12} {'Constellation':<12} {label1:>25} {label2:>25} {'Delta':>8} {'Δ%':>8}"
    print(header)
    print("-" * len(header))

    for band, const in sorted(pivot.index):
        v1 = pivot.loc[(band, const)].get(label1, np.nan)
        v2 = pivot.loc[(band, const)].get(label2, np.nan)
        delta = v2 - v1 if not (np.isnan(v1) or np.isnan(v2)) else np.nan
        pct = (delta / v1 * 100) if not np.isnan(delta) and v1 != 0 else np.nan
        d_str = f"{delta:+.2f}" if not np.isnan(delta) else "N/A"
        p_str = f"{pct:+.1f}%" if not np.isnan(pct) else "N/A"
        v1_str = f"{v1:.2f}" if not np.isnan(v1) else "N/A"
        v2_str = f"{v2:.2f}" if not np.isnan(v2) else "N/A"
        print(f"{band:<12} {const:<12} {v1_str:>25} {v2_str:>25} {d_str:>8} {p_str:>8}")
    print()


def print_obs_table(df, label1, label2):
    """Print observation counts, std dev, and SV counts side by side."""
    obs = df.pivot_table(index=["Band", "Constellation"], columns="Dataset", values="NumObs", aggfunc="first")
    std = df.pivot_table(index=["Band", "Constellation"], columns="Dataset", values="StdCNo_all", aggfunc="first")
    svs = df.pivot_table(index=["Band", "Constellation"], columns="Dataset", values="NumSVs", aggfunc="first")

    print("=" * 100)
    print("OBSERVATION COUNTS AND STD DEV (All elevations)")
    print("=" * 100)
    print()

    hdr_parts = [f"{'Band':<12}", f"{'Constellation':<12}",
                 f"{'Obs '+label1[:8]:>14}", f"{'Obs '+label2[:8]:>14}",
                 f"{'Std '+label1[:8]:>14}", f"{'Std '+label2[:8]:>14}",
                 f"{'SVs '+label1[:8]:>14}", f"{'SVs '+label2[:8]:>14}"]
    header = " ".join(hdr_parts)
    print(header)
    print("-" * len(header))

    for band, const in sorted(obs.index):
        o1 = obs.loc[(band, const)].get(label1, np.nan)
        o2 = obs.loc[(band, const)].get(label2, np.nan)
        s1 = std.loc[(band, const)].get(label1, np.nan)
        s2 = std.loc[(band, const)].get(label2, np.nan)
        sv1 = svs.loc[(band, const)].get(label1, np.nan)
        sv2 = svs.loc[(band, const)].get(label2, np.nan)
        print(f"{band:<12} {const:<12} "
              f"{(f'{int(o1):,}' if not np.isnan(o1) else 'N/A'):>14} "
              f"{(f'{int(o2):,}' if not np.isnan(o2) else 'N/A'):>14} "
              f"{(f'{s1:.2f}' if not np.isnan(s1) else 'N/A'):>14} "
              f"{(f'{s2:.2f}' if not np.isnan(s2) else 'N/A'):>14} "
              f"{(f'{int(sv1)}' if not np.isnan(sv1) else 'N/A'):>14} "
              f"{(f'{int(sv2)}' if not np.isnan(sv2) else 'N/A'):>14}")
    print()


def print_band_summary(df, label1, label2):
    """Print average delta grouped by band."""
    pivot = df.pivot_table(
        index=["Band", "Constellation"],
        columns="Dataset",
        values="MeanCNo_all",
        aggfunc="first",
    )

    print("=" * 100)
    print(f"SUMMARY: AVERAGE DELTA ({label2} minus {label1}) BY BAND")
    print("=" * 100)
    print()

    band_deltas = {}
    for band, const in sorted(pivot.index):
        v1 = pivot.loc[(band, const)].get(label1, np.nan)
        v2 = pivot.loc[(band, const)].get(label2, np.nan)
        if not (np.isnan(v1) or np.isnan(v2)):
            band_deltas.setdefault(band, []).append(v2 - v1)

    header = f"{'Band':<12} {'Avg Delta (dB-Hz)':>18} {'Min Delta':>12} {'Max Delta':>12}"
    print(header)
    print("-" * len(header))
    for band in sorted(band_deltas.keys()):
        vals = band_deltas[band]
        print(f"{band:<12} {np.mean(vals):>+18.2f} {np.min(vals):>+12.2f} {np.max(vals):>+12.2f}")
    print()


def main():
    parser = argparse.ArgumentParser(
        description="Compare mean C/N0 (SNR) from two groups of GNSS HDF5 files by constellation and band."
    )
    parser.add_argument("--dirs-a", nargs="+", required=True,
                        help="One or more directories containing *_gnss.h5 files for group A")
    parser.add_argument("--label-a", required=True,
                        help="Label for group A")
    parser.add_argument("--dirs-b", nargs="+", required=True,
                        help="One or more directories containing *_gnss.h5 files for group B")
    parser.add_argument("--label-b", required=True,
                        help="Label for group B")
    parser.add_argument("--min-elev", type=float, default=0.0,
                        help="Minimum elevation (deg) for filtered stats (default: 0)")
    parser.add_argument("--antenna", default="0",
                        help="Antenna index in snr_raw group (default: 0)")
    args = parser.parse_args()

    files_a = find_h5_files(args.dirs_a)
    files_b = find_h5_files(args.dirs_b)

    if not files_a:
        parser.error(f"No *_gnss.h5 files found in: {args.dirs_a}")
    if not files_b:
        parser.error(f"No *_gnss.h5 files found in: {args.dirs_b}")

    print(f"Group A ({args.label_a}): {len(files_a)} file(s)")
    for fp in files_a:
        print(f"  {fp}")
    print(f"Group B ({args.label_b}): {len(files_b)} file(s)")
    for fp in files_b:
        print(f"  {fp}")
    print()

    raw_a = []
    for fp in files_a:
        raw_a.extend(compute_mean_snr(fp, args.label_a, antenna=args.antenna))
    r1 = aggregate_results(raw_a, args.label_a) if len(files_a) > 1 else raw_a

    raw_b = []
    for fp in files_b:
        raw_b.extend(compute_mean_snr(fp, args.label_b, antenna=args.antenna))
    r2 = aggregate_results(raw_b, args.label_b) if len(files_b) > 1 else raw_b

    if args.min_elev > 0:
        add_elev_filtered_stats(r1, args.min_elev)
        add_elev_filtered_stats(r2, args.min_elev)

    df_rows = []
    for r in r1 + r2:
        df_rows.append({k: v for k, v in r.items() if k not in ("cno", "elev")})
    df = pd.DataFrame(df_rows)

    print_table(df, "MeanCNo_all",
                "MEAN C/N0 (dB-Hz) COMPARISON: ALL ELEVATIONS",
                args.label_a, args.label_b)

    if args.min_elev > 0:
        elev_col = f"MeanCNo_elev{args.min_elev}"
        print_table(df, elev_col,
                    f"MEAN C/N0 (dB-Hz) COMPARISON: ELEVATION >= {args.min_elev} deg",
                    args.label_a, args.label_b)

    print_obs_table(df, args.label_a, args.label_b)
    print_band_summary(df, args.label_a, args.label_b)


if __name__ == "__main__":
    main()
