###############################################################################
# Copyright 2020 - 2023 Trimble Inc.
# $Id: rt27_meas_ddiff.py,v 1.6 2024/02/11 20:58:50 wlentz Exp $
# $Source: /home/wlentz/cvs_stinger/GPSTools/NoPiLive/rt27_meas_ddiff.py,v $
###############################################################################

""" RT27 Measurement Single/Double Difference Calculation Thread
    Thread class to compute single & double differences between two sets (base
    and rover) of RT27 measurements.
"""

import threading
import time
import numpy
from mutils import get_sat_type, get_sub_type, RcvrConst
from curve_fit_collector import CurveFitCollectorThread
from rt27_collector import RT27MeasCollectorThread

###############################################################################
# Constants

_MEAS_QUALITY_IGNORE = 0
_MEAS_QUALITY_FREQ_LOCK = 1
_MEAS_QUALITY_TEST_SV = 2
_MEAS_QUALITY_REF_SV = 3

###############################################################################
# Classes and function related to finding common epochs between base and rover
#  RT27 measurement data and

class _RT27CommonCfg():
    """ Class containing the RT27 common epoch index data (the indexes into
        one of the RT27 input data arrays).
    """
    def __init__(self):
        self.last_meas_idx = -1
        self.cmn_idx = []
    def reset(self):
        """ Reset the common epochs data. """
        self.last_meas_idx = -1
        self.cmn_idx = []
    def num_epochs(self):
        """ Returns the number of common epochs found so far. """
        return len(self.cmn_idx)

class RT27CommonMeas():
    """ Class that searches for common epochs between two arrays of RT27 data,
        base and rover. Accesses the data directly from the RT27 data
        collector.
    """
    def __init__(self, base_meas_collector, rovr_meas_collector, verbose):
        """ Provide base_meas_collector & rovr_meas_collector, the two RT27
            data collectors.
        """
        self.base_cltr = base_meas_collector
        self.rovr_cltr = rovr_meas_collector
        self.base_cfg = _RT27CommonCfg()
        self.rovr_cfg = _RT27CommonCfg()
        self.verbose = verbose

    def num_epochs(self):
        """ Returns the number of common epochs found so far. """
        return len(self.base_cfg.cmn_idx)

    def del_all_epochs(self):
        """ Deletes all of the common epochs found so far. """
        self.base_cfg.reset()
        self.rovr_cfg.reset()

    def get_meas(self, idx):
        """ Returns the measurement time and the base & rover RT27 measurements
            for the requested common index, idx, as a tuple.
            (base_time, rovr_time, base_meas, rovr_meas)
        """
        if 0 <= idx < len(self.base_cfg.cmn_idx):
            base_idx = self.base_cfg.cmn_idx[idx]
            base_time = self.base_cltr.get_time(base_idx)
            base_meas = self.base_cltr.get_meas(base_idx)
            rovr_idx = self.rovr_cfg.cmn_idx[idx]
            rovr_time = self.rovr_cltr.get_time(rovr_idx)
            rovr_meas = self.rovr_cltr.get_meas(rovr_idx)
            return base_time, rovr_time, base_meas, rovr_meas
        return None, [], []

    def common_epochs(self):
        """ The updater function. Call to update the list of common epochs when
            new base or rover RT27 measurements data is available, or just call
            periodically.
            Returns the number of new common epochs found.
        """
        len_base = self.base_cltr.num_epochs()
        len_rovr = self.rovr_cltr.num_epochs()
        tst_base = self.base_cfg.last_meas_idx + 1
        tst_rovr = self.rovr_cfg.last_meas_idx + 1

        new_epochs = 0
        while tst_base < len_base and tst_rovr < len_rovr:
            # time_base/rovr[0] = week #
            # time_base/rovr[1] = ToW [s]
            # time_base/rovr[2] = Bias [ms]
            time_base = self.base_cltr.get_time(tst_base)
            time_rovr = self.rovr_cltr.get_time(tst_rovr)
            if time_base is None or time_rovr is None:
                pass
            elif time_base[0] != time_rovr[0]:
                # Mismatched week #, don't compare ToW
                pass
            elif time_base[1] == time_rovr[1]:
                if self.verbose:
                    print("Common epoch "+str(time_base[1])+" "+str(tst_base)+" "+str(tst_rovr))
                self.base_cfg.cmn_idx.append(tst_base)
                self.rovr_cfg.cmn_idx.append(tst_rovr)
                self.base_cfg.last_meas_idx = tst_base
                self.rovr_cfg.last_meas_idx = tst_rovr
                tst_base += 1
                tst_rovr += 1
                new_epochs += 1
            elif time_base[1] > time_rovr[1]:
                tst_rovr += 1
            else:
                tst_base += 1
        return new_epochs


