#!/usr/bin/env python

# simple script to plot MPX for all satellites in a T0x file

import argparse
parser = argparse.ArgumentParser(description='Plot MPx for all SVs')
parser.add_argument('filename', help='T02/T04 filename')
parser.add_argument('-s','--smooth',
                    help='Low-pass filter time constant [s](default -1 for no smoothing)',
                    default=-1.,
                    type=float)
parser.add_argument('-p','--png',
                    help='Generate PNG files in headless operation',action="store_true")
parser.add_argument('--sat_types',
                    help='List to sat types to process (e.g., 1,10)',
                    default=None)
parser.add_argument('--viewdat',
                    help='Additional arguments to pass to viewdat (e.g. --dec=1000)',
                    default="")
parser.add_argument('--elev',
                    help='Minimum elevation [default 0]',
                    default=0., type=float)
parser.add_argument('--min_pts',
                    help='Minimum # of epochs in a segment [default 1000]',
                    default=1000, type=int)
parser.add_argument('--freq_plot',
                    help='Show FFT of MPx?',
                    action="store_true")
args = parser.parse_args()
if args.min_pts < 60:
    parser.error("--min_pts must be >= 60")

do_sat_types = None
if args.sat_types is not None:
    do_sat_types = [int(x) for x in args.sat_types.split(',')]

if args.smooth < 0.:
    print("No filtering")
else:
    print("Low-pass filter time constant %.1f [s]" % args.smooth)
print("Elevation mask %.1f deg"%args.elev)

if args.png:
    import matplotlib
    # Allow running headless from the command line
    matplotlib.use("agg")

from mutils import *
from collections import defaultdict

if not 'd' in globals():
    if args.viewdat != "":
        print("Extra viewdat args:" + args.viewdat)
    d = vd2cls(args.filename,rec='-d27,35:19'+args.viewdat)

MPx_by_desc = defaultdict(lambda: defaultdict(list))
all_figs = {}
all_fft_figs = {}
sv_list = get_sv_list(d,d.k)
for sv,sat_type in sv_list:
    if do_sat_types is not None and not sat_type in do_sat_types:
        continue

    i = find( (d.SV==sv) & (d.SAT_TYPE==sat_type) & (d.EL>=args.elev ) )
    try:
        k_chans = d.k.channel  # rec 35:19
    except:
        k_chans = d.k.FDMA # rec 27 (for non-GLN)
    for chan in unique(d[i,k_chans]):
        ii = find( (d.SV==sv) & (d.SAT_TYPE==sat_type) & (d[:,k_chans]==chan) )
        d0 = d[ii]
        info = sv_to_LxLy_info( d0, d0.k, sv, sat_type )
        for freq1,track1,freq2,track2,desc in info:
            i1, i2 = get_LxLy_idx( d0, d0.kf, sv, sat_type,
                                   freq1, track1,
                                   freq2, track2,
                                   args.elev, min_seg=args.min_pts )
            if len(i1) == 0:
                print('%s SV %d chan %d - too few points'%
                      (desc,sv,chan))
                continue
            if desc in all_figs:
                fig1,ax1 = all_figs[desc]
                if args.freq_plot:
                    fig2,ax2 = all_fft_figs[desc]
            else:
                fig1,ax1 = subplots()
                title(desc)
                all_figs[desc] = (fig1,ax1)
                if args.freq_plot:
                    fig2,ax2 = subplots()
                    title('FFT of %s'%desc)
                    all_fft_figs[desc] = (fig2,ax2)
            MPx, _, _ = get_MPx( d0, d0.kf, i1, i2, filt_Tc=args.smooth )
            ax1.plot( d0.TIME[i1], MPx, label='%d#%d'%(sv,chan) )
            N = 40
            FFT_MPx = 20*log10(do_1d_avg_fft(array(MPx),N))
            if args.freq_plot:
                dt = d0.TIME[i1[1]] - d0.TIME[i1[0]]
                f = ((r_[0.:N])/N - 0.5) * 1./dt
                ax2.plot( f, FFT_MPx, label='%d#%d'%(sv,chan) )
            MPx_by_desc[desc]['MPx'] += MPx.tolist()
            HF_FFT = FFT_MPx[-1]
            LF_FFT = FFT_MPx[N//2]
            MPx_by_desc[desc]['HF'].append(HF_FFT)
            MPx_by_desc[desc]['LF'].append(LF_FFT)
            print('%17s %2d chan %d len %d std %.3f HF %4.1f LF %4.1f' %
                  (desc,sv,chan,len(i1),std(MPx),HF_FFT,LF_FFT))

print("\nSummary:")
for desc,MPx in MPx_by_desc.items():
    print('%17s: len %d std %.3f HF %4.1f LF %4.1f' %
          (desc,
           len(MPx['MPx']),
           std(MPx['MPx']),
           median(MPx['HF']),
           median(MPx['LF'])
          ))

def finalize_fig(prefix,fig_info):
    for desc,(fig,ax) in fig_info.items():
        figure(fig.number)
        figlegend()
        make_legend_interactive(fig)
        if args.png:
            desc = desc.replace(' ','_')
            desc = desc.replace('-','_')
            savefig('%s_%s.png'%(prefix,desc))

finalize_fig('mpx', all_figs)
finalize_fig('fft_mpx', all_fft_figs)

if not args.png:
    show()
