#!/usr/bin/env python
#
# Take T0x data with DO_EVEREST_CALIBRATION #defined and
# calculate Everest calibration
#
# Upkeep:
#  - analyze_list[] - what satellite signals/types to process?
#  - prep_data() - add correlation function groups
#  - do we need to exclude Beidou GEO data?

from mutils import *
from mutils.GnssConst import *
from mutils.RTConst import *
from mutils.RcvrConst import *
import time, os, sys, collections
import pandas as pd
import argparse
from glob import glob

parser = argparse.ArgumentParser(description="""\
Take data collected from a unit with DO_EVEREST_CALIBRATION #defined
and calculate the Everest calibration table values.
""")

parser.add_argument('filename', help='T04 filename')
parser.add_argument('-u','--use_unhealthy', help='Allow unhealthy data', action="store_true")
parser.add_argument('-e','--elev', help='Elevation mask [deg]', default=15.)
parser.add_argument('-s','--systems', help='List of systems to analyze.  Same as viewdat, e.g. GPS+GLN = g,r', default="g,r,e,q,b,i")
parser.add_argument('-p','--plot', help='Show diagnostic plots', action="store_true")
parser.add_argument('-i','--individual', help='Show diagnostics for each individual satellite', action="store_true")
parser.add_argument('-m','--min_sv_len', help='Min # of points per SV', default=1000, type=int)
parser.add_argument('-v','--verify', help="""\
Add receiver MPhat to MPx.  All scale factors should be very large and "verify w/Everest" MPx values should be close to "smoothed fit" values.""", action="store_true")
args = parser.parse_args()

filename = args.filename
bias_elev_limit = 50.  # cutoff [deg] for high elevation diagnostic
mp_elev_limit = args.elev
remove_unhealthy = not args.use_unhealthy
do_plot = args.plot
show_individual_svs = args.individual
min_sv_len = args.min_sv_len  # minimum # of points to consider for a single SV
do_sys = {}
for c in args.systems.split(','):
    do_sys[c] = 1
print(args)

# To determine RF plan look at rec 35:132 in T04 file:
#   Pol_LGB1 => "RF plan: Polaris Release - 5MHz shifted"
#   Pol_LG1 => "RF plan: Polaris Release - standard"
#   Prog50 => "RF plan: standard" or "RF plan: 5MHz shifted"
rf_plan_str = get_first_str( "viewdat -d35:132 "+glob(filename)[0], " RF plan: " )
print(rf_plan_str)
if rf_plan_str == b' RF plan: Polaris Release - 5MHz shifted':
    print("******** Polaris LGB1")
elif rf_plan_str == b' RF plan: Polaris Release - standard':
    print("******** Polaris LG1")
elif rf_plan_str == b' RF plan: standard':
    print("******** Progeny 50MHz standard")
elif rf_plan_str == b' RF plan: 5MHz shifted':
    print("******** Progeny 50MHz shifted")
else:
    print("******** Unknown RF plan")

analyze_list = []
if 'g' in do_sys:
    analyze_list.extend([
        (RT_SatType_GPS, SUBTYPE_L1CA),
        (RT_SatType_GPS, SUBTYPE_L2C),
        (RT_SatType_GPS, SUBTYPE_L5),
    ])
if 'r' in do_sys:
    analyze_list.extend([
        (RT_SatType_GLONASS, SUBTYPE_G1C),
        (RT_SatType_GLONASS, SUBTYPE_G1P),
        (RT_SatType_GLONASS, SUBTYPE_G2C),
        (RT_SatType_GLONASS, SUBTYPE_G2P),
        #(RT_SatType_GLONASS, SUBTYPE_G3),
    ])
if 'e' in do_sys:
    analyze_list.extend([
        (RT_SatType_GALILEO, SUBTYPE_E1),
        (RT_SatType_GALILEO, SUBTYPE_E5A),
        (RT_SatType_GALILEO, SUBTYPE_E5B),
    ])
if 'q' in do_sys:
    analyze_list.extend([
        (RT_SatType_QZSS, SUBTYPE_L1CA),
        (RT_SatType_QZSS, SUBTYPE_L1SAIF),
        (RT_SatType_QZSS, SUBTYPE_L1C),
        (RT_SatType_QZSS, SUBTYPE_L2C),
        (RT_SatType_QZSS, SUBTYPE_L5),
    ])
if 'b' in do_sys:
    analyze_list.extend([
        (RT_SatType_BEIDOU_B1GEOPhs, SUBTYPE_B1),
        (RT_SatType_BEIDOU_B1GEOPhs, SUBTYPE_B1C),
        (RT_SatType_BEIDOU_B1GEOPhs, SUBTYPE_B2),
        (RT_SatType_BEIDOU_B1GEOPhs, SUBTYPE_B2A),
        (RT_SatType_BEIDOU_B1GEOPhs, SUBTYPE_B2B),
        (RT_SatType_BEIDOU_B1GEOPhs, SUBTYPE_B3),
    ])
if 'i' in do_sys:
    # NOTE - can't get MPx for IRNSS without dual frequency :-(
    analyze_list.extend([
        (RT_SatType_IRNSS, SUBTYPE_L5CA),
        (RT_SatType_IRNSS, SUBTYPE_S1CA),
    ])