###############################################################################
# Classes and function related to finding SVs, of requested signals types,
# common to base and rover for this common epoch.

def test_meas_type(meas, req_sat_type, req_sub_type):
    """ Test is the RT27 measurement matches the requested sat_type and
        sub_type. sat_type and sub_type use the Stinger enumeration.
        Returns True or False if the measurement matches.
        In the future:
        Also checks the signal health, lock-point, ephemeris availability etc.
        to ensure that the signal is valid for plotting.
    """
    # Convert the RT27 measurement signals types to Stinger types using
    # functions from mutils.
    tst_sat_type = get_sat_type(meas['satType']).SAT_TYPE
    tst_sub_type = get_sub_type(meas['satType'], meas['freq'], meas['track']).SUBTYPE
    if tst_sat_type != req_sat_type or tst_sub_type != req_sub_type:
        return False

    # Check health
    # Check eph.
    # Check lock-point

    return True

def calc_meas_quality(meas):
    """ Calculate the quality of the given measurement.
        _IGNORE - don't use this in the differencing calculation
        _TEST_SV - OK but only as a test SV, not reference SV
        _REF_SV - Good measurements, can be reference SV for the differencing
                  calculation
    """
    if meas['svID'] == 0:
        return _MEAS_QUALITY_IGNORE
    # Lock-point not known
    if meas['measFlags'][0] & 0x20 == 0x00:
        return _MEAS_QUALITY_IGNORE
    # SV health
    if meas['svFlags'][0] & 0x10:
        return _MEAS_QUALITY_TEST_SV
    return _MEAS_QUALITY_REF_SV

class CommonSVConfig():
    """ Class holding the common SV data for base or rover. """
    def __init__(self, meas, quality, corr=0.0, cfit=None):
        """ """
        self.meas = meas
        self.quality = quality
        self.corr = corr
        self.cfit = cfit

    def set_correction(self, corr):
        """ Set the measurement baseline correction. """
        self.corr = corr

    def set_cfit(self, cfit):
        """ Set the curve-fit parameters for this measurement. """
        self.cfit = cfit

class CommonSVsConfig():
    """ Class holding the common SV data for both base and rover. """
    def __init__(self):
        self.base = []
        self.rovr = []

    def append(self, base_meas, rovr_meas):
        """ Append new measurement data. Assume that the baseline correction
            and curve-fit is added later.
        """
        quality = calc_meas_quality(base_meas)
        svcfg = CommonSVConfig(base_meas, quality)
        self.base.append(svcfg)

        quality = calc_meas_quality(rovr_meas)
        svcfg = CommonSVConfig(rovr_meas, quality)
        self.rovr.append(svcfg)

    def num_svs(self):
        """ Return the number of common SV measurements. """
        return len(self.base)

