#!/usr/bin/env python3
"""
compare_multi_snr.py

Compare mean C/N0 (SNR) values across N test-scenario directories, broken
down by constellation (SVType) and frequency band.

The first directory is treated as the reference/baseline.  Each directory
must contain a comments.md with a "- Label: <text>" line.

Outputs:
    1. Baseline-relative delta table  -> .csv  +  Jira wiki markup .txt
    2. Grouped bar chart faceted by band -> .png

Usage:
    python skills/compare_multi_snr.py dir1/ dir2/ dir3/ \\
        [--min-elev 10] [--antenna 0] [--output-dir figures/]
"""

import argparse
import os
import re
import sys

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


# ---------------------------------------------------------------------------
# HDF5 helpers (adapted from compare_snr_stats.py)
# ---------------------------------------------------------------------------

def find_h5_files(directory):
    """Recursively find all *_gnss.h5 files under *directory*."""
    h5_files = []
    for root, _dirs, files in os.walk(directory):
        for fname in sorted(files):
            if fname.endswith("_gnss.h5"):
                h5_files.append(os.path.join(root, fname))
    return h5_files


def compute_mean_snr(filepath, label, antenna="0"):
    """Extract per-band, per-constellation C/N0 arrays from one HDF5 file."""
    results = []
    with h5py.File(filepath, "r") as f:
        if f"snr_raw/{antenna}" not in f:
            return results
        base = f[f"snr_raw/{antenna}"]
        for band_name in sorted(base.keys()):
            band_grp = base[band_name]
            for constellation in sorted(band_grp.keys()):
                const_grp = band_grp[constellation]
                all_cno, all_elev, n_svs = [], [], 0
                for sv_prn in const_grp.keys():
                    sv_grp = const_grp[sv_prn]
                    for ant_idx in sv_grp.keys():
                        ds = sv_grp[ant_idx]
                        cno = ds["CNo"][:]
                        elev = ds["elev"][:]
                        valid = cno > 0
                        all_cno.append(cno[valid])
                        all_elev.append(elev[valid])
                    n_svs += 1
                if all_cno:
                    combined_cno = np.concatenate(all_cno)
                    combined_elev = np.concatenate(all_elev)
                    if len(combined_cno) > 0:
                        results.append({
                            "label": label,
                            "Band": band_name,
                            "Constellation": constellation,
                            "NumSVs": n_svs,
                            "NumObs": len(combined_cno),
                            "MeanCNo": round(float(np.mean(combined_cno)), 2),
                            "StdCNo": round(float(np.std(combined_cno)), 2),
                            "cno": combined_cno,
                            "elev": combined_elev,
                        })
    return results


def aggregate_results(all_results, label):
    """Merge per-file results into one set of band/constellation aggregates."""
    buckets = {}
    for r in all_results:
        key = (r["Band"], r["Constellation"])
        if key not in buckets:
            buckets[key] = {"cno": [], "elev": [], "n_svs": 0}
        buckets[key]["cno"].append(r["cno"])
        buckets[key]["elev"].append(r["elev"])
        buckets[key]["n_svs"] = max(buckets[key]["n_svs"], r["NumSVs"])

    merged = []
    for (band, const), v in sorted(buckets.items()):
        combined_cno = np.concatenate(v["cno"])
        combined_elev = np.concatenate(v["elev"])
        merged.append({
            "label": label,
            "Band": band,
            "Constellation": const,
            "NumSVs": v["n_svs"],
            "NumObs": len(combined_cno),
            "MeanCNo": round(float(np.mean(combined_cno)), 2),
            "StdCNo": round(float(np.std(combined_cno)), 2),
            "cno": combined_cno,
            "elev": combined_elev,
        })
    return merged


def apply_elev_filter(results, min_elev):
    """Add elevation-filtered mean C/N0 to each result dict."""
    for r in results:
        mask = r["elev"] >= min_elev
        r["MeanCNo_filt"] = (
            round(float(np.mean(r["cno"][mask])), 2) if mask.sum() > 0 else np.nan
        )


# ---------------------------------------------------------------------------
# Label extraction
# ---------------------------------------------------------------------------

