import datetime
import threading
import time
import serial
import re
import json
import sys
import matplotlib.pyplot as plt
import numpy as np

#
# Script to monitor the NMEA height from receivers and update
# a scrolling plot. By default connects to a couple of test receivers,
# supply a JSON file if you want to connect to more receivers
#
# Copyright Trimble Inc 2022
#

rxData = []
NumPoints = 1000 # 100s at 10Hz

class nmeaHandle(object):
    def __init__(self,IPAddr,port,rxNum,verbose=True):
        print(IPAddr)
        print(port)
        self.port = 'socket://' + IPAddr + ':' + str(port)
        self.verbose = True
        self.baudrate = 9600
        self.s = None
        self.connect_to_remote()
        self.index = 0
        self.num = rxNum

    def connect_to_remote(self):
        '''Connect to remote receiver for getting NMEA data and sending corrections'''
        if self.s is not None:
            self.s.close()
        self.s = serial.serial_for_url(self.port,timeout=5,baudrate=self.baudrate)
        if self.verbose: print(self.port,self.s)

    def checksumtest(self,s):
        '''Returns True if NMEA input 's' passes its checksum'''
        tokens = s.split(b'*')
        if len(tokens) >= 2:
            cs1 = int(tokens[1][:2].decode(),16)
            cs2 = 0
            for val in tokens[0][1:]:
                cs2 ^= val
        return (cs1 == cs2)

    def getGGA(self,s):
        global rxData

        # Check we have a GGA message and it passes the checksum
        if not re.match(b'\$G.GGA.*',s):
            return -1.
        if not self.checksumtest(s):
            return -1.
        
        # Basic sanity check on the data
        tokens = s.split(b',')
        if len(tokens) < 15:
            return -1.
        if len(tokens[1]) == 0:
            return -1.

        # Convert the time to seconds in the day
        secs  = float(tokens[1][0:2]) * 3600
        secs += float(tokens[1][2:4]) * 60
        secs += float(tokens[1][4:])

        # ToDo - deal with E/W and N/S
        #lat = tokens[2]
        #lon = tokens[4]

        # Get the height
        hgt = float(tokens[9])
        hgt += float(tokens[11])

        rxData[self.num][0][self.index] = secs
        rxData[self.num][1][self.index] = hgt
        self.index += 1
        if(self.index >= NumPoints):
          self.index = 0

    def try_to_reconnect(self, timeout_sec=60):
        # Try to reconnect sockets in case of a network error or
        # a receiver reboot, but give up after timeout_sec.
        t1 = time.time()
        while True:
            print(time.strftime(" Retry at %H:%M:%S", time.gmtime()))
            try:
                self.connect_to_remote()
                print(" Reconnected!")
                return 0
            except serial.SerialException:
                time.sleep(1)
                if time.time() - t1 > timeout_sec:
                    print(" Giving up on connection...")
                    return -1

    def do_loop(self):
      while(True):
        received = bytearray(self.s.read_until())
        if len(received) == 0:
          print("Problem with read connection...")
          return self.try_to_reconnect()

        try:
          self.getGGA(received.rstrip())
        except:
          pass

    def close(self):
        self.s.close()

def nmeaThread(thisRX,thisNum):
  nmea = nmeaHandle(thisRX['IPAddr'],thisRX['port'],thisNum)
  nmea.do_loop()
  return

if __name__ == "__main__":

  if len(sys.argv) != 2:
    # Default receivers
    DUT = [{'IPAddr':'10.1.150.73','port':5017,'RXStr':'Trimble'},
           {'IPAddr':'10.1.149.146','port':5000,'RXStr':'Septentrio'}]
  else:
     with open(sys.argv[1],'r') as f:
       DUT = json.load(f)

  # Set the storage array
  for i in range(len(DUT)): # Device ID
    rxData.append([])
    for j in range(2): # Time or Hgt
      rxData[i].append([])
      for k in range(NumPoints): # The data
        rxData[i][j].append(np.nan)

  threadHandle = []

  for index in range(len(DUT)):
    t = threading.Thread(target=nmeaThread, args=(DUT[index],index,))
    threadHandle.append(t)
    t.start()
    # 50ms sleep - gives the task time to start before we start the next
    # one. This simply reduces the chance of the debug getting mangled
    # on stdout.
    time.sleep(0.05) 

  time.sleep(0.5)

  # Setup the figure
  plt.ion()
  fig,ax = plt.subplots()

  # Setup the plot for each device
  figData = []
  for i in range(len(DUT)):
    figData.append(0)
    figData[i], = plt.plot(rxData[i][0],rxData[i][1],'.',label=DUT[i]['RXStr'])

  ax.set_autoscaley_on(True)
  ax.grid()
  fig.legend()
  plt.xlabel('Time in day [Sec]')
  plt.ylabel('Hgt [m]')

  while(True):
    time.sleep(0.05)

    # Update the data
    for i in range(len(DUT)):
      figData[i].set_xdata(rxData[i][0])
      figData[i].set_ydata(rxData[i][1])

    ax.relim()
    ax.autoscale_view()
    fig.canvas.draw()
    fig.canvas.flush_events()

  # Wait for any of the threads to exit
  for i in range(len(threadHandle)):
    threadHandle[i].join()
 
