###############################################################################
# Copyright (c) 2010-2025 Trimble Inc.
# $Id: NoPiDI_Process.py,v 1.22 2024/06/18 21:35:49 acartmel Exp $
###############################################################################
#
# NoPiDI_Process.py
#
# This file contains all the functions uses to compute the statistics on the
# NoPi output
# The plots are generated from these functions
# The classes used to store the statistics data are also defined here
#
###############################################################################

from NoPiDI_Plot import *
from NoPiDI_Utils import add_utils_dir_to_path
add_utils_dir_to_path()

from NoPiUT_Common_Meas import *
from NoPiUT_Const import *
from NoPiDI_Utils import find


###############################################################################
# Two functions used to enable or disable generating plots during debugging
###############################################################################

def enable_plots() :
  global draw_plots
  draw_plots = True


def disable_plots() :
  global draw_plots
  draw_plots = False


###############################################################################
# Class definitions to contain the measurement statistics
###############################################################################

class cl_stats :
  def __init__( self ) :
    self.sv_id          = 0
    self.tot_epochs     = 0
    self.ref_epochs     = 0
    self.vld_epochs     = 0
    self.vld_min        = 0.0
    self.vld_max        = 0.0
    self.vld_mean       = 0.0
    self.vld_std        = 0.0
    self.vld_mav        = 0.0
    self.all_cyc_slips  = 0
    self.elv_cyc_slips  = 0
    self.elv_epochs     = 0

class cl_acq_stats :
  def __init__( self ) :
    self.sv_id             = 0
    self.acq_epochs        = 0
    self.vld_epochs        = 0
    self.mean_res          = 0.0
    self.mean_mod200       = 0.0
    self.data_mean         = 0.0
    self.std_res           = 0.0
    self.std_mod200        = 0.0
    self.data_std          = 0.0
    self.res_abv_thresh    = 0.0
    self.mod200_abv_thresh = 0.0
    self.res_epoch         = []
    self.mod200_epoch      = []
    self.res_mod200        = []
    self.residual          = []

class cl_meas :
  def __init__( self ) :
    self.num_svs = 0
    self.sv = []
    self.sv.append( cl_stats() )

class cl_acq :
  def __init__( self ) :
    self.num_svs = 0
    self.sv = []
    self.sv.append( cl_acq_stats() )


###############################################################################
# Load a NoPi results file for a single signal combo
# Loads into global variables
###############################################################################
def read_combo_file( diffs_summ, fname ) :

  global din
  global flags
  din = []
  flags = []

  din, flags = load_combo_file( diffs_summ, fname )


###############################################################################
# Get indexes into din[] for:
# idx_ref - when the SV was reference for double difference measurements
# idx_vld - when the data is valid for the given measurement type
###############################################################################
def get_data_idx( flags, flags_mask, meas_type ) :
  if ( meas_type == cNoPiConst.MEAS_TYPE_SD_CNO ) :
    idx_ref = []
  else :
    idx_ref = find( flags & cNoPiConst.DIFFS_FLG_SV_REF )

  idx_vld = find( flags & flags_mask )

  return( idx_ref, idx_vld )