def read_label(directory):
    """Parse comments.md in *directory* and return the Label value."""
    comments_path = os.path.join(directory, "comments.md")
    if not os.path.isfile(comments_path):
        sys.exit(f"Error: no comments.md in {directory}")
    with open(comments_path) as fh:
        for line in fh:
            m = re.match(r"^-\s*Label\s*:\s*(.+)", line)
            if m:
                return m.group(1).strip()
    sys.exit(f"Error: no '- Label:' entry found in {comments_path}")


def label_to_filename_token(label):
    return label.replace(" ", "_")


# ---------------------------------------------------------------------------
# Output helpers
# ---------------------------------------------------------------------------

def build_output_stem(labels, output_dir):
    tokens = [label_to_filename_token(l) for l in labels]
    stem = "compare_snr__" + "__".join(tokens)
    return os.path.join(output_dir, stem)


def build_delta_dataframe(group_stats, labels, value_key="MeanCNo"):
    """Build a DataFrame with reference values and deltas for each scenario."""
    ref_label = labels[0]
    ref_lookup = {}
    for r in group_stats[ref_label]:
        ref_lookup[(r["Band"], r["Constellation"])] = r[value_key]

    all_keys = set()
    for label in labels:
        for r in group_stats[label]:
            all_keys.add((r["Band"], r["Constellation"]))

    rows = []
    for band, const in sorted(all_keys):
        row = {"Band": band, "Constellation": const}
        ref_val = ref_lookup.get((band, const), np.nan)
        row[f"{ref_label} (ref)"] = ref_val

        for label in labels[1:]:
            scenario_lookup = {
                (r["Band"], r["Constellation"]): r[value_key]
                for r in group_stats[label]
            }
            val = scenario_lookup.get((band, const), np.nan)
            row[label] = val
            if np.isnan(ref_val) or np.isnan(val):
                row[f"{label} delta"] = np.nan
            else:
                row[f"{label} delta"] = round(val - ref_val, 2)
        rows.append(row)

    return pd.DataFrame(rows)


def export_csv(df, path):
    df.to_csv(path, index=False)
    print(f"  CSV  -> {path}")


def export_jira(df, path):
    """Write a Jira wiki markup table."""
    with open(path, "w") as fh:
        header = "||" + "||".join(df.columns) + "||"
        fh.write(header + "\n")
        for _, row in df.iterrows():
            cells = []
            for col in df.columns:
                v = row[col]
                if isinstance(v, float) and not np.isnan(v):
                    cells.append(f"{v:.2f}" if "delta" not in col.lower() else f"{v:+.2f}")
                elif isinstance(v, float) and np.isnan(v):
                    cells.append("N/A")
                else:
                    cells.append(str(v))
            fh.write("|" + "|".join(cells) + "|\n")
    print(f"  Jira -> {path}")


# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------