def find_common_svs(base_meas, rovr_meas, combo):
    """ Find the RT27 measurements of the requested combo that are common to
        the base and rover measurements for this epoch.
        Returns the common RT27 measurements as a tuple.
        (base_meas_common, rover_meas_common)
        The SV ordering of base_meas_common & rover_meas_common matches.
    """
    sat_type = combo[0]
    sub_type = combo[1]

    cmn_svs = CommonSVsConfig()
    base_idx = []
    base_svs = []
    for idx, meas in enumerate(base_meas):
        if test_meas_type(meas, sat_type, sub_type):
            base_idx.append(idx)
            base_svs.append(meas['svID'])
    for meas in rovr_meas:
        if test_meas_type(meas, sat_type, sub_type):
            rovr_sv = meas['svID']
            for idx, bsv in enumerate(base_svs):
                if bsv == rovr_sv:
                    tmp = base_idx[idx]
                    cmn_svs.append(base_meas[tmp], meas)
                    break
    #print('Common SVs '+str(cmn_svs.num_svs()))
    return cmn_svs


###############################################################################

def find_reference_sv(cmn_sv_cfg, meas_time, cfit_data, verbose=False):
    """ Finds the reference SV, the SV used as the common SV for the double
        difference calculations from the provided common SV configuration data.
        Currently just picks the highest elevation.
        Returns the index into the RT27 measurements for the selected SV.
        Assumes that the RT27 measurement data has already been checked for
        good data.
    """
    best_elev = -1
    best_idx = None
    best_svID = 0
    for idx in range(cmn_sv_cfg.num_svs()):
        quality = cmn_sv_cfg.base[idx].quality
        if quality != _MEAS_QUALITY_REF_SV:
            continue

        meas = cmn_sv_cfg.base[idx].meas
        if cfit_data is not None:
            cfit = cfit_data.best_cfit(meas_time[0], # week #
                                       meas_time[1] / 1000.0, # ToW [s]
                                       meas['svID'],
                                       get_sat_type(meas['satType']).SAT_TYPE)
            cmn_sv_cfg.base[idx].set_cfit(cfit)
            if cfit is None:
                continue

        elev = meas['elev']
        if elev > best_elev:
            best_elev = elev
            best_idx = idx
            best_svID = meas['svID']

    if verbose:
        print('Reference SV '+str(best_svID)+' elev. '+str(best_elev))
    return best_idx


###############################################################################
# Classes and function related to carrier ambiguity resolution.

class CarrierAmbiguities():
    """ Class containing the carrier phase ambiguity resolution functions
        and any currently resolved ambiguities.
    """
    def __init__(self):
        self.use_sdiff_amb = False
        self.ref_sv = None
        self.clear_ambiguities()

    def clear_ambiguities(self):
        """ Reset all of the carrier ambiguities computed so far. Call if
            - the reference SV changes, or
            - the reference SV measurements report a cycle slip, or
            - the signal combo changes, or
            - the RT27 measurements have been cleared.
        """
        # Use multiple lists rather than list of objects to quickly search
        # for existing data for SV
        self.svid = []
        self.ref_amb = []
        self.tst_amb = []

    def append(self, svid, ref_amb, tst_amb):
        """ Append a new carrier ambiguity to the list. """
        self.svid.append(svid)
        self.ref_amb.append(ref_amb)
        self.tst_amb.append(tst_amb)

    def clear_sv_ambiguities(self, svid):
        """ Reset the carrier ambiguity for the given SV. Call when a cycle
            slip is reported in the RT27 measurements.
        """
        if svid in self.svid:
            idx = self.svid.index(svid)
            self.svid.remove(idx)
            self.ref_amb.remove(idx)
            self.tst_amb.remove(idx)

    def check_ref_sv(self, ref_sv):
        """ Check if the reference SV has changed to that used in the previous
            calculation. Clears, or in the future swaps, the ambiguity data for
            the new SV.
        """
        if ref_sv == self.ref_sv:
            return
        self.ref_sv = ref_sv
        self.clear_ambiguities()

    def num_ambiguities(self):
        """ Return the number of carrier ambiguities resolved. """
        return len(self.svid)


