#!/usr/bin/env python3

# . ~/Documents/projects/Olympus/montera/scripts/env_montera/bin/activate

import pandas as pd
import numpy as np
import os
import argparse
import sys
import subprocess
import csv
import matplotlib.pyplot as plt

# For some gneric utiltities: json_read, json_write, timer
import montera.gutil as gutil
import montera.pos_lib as pos_lib
pos_lib.ant2ap_mode = 'single_file'


def main():
    total_time = gutil.timer('total')

    parser = argparse.ArgumentParser(
        description='Analyze swapped positions',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        #epilog="""Under the hood, we run the apps/pos_in_one.cpp code to create
        #.csv files, which we then read. Data is ploted to .png files."""
    )
    parser.add_argument(
        "--folder",
        default=".",
        help="Base folder"
    )
    parser.add_argument(
        "--dut_dirs",
        #default=['KLONDIKE', 'Z3R', 'P3ENCL'],
        default=['P3ENCL'],
        nargs='+',
        help="Receiver specific folders"
    )
    parser.add_argument(
        "--truth_csv",
        default='truth_pos_weighted_stats.csv',
        help="CSV to use for Truth",
    )
    parser.add_argument(
        "--dut_dir",
        default = "BLDG2BD",
        help="directory with DUT T04 data. It will be used as "
             "FOLDER[n]/data/gnss/DUT_DIR[n]/. If only a single DUT_DIR is "
             "provided, then it will be used for all FOLDERS",
    )
    outp = parser.add_argument_group('output folder options')
    outp.add_argument(
        "--data_out_dir",
        default="gen_data",
        help="Base filename for file I/O. Inputs are .T04. "
        "Outputs are .png, .csv"
    )
    outp.add_argument(
        "--log_dir",
        default="logs",
        help="where to put the log (.txt) from running snr_diff. "
             "It will be used as FOLDER/LOG_DIR/",
    )
    outp.add_argument(
        "--fig_dir",
        default="figures",
        help="where to put the PNG. "
             "It will be used as FOLDER/FIG_DIR/",
    )
    parser.add_argument(
        "--force",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Load data from a CSV and json files rather than t04.",
    )
    parser.add_argument(
        "--plot",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Create plots"
        )

    args = parser.parse_args()

    df_truth = pd.read_csv(os.path.join(args.folder, args.truth_csv))
    d_rename = {
        'w_mean_lat': 'lat',
        'w_mean_lon': 'lon',
        'w_mean_hgt': 'hgt',
        }
    df_truth = df_truth.rename(d_rename, axis='columns')
    a0 = df_truth.query('antenna == 0').iloc[0]
    a1 = df_truth.query('antenna == 1').iloc[0]
    
    # ------------ File name mangling
    records = [{
            'path': 'truth',
            'antenna': 0,
            'lat': a0.lat,
            'lon': a0.lon,
            'hgt': a0.hgt,
            'easting': 0,
            'northing': 0,
            'dh': 0,
            'distance': 0,
            }, {
            'path': 'truth',
            'antenna': 1,
            'lat': a1.lat,
            'lon': a1.lon,
            'hgt': a1.hgt,
            'easting': 0,
            'northing': 0,
            'dh': 0,
            'distance': 0,
            }]
    for root, dirs, files in os.walk(args.folder):
        for d in dirs:
            if not d.startswith("2403"):
                continue
            print(f"============ {d} =====================")
            full_d = os.path.join(root, d, 'data', 'gnss')
            print(f"{full_d}")
            for d1 in args.dut_dirs:
                print("    ----------------------------------------")
                print(f"    {d1}")
                data_dir = os.path.join(
                    args.folder, d, args.data_out_dir, 'position', d1)
                #print(f"    {data_dir=}")
                log_dir = os.path.join(args.folder, d, args.log_dir)
                log_fname = os.path.join(log_dir, d1 + '_pos_ant_perm.log')
                os.makedirs(log_dir, exist_ok=True)
                #print(f"    {log_fname=}")
                fname = os.path.join(data_dir, f'{d1}_pos_weighted_stats.csv')
                #print(f'    {fname=}')
                have_data = os.path.isfile(fname)
                if args.force or not have_data:
                    cmd_list = ['pos_in_one_stats',
                                os.path.join(args.folder, d),
                                d1,
                                '--force',
                                '--ne_tight_lim',
                                ]
                    print('    ' + ' '.join(cmd_list))

                    f = open(log_fname, 'w')
                    p = subprocess.run(
                        cmd_list, stderr=subprocess.STDOUT, stdout=f
                    )
                    f.close()
                    if p.returncode != 0:
                        print(f'Warning pos_in_one_stats return code {p.returncode}')

                df = pd.read_csv(fname)
                df = df.rename(d_rename, axis='columns')
                for ant in [0, 1]:
                    row = df.query(f'antenna == {ant}').iloc[0]
                    row_truth = df_truth.query(f'antenna == {ant}').iloc[0]
                    dist = pos_lib.dist_m(row_truth.lat, row_truth.lon,
                                          row.lat, row.lon)
                    enu =  pos_lib.llh2enu(np.array([row_truth.lat,
                                                     row_truth.lon,
                                                     row_truth.hgt]),
                                           np.array([row.lat, row.lon, row.hgt]),
                                           is_rad=False,
                                           is_ref_rad=False,
                                           )
                    easting = enu[0][0]
                    northing = enu[1][0]
                    dh=enu[2][0]
                    record = {
                            'path': os.path.join(args.folder, d, d1),
                            'antenna': ant,
                            'lat': row.lat,
                            'lon': row.lon,
                            'hgt': row.hgt,
                            'easting': easting,
                            'northing': northing,
                            'dh': dh,
                            'distance': dist,
                            }
                    records.append(record)


    df = pd.DataFrame(records)
    df = df.sort_values(by=['antenna', 'path'])
    print(df)

    # move folder and dut to first two columns
    fout = os.path.join(args.folder, "pos_ant_perm.csv")
    print(f'Saving {fout}')
    df.to_csv(fout, index=False)

    if args.plot:
        m = np.max(np.array([df.easting.abs().max(), df.northing.abs().max()]))

        fig, ax = plt.subplots(2, 1, figsize=[12,8])

        df.query("antenna == 0").plot('easting', 'northing',
                                      ax=ax[0],
                                      marker='+', color='b',
                                      linewidth=0,
                                      label='Ant 0')
        df.query("antenna == 1").plot('easting', 'northing',
                                      ax=ax[0],
                                      marker='x', color='r',
                                      linewidth=0,
                                      label='Ant 1')
        ax[0].set_xlim([-m, m])
        ax[0].set_ylim([-0.1, 0.1])
        ax[0].set_xlabel('Easting (m)')
        ax[0].set_ylabel('Northing (m)')
        ax[0].legend()
        ax[0].set_title(os.path.basename(os.path.abspath('.')))

        dh_ant0 = df.query("antenna == 0").dh
        x_ant0 = np.zeros(dh_ant0.shape[0])
        ax[1].plot(x_ant0, dh_ant0, 
                   marker='+', color='b',
                   linewidth=0,
                   label='Ant 0',
                   )
        dh_ant1 = df.query("antenna == 1").dh
        x_ant1 = np.ones(dh_ant1.shape[0])
        print(x_ant1)
        print(dh_ant1)
        ax[1].plot(x_ant1, dh_ant1, 
                   marker='x', color='r',
                   linewidth=0,
                   label='Ant 1',
                   )
        ax[1].set_xlabel('Antenna')
        ax[1].set_ylabel('delta height (m)')
        ax[1].legend(loc='upper center')


        fig_fname = os.path.join(args.folder, 'pos_ant_perm.png')
        print(f"saving {fig_fname}")
        plt.savefig(fig_fname)
        plt.close(fig)

    total_time.dt_print()


if __name__ == '__main__':
    main()
