#!/usr/bin/env python

usage="""\
To use:
1. Log CMRx/RTCM in a T0x file (T0x_filename)
2. Have RF samples logged at the same time
3. Connect a receiver to the Ettus or other sample logging system
4. For single-port usage:
    - enable port bi-directional mode ('port' argument)
   For dual-port usage:
    - allow RTCM input on port #1 ('port' argument)
    - turn on 10Hz ZDA on port #2 ('nmea_port' argument)
Note: needs viewdat in path
"""

import datetime
import serial
import argparse
import leapseconds
import subprocess as sp
import re
import math
import bisect
import RXTools
import time
import os.path
import socket

class ReplayCorr(object):
    def __init__(self,t0x_filename,port,verbose=True,nmea_port=None,log_nmea_filename=None,baudrate=9600, fake=False, log_failed_zda_msg_filename=None):
        '''t0x_filename = T0x file with rec98 data
        port = Where to send data, e.g. /dev/ttyS0 or socket://10.1.150.x:5018
          If "nmea_port" is not set, then the code will turn on 10Hz ZDA on
          this port with Trimcomm.
        Optional:
          nmea_port = if you already have NMEA set up on a separate port
            (e.g., socket://10.1.150.x:5019) then specify this and the code
            won't send Trimcomm commands to turn NMEA on.  This is useful
            for non-Trimble receivers.
          log_nmea_file = If set (e.g., log_file.txt), then log all data from
            the NMEA port.
        %s
        ''' % usage
        self.t0x_filename = t0x_filename
        self.fake = fake
        self.port = port
        self.nmea_port = nmea_port
        self.verbose = verbose
        self.baudrate = baudrate
        self.corr_time = []
        self.corr_data = []
        self.lastIndex = 0
        self.log_nmea_file = None
        self.s = None
        self.s_nmea = None
        self.rcvr_port_idx = -1
        if log_nmea_filename is not None:
            self.log_nmea_file = open(log_nmea_filename,'wb')
        self.log_failed_zda_msg_filename = None
        if log_failed_zda_msg_filename is not None:
            self.log_failed_zda_msg_filename = open(log_failed_zda_msg_filename,'wb') if os.path.isfile(log_failed_zda_msg_filename) == False else open(log_failed_zda_msg_filename,'ab')
        self.connect_to_remote()

    def connect_to_remote(self):
        '''Connect to remote receiver for getting NMEA data and sending corrections'''
        if self.s is not None:
            self.s.close()
        if self.s_nmea is not None:
            self.s_nmea.close()
        self.s = serial.serial_for_url(self.port,timeout=5,baudrate=self.baudrate)
        if self.verbose: print(self.port,self.s)
        if self.nmea_port is not None:
            self.s_nmea = serial.serial_for_url(self.nmea_port,timeout=5,baudrate=self.baudrate)
            if self.verbose: print(self.nmea_port,self.s_nmea)

    def read_corr_data(self):
        '''read RTCM/CMR rec98 data from T0x into a memory array'''
        if self.verbose: print("Reading corr data")
        for line in sp.Popen(['viewdat','-d98','-mb',self.t0x_filename], stdout=sp.PIPE).stdout:
            tokens = line.rstrip().split(b'\t')
            weeknum = int(tokens[1])
            weekseconds = float(tokens[2]) / 1000.0
            self.corr_time.append( weeknum*24*60*60*7. + weekseconds )
            self.corr_data.append( [int(x) for x in tokens[4:]] )
        if self.verbose: print("Done reading corr data")

    def checksumtest(self,s):
        '''Returns True if NMEA input 's' passes its checksum'''
        tokens = s.split(b'*')
        if len(tokens) >= 2:
            cs1 = int(tokens[1][:2].decode(),16)
            cs2 = 0
            for val in tokens[0][1:]:
                cs2 ^= val
        return (cs1 == cs2)

    def log_failed_msg(self,msg):
        if self.log_failed_zda_msg_filename is not None:
            self.log_failed_zda_msg_filename.write( bytearray(str(datetime.datetime.now())+" @ "+str(self.port)+"::",encoding='utf-8') + msg + b'\n')
        
    def ZDA_to_gps_secs(self,s):
        '''Converts a ZDA NMEA string 's' to full GPS seconds since the start of GPS'''
        if not re.match(b'\$G.ZDA.*',s) and not re.match(b'\$INZDA*',s):
            self.log_failed_msg(s)
            return -1.
        if not self.checksumtest(s):
            return -1.

        tokens = s.split(b',')
        if len(tokens) < 5:
            return -1.
        if len(tokens[2]) == 0:
            return -1.

        time_str = bytearray(b',').join(tokens[1:5]).decode('ascii')
        utc_date = datetime.datetime.strptime(time_str,'%H%M%S.%f,%d,%m,%Y')

        # convert to GPS time
        gps_epoch = datetime.datetime(1980,1,6)
        gps_time = leapseconds.utc_to_gps(utc_date) - gps_epoch
        return gps_time.total_seconds()

    def full_secs_to_week_sec(self,secs):
        week = int(secs / (24*7*60*60.))
        secs_in_week = secs - week*24*7*60*60.
        return week, secs_in_week

    def try_to_reconnect(self, timeout_sec=60):
        # Try to reconnect sockets in case of a network error or
        # a receiver reboot, but give up after timeout_sec.
        t1 = time.time()
        while True:
            print(time.strftime(" Retry at %H:%M:%S", time.gmtime()))
            try:
                self.connect_to_remote()
                print(" Reconnected!")
                return 0
            except serial.SerialException:
                time.sleep(1)
                if time.time() - t1 > timeout_sec:
                    print(" Giving up on connection...")
                    return -1

    def send_corr(self):
        '''Get current ZDA timestamp and send rec98 correction
        Returns: 0 if we're waiting for time alignment
                 1 if correction data was sent
                -1 if there is no more rec98 data to send
        '''
        try:
            if self.s_nmea is None:
                received = bytearray(self.s.read_until())
            else:
                received = bytearray(self.s_nmea.read_until())
        except serial.SerialException:
            received = []

        if len(received) == 0:
            print("Problem with read connection...")
            return self.try_to_reconnect()

        if self.log_nmea_file is not None:
            self.log_nmea_file.write(received)

        ZDA_gps_secs = self.ZDA_to_gps_secs(received.rstrip())
        if ZDA_gps_secs < 0.:
            return 0

        if self.fake:
            # send fake data for every NMEA packet
            self.corr_time = [ZDA_gps_secs-.1, ZDA_gps_secs+.1]
            self.lastIndex = 0

        # Find index of corr_time immediately after ZDA_gps_secs
        index = bisect.bisect(self.corr_time, ZDA_gps_secs)
        if index <= 0:
            if self.verbose:
                print("Waiting for real-time system to catch up")
            return 0
        if index >= len(self.corr_time):
            print("Ran out of correction data",self.port)
            print(" ZDA",self.full_secs_to_week_sec(ZDA_gps_secs))
            print(" corr",self.full_secs_to_week_sec(self.corr_time[-1]))
            return -1

        # drop corr_data[] that is really old at startup...
        if self.lastIndex == 0 and index > 10:
            self.lastIndex = index-1

        # Send all unsent data up to 'index'
        problem_with_write = False
        for idx in range(self.lastIndex,index):
            if self.verbose:
                print("corr len %d ZDA=%.3f corr=%.3f" %
                      (len(self.corr_data[idx]),
                       self.full_secs_to_week_sec(ZDA_gps_secs)[1],
                       self.full_secs_to_week_sec(self.corr_time[idx])[1]
                      ))
            try:
                self.s.write(self.corr_data[idx])
            except serial.SerialException:
                problem_with_write = True
                break
        if problem_with_write:
            print("Problem with write connection...")
            return self.try_to_reconnect()

        self.lastIndex = index
        return 1

    def do_DCOL(self,cmd_num,cmd_data):
        '''Send DCOL command # (cmd_num) with data (cmd_data) over
        current receiver port.  Return binary data response.'''
        self.s.reset_input_buffer()
        cmd=RXTools.formDColCommand(cmd_num,cmd_data)
        self.s.write(cmd)
        time.sleep(1)
        data = b''
        while self.s.in_waiting and len(data) < 1024:
            data += self.s.read()
        return bytearray(data)

    def enable_ZDA(self):
        '''Enable 10Hz NMEA ZDA on current receiver port (and disable all
        other outputs on the port).'''
        try_count = 0
        # get port #
        while try_count < 5:
            data = self.do_DCOL(0x6f,[])
            pat = re.compile(b'PORT,([0-9]+),')
            pat_data = pat.findall(data)
            if len(pat_data) == 0:
                print(f"enable_ZDA - failed to find pattern {try_count}")
                time.sleep(5)
                try_count += 1
                continue
            self.rcvr_port_idx = int(pat_data[0])+1

            # turn off all outputs and enable 10Hz ZDA
            self.do_DCOL(0x51,[0xa,self.rcvr_port_idx])
            data = self.do_DCOL(0x51,[0x7,self.rcvr_port_idx,29,1,0])
            if data[0] == 0x06:
                return
            print(f"enable_ZDA - failed to turn on 10Hz ZDA {try_count}")
            time.sleep(5)
            try_count += 1
        print("Couldn't enable ZDA")

    def disable_ZDA(self):
        '''Disable ZDA (and all outputs) on the port.'''
        self.do_DCOL(0x51,[0xa,self.rcvr_port_idx])

    def do_loop(self):
        '''Main loop - runs until we are out of valid correction data'''
        if self.fake:
            print("Fake corr data")
            for _ in range(2):
                self.corr_time.append( [0] )
                self.corr_data.append(
                    [ 0xd3, 0x00, 0x15, 0x3e,
                      0xe0, 0x00, 0x02, 0x39,
                      0xbc, 0x6b, 0xf8, 0x22,
                      0x35, 0xf7, 0xf5, 0x7d,
                      0x2c, 0x48, 0xf3, 0x87,
                      0xb1, 0xe8, 0x00, 0x00,
                      0xb9, 0xde, 0x1f ] )
        else:
            if not os.path.isfile(self.t0x_filename):
                print("Cannot find input file '%s'" % self.t0x_filename)
                return
            self.read_corr_data()
        if self.s_nmea is None:
            self.enable_ZDA()
        while self.send_corr() >= 0:
            pass
        if self.s_nmea is None:
            self.disable_ZDA()

    def close(self):
        '''Call this to turn off ZDA and disconnect from the receiver port'''
        self.disable_ZDA()
        self.s.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=usage,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument("T0x_file", help="File with rec98 data")
    parser.add_argument("port", help="serial or TCP port to playback on (e.g. /dev/ttyS0 or socket://10.1.150.xxx:5018).  Must be writeable")
    parser.add_argument("--nmea_port", help="serial or TCP port with NMEA ZDA data (e.g. /dev/ttyS0 or socket://10.1.150.xxx:5018)")
    parser.add_argument("--log_filename", help="Log all NMEA data to a file?")
    parser.add_argument("-v", "--verbose", help="Show info when we send a packet?", action="store_true", default=False )
    parser.add_argument("--baud", help="Baud rate for serial port? (e.g., 9600)", default=9600, type=int )
    parser.add_argument("--fake", help="Use fake data instead of T0x_file.  Useful for debugging.", action="store_true", default=False )
    args = parser.parse_args()

    rc = ReplayCorr(args.T0x_file, args.port, nmea_port=args.nmea_port, log_nmea_filename=args.log_filename, verbose=args.verbose, baudrate=args.baud, fake=args.fake)
    rc.do_loop()
