#!/usr/bin/env python
usage="""\
Normally this is imported by ProcessResults.py, but you can also
run it stand-alone.  Stand-alone usage:
 ./ProcessResultsSlips.py /mnt/data_drive/SpirentTest/DataDir/RX22-3518/xyz.T04 out_dir
This will produce a lot of plots under out_dir/, e.g.:
  out_dir/all_num_slips.png
  out_dir/all_percent_slips.png
  out_dir/trees_num_slips.png
  out_dir/trees_percent_slips.png
"""

from mutils import *
from subprocess import check_call, Popen, PIPE
import shlex
import shutil
import os
import pandas as pd
from collections import defaultdict
from glob import glob

def run_SNPC_slip_check(lla_truth, pospac_filename, in_filename, out_slip_path, do_resid=False, in_bg=False, rm_iono=False):
    """Run StingerNavPC's cycle slip check.
    lla_truth = (lat[deg], lon[deg], height[m]) truth (or None)
    pospac_filename = full filename of POSPac truth (or None)
      -> either lla_truth _or_ pospac_filename should be set. One should be None.
    in_filename = input T04 filename
    out_slip_path = directory to put results in
    do_resid = if True, also generate residuals.txt.gz
    in_bg = run in background and return proc to .wait() on?
    rm_iono = estimate and remove iono from metric?
    """
    if os.path.isdir( out_slip_path ):
        shutil.rmtree( out_slip_path )
    os.makedirs( out_slip_path )
    if pospac_filename is not None:
        os.symlink(pospac_filename, out_slip_path+'/POSPac.txt')

    all_in_filenames = glob(in_filename)
    process_filename = in_filename
    if len(all_in_filenames) > 1:
        process_filename = '.rcvr.T04'
        with open(out_slip_path + '/' + process_filename, "wb" ) as f_rcvr:
            for fin_name in sorted(all_in_filenames):
                shutil.copyfileobj( open(fin_name,'rb'), f_rcvr )

    cmd = "StingerNavPC %s -k200 --allow_no_iono --compute_elev -e0"%process_filename
    if pospac_filename is not None:
        cmd += " --check_cycle_slip=dynamic"
    else:
        cmd += " --check_cycle_slip=%.9f,%.9f,%.3f"%(lla_truth[0],lla_truth[1],lla_truth[2])
        cmd += " --static_truth=%.9f,%.9f,%.3f"%(lla_truth[0],lla_truth[1],lla_truth[2])
    if rm_iono:
        cmd += " --check_cycle_slip_iono"
    if do_resid:
        if pospac_filename is not None:
            cmd += " --truth"
    cmd += " > runlog.txt"
    if process_filename != in_filename:
        cmd += " && rm %s" % process_filename
    cmd += " && pigz -2 [rs]*.txt"
    print("Running: %s"%cmd)
    p = Popen( cmd, cwd=out_slip_path, shell=True )

    if in_bg:
        return p
    if p.wait() != 0:
        raise RuntimeError("Bad StingerNavPC run",p.returncode)

def discrete_pdf( val ):
    # Do PDF where values can only take a small number of fixed values
    cx_pdf = array(sorted(pd.unique(val)))
    cy_pdf = []
    for x in cx_pdf:
        cy_pdf.append( 1.*len(find(val==x))/len(val) )
    return cx_pdf, array(cy_pdf)

def read_raw_slip_data( out_slip_path ):
    """Load slip diagnostic data from path 'out_slip_data'
    ( d_raw, # all PLL data
      d_chk, # data actually checked
      d_false, # false alarm results
      d_miss,  # missed detection results
      d_svs )  # dual frequency RTK info
    """
    k = ['TIME','SV','SAT_TYPE','TRACK','FREQ','CNO']
    d_raw=pd.read_csv(out_slip_path+'/slip_diag_raw.txt.gz',
                      sep=' ', names=k)
    k = ['TIME','SV','SAT_TYPE','TRACK','FREQ','CNO','metric','iono']
    d_chk=pd.read_csv(out_slip_path+'/slip_diag_check.txt.gz',
                      sep=' ',names=k)
    k = ['TIME','SV','SAT_TYPE','TRACK','FREQ','CNO','metric','aiding']
    with Popen('zgrep false_alarm {}/slip_diag.txt.gz'.format(out_slip_path),
               shell=True,stdout=PIPE) as process:
        d_false=pd.read_csv(process.stdout,sep=' ',names=k)
    with Popen('zgrep miss_detect {}/slip_diag.txt.gz'.format(out_slip_path),
               shell=True,stdout=PIPE) as process:
        d_miss=pd.read_csv(process.stdout,sep=' ',names=k)
    k=['TIME','N_GPS','N_GLN','N_GAL','N_QZSS','N_BDS','N_IRNSS']
    d_svs=pd.read_csv(out_slip_path+'/slip_diag_num_svs.txt.gz',
                      sep=' ',names=k)
    return d_raw,d_chk,d_false,d_miss,d_svs

