#!/usr/bin/env python
"""
 Show single-diff observables between receivers or antennas.

 Examples:
     2-receiver usage:
         ./single_diff.py ip_addr1:port1 ip_addr2:port2
     2-antenna usage:
         ./single_diff.py ip_addr:port

 Make sure to configure the I/O, for example:
     TCP/IP port 5018
     Output RT17/RT27
     Epoch Interval = 5Hz
     Multi-System Support

 Limitations:
  - This only works on dual-antenna systems or clock-steered 0-baselines
    (unless you only use SNR).
"""
from matplotlib.figure import Figure
from matplotlib import animation
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, \
    NavigationToolbar2Tk
import threading
from six.moves import tkinter as Tk
import numpy as np
import socket
from collections import defaultdict, deque
import argparse
import RT27
import mutils

# known signals/measurements to plot:
all_sigs = {}
all_meas = ["C/No", "range", "phase"]  # also update animate_plots()


class State(object):
    def __init__(self):
        pass


class GetDataThread(threading.Thread):
    def __init__(self, ip_addr, port):
        threading.Thread.__init__(self, daemon=True)
        self.ip_addr = ip_addr
        self.port = port
        self.rt27_parser = RT27.ParseRT27(show_info=False)
        self.obs = []

    def run(self):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        sock.settimeout(10)
        sock.connect((self.ip_addr, self.port))
        while True:
            self.rt27_parser.process_data(bytearray(sock.recv(5000)))
            data = self.rt27_parser.get_decoded_obs()
            if len(data) > 0:
                self.obs.extend(data)

    def get_data(self):
        data = self.obs
        self.obs = []
        return data


def select_sig(sig_txt, s):
    s.sel_sig = all_sigs[sig_txt]
    s.sv1.clear()
    s.sv2.clear()
    s.axis = None
    s.sel_sig_txt = sig_txt


def select_meas(meas_txt, s):
    s.sel_meas_txt = meas_txt
    s.sv1.clear()
    s.sv2.clear()
    s.axis = None


dTime_col = 0
dCno_col = 1
dRange_col = 2
dPhase_col = 3
dPhaseSync_col = 4  # have phase lock with positive/negative lock point known?


def append_data(d_in, sel_sig, ant_num, d_out):
    for obs in d_in:
        for meas in obs.meas:
            if ((meas.sat_type, meas.freq, meas.track) == sel_sig
                    and meas.ant_num == ant_num):
                d_out[meas.sv_id].append([int(obs.hdr.sec),
                                          meas.cno,
                                          meas.pseudo,
                                          meas.phase,
                                          (meas.meas_flags[0] >> 5) & 1])


