#!/usr/bin/env python

# $Id: monitorLatency.py,v 1.6 2024/02/05 04:45:04 wlentz Exp $

# One of our product specifications is position solution latency. This
# script gives an indication of the solution latency. In the I/O library
# the NMEA:ZDA message will provide the currrent time with a resolution
# of 1ms. Therefore to get an estimate of the latency for years we have
# enabled GGA and ZDA, ZDA is generated after the GGA. Therefore if you
# subtract the GGA time (time of position) from ZDA time (time the ZDA
# message was generated which is very soon after the GGA message is
# generated) it will give an indicate of the latecny. This does not 
# captured any further delays after the IOTX task, for example it does
# not capture delays in the drivers or any physical layer delays
#
# The code has been written as a class. However, the script will work
# standalone. The class structure was used and the general structure of
# "replayCorr.py" used so that a python layer can be added on top of
# this file to in parallel monitor multiple receivers.
#
# Copyright Trimble Inc., 2019-2020
#

usage="""\
To use:
  provide the socket information
  provide the required frequency rate (will setup NMEA)
  optionally provide the prefix to generate more unique file names

  Important: The port must be read/write.
"""

import matplotlib
# Allow running headless from the command line
matplotlib.use("agg")

from numpy import *
from pylab import *

import datetime
#from datetime import timezone
import serial
import argparse
import leapseconds
import subprocess as sp
import re
import sys
import math
import RXTools as rx
import mutils as m
import plot_spa as spa
import time
import os.path
import signal
import zipfile
import bisect
import threading
import queue
import psutil
import traceback

# Make the object used for latency measurements global so it is
# available from within signal_handler() to perform an orderly shutdown.
rc = None

# List of tasks we will highlight in various graphs. We'll find the IDs
# from the data, as this can vary from product to product / release to release
taskLUT= [{'type':'Ethernetwritetask', 'color':'b'},
         {'type':'IOtransmit', 'color':'g'},
         {'type':'PosManager', 'color':'c'},
         {'type':'IMU', 'color':'m'}]


def signal_handler(received_signal, frame):
  global rc
  print('You pressed Ctrl+C!')
  if rc is not None:
    print("Shutting down...")
    rc.stop()

  print("Exiting")
  sys.exit(0)

