#!/usr/bin/env -S streamlit run --server.headless true
#
# Run as:
#  streamlit run view_jam_diag.py --server.headless true
#
# Plots anti-jam diagnostic information for Auger 5 system.
# Adapted from:
#  GPSTools/pythonTools/plt_jamming_mitigation.py
#
import streamlit as st
from pylab import *
import mutils as m
import itertools
from plt_jamming_mitigation import rt_band_to_center_freq, rt_band_to_label
from bokeh.plotting import figure
from bokeh.palettes import Category20_19 as palette
from bokeh.layouts import column
from bokeh.models import NumeralTickFormatter
import sys

if len(sys.argv) > 1:
    base_file_path = sys.argv[1]
else:
    base_file_path = "/mnt/data_drive/Jam_NoPi/Auger5/0000000005"

@st.cache_data
def load_t04(file_path):
    return m.vd2arr(file_path, rec='-d35:30')

def do_FIR(file_path, sel_rt_band):
    try:
        d,k = load_t04( file_path )
    except:
        st.write(f"no data for {file_path}")
        return
    n_bands = int([x for x in k.keys() if x.startswith('NF_RT_')][-1].split('_')[-1])+1
    n_fir = int([x for x in k.keys() if x.startswith('FIR_RT_')][-1].split('_')[-1])+1
    n_hw_assist = int([x for x in k.keys() if x.startswith('HW_ASSIST_RT_')][-1].split('_')[-1])+1

    i = 0
    rt_band_idx = k['NF_RT_BAND_%2.2d' %i]
    rt_band = int(d[0,rt_band_idx])
    center_freq_mhz = int(rt_band_to_center_freq(rt_band))

    if d[0,k['VERSION']] >= 16:
        p1_title = 'FIR filters - color = HW slot'
    else:
        p1_title = 'FIR filters'

    p1 = figure(
        plot_height=300,
        title=p1_title,
        y_range=(center_freq_mhz-25,center_freq_mhz+25),
        y_axis_label='Notch Filter Placement [MHz]')
    p2 = figure(plot_height=300,
                x_range=p1.x_range,
                x_axis_label='GPS TOW [s]',
                y_axis_label='FIR J/N [dB]')


    colors = itertools.cycle(palette)
    if d[0,k['VERSION']] >= 16:
        for j in range(n_hw_assist):
            rt_band_idx = k['HW_ASSIST_RT_BAND_%2.2d' %j]
            alloc_idx = k['HW_ASSIST_IS_ALLOCATED_%2.2d' %j]
            cmd_idx = k['HW_ASSIST_CMD_BY_SW_%2.2d' %j]
            freq_idx = k['HW_ASSIST_FREQ_%2.2d' %j]
            idx = where(d[:,rt_band_idx]==rt_band)[0]
            if len(idx) == 0:
                continue
            mycolor = palette[j%len(palette)]
            p1.circle( d[idx,k.GPS_SEC], (d[idx,freq_idx]*1e-6)*d[idx,cmd_idx],
                       size=5, color=mycolor )
            p1.x( d[idx,k.GPS_SEC], (d[idx,freq_idx]*1e-6)*d[idx,alloc_idx],
                  size=10, color='red' )
    else:
        for j,color in zip(range(n_fir),colors):
            rt_band_idx = k['FIR_RT_BAND_%2.2d' %j]
            freq_idx = k['FIR_FREQ_%2.2d' %j]
            idx = where(d[:,rt_band_idx]==rt_band)[0]
            if len(idx) > 0:
                p1.circle( d[idx,k.GPS_SEC], d[idx,freq_idx]*1e-6, color=color, size=5 )

        for j in range(n_hw_assist):
            rt_band_idx = k['HW_ASSIST_RT_BAND_%2.2d' %j]
            freq_idx = k['HW_ASSIST_FREQ_%2.2d' %j]
            idx = where(d[:,rt_band_idx]==rt_band)[0]
            if len(idx) > 0:
                p1.x( d[idx,k.GPS_SEC], d[idx,freq_idx]*1e-6, size=10, color='red' )

    for j,color in zip(range(n_fir),colors):
        rt_band_idx = k['FIR_RT_BAND_%2.2d' %j]
        if d[0,k['VERSION']] >= 13:
            j2n_idx = k['FIR_J2N_FILT_%2.2d' %j]
        else:
            j2n_idx = k['FIR_J2N_%2.2d' %j]
        idx = where(d[:,rt_band_idx]==rt_band)[0]
        if len(idx) > 0:
            p2.circle( d[idx,k.GPS_SEC], d[idx,j2n_idx], color=color, size=5 )

    for p in [p1,p2]:
        p.xaxis[0].formatter = NumeralTickFormatter(format="0[.]0")
    st.bokeh_chart(column(p1,p2,sizing_mode='stretch_both'), use_container_width=True)

