#!/usr/bin/env python

#
# One of our product specifications is position solution latency. This
# script gives an indication of the solution latency. In the I/O library
# the NMEA:ZDA message will provide the currrent time with a resolution
# of 1ms. Therefore to get an estimate of the latency for years we have
# enabled GGA and ZDA, ZDA is generated after the GGA. Therefore if you
# subtract the GGA time (time of position) from ZDA time (time the ZDA
# message was generated which is very soon after the GGA message is
# generated) it will give an indicate of the latecny. This does not 
# captured any further delays after the IOTX task, for example it does
# not capture delays in the drivers or any physical layer delays
#
# The code has been written as a class. However, the script will work
# standalone. The class structure was used and the general structure of
# "replayCorr.py" used so that a python layer can be added on top of
# this file to in parallel monitor multiple receivers.
#
# Copyright Trimble Inc., 2019
#

usage="""\
To use:
  provide the socket information
  provide the required frequency rate (will setup NMEA)
  optionally provide the prefix to generate more unique file names

  Important: The port must be read/write.
"""

import matplotlib
# Allow running headless from the command line
matplotlib.use("agg")

from numpy import *
from pylab import *

import datetime
import serial
import argparse
import leapseconds
import subprocess as sp
import re
import math
import RXTools
import time
import os.path
import signal
import zipfile
import bisect

def signal_handler(signal, frame):
  print('You pressed Ctrl+C!')
  print("Exiting")
  sys.exit(0)