if len(analyze_list) == 0:
    print("Must choose some satellites to analyze")
    sys.exit(1)

AllMp = collections.namedtuple('AllMp',['num','std_MPr','std_ss','rt_mean_mphat'])
AllCoeff = collections.namedtuple('AllCoeff',['num','scale','group','bias','high_num','high_bias'])
t1=time.time()
if not 'd' in globals():
    print('loading diags')
    d=doload_sed( filename, 'MMEVDIAG' )
    k=dotdict({})
    k.ms=0
    k.sv=1
    k.sat_type=2
    #k.ch=3
    k.subtype=4
    k.mp_enabled=5
    k.cno=6
    k.raw_mp=7
    k.mp_hat=8 # filtered raw_mp - has bias removed, so don't use for bias calcs
    #k.mp_filt=9
    #k.PR=10
    #k.el=11

    d1,k1=vd2arr(filename,rec='-d35:19 -z --dec=200')
    if remove_unhealthy:
        i1 = find( (d1[:,k1.SV_FLAGS].astype(int) & k1.dSV_FLAGS_UNHEALTHY) != k1.dSV_FLAGS_UNHEALTHY )
        d1 = d1[i1,:]
    print('load time %.2f' % (time.time()-t1))


sv_list = get_sv_list( d1,k1 )

if do_plot:
    close('all')

######################################################################
# Given a satellite, return MPx and diagnostic data
######################################################################
def prep_data( sv, rt_sat_type, subtype ):
    group = 0
    rcvr_scale = .4/4096.
    sat_type = -1
    band1 = -1
    band2 = -1
    track1 = -1
    track2 = -1

    if (rt_sat_type == RT_SatType_GPS or rt_sat_type == RT_SatType_QZSS) and subtype == SUBTYPE_L1CA:
        if rt_sat_type == RT_SatType_GPS:
            sat_type = SAT_TYPE_GPS
        else:
            sat_type = SAT_TYPE_QZSS
        band1, track1 = dRec27L1Band, dRec27Track_CA
        band2, track2 = dRec27L2Band, -1
        if sv in [8,22]:
            group = 0
        elif sv in [7,15,17,21,24,195]:
            group = 1
        else:
            group = 2
    else:
        for rt_sat,sat_info in get_sub_type.sub_type_dict.items():
            if rt_sat.SatType == rt_sat_type and sat_info.subtype == subtype:
                sat_type = get_sat_type(rt_sat_type).SAT_TYPE
                band1, track1 = rt_sat.band, rt_sat.track
                for info in reversed(sv_to_LxLy_info( d1, k1, sv, rt_sat_type )):
                    band2, track2 = info[0], info[1]
                    if band1!=band2:
                        i = find( (d1[:,k1.SV]==sv) & (d1[:,k1.SAT_TYPE]==rt_sat_type)
                                  & (d1[:,k1.FREQ]==band2) & (d1[:,k1.TRACK]==track2) )
                        if len(i) > min_sv_len:
                            break

                break
    if sat_type < 0:
        raise ValueError('Unknown RT sat type %d subtype %d'%(rt_sat_type,subtype))

    i=find( (d[:,k.sv]==sv) & (d[:,k.sat_type]==sat_type)\
            & (d[:,k.subtype]==subtype) \
            & (d[:,k.mp_enabled]==1) )
    i1_Lx, i1_Ly = get_LxLy_idx( d1, k1, sv, rt_sat_type, \
                                 band1, track1, band2, track2, mp_elev_limit )
    if len(i1_Lx) < min_sv_len or len(i1_Ly) < min_sv_len:
        return None
    t,i1a,ia = intersect( around(d1[i1_Lx,k1.TIME]*1000), d[i,k.ms] )
    if len(t) < min_sv_len:
        return None

    i=i[ia]
    i1_Lx=i1_Lx[i1a]
    i1_Ly=i1_Ly[i1a]
    t*=0.001
    MPr,_,_ = get_MPx( d1, k1, i1_Lx, i1_Ly, filt_Tc=-1. )
    rcvr_mp = d1[i1_Lx,k1.MPHat]
    if len(unique(rcvr_mp)) == 1:
        rcvr_mp *= nan
    if args.verify:
        MPr -= rcvr_mp

    if sat_type == SAT_TYPE_GLONASS:
        group = int(d1[i1_Lx[0],k1.FDMA])

    return ( sat_type, t, i, MPr, i1_Lx, rcvr_scale, rcvr_mp, group )