def plot_grouped_bars(group_stats, labels, value_key, ylabel, title, path):
    """Grouped bar chart faceted by band."""
    all_bands = sorted({
        r["Band"] for stats in group_stats.values() for r in stats
    })
    all_constellations = sorted({
        r["Constellation"] for stats in group_stats.values() for r in stats
    })

    n_bands = len(all_bands)
    if n_bands == 0:
        print("  Warning: no data to plot")
        return

    ncols = min(n_bands, 3)
    nrows = (n_bands + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 5 * nrows),
                             squeeze=False)

    n_scenarios = len(labels)
    bar_width = 0.8 / n_scenarios
    colors = plt.cm.Set2(np.linspace(0, 1, max(n_scenarios, 3)))

    for idx, band in enumerate(all_bands):
        ax = axes[idx // ncols][idx % ncols]
        constellations_in_band = sorted({
            r["Constellation"]
            for stats in group_stats.values()
            for r in stats
            if r["Band"] == band
        })
        x = np.arange(len(constellations_in_band))

        for s_idx, label in enumerate(labels):
            lookup = {
                r["Constellation"]: r[value_key]
                for r in group_stats[label]
                if r["Band"] == band
            }
            vals = [lookup.get(c, 0) for c in constellations_in_band]
            offset = (s_idx - n_scenarios / 2 + 0.5) * bar_width
            ax.bar(x + offset, vals, bar_width, label=label, color=colors[s_idx])

        ax.set_title(band, fontweight="bold")
        ax.set_xticks(x)
        ax.set_xticklabels(constellations_in_band)
        ax.set_ylabel(ylabel)
        ax.grid(axis="y", alpha=0.3)

    # hide unused subplots
    for idx in range(n_bands, nrows * ncols):
        axes[idx // ncols][idx % ncols].set_visible(False)

    handles, lbls = axes[0][0].get_legend_handles_labels()
    fig.legend(handles, lbls, loc="upper center", ncol=n_scenarios,
               bbox_to_anchor=(0.5, 1.02), fontsize=10)
    fig.suptitle(title, y=1.06, fontsize=13, fontweight="bold")
    fig.tight_layout()
    fig.savefig(path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"  PNG  -> {path}")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(
        description="Compare mean C/N0 across N GNSS test scenarios."
    )
    parser.add_argument("dirs", nargs="+",
                        help="Directories to compare (first = reference)")
    parser.add_argument("--min-elev", type=float, default=0.0,
                        help="Minimum elevation for filtered stats (default: 0)")
    parser.add_argument("--antenna", default="0",
                        help="Antenna index in snr_raw group (default: 0)")
    parser.add_argument("--output-dir", default="figures/",
                        help="Output directory (default: figures/)")
    args = parser.parse_args()

    if len(args.dirs) < 2:
        parser.error("Need at least 2 directories to compare")

    os.makedirs(args.output_dir, exist_ok=True)

    # Read labels
    labels = []
    for d in args.dirs:
        label = read_label(d)
        labels.append(label)
        print(f"  {d}  ->  \"{label}\"")
    print()

    # Collect and aggregate data per scenario
    group_stats = {}
    for d, label in zip(args.dirs, labels):
        h5_files = find_h5_files(d)
        if not h5_files:
            parser.error(f"No *_gnss.h5 files found in: {d}")
        print(f"[{label}] {len(h5_files)} HDF5 file(s)")
        raw = []
        for fp in h5_files:
            raw.extend(compute_mean_snr(fp, label, antenna=args.antenna))
        group_stats[label] = aggregate_results(raw, label) if len(h5_files) > 1 else raw

        if args.min_elev > 0:
            apply_elev_filter(group_stats[label], args.min_elev)

    print()

    stem = build_output_stem(labels, args.output_dir)

    # --- Output 1: baseline-relative delta table ---
    print("Baseline-relative delta table (all elevations):")
    df_all = build_delta_dataframe(group_stats, labels, value_key="MeanCNo")
    export_csv(df_all, f"{stem}.csv")
    export_jira(df_all, f"{stem}_jira.txt")

    if args.min_elev > 0:
        print(f"\nBaseline-relative delta table (elev >= {args.min_elev}):")
        df_filt = build_delta_dataframe(group_stats, labels, value_key="MeanCNo_filt")
        export_csv(df_filt, f"{stem}_elev{int(args.min_elev)}.csv")
        export_jira(df_filt, f"{stem}_elev{int(args.min_elev)}_jira.txt")

    # --- Output 2: grouped bar chart ---
    print("\nGrouped bar chart:")
    plot_grouped_bars(
        group_stats, labels,
        value_key="MeanCNo",
        ylabel="Mean C/N0 (dB-Hz)",
        title="Mean C/N0 by Band and Constellation",
        path=f"{stem}.png",
    )

    if args.min_elev > 0:
        print(f"\nGrouped bar chart (elev >= {args.min_elev}):")
        plot_grouped_bars(
            group_stats, labels,
            value_key="MeanCNo_filt",
            ylabel=f"Mean C/N0 (dB-Hz), elev >= {args.min_elev}°",
            title=f"Mean C/N0 by Band and Constellation (elev >= {args.min_elev}°)",
            path=f"{stem}_elev{int(args.min_elev)}.png",
        )

    print("\nDone.")


if __name__ == "__main__":
    main()