def do_RSSI( file_path, sel_rt_band ):
    try:
        d,k = load_t04( file_path )
    except:
        st.write(f"no data for {file_path}")
        return
    n_bands = int([x for x in k.keys() if x.startswith('NF_RT_')][-1].split('_')[-1])+1

    p1 = figure(
        plot_height=200,y_axis_label='RSSI')
    p2 = figure(plot_height=200,y_axis_label='PWM',
                x_range=p1.x_range)
    p3 = figure(plot_height=200,y_axis_label='Attenuation',
                x_range=p1.x_range)

    i = 0
    rt_band_idx = k['NF_RT_BAND_%2.2d' %i]
    rt_band = int(d[0,rt_band_idx])

    colors1 = itertools.cycle(palette)
    if d[0,k['VERSION']] >= 10:
        rssi_idx = k['COMM_RSSI']
        p1.circle( d[:,k.GPS_SEC], d[:,rssi_idx], color=next(colors1) )

    colors2 = itertools.cycle(palette)
    if d[0,k['VERSION']] >= 3:
        pwm_idx = k['AGC_PWM_%2.2d' %i]
        p2.circle( d[:,k.GPS_SEC], d[:,pwm_idx], color=next(colors2) )

    colors3 = itertools.cycle(palette)
    if d[0,k['VERSION']] >= 12:
        idx = k['AGC_CTRL_MODE_%2.2d' %i]
        ctrl_mode = d[:,idx]*(340-210)/4 + 210
        p3.circle( d[:,k.GPS_SEC], ctrl_mode, color=next(colors3), legend_label='ctrl mode' )

    if d[0,k['VERSION']] >= 3:
        chan_idx = k['CHAN_ATTEN_%2.2d' %i]
        comm_idx = k['COMM_ATTEN']
        p3.circle( d[:,k.GPS_SEC], d[:,chan_idx],color=next(colors3) , legend_label='channel')
        p3.circle( d[:,k.GPS_SEC], d[:,comm_idx],color=next(colors3) , legend_label='comm')

    for p in [p1,p2,p3]:
        p.xaxis[0].formatter = NumeralTickFormatter(format="0[.]0")
    st.bokeh_chart(column(p1,p2,p3,sizing_mode='stretch_both'), use_container_width=True)

def do_HW_FFT( file_path, sel_rt_band ):
    try:
        d,k = load_t04( file_path )
    except:
        st.write(f"no data for {file_path}")
        return
    n_bands = int([x for x in k.keys() if x.startswith('NF_RT_')][-1].split('_')[-1])+1

    center_freq_mhz = int(rt_band_to_center_freq(sel_rt_band))
    freq_mhz = (r_[0:128]/128. - 0.5)*50 + center_freq_mhz
    gps_sec = d[:,k.GPS_SEC]
    X,Y = meshgrid(freq_mhz,gps_sec)
    z = np.full_like(X,0)

    i = 0
    for j in range(0,z.shape[0]):
        for b in range(0,8):
            bitfield_idx = k['HW_BINS_BLANKED_%2.2d_%2.2d' %(i,b)]
            bitfield = int(d[j,bitfield_idx])
            for n in range(0,16):
                z[j,16*b+n] = 255 - (1 & (bitfield >> n))*127

    p = figure(height=400,
               x_range=(gps_sec[0],gps_sec[-1]),
               y_range=(freq_mhz[0],freq_mhz[-1])
               )
    dx = gps_sec[-1]-gps_sec[0]
    dy = freq_mhz[-1] - freq_mhz[0]
    p.image(image=[z.T],x=gps_sec[0],y=freq_mhz[0],dw=dx,dh=dy)
    p.xaxis[0].formatter = NumeralTickFormatter(format="0[.]0")
    st.bokeh_chart(p, use_container_width=True)

def do_RFI( file_path, sel_rt_band ):
    try:
        d,k = load_t04( file_path )
    except:
        st.write(f"no data for {file_path}")
        return
    n_bands = int([x for x in k.keys() if x.startswith('NF_RT_')][-1].split('_')[-1])+1
    return # TODO

def do_SW_FFT( file_path, sel_rt_band ):
    try:
        d,k = load_t04( file_path )
    except:
        st.write(f"no data for {file_path}")
        return
    n_bands = int([x for x in k.keys() if x.startswith('NF_RT_')][-1].split('_')[-1])+1
    return # TODO

######################################################################

default_date = datetime.date.today()
default_hour = 0
if "date" in st.query_params:
    default_date = datetime.datetime.strptime(st.query_params["date"], "%Y-%m-%d").date()
if "hour" in st.query_params:
    default_hour = int(st.query_params["hour"]) % 24

st.set_page_config(layout="wide")
col1, col2, col3, col4 = st.columns(4)
with col1:
    file_date = st.date_input("file date", default_date)
with col2:
    file_hour = st.slider('file hour', value=default_hour, min_value=0, max_value=23)
with col3:
    sel_plot = st.selectbox("Plot",("FIR","RSSI","HW_FFT"))
#with col4:
#    sel_rt_band = st.selectbox("RT band",(0,1,2))
sel_rt_band = 0

if st.button("Run"):
    st.query_params["date"] = "%d-%2.2d-%2.2d"%(file_date.year,file_date.month,file_date.day)
    st.query_params["hour"] = file_hour
    time_str = "%d%2.2d%2.2d%2.2d00"%(file_date.year,
                                      file_date.month,
                                      file_date.day,
                                      file_hour)
    file_path = f"{base_file_path}{time_str}.T04"
    if sel_plot=="FIR":
        do_FIR( file_path, sel_rt_band )
    elif sel_plot=="HW_FFT":
        do_HW_FFT( file_path, sel_rt_band )
    elif sel_plot=="RSSI":
        do_RSSI( file_path, sel_rt_band )
