#!/usr/bin/env python
##############################
#
# Compares the observables (35:19) in a T04 file with the observables
# generated by a Spirent simulator (using the sat_data*.csv file the
# Spirent device generates).
#
# Assumptions:
#  - The receiver code/calibration is disabled (if not this will be part
#    of any bias observed)
#  - The user has a post 2018-10-29 viewdat that supports the 
#    '--clk_steer' flag. The script needs to either have data that has
#    clock steering enabled or be able to clock steer the data in order
#    to match the Spirent 
#
# 
# Copyright Trimble Inc 2018-2019
#
##############################

import matplotlib
# Allow running headless from the command line
matplotlib.use("agg")

from numpy import *
from pylab import *
import numpy as np
import pandas as pd
import mutils as m
import math
import sys
import argparse

def find(x):
  return np.where(x)[0]

def commandHelp():
  #print("usage processSpirentObs.py [-d] T04File SpirentObsFile")
  #print("  -d is optional and enables Doppler processing")
  parser.print_help()
  sys.exit(1)

def plotResidual(time,
                 residual,
                 antInfo,
                 satType,
                 svPrefix,
                 sv,
                 mu,sig,numEpochs,
                 plotType,
                 axisStr):

  fig=figure()
  ax=fig.add_subplot(111)
  plot(time, residual, 'r.')
  if(FixXAxis):
    xlim([xStart,xStop])
  xlabel('Time [Sec]')
  ylabel(axisStr)
  grid(True)
  title(  satType + ' ' + svPrefix + ' ' + str(sv) + ' ' 
        + r'$\sigma$ = ' + "{:.3f}".format(sig) 
        + r'm $\mu$ = '  + "{:.3f}".format(mu) 
        + 'm epochs = '  + str(numEpochs))

  tight_layout()
  # Prevent the axis numers having an offset
  ax.get_xaxis().get_major_formatter().set_useOffset(False)
  ax.get_yaxis().get_major_formatter().set_useOffset(False)
  # Save the data as a PNG file
  savefig(antInfo + satType + '-' + svPrefix + str(sv) + "-" + plotType + ".png", dpi=150)
  close()

# Convert T04 sat_type to Spirent logged data satellite type
def rt_sat_type_to_spirent(rt_sat_type):
  if rt_sat_type == 0:
    return "GPS"
  elif rt_sat_type == 2:
    return "GLONASS"
  elif rt_sat_type == 4:
    return "Quasi-Zenith"
  elif rt_sat_type == 9:
    return "IRNSS"
  elif rt_sat_type == 10:
    return "BeiDou"
  else:
    return None

#################
#
# Script Start
#
#################