def analyze_slip_data( config_span, out_slip_path, filter_span=None ):
    """Create raw data and plots of StingerNavPC's cycle slip analysis.
    config_span = 'span' element from ProcessResults.py : parse_sampleConfig()
    out_slip_path = directory name where output goes
    filter_span = if not None, only look at these segment types (e.g., "Freeways")
    """
    print("Loading cycle slip data")
    d_raw,d_chk,d_false,d_miss,d_svs = read_raw_slip_data( out_slip_path )

    # Get all signal types
    raw_sigs = pd.unique(d_raw.SAT_TYPE*1000000
                         + d_raw.FREQ*1000
                         + d_raw.TRACK)
    signals = defaultdict(list)
    for r in raw_sigs:
        signals[r//1000000].append((r//1000 % 1000,r%1000))

    # combine timespans:
    #   spans[timespan_desc] = [(start1,end1), (start2,end2), ...]
    spans = defaultdict(list)
    for desc,start_stops in config_span:
        for start,stop in start_stops:
            if filter_span is not None and not desc in filter_span:
                continue
            spans[desc].append( (start,stop) )

    print("Creating cycle slip plots")
    for span_desc in spans.keys():
        i_raw = zeros(len(d_raw),dtype=bool)
        i_chk = zeros(len(d_chk),dtype=bool)
        i_svs = zeros(len(d_svs),dtype=bool)
        i_false = zeros(len(d_false),dtype=bool)
        i_miss = zeros(len(d_miss),dtype=bool)
        for start,stop in spans[span_desc]:
            i_raw[ find( (d_raw.TIME>=start)&(d_raw.TIME<=stop)) ] = True
            i_chk[ find( (d_chk.TIME>=start)&(d_chk.TIME<=stop)) ] = True
            i_false[ find( (d_false.TIME>=start)&(d_false.TIME<=stop)) ] = True
            i_miss[ find( (d_miss.TIME>=start)&(d_miss.TIME<=stop)) ] = True
            i_svs[ find( (d_svs.TIME>=start)&(d_svs.TIME<=stop)) ] = True
        d0_raw = d_raw[i_raw]
        if len(d0_raw) == 0:
            continue
        d0_chk = d_chk[i_chk]
        d0_false = d_false[i_false]
        d0_miss = d_miss[i_miss]
        d0_svs = d_svs[i_svs]

        num_smry = pd.DataFrame(None,
                       columns=['signal',
                                'raw',
                                'checked',
                                'false',
                                'missed'])
        pct_smry = pd.DataFrame(None,
                        columns=['signal',
                                 'drop',
                                 'false',
                                 'missed'])
        for sat_type,sig_list in sorted(signals.items()):
            for freq,track in sorted(sig_list):
                n_raw = len(find((d0_raw.SAT_TYPE==sat_type)
                                &(d0_raw.TRACK==track)
                                &(d0_raw.FREQ==freq)))
                if n_raw < 100:
                    continue
                n_check = len(find((d0_chk.SAT_TYPE==sat_type)
                                  &(d0_chk.TRACK==track)
                                  &(d0_chk.FREQ==freq)))
                n_false = len(find((d0_false.SAT_TYPE==sat_type)
                                  &(d0_false.TRACK==track)
                                  &(d0_false.FREQ==freq)))
                n_miss = len(find((d0_miss.SAT_TYPE==sat_type)
                                 &(d0_miss.TRACK==track)
                                 &(d0_miss.FREQ==freq)))
                sigstr = get_sub_type(sat_type,freq,track).sigstr
                num_smry.loc[len(num_smry)] = [
                    sigstr,
                    n_raw,
                    n_check,
                    n_false,
                    n_miss]
                pct_smry.loc[len(num_smry)] = [
                    sigstr,
                    100 - 100.*n_check/n_raw,
                    100.*n_false/n_raw,
                    100.*n_miss/n_raw]

        ax = num_smry.plot.bar( x='signal', figsize=(8,6) )
        ax.legend(loc='center left',bbox_to_anchor=(1.0, 0.5))
        ax.set_yscale('log')
        ax.set_title('Number of points: %s'%span_desc)
        ax.xaxis.set_label_text('')
        ax.grid()
        ax.set_axisbelow(True)
        ax.set_ylabel('# points')
        tight_layout()
        savefig('{}/{}_num_pts.png'.format(out_slip_path,span_desc))
        close()

        ax = num_smry.plot.bar( x='signal', figsize=(8,6), subplots=True )
        for a in ax:
            a.legend(loc='center left',bbox_to_anchor=(1.0, 0.5))
            a.set_title('')
            a.xaxis.set_label_text('')
            a.grid()
            a.set_axisbelow(True)
            a.set_ylabel('# points')
        ax[0].set_title('Number of points: %s'%span_desc)
        tight_layout()
        savefig('{}/{}_num_pts_subplots.png'.format(out_slip_path,span_desc))
        close()

        ax = pct_smry.plot.bar( x='signal', figsize=(8,6) )
        ax.legend(loc='center left',bbox_to_anchor=(1.0, 0.5))
        ax.set_yscale('log')
        ax.set_title('Percentage of points: %s'%span_desc)
        ax.xaxis.set_label_text('')
        ax.grid()
        ax.set_axisbelow(True)
        ax.set_ylabel('% points')
        tight_layout()
        savefig('{}/{}_percent_pts.png'.format(out_slip_path,span_desc))
        close()

        df_svs = ['GPS','GLN','GAL','BDS']
        fig,ax=subplots(len(df_svs),1,sharex=True,sharey=True,figsize=(8,6))
        for n,txt in enumerate(df_svs):
            ax[n].plot(d0_svs.TIME, d0_svs['N_'+txt] )
            ax[n].set_ylabel('# %s SVs'%txt)
            ax[n].grid()
        ax[0].set_title('# dual-freq satellites: %s'%span_desc)
        ax[-1].set_xlabel('GPS secs')
        tight_layout()
        savefig("%s/%s_num_dual.png"%(out_slip_path,span_desc))
        close()

        figure()
        colors = ['#e6194b', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', '#46f0f0',
                  '#f032e6', '#bcf60c', '#fabebe', '#008080', '#e6beff', '#9a6324', '#fffac8',
                  '#800000', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080', '#ffffff',
                  '#000000']
        maxColor = len(colors)

        n_svs_list = list(reversed(unique(d0_svs[['N_'+x for x in df_svs]])))
        for n_txt,txt in enumerate(df_svs):
            cx,cy = discrete_pdf( d0_svs['N_'+txt] )
            cy *= 100
            bottom = 0
            for i,n_sv in enumerate(n_svs_list):
                i_cy = find( cx==n_sv )
                if len(i_cy) == 0:
                    percent_sv = 0
                else:
                    percent_sv = cy[i_cy[0]]
                colorIndex = i % maxColor
                if n_txt == 0:
                    legend_txt = n_sv
                else:
                    legend_txt = None
                bar(txt,
                    percent_sv,
                    bottom=bottom,
                    label=legend_txt,
                    zorder=3,
                    color=colors[colorIndex])
                bottom += percent_sv

        ylim([0,100.0])
        xticks(rotation=45,ha='right')
        ylabel('Dual-Freq. Obs [%]')
        title('Dual-freq SVs: %s'%span_desc)
        grid(True,zorder=0)
        legend(loc=(1.04,0),title="Dual-Freq SVs", fontsize=8, fancybox=True,title_fontsize=8)
        tight_layout()
        save_compressed_png("%s/%s_pdf_num_dual.png"%(out_slip_path,span_desc),
                            dpi=300)
        close()

def main():
    """Stand-alone operation.  See usage at top of file"""
    import argparse
    from ProcessResults import parse_sampleConfig, fill_in_auto_config_spans
    from glob import glob

    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                                     description=usage)
    parser.add_argument('t04_filename',
                        help='T04 file (or file pattern)')
    parser.add_argument('out_dir',
                        help='Output directory for plots.  Will get overwritten')
    args = parser.parse_args()

    dir_name = os.path.dirname(args.t04_filename) + '/'
    for tmp in glob(dir_name + '*.xml'):
        if tmp.endswith('ALM.xml') or tmp.endswith('config.xml'):
            continue
        config_filename = tmp
        break
    config = parse_sampleConfig( config_filename )
    if config.truth_format not in ["POSPAC_ASCII","STATIC"]:
        raise RuntimeError("no truth")
    if config.truth_format == "POSPAC_ASCII":
        fill_in_auto_config_spans( config, doload(config.truth_file) )

    out_slip_path = args.out_dir
    run_SNPC_slip_check(config.truth, config.truth_file, args.t04_filename,
                        out_slip_path, rm_iono=config.rm_iono )
    analyze_slip_data(config.span, out_slip_path)

if __name__ == '__main__':
    main()
