import subprocess
import pandas as pd
import matplotlib.pyplot as plt
import argparse
import sys
import os

#import viewdat_cno_lib as vdl

parser = argparse.ArgumentParser(
    description='Apply an elevation mask with t0x2t0x.',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
    '-F', "--folder",
    default='240918',
    help="Base folder name."
)
#parser.add_argument(
#    '--t04_list',
#    nargs='+', 
#    default=['BX992_all.T04', 'EB1-5093_all.T04', 'EB1-5095_all.T04',
#             'EB-5081_all.T04', 'P4-05_all.T04', 'Strip_all.T04'],
#    help="Elevation Mask to apply."
#)
parser.add_argument(
    '--dev_list',
    default=["EB-5081", "EB1-5093", "EB1-5095", "P4-05"],
    nargs='+', 
    help="Elevation Mask to apply."
)
parser.add_argument(
    '--el_args',
    nargs='+', 
    default=[0, 10, 12, 15, 20],
    help="Elevation Mask dir postfixes."
)
args = parser.parse_args()

df_list = []
for el in args.el_args:
    for ant in [0, 1]:
        el_folder = f'{args.folder}_el{el}'
        csv_path = os.path.join(el_folder, 'outputs', f'ant{ant}_num_slips.csv')
        print(f'Reading {csv_path}')
        df = pd.read_csv(csv_path)
        df['el'] = el
        df['ant'] = ant
        df_list.append(df)

df = pd.concat(df_list)

dlist = []
for ant in [0, 1]:
    for el in args.el_args:
        judo_mean_n_slips = 0
        judo_mean_n_obs = 0
        for dev in args.dev_list:
            dfq = df.query(f'ant=={ant} and dev=="{dev}" and el=={el}')
            #print(dfq)
            n_slips = dfq.num_slips.sum()
            n_obs = dfq.num_obs.sum()
            judo_mean_n_slips += n_slips
            judo_mean_n_obs += n_obs
            dlist.append( {
                'ant': ant,
                'dev': dev,
                'el': el,
                'num_slips': n_slips,
                'num_obs': n_obs,
                })
        dlist.append( {
            'ant': ant,
            'dev': 'Judo_mean',
            'el': el,
            'num_slips': judo_mean_n_slips/len(args.dev_list),
            'num_obs': judo_mean_n_obs/len(args.dev_list),
            })

        dfq = df.query(f'ant=={ant} and dev=="BX992" and el=={el}')
        n = dfq.num_slips.sum()
        dlist.append( {
            'ant': ant,
            'dev': 'BX992',
            'el': el,
            'num_slips': dfq.num_slips.sum(),
            'num_obs': dfq.num_obs.sum(),
            })
df = pd.DataFrame(dlist)
max_slips = df.num_slips.max()
df['percent'] = (df.num_slips / df.num_obs) * 100.0
#print(df)

for ant in [0, 1]:

    dfq = df.query(f'ant=={ant} and dev in ["BX992", "Judo_mean"]')[['el', 'dev', 'num_slips']]
    df_stats = dfq.groupby(['dev', 'el'], dropna=False)['num_slips'].sum()
    df_stats = df_stats.unstack()
    df_stats = df_stats.reset_index()

    print('df_stats:')
    print(df_stats)

    fig, ax = plt.subplots(figsize=(8,4), constrained_layout=True)
    df_stats.plot(x='dev', kind='bar', ax=ax)
    plt.title(f'ant{ant}')
    plt.ylabel('Number of Cycle Slips')
    plt.xticks(rotation=0)
    plt.ylim([0, max_slips*1.1])
    #plt.grid(True, axis='y')
    for p in ax.patches:
        # Print values above bars
        ax.annotate(str(p.get_height()), (p.get_x() * 1.005, p.get_height() * 1.005))
    # Move legend to the side
    ax.legend(bbox_to_anchor=(1.22, 1.05), title='Elevation Mask')

    png_fname = f'slips_stats_ant{ant}.png'
    print(f'Saving {png_fname}')
    plt.savefig(png_fname)

    #plt.show()
    plt.close(fig)


#--------------------------------------
# Percent Plot
#--------------------------------------
print('df:')
print(df)
print('df_stats (percent):')
print(df_stats)

dfq = df.query(f'dev in ["BX992", "Judo_mean"]')[
               #['el', 'dev', 'num_slips', 'num_obs']
               ['el', 'dev', 'percent']
            ]
df_stats = dfq.groupby(['el', 'dev'], dropna=False)['percent'].mean()
df_stats = df_stats.unstack()
df_stats = df_stats.reset_index()

print('df_stats:')
print(df_stats)

fig, ax = plt.subplots(figsize=(8,4), constrained_layout=True)
df_stats.plot(x='el',
              #kind='bar',
              marker='o', linewidth=0.3,
              ax=ax)
#plt.title(f'ant{ant}')
plt.ylabel('Cycle Slips as Percent of Number of Observations')
plt.xlabel('Elevation (degrees)')
#plt.xticks(rotation=0)
#plt.xticks(args.el_args)
plt.xticks(range(0, 21, 5))
#plt.ylim([0, 1])
plt.grid(True)
# Move legend to the side
#ax.legend(bbox_to_anchor=(1.22, 1.05), title='Elevation Mask')

png_fname = f'slips_stats_percent.png'
print(f'Saving {png_fname}')
plt.savefig(png_fname)

plt.show()
plt.close(fig)