###############################################################################
# Functions related to forming the single and double difference measurements
# in the NoPi output format.

def _resolve_sdiff_ambig():
    """ Resolve the carrier ambiguity using single difference method. The
        reference and test signals have different wavelengths.
    """

def _resolve_ddiff_ambig(carr_ambig, svid, ddiff):
    """ Resolve the carrier ambiguity using double difference method. The
        reference and test signals have the same wavelength.
    """
    if ddiff > 0:
        ambig = int(ddiff + .5)
    else:
        ambig = int(ddiff - .5)
    carr_ambig.append(svid, ambig, 0.0)
    return carr_ambig.num_ambiguities()-1

def _calc_sdcarr(svid, carr_ambig, base_tst, rovr_tst):
    """ Calculate and return the single difference carrier measurement. """
    if base_tst.quality <= _MEAS_QUALITY_FREQ_LOCK \
        or rovr_tst.quality <= _MEAS_QUALITY_FREQ_LOCK:
        return None
    sdiff_tst = base_tst.meas['phase'] - rovr_tst.meas['phase']
    # Check if the ambiguity is already known
    # Once set, it should not change until a cycle slip is reported
    if svid in carr_ambig.svid:
        idx = carr_ambig.svid.index(svid)
    else:
        idx = _resolve_ddiff_ambig(carr_ambig, svid, sdiff_tst)
    # Only ref_amb is currently used even though this is from the _tst data
    return sdiff_tst - carr_ambig.ref_amb[idx]

def _calc_ddcarr(svid, carr_ambig, base_ref, rovr_ref, base_tst, rovr_tst):
    """ Calculate and return the double difference carrier measurement. """
    if base_ref.quality <= _MEAS_QUALITY_FREQ_LOCK \
        or base_tst.quality <= _MEAS_QUALITY_FREQ_LOCK \
        or rovr_ref.quality <= _MEAS_QUALITY_FREQ_LOCK \
        or rovr_tst.quality <= _MEAS_QUALITY_FREQ_LOCK:
        return None

    sdiff_ref = base_ref.meas['phase'] - rovr_ref.meas['phase']
    sdiff_tst = base_tst.meas['phase'] - rovr_tst.meas['phase']

    # Check if the ambiguity is already known
    # Once set, it should not change until a cycle slip is reported
    if svid in carr_ambig.svid:
        idx = carr_ambig.svid.index(svid)
    else:
        idx = _resolve_ddiff_ambig(carr_ambig, svid, sdiff_ref-sdiff_tst)

    return sdiff_ref - sdiff_tst - carr_ambig.ref_amb[idx] + carr_ambig.tst_amb[idx]

def _check_ddcarr(ddcarr):
    """ Check if the double (or single) difference carrier phase is valid and
        set the appropriate status bits
    """
    if ddcarr is None:
        # Status bits: dd_code, sd_cno & dd_dopp valid
        return (0.0, int('218',16))
    # Status bits: dd_carr, dd_code, sd_cno & dd_dopp valid
    return (ddcarr, int('21c',16))

