#!/usr/bin/env python3
"""
plotCN0vsEl.py
  For usage details "plotCN0vsEl.py -h"

  @description: Form a per-signal plot of C/N0 vs Elevation for all signals found in T04 files. Optionally, create a skyplot per signal showing the C/N0 of each receiver.
  @author: AAA

  Copyright Trimble, Inc. 2024
"""

import datetime

import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import mutils as m

sequential_cmaps = [
    "Blues",
    "YlOrRd",
    "Greens",
    "Reds",
    "Oranges",
    "Purples",
    "YlOrBr",
    "OrRd",
    "PuRd",
    "RdPu",
    "BuPu",
    "GnBu",
    "PuBu",
    "YlGnBu",
    "PuBuGn",
    "BuGn",
    "YlGn",
]

CN0_MAX = 60

# viewdat's UTC conversion does not appear
# to account for leap seconds, so this doesn't either
def _gps_to_utc(gps_weeks, gps_sow) -> datetime.datetime:
    # Define the GPS epoch
    gps_epoch = datetime.datetime(1980, 1, 6, 0, 0, 0)
    delta = datetime.timedelta(weeks=gps_weeks, seconds=gps_sow)
    return gps_epoch + delta

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-logs",
        "--receiver_list",
        required=True,
        nargs="+",
        help="List of T04 files to compare. If you pass in a file containing dual-antenna data, this script will create a new 'receiver' with your filename and _ant_1 for the second antenna. Does not support rec27!",
    )
    parser.add_argument(
        "-snr_floor",
        "--snr_floor",
        required=False,
        default=20,
        type=float,
        help="SNR threshold below which data will be discarded. Default = 20",
    )
    parser.add_argument(
        "-el_floor",
        "--el_floor",
        required=False,
        default=2.5,
        type=float,
        help="Elevation threshold below which data will be discarded. Default = 2.5",
    )
    parser.add_argument(
        "-s",
        "--show_figures",
        help="Show the figures and summary data, instead of saving to a file.",
        action="store_true",
    )
    parser.add_argument(
        "-skyplot",
        "--skyplot",
        help="Create a skyplot per signal showing the C/N0 of each receiver. Darker colors indicate higher C/N0.",
        action="store_true",
    )
    parser.add_argument(
        "-el_vs_cn0",
        "--el_vs_cn0",
        help="Create a plot of elevation vs C/N0 of all receivers for each tracked signal.",
        action="store_true",
    )

    args = parser.parse_args()
    if not args.show_figures:
        # Allow running headless from the command line
        matplotlib.use("agg")

    if not args.skyplot and not args.el_vs_cn0:
        args.el_vs_cn0 = True

    rx_dict = dict((args.receiver_list[i],i) for i in range(len(args.receiver_list)))
    rx_cmap_dict = {}
    working_rx_dict = {}

    common_sigs = None
    rx_count = 0
    for rx in rx_dict:
        # Load the data, while using viewdat to decimate to 1Hz
        obs_raw = m.vd2cls(rx, "-d35:19 --dec=1000")

        # Find and optionally update the signals in the data files
        obs_rx_sigs = m.get_signals(obs_raw)

        if common_sigs is None:
            common_sigs = obs_rx_sigs
        # else:
        #     # Merge the dicts
        #     common_sigs = {**common_sigs, **obs_rx_sigs}

        # Filter per the input arguments
        obs = obs_raw[(obs_raw.EL >= args.el_floor) & (obs_raw.CNO >= args.snr_floor) & (obs_raw.ANTENNA == 0)]
        obs_ant_1 = obs_raw[(obs_raw.EL >= args.el_floor) & (obs_raw.CNO >= args.snr_floor) & (obs_raw.ANTENNA == 1)]

        # Convert to a DataFrame
        working_rx_dict[rx] = pd.DataFrame(obs, columns=obs.k.keys())
        working_rx_dict[rx]['UTC'] = working_rx_dict[rx].apply(lambda row: _gps_to_utc(row['WN'], row['TIME']), axis=1)
        working_rx_dict[rx].set_index('UTC', inplace=True)
        rx_cmap_dict[rx] = sequential_cmaps[rx_count]
        rx_count += 1

        if len(obs_ant_1) > 0:
            rx_dict_entry = rx[:-4] + "_ant_1" + rx[-4:]
            print(f'Creating a new \'receiver\' {rx_dict_entry} in the logs list for antenna 1 of {rx}')
            working_rx_dict[rx_dict_entry] = pd.DataFrame(obs_ant_1, columns=obs_ant_1.k.keys())
            working_rx_dict[rx_dict_entry]['UTC'] = working_rx_dict[rx_dict_entry].apply(lambda row: _gps_to_utc(row['WN'], row['TIME']), axis=1)
            working_rx_dict[rx_dict_entry].set_index('UTC', inplace=True)
            rx_cmap_dict[rx_dict_entry] = sequential_cmaps[rx_count]
            rx_count += 1

    print("Data Loaded, finding common times")

    common_datetimes = sorted(set.intersection(*(set(df.index) for df in working_rx_dict.values())))
    time_filtered_rx_dict = {name: df[df.index.isin(common_datetimes)] for name, df in working_rx_dict.items()}

    if len(working_rx_dict) > len(sequential_cmaps):
        print("Wow, you're really brave! This script might have run out of colormaps if these files all contain dual-antenna data. If not, this warning can be safely ignored.")

    if common_sigs is None:
        print("No common signals found between all files")
        exit(1)

    print("Done. Plotting...")

    snr_size = np.flip(np.logspace(0, 2, len(working_rx_dict)))

    for sat_type, sat_signals in common_sigs.items():
        for freq,track in sat_signals:
            signal_type_label = m.get_sub_type(sat_type,freq,track).fullstr

            if args.el_vs_cn0:
                # Elevation vs C/N0 Plot
                fig, ax = plt.subplots()
                fig.set_size_inches(w = 18.5, h = 10.5, forward = True)
                ax.set_title(signal_type_label)
                ax.set_xlabel('Elevation')
                ax.set_ylabel('C/N0')

                for rx in time_filtered_rx_dict:
                    filtered_singal_df_stage_1 = time_filtered_rx_dict[rx][(time_filtered_rx_dict[rx].SAT_TYPE == sat_type) & (time_filtered_rx_dict[rx].FREQ == freq) & (time_filtered_rx_dict[rx].TRACK == track)]

                    if len(filtered_singal_df_stage_1) == 0:
                        print(f'No data for {rx} {signal_type_label}')
                        continue

                    # Filter to only PLL
                    # only test the first element for dMasterSubChan
                    is_master_subchan_data = filtered_singal_df_stage_1.MEAS.iloc[0].astype(int) & 512 != 0 # 512 = flag.dMEAS_MASTER_SUB

                    if is_master_subchan_data:
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_A_WIDE       (14) //  E
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_A_NARROW     (15) //  F
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_B            (16) // 10
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_C            (17) // 11
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_D            (18) // 12
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_E            (19) // 13
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_SBAS             (20) // 14
                        filtered_singal_df = filtered_singal_df_stage_1[
                            (filtered_singal_df_stage_1.CSTATE >= 14)
                            & (filtered_singal_df_stage_1.CSTATE <= 20)
                        ]
                    else:
                        # io/tx/rt_dat_defs.h:define mRec35sub19_TrackingStatus_SLAVE_CSTATE_LOCKED           (9)
                        # io/tx/rt_dat_defs.h:define mRec35sub19_TrackingStatus_SLAVE_CSTATE_AFC_B            (10)
                        # io/tx/rt_dat_defs.h:define mRec35sub19_TrackingStatus_SLAVE_CSTATE_PLL_B            (11)
                        # io/tx/rt_dat_defs.h:define mRec35sub19_TrackingStatus_SLAVE_CSTATE_AFC_C            (12)
                        filtered_singal_df = filtered_singal_df_stage_1[
                            (filtered_singal_df_stage_1.CSTATE >= 9)
                            & (filtered_singal_df_stage_1.CSTATE <= 12)
                        ]

                    # Bin the elevations in filtered_singal_df in groups of 5 degrees and then perform the average
                    matlab_binned_elevations = np.arange(2.5, 87.5, 5)
                    matlab_binned_elevations = np.append(matlab_binned_elevations, 90)
                    matlab_binned_snr = np.zeros(len(matlab_binned_elevations))
                    for i in range(len(matlab_binned_elevations)):
                        matlab_binned_snr[i] = filtered_singal_df[(filtered_singal_df['EL'] >= matlab_binned_elevations[i] - 2.5) & (filtered_singal_df['EL'] < matlab_binned_elevations[i] + 2.5)]['CNO'].mean()

                    ax.scatter(filtered_singal_df['EL'],filtered_singal_df['CNO'], s=1, label=f'{rx}')

                    ax.plot(matlab_binned_elevations, matlab_binned_snr, label=f'{rx} binned average', linestyle='-', marker='x',)

                ax.set_xlim(0, 90)
                ax.set_ylim(args.snr_floor, CN0_MAX)
                fig.tight_layout()
                ax.legend(loc='lower right')
                ax.grid(True)
                print(f"Plotting {signal_type_label} Elevation vs C/N0 {filtered_singal_df.index[0]} to {filtered_singal_df.index[-1]} spanning {filtered_singal_df.index[-1]-filtered_singal_df.index[0]}")
                plt.show()

                if not args.show_figures:
                    fig.savefig(signal_type_label + " elevation vs cn0.png", dpi=300)
                    plt.close()

            if args.skyplot:
                # C/N0 Skyplot
                # Roughly, plot a per-signal skyplot where the color gradient represents the C/N0 for that signal at that receiver
                fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
                fig.set_size_inches(w = 18.5, h = 10.5, forward = True)
                ax.set_title(signal_type_label)

                rx_count = 0
                for rx in time_filtered_rx_dict:
                    filtered_singal_df_stage_1 = time_filtered_rx_dict[rx][(time_filtered_rx_dict[rx].SAT_TYPE == sat_type) & (time_filtered_rx_dict[rx].FREQ == freq) & (time_filtered_rx_dict[rx].TRACK == track)]

                    if len(filtered_singal_df_stage_1) == 0:
                        print(f'No data for {rx} {signal_type_label}!!')
                        continue

                    # Filter to only PLL
                    # only test the first element for dMasterSubChan
                    is_master_subchan_data = filtered_singal_df_stage_1.MEAS.iloc[0].astype(int) & 512 != 0 # 512 = flag.dMEAS_MASTER_SUB

                    if is_master_subchan_data:
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_A_WIDE       (14) //  E
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_A_NARROW     (15) //  F
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_B            (16) // 10
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_C            (17) // 11
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_D            (18) // 12
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_E            (19) // 13
                        # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_SBAS             (20) // 14
                        filtered_singal_df = filtered_singal_df_stage_1[
                            (filtered_singal_df_stage_1.CSTATE >= 14)
                            & (filtered_singal_df_stage_1.CSTATE <= 20)
                        ]
                    else:
                        # io/tx/rt_dat_defs.h:define mRec35sub19_TrackingStatus_SLAVE_CSTATE_LOCKED           (9)
                        # io/tx/rt_dat_defs.h:define mRec35sub19_TrackingStatus_SLAVE_CSTATE_AFC_B            (10)
                        # io/tx/rt_dat_defs.h:define mRec35sub19_TrackingStatus_SLAVE_CSTATE_PLL_B            (11)
                        # io/tx/rt_dat_defs.h:define mRec35sub19_TrackingStatus_SLAVE_CSTATE_AFC_C            (12)
                        filtered_singal_df = filtered_singal_df_stage_1[
                            (filtered_singal_df_stage_1.CSTATE >= 9)
                            & (filtered_singal_df_stage_1.CSTATE <= 12)
                        ]

                    # Create a polar plot of the AZ and EL fields, where r is the EL field converted to degrees from radians, theta is the AZ field converted to radians, and the color is the C/N0 field converted to a single color gradient
                    # for the color, scale the C/N0 field to a color value between 0 and 255, and use the color map to get the color
                    cn0_min = args.snr_floor
                    cn0_max = CN0_MAX
                    cn0_scaled = (filtered_singal_df['CNO'] - cn0_min) / (cn0_max - cn0_min) * 255
                    point_size = snr_size[rx_count]*2
                    ax.scatter(np.radians(filtered_singal_df['AZ']), (90-filtered_singal_df['EL']), c=cn0_scaled, cmap=rx_cmap_dict[rx], s=point_size, label=f'{rx}')
                    rx_count += 1

                ax.set_theta_direction(-1)
                ax.set_theta_zero_location('N')
                ax.set_yticklabels([])
                plt.tight_layout()
                ax.legend(loc='best')
                print(f"Plotting {signal_type_label} Skyplot {filtered_singal_df.index[-1]-filtered_singal_df.index[0]}")
                plt.show()

                if not args.show_figures:
                    fig.savefig(signal_type_label + " skyplot.png", dpi=300)
                    plt.close()