######################################################################
# Go through each satellite type to test and print results
######################################################################
for desired_rt_sat_type, desired_subtype in analyze_list:
    all_mp = []
    all_coeff = []
    need_table = False
    for sv,rt_sat_type in sv_list:
        if rt_sat_type != desired_rt_sat_type:
            continue

        result = prep_data( sv, rt_sat_type, desired_subtype )
        if result is None:
            continue
        sat_type, t, i, MPr, i1_Lx, rcvr_scale, rcvr_mp, group = result

        # compute mean from unsmoothed data
        # for scale, use least-squares on smoothed data
        raw_xi =16.*d[i,k.raw_mp]
        xi = d[i,k.mp_hat]
        computed_bias = mean(raw_xi)
        scale_xi = lstsq(reshape(xi-mean(xi),(-1,1)),MPr,rcond=None)[0][0]
        myfit_ss = zeros(xi.shape)
        myfit_ss = scale_xi*(xi - computed_bias)

        # for high-elevation, compute mean from unsmoothed data
        i2 = find( d1[i1_Lx,k1.EL] >= bias_elev_limit )
        if len(i2) > min_sv_len:
            computed_high_el_bias = mean(raw_xi[i2])
        else:
            i2 = []
            computed_high_el_bias = 0.

        fdma = None
        if sat_type == SAT_TYPE_GLONASS:
            fdma = group

        std_MPr = std(MPr)
        std_ss = std(MPr-myfit_ss)
        rt_mean_mphat = mean(rcvr_mp)
        if show_individual_svs:
            print( 'sv %d/%d%s subtype %d len %d meas_scale %.3f mean cno %.1f mean MPx %.2f:' %
                   (sv, sat_type,
                    "/%d"%fdma if fdma is not None else "",
                    desired_subtype,
                    len(i1_Lx),
                    scale_xi*4096.,
                    mean(d1[i1_Lx,k1.CNO]),
                    rt_mean_mphat
                   ) )
            if args.verify:
                print( '  Verify w/Everest MPx %.3f [m] = 100%%'% std_MPr )
            else:
                print( '  No Everest MPx       %.3f [m] = 100%%'% std_MPr )
            print( '  smooth fit           %.3f [m] = %.1f%%' %
                   (std_ss,100*std_ss/std_MPr) )
            print( '  bias %.1f' % (computed_bias))

        if std_MPr > 1.5:
            print( '  ** dropping noisy data for sv %d sata_type %d' % (sv, sat_type) )
        else:
            all_mp.append( AllMp(num=len(MPr), std_MPr=std_MPr, std_ss=std_ss,
                                 rt_mean_mphat=rt_mean_mphat) )

            all_coeff.append( AllCoeff(num=len(xi), scale=4096.*scale_xi, group=group,
                                       bias=computed_bias,
                                       high_num=len(i2), high_bias=computed_high_el_bias ) )

        if do_plot:
            figure()
            ss = MPr - myfit_ss
            plot( t, MPr, label='MPx no Ev' )
            plot( t, ss - mean(ss), label='MPx w/Ev' )
            title('sv %d/%d subtype %d no Ev $\sigma=%.3f$ w/Ev $\sigma=%.3f$'%
                  (sv,sat_type,desired_subtype,std_MPr,std_ss))
            legend()

    all_mp = pd.DataFrame(all_mp)
    all_coeff = pd.DataFrame(all_coeff)

    if len(all_coeff) == 0:
        print( 'no data for rt_sat_type=%d subtype=%d' % (desired_rt_sat_type, desired_subtype) )
    else:
        print( 'Overall results (rt_sat_type %d, subtype %d, len %d):' % (desired_rt_sat_type, desired_subtype, sum(all_mp.num)) )

        best_coeff = sum(all_coeff.num*all_coeff.scale)/sum(all_coeff.num)

        num = all_mp.num
        sum_num = sum(num)
        std_no_ev = sum(num*all_mp.std_MPr)/sum_num
        std_sm_fit = sum(num*all_mp.std_ss)/sum_num
        rt_mean_mphat = sum(num*all_mp.rt_mean_mphat)/sum_num
        print( ' Realtime mean(MPHat): %.3f' % rt_mean_mphat )
        print( ' Combined MP std.dev.:' )
        if args.verify:
            print( '  verify w/Everest  %.3f [m] = 100%%' % (std_no_ev) )
        else:
            print( '  no Everest        %.3f [m] = 100%%' % (std_no_ev) )
        print( '  smoothed fit      %.3f [m] = %.0f%%' % (std_sm_fit, 100*std_sm_fit/std_no_ev ) )

        best_mp_scale = rcvr_scale / (best_coeff/4096.)
        print( ' Stinger constants:' )
        print( '  best mp_scale = %.3f' % (best_mp_scale) )

        all_groups = unique(all_coeff.group)
        print( ' Tables:' )
        for group in all_groups:
            ig = find( all_coeff.group == group )
            high_bias = nan
            high_len = sum(all_coeff.high_num[ig])
            if high_len > 0:
                high_bias = sum(all_coeff.high_num[ig]*all_coeff.high_bias[ig])/sum(all_coeff.high_num[ig])
            all_bias = sum(all_coeff.num[ig]*all_coeff.bias[ig])/sum(all_coeff.num[ig])
            group_name = ''
            if len(all_groups) > 1:
                group_name = ' group %d' % group
            print( '  overall bias%s = %.1f (len %d)' % (group_name,all_bias,sum_num) )
            print( '  high elev bias%s = %.1f (len %d)' % (group_name,high_bias,high_len) )

print('total time %.2f' % (time.time()-t1))
if do_plot:
    show()