def form_nopi_diffs(tow, is_ref_sv, carr_ambig, base_ref, rovr_ref,
                    base_tst, rovr_tst, cmn_clk):
    """ Calculate the single/double difference for the given measurements.
        Return the diffs. in the NoPi output format.
        If cmn_clk (common clock for Base and Rover) is True then compute
        single difference code, carrier and Doppler.
    """
    svid = base_tst.meas['svID']
    azi = base_tst.meas['azi']
    elev = base_tst.meas['elev']
    cno = base_tst.meas['CNo']
    sdcno = base_tst.meas['CNo'] - rovr_tst.meas['CNo']
    if cmn_clk:
        ddcode = base_tst.meas['pseudo'] - rovr_tst.meas['pseudo']
        ddcarr = _calc_sdcarr(svid, carr_ambig, base_tst, rovr_tst)
        dddopp = base_tst.meas['doppler'] + rovr_tst.meas['doppler']
        (ddcarr, status) = _check_ddcarr(ddcarr)
    elif is_ref_sv:
        ddcode = 0.0
        ddcarr = 0.0
        dddopp = 0.0
        # Status bits: SV was reference, sd_cno valid
        status = int('11', 16)
    else:
        ddcode = base_ref.meas['pseudo'] - rovr_ref.meas['pseudo'] - base_ref.corr \
               - base_tst.meas['pseudo'] + rovr_tst.meas['pseudo'] + rovr_ref.corr
        ddcarr = _calc_ddcarr(svid, carr_ambig,
                              base_ref, rovr_ref, base_tst, rovr_tst)
        dddopp = base_ref.meas['doppler'] - rovr_ref.meas['doppler'] \
               - base_tst.meas['doppler'] + rovr_tst.meas['doppler']
        (ddcarr, status) = _check_ddcarr(ddcarr)
    diffs = numpy.array([tow, svid, azi, elev, cno, sdcno,
                         ddcode, ddcarr, dddopp, status])
    return diffs


###############################################################################