class getLatency(object):
    count = 0
    oldTime = 'XX'
    def __init__(self,port,prefix=None,log=False, freq=0, short_diags=False):
        '''
        port = Where to send data, e.g. /dev/ttyS0 or socket://10.1.150.x:5018
        freq = Freqency of NMEA in Hz
        Optional:
          prefix - prefix to generated files
          log    - True/False - enables logging of NMEA ZDA/GGA to a
                   file. Note the file rolls over every hour
          short  - print shorter diagnostics to the screen
        %s
        ''' % usage
        
        self.freq = int(freq)
        self.msPerEpoch = 1000/self.freq
        self.minValue = [99999] * self.freq
        self.maxValue = [0] * self.freq
        self.sumValue = [0] * self.freq
        self.numInSum = [0] * self.freq
        self.short_diags = short_diags
        self.postype_changes = -1
        self.last_postype = b""
        # Storage array for all logged data as a function of epoch within
        # the second. We'll use bisect to keep this sorted so we can easily
        # extract the percentiles
        self.history = [[] for x in range(self.freq)] 

        if(prefix is not None):
          self.prefix = prefix
        else:
          self.prefix = ''
        self.summary_fid = open(self.prefix + 'Summary.txt','a')

        self.port = port

        if(log == True):
          now = datetime.datetime.utcnow()
          self.logHour = now.hour

          self.NMEA_fid_root = (  self.prefix+ '-' 
                                 + str(now.year) 
                                 + str(now.month).zfill(2) 
                                 + str(now.day).zfill(2) 
                                 + str(now.hour).zfill(2) + '-NMEA')

          self.NMEA_fid = open(self.NMEA_fid_root + '.txt','wb')
        else:
          self.NMEA_fid = None

        self.s = serial.serial_for_url(port,timeout=5)
        self.s_nmea = None
        self.rcvr_port_idx = -1


    def plotLatency(self,filename):
      fig=figure()
      ax=fig.add_subplot(111)

      sixtyEight = []
      ninetyFive = []
      for i in range(self.freq):
        # history is already sorted
        sixtyEight.append(self.history[i][int(len(self.history[i])*0.68)])
        ninetyFive.append(self.history[i][int(len(self.history[i])*0.68)])

      plot( array(range(0,self.freq)) * self.msPerEpoch, array(self.minValue)-0.0002,'gx',  label='Min Latency')
      plot( array(range(0,self.freq)) * self.msPerEpoch, array(self.maxValue)+0.0002,'rx',  label='Max Latency')
      plot( array(range(0,self.freq)) * self.msPerEpoch, array(self.sumValue) / array(self.numInSum),'bx',  label='Avg Latency')
      plot( array(range(0,self.freq)) * self.msPerEpoch, array(sixtyEight)-0.0001,'cx',  label='68%')
      plot( array(range(0,self.freq)) * self.msPerEpoch, array(ninetyFive)+0.0001,'mx',  label='95%')
   
      ax.set_xlim([0,1000])
      ydata = list(ax.get_ylim())
      ydata[0] = 0
      ax.set_ylim(ydata)
      xlabel('Millisecond Bin')
      ylabel('Latency [s]')
      grid(True)
      legend()
      tight_layout()
      # Prevent the axis numers having an offset
      ax.get_xaxis().get_major_formatter().set_useOffset(False)
      ax.get_yaxis().get_major_formatter().set_useOffset(False)
      show()
      # Save the data as a PNG file
      savefig(filename + '.png',dpi=150)
      close()




    def checksumtest(self,s):
        '''Returns True if NMEA input 's' passes its checksum'''
        tokens = s.split(b'*')
        if len(tokens) >= 2:
            cs1 = int(tokens[1][:2].decode(),16)
            cs2 = 0
            for val in tokens[0][1:]:
                cs2 ^= val
        return (cs1 == cs2)

    def NMEA_to_gps_secs(self,s,processDate):
        '''Converts an NMEA string 's' to full GPS seconds since the start of GPS
           When the data rolls into a new hour save some diagnostics
        '''
        #if not re.match(b'\$G.ZDA.*',s):
        #    return -1.
        #if not self.checksumtest(s):
        #    return -1.

        tokens = s.split(b',')
        if len(tokens) < 5:
            return -1.
        if len(tokens[2]) == 0:
            return -1.

        if(processDate == True):
          time_str = bytearray(b',').join(tokens[1:5]).decode('ascii')
          utc_date = datetime.datetime.strptime(time_str,'%H%M%S.%f,%d,%m,%Y')
        else:
          now = datetime.datetime.utcnow()
          time_str = tokens[1].decode('ascii') + ' ' + str(now.year) + ' ' + str(now.month) + ' ' + str(now.day)
          utc_date = datetime.datetime.strptime(time_str,'%H%M%S.%f %Y %m %d')

        # convert to GPS time
        gps_epoch = datetime.datetime(1980,1,6)
        gps_time = leapseconds.utc_to_gps(utc_date) - gps_epoch

        # Test to see if we are processing full date (ZDA) and the time
        # has bumped to the next hour. If so plot the data and write
        # information to the summary file
        if( (processDate == True) and (time_str[0:2] != self.oldTime)):
          self.oldTime = time_str[0:2]
          if(self.numInSum[0] > 0):
            date = str(utc_date.year) + '-' + str(utc_date.month).zfill(2) + '-' + str(utc_date.day).zfill(2) + 'T' + str(utc_date.hour).zfill(2) 
            self.plotLatency(self.prefix + '-' + date + '-Latency-' + str(self.freq) + 'Hz')
            self.summary_fid.write("%d %d %d %d %d %d " % (utc_date.year, utc_date.month, utc_date.day, utc_date.hour, min(self.numInSum), self.freq))
            for i in range(self.freq):
              if(self.numInSum[i] > 0):
                sixtyEight = self.history[i][int(len(self.history[i])*0.68)]
                ninetyFive = self.history[i][int(len(self.history[i])*0.95)]
                self.summary_fid.write("%.6f %.6f %.6f %.6f %.6f " % (self.minValue[i],
                                                                      self.maxValue[i],
                                                                      self.sumValue[i]/self.numInSum[i],
                                                                      sixtyEight,
                                                                      ninetyFive))
              else:
                self.summary_fid.write("%.6f %.6f NaN NaN NaN " % (self.minValue[i],self.maxValue[i]))
              
              self.minValue[i] = 9999
              self.maxValue[i] = 0
              self.sumValue[i] = 0
              self.numInSum[i] = 0
              self.history[i] = []
            self.summary_fid.write("\n")
            self.summary_fid.flush()

        return gps_time.total_seconds()

    def full_secs_to_week_sec(self,secs):
        week = int(secs / (24*7*60*60.))
        secs_in_week = secs - week*24*7*60*60.
        return week, secs_in_week

    def print_diags(self,epoch,GGA_gps_secs,ZDA_gps_secs):
        if self.short_diags:
          if(self.count == (self.freq + 1)):
            self.count = 0
            ninetyFive = 0.0
            for i in range(self.freq):
              if(self.numInSum[i] > 0):
                curr_val = self.history[i][int(len(self.history[i])*0.95)]
                if curr_val > ninetyFive:
                  ninetyFive = curr_val
            print("T=%.3f min=%.6f max=%.6f 95%%=%.6f pos_change=%d"%(
              GGA_gps_secs,
              min(self.minValue),
              max(self.maxValue),
              ninetyFive,
              self.postype_changes
            )
            )
        else:
          if(self.count == (self.freq + 1)):
            if( epoch >= 0 and epoch < self.freq):
              if(self.numInSum[epoch] > 0):
                sixtyEight = self.history[epoch][int(len(self.history[epoch])*0.68)]
                ninetyFive = self.history[epoch][int(len(self.history[epoch])*0.95)]
                print("%.2f %2d %.6f -- %.6f %.6f %.6f %.6f %.6f %d" % (GGA_gps_secs, epoch, ZDA_gps_secs- GGA_gps_secs,
                                                              self.minValue[epoch],
                                                              self.maxValue[epoch],
                                                              self.sumValue[epoch]/self.numInSum[epoch],
                                                              sixtyEight,
                                                              ninetyFive,
                                                              self.numInSum[epoch]))
              else:
                print("%.2f %2d %.6f -- NaN NaN NaN 0" % (GGA_gps_secs, epoch, ZDA_gps_secs- GGA_gps_secs))
            self.count = 0

          if(self.count == 0):
            for i in range(self.freq):
              if(self.numInSum[i] > 0):
                sixtyEight = self.history[i][int(len(self.history[i])*0.68)]
                ninetyFive = self.history[i][int(len(self.history[i])*0.95)]
                print("%d %.6f %.6f %.6f %.6f %.6f %d" % (i,
                                                          self.minValue[i],
                                                          self.maxValue[i],
                                                          self.sumValue[i]/self.numInSum[i],
                                                          sixtyEight,
                                                          ninetyFive,
                                                          self.numInSum[i]))
              else:
                print("%d NaN NaN NaN 0" % i)

    def parse_nmea(self):
        '''Parse ZDA/GGA and get the latency. It will loop until it has an issue'''
        ZDA_gps_secs = -1;
        gotZDA = False

        while(ZDA_gps_secs == -1):
          received = bytearray(self.s.readline())
          s = received.rstrip()

          if self.NMEA_fid is not None:
            now = datetime.datetime.utcnow()
            if(now.hour != self.logHour):
              self.logHour = now.hour
              self.NMEA_fid.close()

              # zip the file
              zipfile.ZipFile(self.NMEA_fid_root + '.zip', mode='w').write(self.NMEA_fid_root + '.txt', compress_type=zipfile.ZIP_DEFLATED)
              # Remove the file we zipped
              os.remove(self.NMEA_fid_root + '.txt')

              self.NMEA_fid_root = (  self.prefix+ '-' 
                                    + str(now.year) 
                                    + str(now.month).zfill(2) 
                                    + str(now.day).zfill(2) 
                                    + str(now.hour).zfill(2) + '-NMEA')

              self.NMEA_fid = open(self.NMEA_fid_root + '.txt','wb')

            self.NMEA_fid.write(received)

          if(re.match(b'\$..GGA.*',s) and self.checksumtest(s)):
            GGA_gps_secs = self.NMEA_to_gps_secs(s,False)
            gotZDA = True
          elif((gotZDA == True) and re.match(b'\$..ZDA.*',s) and self.checksumtest(s)):
            ZDA_gps_secs = self.NMEA_to_gps_secs(s,True)
            if s[0:3] != self.last_postype:
              self.postype_changes += 1
              self.last_postype = s[0:3]
        
        if ZDA_gps_secs < 0.:
          return 0


        delta =  ZDA_gps_secs- GGA_gps_secs;

        # Build up a data table if it passes basic sanity tests.
        # It should not be negative, but it might from an inertial
        # system as the data is extrapolated under certain conditions
        if delta < 0.0:
          print("Negative delta: %.6f %.6f"%(GGA_gps_secs,ZDA_gps_secs))
        epoch = -1
        if( (delta > -0.05) and (delta < 1.00)):
          millisecond = round(1000*(GGA_gps_secs - int(GGA_gps_secs)))
          epoch   = int(millisecond / self.msPerEpoch)
          if(epoch >= 0 and epoch < self.freq):
            if(delta < self.minValue[epoch]):
              self.minValue[epoch] = delta
            if(delta > self.maxValue[epoch]):
              self.maxValue[epoch] = delta
            self.sumValue[epoch] += delta
            self.numInSum[epoch] += 1
            # Add the data to a sorted array one array per epoch, use
            # bisect to keep the list sorted.
            bisect.insort(self.history[epoch],delta)

        # Output a diagnostic occassionally
        self.count += 1
        # Test for number of bins +1 so the data slowly moves through all millisecond
        # bins for data
        self.print_diags(epoch,GGA_gps_secs,ZDA_gps_secs)

        return 1

    def do_DCOL(self,cmd_num,cmd_data):
        '''Send DCOL command # (cmd_num) with data (cmd_data) over
        current receiver port.  Return binary data response.'''
        self.s.reset_input_buffer()
        cmd=RXTools.formDColCommand(cmd_num,cmd_data)
        self.s.write(cmd)
        time.sleep(1)
        data = b''
        while self.s.in_waiting and len(data) < 1024:
            data += self.s.read()
        return bytearray(data)

    def enable_NMEA(self):
        '''Enable NMEA ZDA/GGA on current receiver port (and disable all
        other outputs on the port).'''
        
        print("Enabling NMEA....")
        # it is hard to get the port number if a lot of data is already 
        # streaming. Start by turning off all data
        self.do_DCOL(0x51,[0xa,0xff])
        
        # get port #
        data = self.do_DCOL(0x6f,[])
        pat = re.compile(b'PORT,([0-9]+),')
        self.rcvr_port_idx = int(pat.findall(data)[0])+1

        # Convert the frequency to the rate parmeter for the NMEA I/O
        # command
        if(self.freq == 100):
          rate = 16
        elif(self.freq == 50):
          rate = 15
        elif(self.freq == 20):
          rate = 13
        elif(self.freq == 10):
          rate = 1
        elif(self.freq == 5):
          rate = 2
        else:
          print("Not a valid frequency " + str(self.freq))
          exit()
        # Enable the GGA and ZDA
        data = self.do_DCOL(0x51,[0x7,self.rcvr_port_idx,26,rate,0])
        data = self.do_DCOL(0x51,[0x7,self.rcvr_port_idx,29,rate,0])
        print("  .. done!")

    def disable_NMEA(self):
        '''Disable ZDA/NMEA (and all outputs) on the port.'''
        self.do_DCOL(0x51,[0xa,self.rcvr_port_idx])

    def do_loop(self):
        '''Main loop - runs until we have an issue with the NMEA data'''
        self.enable_NMEA()
        while self.parse_nmea() >= 0:
            pass

        self.disable_NMEA()

    def close(self):
        '''Call this to turn off ZDA and disconnect from the receiver port'''
        self.disable_ZDA()
        self.s.close()

if __name__ == "__main__":
    signal.signal(signal.SIGINT, signal_handler)
    parser = argparse.ArgumentParser(description=usage,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument("port", help="socket://10.1.150.xxx:5018).  Must be writeable")
    parser.add_argument("freq", help="NMEA frequency in Hz")
    parser.add_argument("--prefix", help="file prefix")
    parser.add_argument("--log_NMEA", help="Log the NMEA data", action="store_true", default=False )
    parser.add_argument("--short", help="Print shorter diagnostics", action="store_true", default=False )
    args = parser.parse_args()

    rc = getLatency(args.port,prefix=args.prefix,log=args.log_NMEA,freq=args.freq,short_diags=args.short)
    rc.do_loop()


