import pandas as pd
import matplotlib.pyplot as plt
import os
import sys
import logging
import argparse
import datetime


def main():
    parser = argparse.ArgumentParser(
        description=("Report basic CMR/DCOL packet information"),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    ) 
    parser.add_argument(
        '-d', "--dir",
        default=None,
        help="The directory with _cmr.csv files to process. None for the last directory in the current path.",
    )
    parser.add_argument('--show',
        action=argparse.BooleanOptionalAction,
        default=True,
        help='Show the plot',
    )
    #parser.add_argument(
    #    '-f', '--files',
    #    nargs='+',
    #    type=str
    #    default=['pp20_cmr.csv', 'pp40_cmr.csv', 'pp50_cmr.csv']
    #    help="List of file names in DIR.",
    #)
    parser.add_argument(
        "--verbose", "-v",
        action='count',
        default=0
    )
    args = parser.parse_args()

    logger_format = ("%(asctime)s [%(levelname)-8s]"
                     "[%(filename)15s:%(lineno)-5s] "
                     "%(funcName)-17s: %(message)s")
    logging.basicConfig(
        level=logging.WARNING - (10 * args.verbose),
        format=logger_format,
        stream=sys.stdout
    )

    if args.dir is None:
        files = os.listdir('.')
        dirs = []
        for f in files:
            if not os.path.isdir(f):
                continue
            if f.startswith('20'):
                dirs.append(f)
        dirs.sort()
        data_dir = dirs[-1]
    else:
        data_dir = args.dir
    print(f'{data_dir=}')

    files = os.listdir(data_dir)
    files.sort()
    cmr_files = []
    for f in files:
        fname = os.path.join(data_dir, f)
        if not os.path.isfile(fname):
            continue
        if f.endswith('cmr.csv'):
            cmr_files.append(f)

    print(f'{cmr_files=}')
    
    dfs = []
    max_dt = 0
    max_t = datetime.datetime.min
    min_t = datetime.datetime.max
    max_n_large_dt =0
    #max_times=[]
    #min_times=[]
    for n, f in enumerate(cmr_files):
        print(f'Reading {f}')
        df = pd.read_csv(os.path.join(data_dir, f), parse_dates=['time'])
        if df.empty:
            print(f'Error reading {f}')
        df['dt'] = df['time'].diff().dt.total_seconds()
        print(df)
        
        dfs.append(df)
        #max_times.append(df.time.max())
        #min_times.append(df.time.min())
        if df.dt.max() > max_dt:
            max_dt = df.dt.max()
        if n==0 or df.time.max() > max_t:
            max_t = df.time.max()
        if n==0 or df.time.min() < min_t:
            min_t = df.time.min()
        n_large_dt = df.query('dt > 2.5').shape[0]
        if n_large_dt > max_n_large_dt:
            max_n_large_dt = n_large_dt

    max_dt = round(max_dt + 1, 0)

    one_second = datetime.timedelta(seconds=1)
    max_t = max_t + one_second
    max_t = max_t.replace(microsecond=0)
    min_t = min_t - one_second
    min_t = min_t.replace(microsecond=0)
    #min_t = min(min_times)
    #max_t = min(max_times)

    print(f'{max_dt=} {min_t=} {max_t=}')

    fig, ax = plt.subplots(
            len(dfs), 2,
            figsize=(10, 5),
            gridspec_kw={'width_ratios': [3, 1]},
            )
    bins = [0, 2.5, 4, 6, 10]
    for n in range(len(dfs)):
        df = dfs[n]
        if df.empty:
            continue
        dfq = df.query('dt > 2.5')
        if n == 0:
            ax[n, 0].set_title(f'CMR Packet time deltas {data_dir}')
            ax[n, 1].set_title(f'Large dt (>2.5) histogram')

        mean = df.query('dt > 0.25').dt.mean()
        count_long_dt = df.query("dt > 2.5").dt.count()
        #m = df_group.packet_count.mean()
        ax[n, 0].plot(df.time, df.dt, label=cmr_files[n])
        #ax[n, 0].plot(df_group.seconds, df_group.seconds.diff(), label='dt')
        #ax[n, 0].axhline(y=m, color='r', linestyle='--', label=f'mean: {m:0.2f}')
        ax[n, 0].set_xlim([min_t, max_t])
        ax[n, 0].set_ylabel('dt (seconds)')
        ax[n, 0].set_ylim([0, max_dt])
        ax[n, 0].legend()
        #ax[n].text(df.time.min() + 1, 1, f'dt mean {mean}')
        if n != len(dfs)-1:
            ax[n, 0].xaxis.set_tick_params(labelbottom=False)

        ax[n, 1].hist(dfq.dt, bins=bins, label=cmr_files[n])
        #ax[n, 1].set_xlim([min_t, max_t])
        ax[n, 1].set_ylabel('bin count')
        ax[n, 1].set_ylim([0, max_n_large_dt])
        #ax[n, 1].legend()


        print(f'{cmr_files[n]} dt mean: {mean:0.3f}, long dt count: {count_long_dt}')
        #print(df.query('dt > 0.25').dt.describe())

    ax[-1, 0].set_xlabel('Time (s)')
    ax[-1, 0].tick_params(axis='x', labelrotation=-45)
    ax[-1, 1].set_xlabel('dt (s)')
    plt.tight_layout()

    png_path = os.path.join(data_dir, f'{data_dir}_compare_cmr_age.png')
    print(f'Saving {png_path}')
    plt.savefig(png_path)
    if args.show:
        plt.show()
    plt.close()


if __name__ == "__main__":
    main()
