#!/usr/bin/env python
"""
 Plot rec 35:30 anti-jam states and FFT spectrum.

 Example:
  ./plot_aj_and_fft.py /mnt/data_drive/Jam_NoPi/Auger5/0000000005202402130300.T04 183762 183823

 Note: if the time range is too big, this script will get very slow unless you
       use the --fft_60s option.
"""
from pylab import *
from mutils import vd2cls
import argparse

def main(args):
    filename = args.filename
    t0 = args.t_start
    t1 = args.t_end
    fft_center = args.fft_center*1e6
    min_freq = args.min_freq
    max_freq = args.max_freq
    do_show_rfi = args.show_rfi
    do_show_agc = args.show_agc
    target_rt_band = args.rt_band
    do_fft_60s = args.fft_60s
    ant_num = args.antenna

    d_aj=vd2cls(filename,'-d35:50 -s%d -e%e'%(t0,t1))
    if do_fft_60s:
        d_fft=vd2cls(filename,'-d35:25 -s%d -e%e'%(t0,t1))
    else:
        try:
            d_fft=vd2cls(filename,'-d35:39 -s%d -e%e'%(t0,t1))
        except:
            d_fft=vd2cls(filename,'-d35:25 -s%d -e%e'%(t0,t1))

    d_fft0_pre = d_fft[(d_fft.ANTNUM==ant_num)&(d_fft.SAMP_PT==1)&(d_fft.SEC>=t0)&(d_fft.SEC<=t1)
                       &(d_fft.CENTERFREQ==fft_center)]
    d_fft0_post = d_fft[(d_fft.ANTNUM==ant_num)&(d_fft.SAMP_PT==2)&(d_fft.SEC>=t0)&(d_fft.SEC<=t1)
                        &(d_fft.CENTERFREQ==fft_center)]
    d_aj0 = d_aj[(d_aj.GPS_SEC>=t0)&(d_aj.GPS_SEC<=t1)]

    n_firs = int([x for x in d_aj.k.keys() if x.startswith('FIR_RT_')][-1].split('_')[-1])+1
    n_rfi = int([x for x in d_aj.k.keys() if x.startswith('RFI_RT_')][-1].split('_')[-1])+1
    n_agc = int([x for x in d_aj.k.keys() if x.startswith('AGC_IS_ENB_')][-1].split('_')[-1])+1
    n_hw_assist = int([x for x in d_aj.k.keys() if x.startswith('HW_ASSIST_RT_')][-1].split('_')[-1])+1

    close('all')

    fig,ax = subplots(1,3,sharex=True,sharey=True,figsize=(14,6))
    fig.suptitle(filename.split('/')[-1])
    ax[0].set_title('pre-mitigation')
    freq_mhz = (r_[0:2048]/2048. - 0.5)*50 + d_fft0_pre.CENTERFREQ[0]*1e-6
    X,Y = meshgrid(freq_mhz, d_fft0_pre.SEC)
    plt_im = ax[0].pcolormesh(Y,X,
                              d_fft0_pre[:,d_fft0_pre.k.FREQDATA:d_fft0_pre.k.FREQDATA_LAST+1],
                              cmap='rainbow',vmin=10,vmax=80)
    ax[0].set_ylim([min_freq,max_freq])
    if d_fft0_pre.SEC[-1]-d_fft0_pre.SEC[0] <= 0:
        print("WARNING: FFT data is too short to display")

    if do_show_agc:
        fig2,ax2= subplots(2,1,sharex=True)
        ax2[0].set_ylabel('AGC [PWM count]')
        for n in range(n_agc):
            rt_band_i = d_aj0.k['NF_RT_BAND_%2.2d'%n]
            enb_i = d_aj0.k['AGC_IS_ENB_%2.2d'%n]
            pwm_i = d_aj0.k['AGC_PWM_%2.2d'%n]
            i = where(d_aj0[:,enb_i]!=0)[0]
            if len(i) > 0:
                ax2[0].plot( d_aj0[i].GPS_SEC, d_aj0[i,pwm_i],
                             '.', label='RT %d'%d_aj0[i[0],rt_band_i] )
        ax2[0].grid()
        ax2[0].legend()
        ax2[1].set_ylabel('Attenuation [dB]')
        for n in range(n_agc):
            rt_band_i = d_aj0.k['NF_RT_BAND_%2.2d'%n]
            enb_i = d_aj0.k['CHAN_IS_ENB_%2.2d'%n]
            atten_i = d_aj0.k['CHAN_ATTEN_%2.2d'%n]
            i = where(d_aj0[:,enb_i]!=0)[0]
            if len(i) > 0:
                ax2[1].plot( d_aj0[i].GPS_SEC, d_aj0[i,atten_i],
                             '.', label='RT %d'%d_aj0[i[0],rt_band_i] )
        i = where(d_aj0.COM_IS_ENB!=0)[0]
        if len(i) > 0:
            ax2[1].plot( d_aj0[i].GPS_SEC, d_aj0[i].COMM_ATTEN,
                         '.', label='common' )
        ax2[1].grid()
        ax2[1].legend()
        fig2.tight_layout()

    if do_show_rfi:
        ax[1].set_title('RFI states')
        for n in range(n_rfi):
            rt_band_i = d_aj0.k['RFI_RT_BAND_%2.2d'%n]
            freq_i = d_aj0.k['RFI_FREQ_%2.2d'%n]
            i = where(d_aj0[:,rt_band_i]==target_rt_band)[0]
            if len(i) > 0:
                ax[1].plot( d_aj0[i].GPS_SEC, d_aj0[i,freq_i]*1e-6, '.', label='%d'%n )
        ax[1].set_ylim([min_freq,max_freq])
        ax[1].grid()
        ax[1].legend()
    else:
        ax[1].set_title('post-mitigation')
        freq_mhz = (r_[0:2048]/2048. - 0.5)*50 + d_fft0_post.CENTERFREQ[0]*1e-6
        X,Y = meshgrid(freq_mhz, d_fft0_post.SEC)
        plt_im = ax[1].pcolormesh(Y,X,
                                  d_fft0_post[:,d_fft0_post.k.FREQDATA:d_fft0_post.k.FREQDATA_LAST+1],
                                  cmap='rainbow',vmin=10,vmax=80)
        ax[1].set_ylim([min_freq,max_freq])

    ax[2].set_title('FIR states')
    for n in range(n_firs):
        rt_band_i = d_aj0.k['FIR_RT_BAND_%3.3d'%n]
        freq_i = d_aj0.k['FIR_FREQ_%3.3d'%n]
        i = where(d_aj0[:,rt_band_i]==target_rt_band)[0]
        if len(i) > 0:
            ax[2].plot( d_aj0[i].GPS_SEC, d_aj0[i,freq_i]*1e-6, '.' )
    for j in range(n_hw_assist):
        fir_band_idx = d_aj0.k['HW_ASSIST_RT_BAND_%3.3d' %j]
        fir_ant_idx = d_aj0.k['HW_ASSIST_ANT_NUM_%3.3d' %j]
        idx = where(d_aj0[:,fir_band_idx]==target_rt_band)[0]
        if len(idx) > 0:
            alloc_idx = d_aj0.k['HW_ASSIST_ALLOC_BY_HW_%3.3d' %j]
            cmd_idx = d_aj0.k['HW_ASSIST_CMD_BY_SW_%3.3d' %j]
            fir_freq_idx = d_aj0.k['HW_ASSIST_FREQ_%3.3d' %j]
            #ax.plot( d_aj0[idx].GPS_SEC, (d_aj0[idx,fir_freq_idx]*1e-6)*d_aj0[idx,cmd_idx], '.' )
            ax[2].plot( d_aj0[idx].GPS_SEC, (d_aj0[idx,fir_freq_idx]*1e-6)*d_aj0[idx,alloc_idx], 'rx' )
    ax[2].set_ylim([min_freq,max_freq])
    ax[2].grid()
    fig.tight_layout()

    show()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                                     description=__doc__)
    parser.add_argument('filename', help='T04 file to analyze')
    parser.add_argument('t_start', type=float, help='start time in GPS secons')
    parser.add_argument('t_end', type=float, help='end time in GPS secons')
    parser.add_argument('--min_freq', type=int, default=1573,
                        help='low freq MHz (default 1573)')
    parser.add_argument('--max_freq', type=int, default=1578,
                        help='high freq MHz (default 1578)')
    parser.add_argument('--fft_center', type=float, default=1590,
                        help='FFT center freq MHz (default 1590)')
    parser.add_argument('--rt_band', type=int, default=0,
                        help='RT band to observe (default 0 = GPS L1)')
    parser.add_argument('--show_rfi', action="store_true",
                        help='show RFI tracking states instead of post-miti FFT')
    parser.add_argument('--show_agc', action="store_true",
                        help='show AGC as an additional plot')
    parser.add_argument('--fft_60s', action="store_true",
                        help='Use 60s FFT instead of high-rate FFT. Do this for large time ranges.')
    parser.add_argument('--antenna', type=int, default=1,
                        help='antenna num (default 1=vector)')
    args = parser.parse_args()
    main(args)