###############################################################################
# Process NoPi results for single signal combo
# Compute the statistics and generates the plots for each measurement type and
# each satellite
###############################################################################
def process_meas( diffs_summ,
                  combo,
                  meas_type,
                  plot_raw_cno_elev,
                  cslip_stats_elev_thr
                ) :

  kwargs = {}

  data_col   = get_data_column( diffs_summ, meas_type )
  data_scale = get_data_scale( meas_type )
  din[:, data_col] = din[:, data_col] * data_scale;

  flags_mask = get_flags_mask( meas_type )
  idx_ref, idx_vld = get_data_idx( flags, flags_mask, meas_type )

  tmp_meas = cl_meas()
  tmp_meas.sv[0] = compute_stats( diffs_summ,
                                  din,
                                  idx_ref,
                                  idx_vld,
                                  data_col,
                                  cslip_stats_elev_thr
                                )

  if ( draw_plots ) :
    dist_diff_data( din,
                    diffs_summ,
                    tmp_meas.sv[0],
                    idx_vld,
                    combo,
                    meas_type )

  # Create figure used to show all satellites at once
  fig_all    = create_figure( True )
  fig_all_el = create_figure( True )

  # Include the Base/Rover CNo vs elevation plots?
  if ( plot_raw_cno_elev ) :
    fig_base_el = create_figure( True )
    fig_rovr_el = create_figure( True )
    kwargs.update( { 'fig_base_cno_el':fig_base_el } )
    kwargs.update( { 'fig_rovr_cno_el':fig_rovr_el } )

  # Reference SV plot for this combo
  # Same for all D.D measurements so only generate it once
  plot_ref_svs = False
  if ( meas_type == cNoPiConst.MEAS_TYPE_DD_CARR ) :
    fig_ref_svs = create_figure( True )
    plot_ref_svs = True
    kwargs.update( { 'fig_ref_svs':fig_ref_svs } )

  sv_ind = 1
  sv_list = unique( din[:, 1] )
  for sv in sv_list :
    # Find epochs for this satellite
    idx_sv = find( din[:, 1] == sv )
    if ( len( idx_sv ) > 0 ) :
      idx_ref, idx_vld = get_data_idx( flags[ idx_sv ], flags_mask, meas_type )

      # Compute the data statistics for this satellite
      tmp_meas.sv.append( cl_meas() )
      tmp_meas.sv[ sv_ind ] = compute_stats( diffs_summ,
                                             din[idx_sv, :],
                                             idx_ref,
                                             idx_vld,
                                             data_col,
                                             cslip_stats_elev_thr
                                           )
      tmp_meas.sv[ sv_ind ].sv_id = sv

      # Plot data for this satellite
      if ( draw_plots ) :
        plot_diff_data( din[idx_sv, :],
                        diffs_summ,
                        tmp_meas.sv[ sv_ind ],
                        idx_ref,
                        idx_vld,
                        combo,
                        meas_type,
                        fig_all,
                        fig_all_el,
                        plot_raw_cno_elev,
                        plot_ref_svs,
                        **kwargs
                      )

      # Increment the index into the SV statistics array
      sv_ind += 1

  if ( draw_plots ) :
    if ( meas_type == cNoPiConst.MEAS_TYPE_SD_CNO ) :
      plot_time_avg_sdiff_cno( diffs_summ, fig_all )

    save_and_close_figure( fig_all,
                           diffs_summ,
                           tmp_meas.sv[0],
                           combo,
                           meas_type,
                           'TIME_PLOT'
                         )

    save_and_close_figure( fig_all_el,
                           diffs_summ,
                           tmp_meas.sv[0],
                           combo,
                           meas_type,
                           'ELEV_PLOT'
                         )

    if ( plot_ref_svs ) :
      save_and_close_figure( fig_ref_svs,
                             diffs_summ,
                             tmp_meas.sv[0],
                             combo,
                             meas_type,
                             'REF_SVS_PLOT'
                           )

    if ( plot_raw_cno_elev ) :
      # Update CNo versus elevation plots with average CNo versus elevation
      plot_avg_cno_elev( fig_base_el )
      save_and_close_figure( fig_base_el,
                             diffs_summ,
                             tmp_meas.sv[0],
                             combo,
                             meas_type,
                             'BASE_CNO'
                           )
      plot_avg_cno_elev( fig_rovr_el )
      save_and_close_figure( fig_rovr_el,
                             diffs_summ,
                             tmp_meas.sv[0],
                             combo,
                             meas_type,
                             'ROVR_CNO'
                           )

  tmp_meas.num_svs = sv_ind

  return( tmp_meas )


###############################################################################
# Process NoPi results for acquisition analysis
# Compute the acquisition statistics and generates the plots for each
# measurement type and each satellite
###############################################################################
def do_acq_analysis( diffs_summ, combo, meas_type ) :

  data_col   = get_data_column( diffs_summ, meas_type )
  data_scale = get_data_scale( meas_type )

  tmp_acq = cl_acq()

  resolve_sdiff = diffs_summ.combo[ combo ].resolve_sdiff

  tmp_acq.sv[0] = acq_analysis( diffs_summ,
                                din,
                                flags,
                                meas_type,
                                resolve_sdiff
                              )

  if ( draw_plots ) :
    dist_acq_data( diffs_summ,
                   tmp_acq.sv[0],
                   combo,
                   meas_type
                 )

  fig_all_acq = create_figure( False )
  sv_ind = 1
  sv_list = unique( din[:, 1] )

  for sv in sv_list :
    # Find epochs for this satellite
    idx_sv = find( din[:, 1] == sv )
    if ( len( idx_sv ) > 0 ) :
      tmp_acq.sv.append( cl_acq() )
      tmp_acq.sv[ sv_ind ] = acq_analysis( diffs_summ,
                                           din[idx_sv, :],
                                           flags[ idx_sv ],
                                           meas_type,
                                           resolve_sdiff
                                         )
      tmp_acq.sv[ sv_ind ].sv_id = sv
      if ( draw_plots ) :
        dist_acq_data( diffs_summ,
                       tmp_acq.sv[ sv_ind ],
                       combo,
                       meas_type
                     )
        plot_acq_data( diffs_summ,
                       tmp_acq.sv[ sv_ind ],
                       fig_all_acq,
                       combo,
                       meas_type
                     )

      # Increment the index into the SV statistics array
      sv_ind += 1

  if ( draw_plots ) :
    figure( fig_all_acq.number )

    # Inflate ymin and ymax by 15% to avoid matplotlib collapsing the plotting
    # axis to fit tight.
    [ xmin, xmax, ymin, ymax ] = axis( )
    xmin = 0
    xmax = divide( cNoPiConst.ACQ_ANALYSIS_DUR, diffs_summ.drate_ms ) + 1
    if ( ymin < 0 ) :
      ymin *= 1.15
    else :
      ymin *= 0.85
    if ( ymax < 0 ) :
      ymax *= 0.85
    else :
      ymax *= 1.15
    axis( [ xmin, xmax, ymin, ymax ] )

    save_and_close_figure( fig_all_acq,
                           diffs_summ,
                           tmp_acq.sv[0],
                           combo,
                           meas_type,
                           'TIME_ACQ_PLOT'
                         )

  tmp_acq.num_svs = sv_ind

  return( tmp_acq )