class monitorLatency(object):
    count = 0
    oldTime = 'XX'
    def __init__(self, port, latency_thres_ms, latency_num, 
                       prefix=None, 
                       user='admin',
                       password='password',
                       log=False,
                       freq=0,
                       start_nmea_output=False ):
        '''
        port = Where to send data, e.g. /dev/ttyS0 or socket://10.1.150.x:5018
        freq = Freqency of NMEA in Hz
        Optional:
          prefix - prefix to generated files
          log    - True/False - enables logging of NMEA ZDA/GGA to a
                   file. Note the file rolls over every hour
        %s
        ''' % usage

        # Option to enable sending commands to start NMEA output when
        # TCP/IP port is 2-way.
        self.start_nmea_output = start_nmea_output

        self.newData = False
        self.data_thread         = self.pkt_thread = threading.Thread( target=self.dataThread)
        self.data_thread_enabled = True
        self.data_mutex = threading.Lock()        
        self.data_thread.daemon = True # make thread stop when program stops
        self.data_thread.start() # Start the thread

        # Default 
        if(user is not None):
          self.user = user
        else:
          self.user = 'admin'

        if(password is not None):
          self.password = password
        else:
          self.password = 'password'

        self.latency_thres_ms = latency_thres_ms
        self.latency_num = latency_num

        utc_date  = datetime.datetime.utcnow()
        # Leap second call is expensive, so only do it once (at least on a PC under WSL)
        gps_time = leapseconds.utc_to_gps(utc_date)
        # Get the leap seconds
        self.LS = (gps_time - utc_date).seconds
        self.gps_epoch = datetime.datetime(1980,1,6)

        self.freq = int(freq)
        self.msPerEpoch = 1000/self.freq
        self.minValue = [99999] * self.freq
        self.maxValue = [0] * self.freq
        self.sumValue = [0] * self.freq
        self.numInSum = [0] * self.freq
        # Storage array for all logged data as a function of epoch within
        # the second. We'll use bisect to keep this sorted so we can easily
        # extract the percentiles
        self.history_time    = []
        self.history_latency = []

        if(prefix is not None):
          self.prefix = prefix
        else:
          self.prefix = ''
        self.summary_fid = open(self.prefix + 'Summary.txt','a')
        self.hist_fid = open(self.prefix + 'Hist.txt','a')

        self.port = port

        if(log == True):
          now = datetime.datetime.utcnow()
          self.logHour = now.hour

          self.NMEA_fid_root = (  self.prefix+ '-' 
                                 + str(now.year) 
                                 + str(now.month).zfill(2) 
                                 + str(now.day).zfill(2) 
                                 + str(now.hour).zfill(2) + '-NMEA')

          self.NMEA_fid = open(self.NMEA_fid_root + '.txt','wb')
        else:
          self.NMEA_fid = None

        # Quick hack to get the IP address - likely a more robust
        # way to do this
        if(':' in port):
          self.IPAddr = port[9:]
          token = self.IPAddr.split(':')
          self.IPAddr = token[0]
        else:
          self.IPAddr = None

        self.s = serial.serial_for_url(port,timeout=5)
        self.s_nmea = None
        self.rcvr_port_idx = -1
        self.FileHandle = None
        self.FilePrefix = None

    def dataThread(self):
      while(True):
        self.data_mutex.acquire()
        newData = self.newData
        self.data_mutex.release()

        if(newData):
          # Assume the path
          path = '/Internal/'

          # Must be T04 format for this to work
          filetype = 'T04'
            
          try:
            ret = rx.GetDirectory(self.IPAddr,self.user,self.password,path)

            fileList = []
            for line in ret.splitlines():
              if("file" in line):
               data = line.split()
               filename = data[1][5:]
               if( ("." + filetype) in filename):
                 fileList.append(filename)

            fileList.sort()     
            print('File download',fileList[-1])
            
            # Remove the weeks from the time tag
            weeks = int(self.cacheTime[-1]/ (86400*7) )
            self.cacheTime -= weeks*86400*7

            # Make a copy so the other thread can't impact us
            cacheTime    = self.cacheTime.copy()
            cacheLatency = self.cacheLatency.copy()

            # Save the latency data - add '.latency.txt' to the matching T04 filename
            with open(fileList[-1] + '.latency.txt','w') as fid:
              for i in range(len(cacheTime)):
                fid.write('%.3f %.6f\n' % (cacheTime[i], cacheLatency[i]))

            plot(cacheTime,cacheLatency,'.')
            xlabel('Time [GPS Secs]')
            ylabel('Latency [ms]')
            grid()
            tight_layout()
            m.save_compressed_png('%s.png' % (fileList[-1]), dpi=600)
            close()

            # Now download the T04 file - we'll spend most time here!
            rx.getFileHTTP(self.IPAddr,self.user,self.password,'.',path,fileList[-1],fileList[-1],progress=True,chunksize=1024*1024)

            # now get the task IDs from the data file
            taskData = []
            IDs,names,cores = m.parse_load_diags_task_names(fileList[-1])

            # First deal with RTK, which is a special case.
            start = IDs[names.index('PreTN')]
            # Find the last task that starts with TN
            i = m.find(array([x.startswith('TN') for x in names]))
            stop = IDs[i[-1]]

            taskData.append({'start':start, 'stop':stop, 'color':'r'})
            
            # Now deal with the rest of the task IDs
            names = array(names)
            IDs = array(IDs)
            for thisTask in taskLUT:
              i = m.find(names == thisTask['type'])
              thisID = IDs[i]

              taskStart = min(thisID)
              taskStop  = max(thisID)
              taskData.append({'start':taskStart,
                                'stop':taskStop,
                                'color':thisTask['color']})

            # Sort in reverse order
            latencySort = sort(cacheLatency)[::-1]

            for index in range(self.latency_num):
              latency = latencySort[index]
              i = m.find(latency == cacheLatency)
              idx = i[0]

              print(idx,cacheTime[idx],cacheLatency[idx])
              t0 = cacheTime[idx]
              t1 = t0 + 0.01

              # Matlab - context switch only
              print('viewdat -d35:271 --mat=context.mat -s%.3f -e%.3f %s   ' % (t0,t1,fileList[-1]))
              print('plotContextTiming(%s,%.3f,1,0);' % (fileList[-1],t0))
              # Now output helper debug - this can be used for more interactive debugging
              # in Matlab or Python:
              # Python - context and mutex
              print('python plot_spa.py -l %s %.3f %.3f' % (fileList[-1],t0,t1) )
              # python wrapper to the above
              #print('python plot_spa_segment.py -l %s %.3f %.3f %.6f' % (fileList[-1],t0,t1,maxVal) )


              # Extract the mutex and context switch data
              tsk,mut = spa.load_spa( fileList[-1], t0, t1, spa.mode_long_mutexes)
              # Plot it
              spa.do_plots( tsk, mut, spa.mode_long_mutexes)

              # The following adjusts these plots:
              # 1. Adds a vertical dashed line showing the latency for this epoch
              # 2. Adds rectrangles to highlight different tasks / task groups.
              # 3. Saves as a PNG
              latency /= 1000 # Convert from seconds to ms
              ID = 0
              for i in get_fignums():
                figure(i)

                # Add a dashed line showing the latency for this 10ms period
                yMin, yMax = gca().get_ylim()
                xMin, xMax = gca().get_xlim()
                plot([t0 + latency,t0 + latency],[yMin,yMax],'k-.')
              
                # A a rectangle highlighting certain tasks
                for thisTask in taskData:
                  polygon = Polygon([(xMin, thisTask['start']-0.5), (xMin, thisTask['stop']+0.5), (xMax, thisTask['stop']+0.5), (xMax, thisTask['start']-0.5)], facecolor=thisTask['color'], alpha=0.1)
                  gca().add_patch(polygon)

                gca().set_xlim(xMin, xMax)
                gca().set_ylim(0, yMax)
                tight_layout()

                # Save as a PNG
                if(ID==0):
                  m.save_compressed_png('%s-%.3f-context-rank-%d.png' % (fileList[-1],t0,index+1),dpi=600)
                else:
                  m.save_compressed_png('%s-%.3f-mutex-core-%d-rank-%d.png' % (fileList[-1],t0,ID-1,index+1),dpi=600)
                ID += 1
                close()


            print('Context/Mutex/Download complete')
          
            # This test uses test firmware that enables context and mutex switch data.
            # It creates over 1GB/hour, which is a far greater rate (10x) that normal.
            # We've run into some lost data, likely an issue with autodelete not kicking
            # in quickly enough. To help, we can delete all the files in fileList[] - 
            # we don't need any of these old files. We've either downloaded them, or
            # skipped them as the latency was not above the threshild.
            # fileList was the directory listing of T04 files soon after the main
            # thread triggered this code
            for thisFile in fileList:
              rx.DeleteFile(self.IPAddr,self.user,self.password,path,thisFile)
          except Exception:
            isotime = datetime.datetime.utcnow().isoformat()
            print(isotime,'Recovering from data error')
            # print the traceback information
            traceback.print_exc()
          
          # We've finished processing and using any shared data. Clear the flag so 
          # that the main NMEA thread can send more data. We'll then continue to
          # complete the data analytics in this thread - this is relatively quick
          self.data_mutex.acquire()
          self.newData = False
          self.data_mutex.release()

        else:
          time.sleep(10)

    def checksumtest(self,s):
        '''Returns True if NMEA input 's' passes its checksum'''
        tokens = s.split(b'*')
        if len(tokens) >= 2:
            cs1 = int(tokens[1][:2].decode(),16)
            cs2 = 0
            for val in tokens[0][1:]:
                cs2 ^= val
        return (cs1 == cs2)

    def NMEA_to_gps_secs(self,s,processDate):
        '''Converts an NMEA string 's' to full GPS seconds since the start of GPS
           When the data rolls into a new hour save some diagnostics
        '''
        #if not re.match(b'\$G.ZDA.*',s):
        #    return -1.
        #if not self.checksumtest(s):
        #    return -1.

        tokens = s.split(b',')
        if len(tokens) < 5:
            return -1.
        if len(tokens[2]) == 0:
            return -1.

        if(processDate == True):
          time_str = bytearray(b',').join(tokens[1:5]).decode('ascii')
          utc_date = datetime.datetime.strptime(time_str,'%H%M%S.%f,%d,%m,%Y')
        else:
          now = datetime.datetime.utcnow()
          time_str = tokens[1].decode('ascii') + ' ' + str(now.year) + ' ' + str(now.month) + ' ' + str(now.day)
          utc_date = datetime.datetime.strptime(time_str,'%H%M%S.%f %Y %m %d')

        # convert to GPS time
        gps_time = (utc_date - self.gps_epoch).total_seconds() + self.LS 
        
        # Test to see if we are processing full date (ZDA) and the time
        # has bumped to the next hour. If so plot the data and write
        # information to the summary file
        if( (processDate == True) and (time_str[0:2] != self.oldTime)):
          self.oldTime = time_str[0:2]
          if(self.numInSum[0] > 0):
            date = str(utc_date.year) + '-' + str(utc_date.month).zfill(2) + '-' + str(utc_date.day).zfill(2) + 'T' + str(utc_date.hour).zfill(2) 
            print('%s %ld %.4f %.4f' % (date,len(self.history_latency),max(self.history_latency), min(self.history_latency)) )

            if(max(self.history_latency) > self.latency_thres_ms):
              self.cacheLatency = self.history_latency.copy()
              self.cacheTime    = array(self.history_time.copy())

              self.data_mutex.acquire()
              self.newData = True
              self.data_mutex.release()

            # Clear the history
            self.history_time = []
            self.history_latency = []

        return gps_time

    def full_secs_to_week_sec(self,secs):
        week = int(secs / (24*7*60*60.))
        secs_in_week = secs - week*24*7*60*60.
        return week, secs_in_week

    def parse_nmea(self):
        '''Parse ZDA/GGA and get the latency. It will loop until it has an issue'''
        ZDA_gps_secs = -1;
        gotZDA = False

        while(ZDA_gps_secs == -1):
          # Do regular socket/serial read.
          received = bytearray(self.s.readline())
          # Get the current time.
          t_raw = time.time()

          # Ethernet delay.
          s = received.rstrip()

          if self.NMEA_fid is not None:
            now = datetime.datetime.utcnow()
            if(now.hour != self.logHour):
              self.logHour = now.hour
              self.NMEA_fid.close()

              # zip the file
              zipfile.ZipFile(self.NMEA_fid_root + '.zip', mode='w').write(self.NMEA_fid_root + '.txt', compress_type=zipfile.ZIP_DEFLATED)
              # Remove the file we zipped
              os.remove(self.NMEA_fid_root + '.txt')

              self.NMEA_fid_root = (  self.prefix+ '-' 
                                    + str(now.year) 
                                    + str(now.month).zfill(2) 
                                    + str(now.day).zfill(2) 
                                    + str(now.hour).zfill(2) + '-NMEA')

              self.NMEA_fid = open(self.NMEA_fid_root + '.txt','wb')

            self.NMEA_fid.write(received)

          if( (re.match(b'\$G.GGA.*',s) or re.match(b'\$I.GGA.*',s)) and self.checksumtest(s)):
            GGA_gps_secs = self.NMEA_to_gps_secs(s,False)
            gotZDA = True
          elif((gotZDA == True) and (re.match(b'\$G.ZDA.*',s) or re.match(b'\$I.ZDA.*',s)) and self.checksumtest(s)):
            # We have the GGA and ZDA data
            ZDA_gps_secs = self.NMEA_to_gps_secs(s,True)

            # The following code uses "t_raw" and time encoded in the
            # GGA/ZDA to approximate the Ethernet delay. This delay is a
            # combination of:
            #
            # - Delays in the receiver from the creation of the GGA/ZDA
            # (in IOTX) to output on the wire
            # - Collisions and other problems on the network
            # - Router delay
            # - Linux delays / processing load
            # - Delays thru to the "t_raw = time.time()" statement above
            #
            # After the following we'll have
            # GGA_gps_secs - GPS seconds of the validity of the position
            # ZDA_gps_secs - GPS time of when IOTX created the NMEA message
            # Eth_gps_secs - GPS time of when this function parsed the data (approx the receive time - or Ethernet delayed time)
