#!/usr/bin/env python3
# pylint: disable=line-too-long,invalid-name

# CycleSlipHeatMap.py
#  For usage details "CycleSlipHeatMap.py -h"
#
# Description:
#  Given a T04 file, create a heatmap of slips per epoch per signal, and plot this on a satellite map
#  Useful for determining location of clumps of cycle slips on a dynamic data file.
#
# Original Author(s): Andy Archinal, initial framework by Stuart Riley
# Copyright Trimble Inc 2022
#######################################################

import argparse
from pathlib import Path
import os

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib.cm as cmx
import numpy as np
import mutils as m
from salem import GoogleVisibleMap, Map
import urllib.error

# Basic requirements.txt for this script (2022)
# matplotlib==3.5.2
# mutils==1.0.5
# numpy==1.22.4
# salem==0.3.7

MAPS_API_KEY_FILENAME = ".googlemapsapikey"

# Parse arguments
parser = argparse.ArgumentParser(description="Create a heatmap of cycle slips from a T04 file. If your data is not in T04, use t0x2t0x to update the format. This also assumes a 5Hz obs/pos (matched) rate.")
parser.add_argument("filename", help="T04")
parser.add_argument('-s','--showFigures', help='Show the figures and summary data, instead of saving to a file.', action="store_true")
parser.add_argument('-m','--elevation_mask', help='Ignore satellites below a specified elevation', default=0)
parser.add_argument('-b','--start_time', help='', default=None)
parser.add_argument('-e','--end_time', help='', default=None)
args = parser.parse_args()

filename = args.filename
if not args.showFigures:
    # Allow running headless from the command line
    matplotlib.use("agg")

RXFile = args.filename
RXStr = filename


# Obs record
obs_RX = m.vd2cls(RXFile, "-d35:19")
# Position record
pos_RX = m.vd2cls(RXFile, "-d35:2")

# Calculate the Obs Rate
# Assume 5Hz Obs at start
Millisecond = 200
obsTimestamps = np.unique(obs_RX.TIME)
obsRateMilliseconds = int(np.average(obsTimestamps[1:] - obsTimestamps[:-1])*1e3)
if (obsRateMilliseconds in (10, 20, 50, 100, 200)):
    Millisecond = obsRateMilliseconds
    Rate = int(1000 / obsRateMilliseconds)
else:
    Rate = int(1000 / Millisecond)

obs_RX_sigs = m.get_signals(obs_RX, obs_RX.k)

startTime = np.min(pos_RX.TIME) if args.start_time is None else float(args.start_time)
stopTime = np.max(pos_RX.TIME) if args.end_time is None else float(args.end_time)

googlemapsapikey = str()
try:
    with open(str(Path.home()) + os.sep + MAPS_API_KEY_FILENAME, 'r', encoding='utf-8') as f:
        googlemapsapikey = f.read().strip()
except FileNotFoundError:
    pass

if googlemapsapikey == "":
    print('You need to enter your Google Maps API key!\n You can either hardcode it in this script, or, put the key in a dotfile named .googlemapsapikey in your home directory.')