def animate_plots(_, s):
    d1 = s.rt1.get_data()
    d2 = s.rt2.get_data()
    if len(d1) == 0 or len(d2) == 0:
        return
    if s.axis is None:
        s.fig.clf()
        s.axis = s.fig.add_subplot()
        s.axis.grid(True)
        s.lines = {}
        s.slip_lines = {}
        s.sv1 = defaultdict(lambda: deque(maxlen=s.max_len))
        s.sv2 = defaultdict(lambda: deque(maxlen=s.max_len))
        s.ph = np.nan

    append_data(d1, s.sel_sig, s.ant1, s.sv1)
    append_data(d2, s.sel_sig, s.ant2, s.sv2)
    sv_x_y = []
    sv_slip_x_y = []
    min_t = s.min_t
    ph_medians = []
    for sv in set.intersection(set(s.sv1.keys()), set(s.sv2.keys())):
        sv1 = np.array(s.sv1[sv])
        sv1 = sv1[sv1[:, dTime_col] >= min_t]
        sv2 = np.array(s.sv2[sv])
        sv2 = sv2[sv2[:, dTime_col] >= min_t]
        t, i1, i2 = np.intersect1d(
            sv1[:, dTime_col], sv2[:, dTime_col], return_indices=True)
        if len(t) > 2:
            # update data rate
            dt = t[-1] - t[-2]
            if dt > 0 and (dt < s.dt or s.dt < 0):
                s.dt = dt
            if s.dt > 0:
                # don't show really old data
                new_min_t = t[-1] - s.dt*s.max_len
                if new_min_t > s.min_t:
                    s.min_t = new_min_t
        if s.sel_meas_txt == 'phase':
            y = sv1[i1, dPhase_col] - sv2[i2, dPhase_col]
            i = np.where(sv1[i1, dPhaseSync_col].astype(int) &
                         sv2[i2, dPhaseSync_col].astype(int))[0]
            if len(i) > 0:
                # Compute sub-cycle offset for measurements w/ lock point known
                y_median = np.mod(np.median(y[i]), 1)
                ph_medians.append(y_median)
                if np.isfinite(s.ph):
                    # Try to use a global offset for all SVs
                    y -= s.ph
            # Remove integer cycles
            y -= np.around(y)
        elif s.sel_meas_txt == 'range':
            y = sv1[i1, dRange_col] - sv2[i2, dRange_col]
        elif s.sel_meas_txt == 'C/No':
            y = sv1[i1, dCno_col] - sv2[i2, dCno_col]
        sv_x_y.append([sv, t*1e-3, y])

        if s.show_slips:
            isl = np.where((np.diff(sv1[i1, 4]) != 0) | (
                np.diff(sv2[i2, 4]) != 0))[0]
            sv_slip_x_y.append([sv, t[isl+1]*1e-3, y[isl+1]])

    if len(ph_medians) > 1:
        s.ph = np.median(ph_medians)

    got_data = set()
    got_slips = set()
    for sv, x, y in sv_x_y:
        if len(x) == 0:
            continue
        got_data.add(sv)
        if sv in s.lines.keys():
            s.lines[sv].set_data(x, y)
        else:
            line, = s.axis.plot(x, y, '-', label='%d' % sv)
            s.lines[sv] = line
    for sv, x, y in sv_slip_x_y:
        if len(x) == 0:
            continue
        got_slips.add(sv)
        if sv in s.slip_lines.keys():
            s.slip_lines[sv].set_data(x, y)
        else:
            line, = s.axis.plot(x, y,
                                color=s.lines[sv].get_color(),
                                linestyle='None', marker='v',
                                markeredgecolor='k')
            s.slip_lines[sv] = line

    # Remove lines with no data
    for sv in got_data.symmetric_difference(s.lines.keys()):
        s.lines[sv].remove()
        del s.lines[sv]
        if sv in s.sv1.keys():
            del s.sv1[sv]
        if sv in s.sv2.keys():
            del s.sv2[sv]
    for sv in got_slips.symmetric_difference(s.slip_lines.keys()):
        s.slip_lines[sv].remove()
        del s.slip_lines[sv]

    s.axis.set_title('%s %s single-diff' % (s.sel_sig_txt, s.sel_meas_txt))
    s.axis.set_xlabel('GPS TOW [sec]')
    if len(s.lines) > 0:
        s.axis.legend(loc='upper left', bbox_to_anchor=(1.04, 1))
        s.fig.subplots_adjust(right=0.85)
    s.axis.relim()
    s.axis.autoscale_view(True, True, True)
    bottom, top = s.axis.get_ylim()
    if s.sel_meas_txt == 'phase':
        s.axis.set_ylabel('single-diff [cycles]')
        if bottom > -0.2:
            s.axis.set_ylim(bottom=-0.2)
        if top < 0.2:
            s.axis.set_ylim(top=0.2)
    elif s.sel_meas_txt == 'range':
        s.axis.set_ylabel('single-diff [m]')
        if bottom > -10:
            s.axis.set_ylim(bottom=-10)
        if top < 10:
            s.axis.set_ylim(top=10)
    elif s.sel_meas_txt == 'C/No':
        s.axis.set_ylabel('single-diff [dB]')
        if bottom > -2:
            s.axis.set_ylim(bottom=-2)
        if top < 15:
            s.axis.set_ylim(top=15)