###############################################################################
# Compute the statistcs for a single measurement
###############################################################################
def compute_stats( diffs_summ, din_sv, idx_ref, idx_vld, data_col, slp_elv ) :

  tmp_stats = cl_stats()
  tmp_stats.tot_epochs = len( din_sv )
  tmp_stats.ref_epochs = len( idx_ref )
  tmp_stats.vld_epochs = len( idx_vld )

  if ( len( idx_vld ) > 0 ) :
    tmp_stats.vld_min  = nanmin( din_sv[ idx_vld, data_col ] )
    tmp_stats.vld_max  = nanmax( din_sv[ idx_vld, data_col ] )
    tmp_stats.vld_mean = mean( din_sv[ idx_vld, data_col ] )
    tmp_stats.vld_mav  = mean( fabs( din_sv[ idx_vld, data_col ] ) )
  else :
    tmp_stats.vld_min  = 0
    tmp_stats.vld_max  = 0
    tmp_stats.vld_mean = 0
    tmp_stats.vld_mav  = 0

  if ( len( idx_vld ) > 1 ) :
    tmp_stats.vld_std = std( din_sv[ idx_vld, data_col ] )
  else :
    tmp_stats.vld_std = 0

  ecol = get_elev_data_column( 'Base' )
  fcol = get_flags_column( diffs_summ )
  pcol = get_data_column( diffs_summ, cNoPiConst.MEAS_TYPE_DD_CARR )
  if ( data_col == pcol and len(idx_vld) ):
    slips = find( din[idx_vld, fcol].astype(int) & cNoPiConst.DIFFS_FLG_CYC_SLIP )
    tmp_stats.all_cyc_slips = len(slips)
    tmp_stats.elv_cyc_slips = 0
    if len(slips) :
      slips = find( din[idx_vld[slips], ecol] >= slp_elv )
      tmp_stats.elv_cyc_slips = len(slips)
    above = find( din[idx_vld, ecol] >= slp_elv )
    tmp_stats.elv_epochs = len(above)

  return( tmp_stats )