class RT27MeasDDiffThread(threading.Thread):
    """ Thread class for computing the single/double difference measurements
        from the common epoch base and rover RT27 data.
        The cfit_data is needed for the ephemeris calculation used for
        non-zero baseline and non-clock steered resolution.
    """
    def __init__(self, base_rt27_coll, rovr_rt27_coll, cfit_data, base_xyz, rovr_xyz, cmn_clk,
                 verbose=False):
        threading.Thread.__init__(self)
        self.find_meas = RT27CommonMeas(base_rt27_coll, rovr_rt27_coll, verbose)
        self.carr_ambig = CarrierAmbiguities()
        self.cfit_data = cfit_data
        self.last_cmn = -1
        self.ref_combo = (None, None)
        self.tst_combo = (None, None)
        self.nopi_diffs = None
        self.run_thread = True
        self.zero_baseline = True
        self.common_clock = cmn_clk
        self.verbose = verbose
        for xyz in range(3):
            if abs(base_xyz[xyz] - rovr_xyz[xyz]) > 1e-3:
                self.zero_baseline = False
                break

    def quit_thread(self):
        """ Call to stop the thread. Used for clean shut-down. """
        self.run_thread = False

    def get_nopi_diffs(self):
        """ Return a numpy array containing all of the single/double difference
            measurements computed so far in the NoPi output format.
        """
        return self.nopi_diffs

    def del_all_data(self):
        """ Delete all of the single/double difference measurements and all of
            the supporting data used by this class.
        """
        self.find_meas.del_all_epochs()
        self.carr_ambig.clear_ambiguities()
        self.last_cmn = -1
        self.nopi_diffs = None

    def set_combo(self, sat_type, sub_type, is_ref):
        """ Set the signal combination: sat_type and sub_type using the Stinger
            enumeration. See, for example, mutils/RcvrConst.py. Need to specify
            both the reference signal combination, is_ref is True, and test
            signal combination, is_ref is False, to fully configure the combo.
        """
        if is_ref:
            self.ref_combo = (sat_type, sub_type)
        else:
            self.tst_combo = (sat_type, sub_type)
        self.last_cmn = -1
        self.nopi_diffs = None

    def calc_nopi_diffs(self, tow, cmn_svs_ref, cmn_svs_tst, ref_sv_idx):
        """ Calculate the single/double difference measurements in the NoPi
            output format.
            Appends the new diffs. to existing data.
        """
        base_ref = cmn_svs_ref.base[ref_sv_idx]
        rovr_ref = cmn_svs_ref.rovr[ref_sv_idx]
        ref_svid = base_ref.meas['svID']
        self.carr_ambig.check_ref_sv(ref_svid)
        if self.verbose:
            print('Calc. diffs for ToW '+str(tow)+' RefSV '+str(ref_svid), end='')

        new_diffs = []
        for idx in range(cmn_svs_tst.num_svs()):
            base_tst = cmn_svs_tst.base[idx]
            rovr_tst = cmn_svs_tst.rovr[idx]
            is_ref = ref_sv_idx == idx
            diffs = form_nopi_diffs(tow, is_ref, self.carr_ambig,
                                    base_ref, rovr_ref, base_tst, rovr_tst,
                                    self.common_clock)
            if idx == 0:
                new_diffs = numpy.array([diffs])
            else:
                new_diffs = numpy.append(new_diffs, [diffs], axis=0)
        if self.nopi_diffs is None:
            self.nopi_diffs = new_diffs
        else:
            self.nopi_diffs = numpy.append(self.nopi_diffs, new_diffs, axis=0)
        if self.verbose:
            print(' NewDiffs '+str(len(new_diffs))+' TotDiffs '+str(len(self.nopi_diffs)))

    def process_epochs(self, len_cmn):
        """ Process the next number of common epochs. """
        while (self.last_cmn+1) < len_cmn:
            self.last_cmn += 1
            base_time, rovr_time, base_meas, rovr_meas = self.find_meas.get_meas(self.last_cmn)

            # Even steered data can have a few usec residual bias
            have_clk_bias = abs(base_time[2]) > 2e-6 or abs(rovr_time[2]) > 2e-6
            cfit_data = self.cfit_data
            if self.zero_baseline and (self.common_clock or not have_clk_bias):
                cfit_data = None

            cmn_svs_ref = find_common_svs(base_meas, rovr_meas, self.ref_combo)
            ref_sv_idx = find_reference_sv(cmn_svs_ref, base_time, cfit_data)
            if ref_sv_idx is None:
                continue

            if self.ref_combo == self.tst_combo:
                cmn_svs_tst = cmn_svs_ref
            else:
                cmn_svs_tst = find_common_svs(base_meas, rovr_meas, self.tst_combo)
            self.calc_nopi_diffs(base_time[1], cmn_svs_ref, cmn_svs_tst, ref_sv_idx)

    def run(self):
        """ Call to run the thread. The thread will run until quit_thread() is
            called.
        """
        while self.run_thread:
            _ = self.find_meas.common_epochs()
            # Check for unprocessed data
            # Don't check common_epochs return as this won't notice when the
            # processing combo changes
            len_cmn = self.find_meas.num_epochs()
            if (self.last_cmn+1) < len_cmn:
                self.process_epochs(len_cmn)
            else:
                # Wait for new RT27 data
                time.sleep(0.1)
        #print('Generate diffs thread shut-down cleanly')


###############################################################################

def main():
    """ Run one instance of the RT27 measurement difference calculator. """
    base = RT27MeasCollectorThread(name='Base', ip_addr='10.1.150.75', port_num=5017)
    rovr = RT27MeasCollectorThread(name='Rover', ip_addr='10.1.150.92', port_num=5017)

    cfit = CurveFitCollectorThread(ip_addr='10.1.150.75')

    ant_xyz = [-2689308.106, -4302881.090, 3851417.715] # RS-3
    diffs = RT27MeasDDiffThread(base, rovr, cfit, ant_xyz, ant_xyz, False)
    diffs.set_combo(RcvrConst.SAT_TYPE_GPS, RcvrConst.SUBTYPE_L1CA, True)
    diffs.set_combo(RcvrConst.SAT_TYPE_GPS, RcvrConst.SUBTYPE_L1CA, False)

    base.start()
    rovr.start()
    cfit.start()
    diffs.start()

    cntr = 30
    while cntr:
        cntr -= 1
        time.sleep(1)

    base.quit_thread()
    rovr.quit_thread()
    cfit.quit_thread()
    diffs.quit_thread()

if __name__ == "__main__":
    main()