def main():
    parser = \
        argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                                description=__doc__)
    parser.add_argument('unit1', help='Base RT27, e.g. "10.1.x.y:501x"')
    parser.add_argument('unit2', nargs='?',
                        help='Rover RT27, e.g. "10.1.x.y:502x"')
    parser.add_argument(
        '--no_slips', help="Don't show slips", action="store_true")
    parser.add_argument('--ant1', default=0, type=int,
                        help='Unit #1 antenna number. Default=0')
    parser.add_argument('--ant2', type=int, help='Unit #2 antenna number')
    parser.add_argument('--max_len', default=1000, type=int,
                        help='Max # of time epochs to draw. Default=1000')
    parser.add_argument('--update_ms', default=200, type=int,
                        help='Screen update rate [msecs]')
    args = parser.parse_args()
    unit1_ip, unit1_port = args.unit1.split(':')
    if args.unit2:
        if args.ant2 is None:
            args.ant2 = 0
        unit2_ip, unit2_port = args.unit2.split(':')
    else:
        if args.ant2 is None:
            args.ant2 = 1
        unit2_ip, unit2_port = unit1_ip, unit1_port

    print(f"Base: {unit1_ip} : {unit1_port}, antenna #{args.ant1}")
    print(f"Rover: {unit2_ip} : {unit2_port}, antenna #{args.ant2}")
    s = State()
    s.root = Tk.Tk()
    s.root.wm_title('single diffs')
    s.fig = Figure(figsize=(8, 6))
    s.axis = None
    mainFrame = Tk.Frame(s.root)
    mainFrame.pack()
    s.canvas = FigureCanvasTkAgg(s.fig, mainFrame)
    toolbar = NavigationToolbar2Tk(s.canvas, mainFrame)
    toolbar.update()
    toolbar.pack(side=Tk.BOTTOM)
    check_frame = Tk.Frame(mainFrame)
    check_frame.pack(side=Tk.RIGHT, fill=Tk.BOTH, expand=1)
    s.canvas._tkcanvas.pack(side=Tk.LEFT, fill=Tk.BOTH, expand=1)
    s.canvas.get_tk_widget().pack(side=Tk.LEFT, fill=Tk.BOTH, expand=1)
    menubar = Tk.Menu(s.root)
    filemenu = Tk.Menu(menubar, tearoff=0)
    filemenu.add_command(label="Quit", command=s.root.destroy)
    menubar.add_cascade(label="File", menu=filemenu, underline=0)
    signal_menu = Tk.Menu(menubar, tearoff=0)
    last_sv_sys = ''
    sub_menu = None
    for track_tuple, sat_info in mutils.get_sub_type.sub_type_dict.items():
        if sat_info.desc not in all_sigs:
            sig_txt = sat_info.desc
            curr_sv_sys = sig_txt[:3]
            if last_sv_sys.find(curr_sv_sys) < 0:
                sub_menu = None
            if sub_menu is None:
                sub_menu = Tk.Menu(signal_menu, tearoff=0)
                signal_menu.add_cascade(
                    label=curr_sv_sys, menu=sub_menu, underline=0)
            all_sigs[sig_txt] = track_tuple
            sub_menu.add_command(
                label=sig_txt, command=lambda txt=sig_txt: select_sig(txt, s))
            last_sv_sys = curr_sv_sys
    menubar.add_cascade(label="Signal", menu=signal_menu, underline=0)
    meas_menu = Tk.Menu(menubar, tearoff=0)
    for meas_txt in all_meas:
        meas_menu.add_command(
            label=meas_txt, command=lambda txt=meas_txt: select_meas(txt, s))
    menubar.add_cascade(label="Meas", menu=meas_menu, underline=0)

    s.sel_sig_txt = list(all_sigs.keys())[0]
    s.sel_sig = all_sigs[s.sel_sig_txt]
    s.sel_meas_txt = all_meas[0]
    s.ant1 = args.ant1
    s.ant2 = args.ant2
    s.show_slips = not args.no_slips
    s.max_len = args.max_len
    s.dt = -1
    s.min_t = -1

    s.rt1 = GetDataThread(unit1_ip, int(unit1_port))
    s.rt2 = GetDataThread(unit2_ip, int(unit2_port))
    s.rt1.start()
    s.rt2.start()
    # Assign the Animation to a variable to avoid it getting deleted
    # prior to rendering (this will still generate a pep8 warning)
    func_animate = animation.FuncAnimation(s.fig, animate_plots,
                                           fargs=(s,), interval=args.update_ms)
    s.root.config(menu=menubar)
    s.root.mainloop()


if __name__ == '__main__':
    main()