#            Eth_gps_secs = (leapseconds.utc_to_gps(datetime.datetime.utcfromtimestamp(t_raw)) - gps_epoch).total_seconds()
#            ZDADelay = ZDA_gps_secs - GGA_gps_secs
#            EthDelay = Eth_gps_secs - ZDA_gps_secs
#            TotalDelay = Eth_gps_secs - GGA_gps_secs
#            Week = int(GGA_gps_secs/(7*86400))
#            WeekSeconds = GGA_gps_secs - Week*7*86400

#            dateInfo = s.decode('ascii').split(',')
#            DateString = dateInfo[4] + '-' + dateInfo[3] + '-' + dateInfo[2]
#            if(self.FileHandle == None):
#              self.FileHandle = open('EthernetTime-' + DateString + '.txt','a')
#              self.FilePrefix = DateString
#            elif(self.FilePrefix != DateString):
#              self.FileHandle.close()
#              self.FileHandle = open('EthernetTime-' + DateString + '.txt','a')
#              FilePrefix = DateString
       
#            self.FileHandle.write("%d %.2f %.3f %.6f %.6f\n" % (Week,WeekSeconds,ZDADelay,EthDelay,TotalDelay))

        if ZDA_gps_secs < 0.:
          return 0

        delta =  ZDA_gps_secs- GGA_gps_secs;

        # Build up a data table if it passes basic sanity tests.
        # It should not be negative, but it might from an inertial
        # system as the data is extrapolated under certain conditions
        if( (delta > -0.05) and (delta < 1.00)):
          millisecond = round(1000*(GGA_gps_secs - int(GGA_gps_secs)))
          epoch   = int(millisecond / self.msPerEpoch)
          if(epoch >= 0 and epoch < self.freq):
            if(delta < self.minValue[epoch]):
              self.minValue[epoch] = delta
            if(delta > self.maxValue[epoch]):
              self.maxValue[epoch] = delta
            self.sumValue[epoch] += delta
            self.numInSum[epoch] += 1
            # Add the data to a sorted array one array per epoch, use
            # bisect to keep the list sorted.
            #bisect.insort(self.history[epoch],delta*1000.0)
            self.history_latency.append(delta*1000.0)
            # This should be in GPS time
            self.history_time.append(GGA_gps_secs)

        # Output a diagnostic occassionally
        self.count += 1
        if (self.count % self.freq) == 0:
          print("min %.2f max %.2f [ms]"%(1e3*min(self.minValue),1e3*max(self.maxValue)))
        return 1

    def do_DCOL(self,cmd_num,cmd_data):
        '''Send DCOL command # (cmd_num) with data (cmd_data) over
        current receiver port.  Return binary data response.'''
        self.s.reset_input_buffer()
        cmd=RXTools.formDColCommand(cmd_num,cmd_data)
        self.s.write(cmd)
        time.sleep(1)
        data = b''
        while self.s.in_waiting and len(data) < 1024:
            data += self.s.read()
        return bytearray(data)

    def enable_NMEA(self):
        '''Enable NMEA ZDA/GGA on current receiver port (and disable all
        other outputs on the port).'''
        
        # it is hard to get the port number if a lot of data is already 
        # streaming. Start by turning off all data
        self.do_DCOL(0x51,[0xa,0xff])
        
        # get port #
        data = self.do_DCOL(0x6f,[])
        pat = re.compile(b'PORT,([0-9]+),')
        self.rcvr_port_idx = int(pat.findall(data)[0])+1

        # Convert the frequency to the rate parmeter for the NMEA I/O
        # command
        if(self.freq == 100):
          rate = 16
        elif(self.freq == 50):
          rate = 15
        elif(self.freq == 20):
          rate = 13
        elif(self.freq == 10):
          rate = 1
        elif(self.freq == 5):
          rate = 2
        else:
          print("Not a valid frequency " + str(self.freq))
          exit()
        # Enable the GGA and ZDA
        data = self.do_DCOL(0x51,[0x7,self.rcvr_port_idx,26,rate,0])
        data = self.do_DCOL(0x51,[0x7,self.rcvr_port_idx,29,rate,0])

    def disable_NMEA(self):
        '''Disable ZDA/NMEA (and all outputs) on the port.'''
        self.do_DCOL(0x51,[0xa,self.rcvr_port_idx])

    def start(self):
        if self.start_nmea_output:
          self.enable_NMEA()

    def stop(self):
        # Close the "EthernetTime-*" file.
        if not None == self.FileHandle:
          self.FileHandle.close()

    def do_loop(self):
        '''Main loop - runs until we have an issue with the NMEA data'''

        self.start()

        while self.parse_nmea() >= 0:
            pass

        # Add call to stop() for form.
        # It will actually be called by signal_handler().
        self.stop()

    def close(self):
        '''Call this to turn off ZDA and disconnect from the receiver port'''
        # ??? Looks like disable_ZDA() does not exist. This is never called.
        self.disable_ZDA()
        self.s.close()