for sat_type, sat_signals in obs_RX_sigs.items():
    for freq,track in sat_signals:
        # First, filter the data by time,satellite,track,frequency, and elevation.
        obsFilt_RX_stage_1 = obs_RX[
            (obs_RX.TIME >= startTime)
            & (obs_RX.TIME <= stopTime)
            & (obs_RX.SAT_TYPE == sat_type)
            & (obs_RX.FREQ == freq)
            & (obs_RX.TRACK == track)
            & (obs_RX.EL > float(args.elevation_mask))
        ]

        # Record 35:19 can also log FLL data, so make sure it is phase locked

        # only test the first element for dMasterSubChan
        isMasterSubchanData = obsFilt_RX_stage_1.MEAS[0].astype(int) & obsFilt_RX_stage_1.f['dMEAS_MASTER_SUB'] != 0

        if isMasterSubchanData:
            # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_A_WIDE       (14) //  E
            # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_A_NARROW     (15) //  F
            # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_B            (16) // 10
            # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_C            (17) // 11
            # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_D            (18) // 12
            # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_PLL_E            (19) // 13
            # io/tx/rt_dat_defs.h:#define mRec35sub19_TrackingStatus_CSTATE_SBAS             (20) // 14
            obsFilt_RX = obsFilt_RX_stage_1[
                (obsFilt_RX_stage_1.CSTATE >= 14)
                & (obsFilt_RX_stage_1.CSTATE <= 20)
            ]
        else:
            # io/tx/rt_dat_defs.h:define mRec35sub19_TrackingStatus_SLAVE_CSTATE_PLL_B            (11)
            obsFilt_RX = obsFilt_RX_stage_1[
                (obsFilt_RX_stage_1.CSTATE == 11)
            ]

        # Use the same time tags for the position
        posFilt_RX = pos_RX[(pos_RX.TIME >= startTime) & (pos_RX.TIME <= stopTime)]

        # Create an array to hold the cycle slipt count - the code assumes that
        # the pos and obs arrays are aligned. If we have data where there are
        # missing positions (but some obs) the code does fail!
        # Add a 1s buffer for the case where things are not aligned (usually due)
        # to obstructions
        slips_RX = np.zeros(Rate + int(Rate * (stopTime - startTime)))

        if len(slips_RX) > len(posFilt_RX):
            print(f'RX pos is missing points {len(slips_RX)} > {len(posFilt_RX)}')

        max_prn = 33
        if sat_type == m.RTConst.RT_SatType_BEIDOU_B1GEOPhs:
            max_prn = 63

        for sv in range(1, max_prn):
            svData_RX = obsFilt_RX[obsFilt_RX.SV == sv]
            index = (np.round((svData_RX.TIME[:] - startTime) * Rate)).astype(int)
            rec35_19_slipFlag = obsFilt_RX.f.dMEAS_SLIP
            slips_RX[index] += (
                np.bitwise_and(svData_RX.MEAS[:].astype(int), rec35_19_slipFlag)
                / rec35_19_slipFlag
            ).astype(int)

        # Get the map extents
        minLat = np.min(posFilt_RX.LAT)
        maxLat = np.max(posFilt_RX.LAT)
        minLon = np.min(posFilt_RX.LON)
        maxLon = np.max(posFilt_RX.LON)

        # Get the color map for each unique value
        types = np.unique(slips_RX)
        jet = plt.get_cmap("jet")
        cNorm = colors.Normalize(vmin=0, vmax=len(types))
        scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet)

        # Setup the plot and get the map data from Google
        g = GoogleVisibleMap(
            x=[minLon, maxLon],
            y=[minLat, maxLat],
            scale=2,
            maptype="satellite",
            key=googlemapsapikey,
        )

        # the google static image is a standard rgb image
        # pylint: disable=broad-except
        try:
            ggl_img = g.get_vardata()
        except urllib.error.HTTPError as httpError:
            if httpError.status == 403:
                print(f'No map loaded!! Maps exception {httpError}\nIs your API key in the correct place?')
            else:
                print(f'No map loaded!! Maps exception {httpError}')
        except Exception as e:
            print(f'No map loaded!! Maps exception {e}')

        signal_type_label = m.get_sub_type(sat_type,freq,track).fullstr
        f, ax1 = plt.subplots()
        ax1.imshow(ggl_img)
        ax1.set_title(signal_type_label + " Cycle Slips\n" + RXStr)
        sm = Map(g.grid, factor=1, countries=False)
        sm.visualize(ax=ax1)
        # transform the lat/long into the image coord frame
        x_, y_ = sm.grid.transform(posFilt_RX.LON, posFilt_RX.LAT)
        # Plot a different color for each slip count value
        for i in range(len(types)):
            index = np.where(slips_RX == i)[0]
            # filter for extra obs with no pos
            index = index[(index < len(posFilt_RX))]
            ax1.scatter(
                x_[index], y_[index], s=1, color=scalarMap.to_rgba(i), label=(str(i) + " Slips")
            )
        ax1.legend(prop={'size': 4})
        plt.xticks(rotation=45,ha='right')
        plt.tight_layout()
        plt.show()

        if not args.showFigures:
            plt.savefig(RXStr + signal_type_label + ".png", dpi=300)
            plt.close()