if __name__ == '__main__':
  ######################################################################
  # Parse arguments
  parser = argparse.ArgumentParser(description='Parse observables from a T04 file and compares to a Spirent Truth file')
  parser.add_argument('T04file', help='T04 files - required')
  parser.add_argument('SpirentFile', help='Spirent file - required')
  parser.add_argument('-b','--start', help='start time tag in GPS week,secs e.g. -b 1234,12345.0 or just secs e.g. -b 12345.0')
  parser.add_argument('-e','--stop', help='stop time tag in GPS week,secs e.g. -e 1234,20000.0 or just secs e.g. -b 20000.0')
  parser.add_argument('-d','--doppler', help='process the doppler',action="store_true")
  parser.add_argument('-x','--xaxis', help='When set the x-axis is common for all plots',action="store_true")
  parser.add_argument('-m','--min_len', help='minimum points for summary info',default=1000,type=int)
  parser.add_argument('-a','--all_sigs', help='Analyze all signals.  Currently needs iono/tropo/Tgd off!',action="store_true")
  parser.add_argument('--antenna', help='antenna # to analyze',default=0,type=int)
  args = parser.parse_args()
  ######################################################################

  T04Filename     = args.T04file
  spirentFilename = args.SpirentFile

  if(args.doppler):
    ProcessDoppler = True
  else:
    ProcessDoppler = False

  if(args.xaxis):
    FixXAxis = True
  else:
    FixXAxis = False

  startWeek = 0
  startSecs = 0
  if(args.start):
    tmp = args.start.split(',')
    if(len(tmp) == 1):
      startSecs = float(tmp[0])
    elif(len(tmp) == 2):
      startWeek = int(tmp[0])
      startSecs = float(tmp[1])
    else:
      print("Incorrect start time")
      commandHelp()

  stopWeek = 10000
  stopSecs = 9999999999999999
  if(args.stop):
    tmp = args.stop.split(',')
    if(len(tmp) == 1):
      stopSecs = float(tmp[0])
    elif(len(tmp) == 2):
      stopWeek = int(tmp[0])
      stopSecs = float(tmp[1])
    else:
      print("Incorrect stop time")
      commandHelp()

  print('T04 File     = ' + T04Filename)
  print('Spirent File = ' + spirentFilename)

  t04_opts = ['--clk_steer=on']
  if not args.all_sigs:
    t04_opts += ['-t']
  (all_ref,k)=m.vd2arr(T04Filename,rec='-d35:19',opt=t04_opts)

  # LUT mapping the T04 data types to the Spirent type
  all_sigs = m.get_signals(all_ref,k)
  types = []
  for sat_type in all_sigs.keys():
    for freq,track in all_sigs[sat_type]:
      spirent = rt_sat_type_to_spirent(sat_type)
      if spirent is None:
        print("Skipping RT sat_type %d"%sat_type)
        continue
      desc = m.get_sub_type(sat_type,freq,track).fullstr
      sig_name = m.get_sub_type(sat_type,freq,track).sigstr.replace(' ','_')
      types.append( {"T04SVType":sat_type,
                     "T04Freq":freq,
                     "T04Track":track,
                     "Spirent":spirent,
                     "sigName":sig_name,
                     "desc":desc} )

  # Read the Spirent data skipping the first few lines which have header
  # information.
  f = open( spirentFilename ,'r')
  header_rows = 0
  while not f.readline().startswith('TOW'):
    header_rows += 1
  f.close()
  svData = pd.read_csv( spirentFilename, skiprows=header_rows )

  # Spirent breaks signals down into "groups" A-D.
  # Choose the group you are interested in here:
  groupMap = {"GPS": "Group A",
              "GLONASS": "Group A",
              "GALILEO": "Group A",
              "IRNSS": "Group C", # for L5
              "BeiDou": "Group A",
              "Quasi-Zenith": "Group A"
             }

  all_antennas = unique(all_ref[:,k.ANTENNA])
  for ant_num in all_antennas:
    if len(all_antennas) > 1:
      print("antenna %d"%ant_num)
      antInfo = 'Ant%d-'%ant_num
    else:
      antInfo = ''
    # Filter the T04 based on the requested start/stop times.
    Week    = all_ref[:,k.WN]
    Time = all_ref[:,k.TIME]
    i = find( (all_ref[:,k.ANTENNA] == ant_num)
              & (Time >= startSecs) & (Week >= startWeek)
              & (Week <= stopWeek) & (Time <= stopSecs) )

    ref = all_ref[i,:]

    # Get the min/max time. If --xaxis is set we'll use this to set the x-axis range on all
    # plots.
    # ToDo - week extend
    xStart = min(ref[:,k.TIME])
    xStop  = max(ref[:,k.TIME])

    # Allocate storage for per system summary statistics
    sysSig_pseudo = {}
    sysMu_pseudo  = {}
    sysSig_doppler = {}
    sysMu_doppler = {}

    # Loop for each satellite type
    for satType in range(len(types)):
      i = find(   (ref[:,k.SAT_TYPE] == types[satType]['T04SVType'])
                & (ref[:,k.FREQ] == types[satType]['T04Freq'])
                & (ref[:,k.TRACK] == types[satType]['T04Track']))

      T04ThisSys = ref[i,:] # First get all data with the current satellite system

      # If the data is GLONASS we index by almanac number
      if(types[satType]['T04SVType'] == 2):
        svPrefix = 'ALM'
      else:
        svPrefix = 'PRN'

      # Get a unique list of satellite IDs
      svs = list(set(ref[i,k.SV].astype(int)))

      muSum_pseudo  = 0.0
      sigSum_pseudo = 0.0
      muSum_doppler = 0.0
      sigSum_doppler = 0.0
      numSum = 0

      if(len(T04ThisSys) > 0):
        m.plot_tracking(T04ThisSys, k, sat_types=[types[satType]['T04SVType']], title_txt=types[satType]['sigName'])
        tight_layout()
        # Save the data as a PNG file
        savefig(antInfo + types[satType]['sigName'] + "-Track.png", dpi=150)
        close()

      # Loop for all satellites in the data file for the current system
      for index in range(len(svs)):
        sv = svs[index]

        # Get the spirent data for the current satellite/system
        if(types[satType]['T04SVType'] == 2):
          spirentIdx = find( (svData.Sat_type == types[satType]['Spirent']) & (svData.Sat_ID == sv) )
          glnChn = svData.Sat_PRN.iloc[spirentIdx[0]]
        else:
          spirentIdx = find( (svData.Sat_type == types[satType]['Spirent']) & (svData.Sat_PRN == sv) )
          channels = unique( svData.iloc[spirentIdx].Channel )
          if len( channels ) > 1:
            spirentIdx = find( (svData.Sat_type == types[satType]['Spirent']) & (svData.Sat_PRN == sv) & (svData.Channel==channels[0]) )
            print("Spirent is simulating SV %d %d more than once"%(sv,satType))

        # Get the Spirent data for the current system / satellite
        spirentData = svData.iloc[spirentIdx]

        # Get the T04 data for the current system / satellite
        idx = find(T04ThisSys[:,k.SV] == sv)
        T04Data = T04ThisSys[idx,:] # Get the data for the current satellite

        # Find common epochs between the T04 and Spirent file
        # convert T04 timetags to milliseconds (to match the Spirent data)
        # and convert to int to avoid any truncation errors
        t, i_t04, i_spirent = m.intersect( (T04Data[:,k.TIME]*1000).astype(int), spirentData.TOW_ms )

        if(len(t) > 0):
          # We have common data, now generate a plot and statistics
          SVDataCommon = T04Data[i_t04,:]
          SpirentCommon = spirentData.iloc[i_spirent]

          # Form the difference between the measured and generated
          # psuedorange.
          group = groupMap[SpirentCommon.iloc[0].Sat_type]
          residual_pseudo  = SVDataCommon[:,k.RANGE] - SpirentCommon['P-Range '+group].values
          # Calculate the Doppler residual between generated and measured
          residual_doppler = SVDataCommon[:,k.DOPP]  - SpirentCommon['Doppler_shift '+group].values

          # Calculate and output the mean and standard deviation

          # Remove the gross errors first
          i = find( (residual_pseudo < 100) & (residual_pseudo > -100) )
          sig = np.std(residual_pseudo[i])
          mu  = np.mean(residual_pseudo[i])
          # I've observed data sets with an occassional significantly
          # incorrect point - perform a sigma test. Make this really relaxed
          # (8 x sigma) so we don't filter out bad data we care about
          i = find( (residual_pseudo < (mu + 8*sig)) & (residual_pseudo > (mu-8*sig)) )

          if(len(i) > 0):
            sig = np.std(residual_pseudo[i])
            mu  = np.mean(residual_pseudo[i])
            if(types[satType]['T04SVType'] == 2):
              print("Pseudo  %s %s %2d chn %2d sigma = %.3f m mu = %.3f m epochs = %d" % (types[satType]['desc'],svPrefix,sv,glnChn,sig,mu,len(i)))
            else:
              print("Pseudo  %s %s %2d sigma = %.3f m mu = %.3f m epochs = %d" % (types[satType]['desc'],svPrefix,sv,sig,mu,len(i)))

            if(len(i) > args.min_len):
              # If we have more than 1000 epochs generate a weighted mean of
              # the mean and sigma
              muSum_pseudo  += mu * len(i)
              sigSum_pseudo += sig * len(i)
              numSum += len(i)

              plotResidual(SVDataCommon[i,k.TIME],
                           residual_pseudo[i],
                           antInfo,
                           types[satType]['sigName'],
                           svPrefix,
                           sv, mu, sig,len(i),
                           "Pseudorange",
                           "Pseudorange Residual [m]")

          if(ProcessDoppler):
            i = find( (residual_doppler < 100) & (residual_doppler > -100) )
            sig = np.std(residual_doppler[i])
            mu  = np.mean(residual_doppler[i])
            i = find( (residual_doppler < (mu+8*sig)) & (residual_doppler > (mu-8*sig)) )
            if(len(i) > 0):
              sig = np.std(residual_doppler[i])
              mu  = np.mean(residual_doppler[i])
              if(types[satType]['T04SVType'] == 2):
                print("Doppler %s %s %2d chn %2d sigma = %.3f Hz mu = %.3f Hz epochs = %d" % (types[satType]['desc'],svPrefix,sv,glnChn,sig,mu,len(i)))
              else:
                print("Doppler %s %s %2d sigma = %.3f Hz mu = %.3f Hz epochs = %d" % (types[satType]['desc'],svPrefix,sv,sig,mu,len(i)))


              if(len(i) > 1000):
                # If we have more than 1000 epochs generate a weighted mean of
                # the mean and sigma
                muSum_doppler  += mu * len(i)
                sigSum_doppler += sig * len(i)

              plotResidual(SVDataCommon[i,k.TIME],
                           residual_doppler[i],
                           antInfo,
                           types[satType]['sigName'],
                           svPrefix,
                           sv, mu, sig,len(i),
                           "Doppler",
                           "Doppler Residual [Hz]")

      # We've processed all the data for the current system. Now generate
      # the mean for the system
      desc = types[satType]['desc']
      if(numSum > 0):
        sysMu_pseudo[desc]  = muSum_pseudo/float(numSum)
        sysSig_pseudo[desc] = sigSum_pseudo/float(numSum)
        sysMu_doppler[desc]  = muSum_doppler/float(numSum)
        sysSig_doppler[desc] = sigSum_doppler/float(numSum)
      else:
        sysMu_pseudo[desc]  = np.nan
        sysSig_pseudo[desc] = np.nan
        sysMu_doppler[desc]  = np.nan
        sysSig_doppler[desc] = np.nan

    # We've completed all processing, now output the summary statistics for
    # each satellite system
    print('\nSummary - Pseudorange')
    for desc in sysMu_pseudo.keys():
      if(np.isnan(sysMu_pseudo[desc]) == False):
        print('%20s'%desc
              + ' sig = ' + "{:.3f}".format(sysSig_pseudo[desc])
              + 'm mu = ' + "{:.3f}".format(sysMu_pseudo[desc])
              + 'm')

    if(ProcessDoppler):
      print('\nSummary - Doppler')
      for desc in sysMu_pseudo.keys():
        if(np.isnan(sysMu_doppler[desc]) == False):
          print('%20s'%desc
                + ' sig = ' + "{:.3f}".format(sysSig_doppler[desc])
                + 'Hz mu = ' + "{:.3f}".format(sysMu_doppler[desc])
                + 'Hz')