###############################################################################
# Obtain the residuals for the 1 second after reacquiring a satellite
# Store the double difference residuals separately for epochs which are modulo
# 200ms
###############################################################################
def acq_analysis( diffs_summ,
                  din_sv,
                  flags_sv,
                  meas_type,
                  resolve_sdiff ) :

  drate_ms     = diffs_summ.drate_ms
  flags_msk    = get_flags_mask( meas_type )
  data_col     = get_data_column( diffs_summ, meas_type )
  tmp_acq_data = cl_acq_stats()

  sv_list = unique( din_sv[:, 1] )

  idx_vld = find( flags_sv & flags_msk )
  tmp_acq_data.vld_epochs = len( idx_vld )

  # The for loop is necessary to identify acquisition events for each SV
  # uniquely.  This is to handle the special case when do_acq_analysis() passes
  # din_sv with data for all SVs.

  for sv in sv_list :
    idx_sv = find( din_sv[:, 1] == sv )
    din_cur_sv = din_sv[ idx_sv, : ]
    idx_vld = find( flags_sv[ idx_sv ] & flags_msk )

    if ( len( idx_vld ) > 2 ) :

      # WARNING!!! This code treats the first observation timetag for each
      # satellite as an acquisition event.  Will not be always true if
      # the data was collected from a continuously operating receiver.

      # WARNING!!! The 200ms/non-200ms was probably intended to look at
      # full- vs demi-measurements, although this should only impact the
      # pseudorange values. The code separating the data assumed that the
      # first epoch was always a full (200ms) measurements.

      # Initialise such that the first epoch is considered an acqisition.
      # Given cur_epoch (below) must be >= 0, initialise previous_epoch to
      # something < -drate_ms.
      acq_epoch        = -drate_ms-1
      previous_epoch   = -drate_ms-1
      sat_acq          = False
      epochs_since_acq = 0

      for loop in range( len( idx_vld ) ) :
        cur_epoch = din_cur_sv[ idx_vld[ loop ], 0 ]
        if ( cur_epoch - previous_epoch ) > drate_ms :
          sat_acq = True
          tmp_acq_data.acq_epochs += 1
          acq_epoch = cur_epoch
          epochs_since_acq = 0

        if ( sat_acq
         and ( epochs_since_acq <= divide(cNoPiConst.ACQ_ANALYSIS_DUR, drate_ms) )
           ) :
          if ( ( cur_epoch - acq_epoch )%200.0 == 0.0 ) :
            tmp_acq_data.mod200_epoch.append( epochs_since_acq )
            tmp_acq_data.res_mod200.append( din_cur_sv[ idx_vld[ loop ],
                                                        data_col
                                                      ]
                                          )
          else :
            tmp_acq_data.res_epoch.append( epochs_since_acq )
            tmp_acq_data.residual.append( din_cur_sv[ idx_vld[ loop ],
                                                      data_col
                                                    ]
                                        )
          if ( epochs_since_acq == cNoPiConst.ACQ_ANALYSIS_DUR/drate_ms ) :
            sat_acq = False

          epochs_since_acq += 1

        previous_epoch = cur_epoch

  if ( len( idx_vld ) > 2 ) :

    # If acq_analysis is performed check to make sure there are epochs which
    # are not mod200ms from the time a satellite is acquired
    if ( len( tmp_acq_data.residual ) > 0 ) :

      tmp_acq_data.mean_res = mean( tmp_acq_data.residual )
      tmp_acq_data.std_res = std( tmp_acq_data.residual )
      tmp_acq_data.res_abv_thresh = data_above_threshold(
                                              tmp_acq_data.residual,
                                              meas_type,
                                              resolve_sdiff
                                                        )
      tmp_acq_data.mean_mod200 = mean( tmp_acq_data.res_mod200 )
      tmp_acq_data.std_mod200 = std( tmp_acq_data.res_mod200 )
      tmp_acq_data.mod200_abv_thresh = data_above_threshold(
                                              tmp_acq_data.res_mod200,
                                              meas_type,
                                              resolve_sdiff
                                                           )
      # Compute mean and std. over the entire dataset comprising of both mod200
      # ms and non mod200 ms observations
      tmp_acq_data.data_mean = mean( tmp_acq_data.res_mod200
                                   + tmp_acq_data.res_mod200
                                   )
      tmp_acq_data.data_std  = std( tmp_acq_data.res_mod200
                                  + tmp_acq_data.res_mod200
                                  )
    # Else only compute mean, sigma for mod200 ms epochs
    else :
      tmp_acq_data.mean_mod200 = mean( tmp_acq_data.res_mod200 )
      tmp_acq_data.std_mod200 = std( tmp_acq_data.res_mod200 )
      tmp_acq_data.mod200_abv_thresh = data_above_threshold(
                                              tmp_acq_data.res_mod200,
                                              meas_type,
                                              resolve_sdiff
                                                           )
      # If only mod200 ms epochs are present, the entire dataset comprising of
      # only mod200 ms observations
      tmp_acq_data.data_mean = tmp_acq_data.mean_mod200
      tmp_acq_data.data_std  = tmp_acq_data.std_mod200

  return( tmp_acq_data )


###############################################################################
# Hardcoded thresholds for acquisition outlier analysis
###############################################################################
def data_above_threshold( acq_data, meas_type, resolve_sdiff ) :

  if ( meas_type == cNoPiConst.MEAS_TYPE_DD_CARR ) :
    if ( resolve_sdiff ) :
      threshold = 0.2
    else :
      threshold = 50.0
  elif ( meas_type == cNoPiConst.MEAS_TYPE_DD_CODE ) :
    threshold = 2.0
  elif ( meas_type == cNoPiConst.MEAS_TYPE_DD_DOPP ) :
    threshold = 1.0
  elif ( meas_type == cNoPiConst.MEAS_TYPE_SD_CNO ) :
    threshold = 0.25
  else :
    print('Unknown measurement type')
    sys.exit()

  idx_abv_thresh = find( absolute( acq_data ) > threshold )
  perc_abv_thresh = true_divide( len( idx_abv_thresh ), len( acq_data) )*100.0

  return( perc_abv_thresh )