if __name__ == "__main__":
    # Handle Ctrl-C termination.
    signal.signal(signal.SIGINT, signal_handler)
    # Handle "kill PID" termination.
    signal.signal(signal.SIGTERM, signal_handler)

    parser = argparse.ArgumentParser(description=usage,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument( "port",
                          help="socket://10.1.150.xxx:5018).  Must be writeable")
    parser.add_argument( "freq", help="NMEA frequency in Hz" )
    parser.add_argument( "--prefix", help="file prefix" )
    parser.add_argument( "--user", help="username - default admin" )
    parser.add_argument( "--password", help="password - default password" )
    parser.add_argument( "--latency", help="latency threshold in milliseconds - default 7.0" )
    parser.add_argument( "--num", help="Find the worst N - at least one must be over the threshold - default 1" )
    parser.add_argument( "--log_NMEA",
                         help="Log the NMEA data",
                         action="store_true",
                         default=False )

    # Added optional --start_nmea_output argument.
    parser.add_argument( "--start_nmea_output",
                         help="Send commands to start NMEA output.",
                         action="store_true",
                         default=False )

    args = parser.parse_args()

    if(args.latency is not None):
      latency_thres_ms = float(args.latency)
    else:
      latency_thres_ms = 7.0
    
    if(args.num is not None):
      latency_num = int(args.num)
    else:
      latency_num = 1

    rc = monitorLatency( args.port,
                         latency_thres_ms,
                         latency_num,
                         prefix = args.prefix,
                         user = args.user,
                         password = args.password,
                         log = args.log_NMEA,
                         freq=args.freq,
                         start_nmea_output = args.start_nmea_output)
    rc.do_loop()
