#!/usr/bin/python3

import pandas as pd
import matplotlib.pyplot as plt
import argparse
import sys
import os
import json
import logging
from copy import deepcopy


def read_next_epoch(fp, one_sec=False):
    line = ""
    if one_sec:
        while not('[HRPROC]' in line and
                  'MetaData::SatUsage' in line and
                  '.000' in line):
            line = fp.readline()
            if not line:
                logging.info('end of file')
                return None, None
    else:
        while not('[HRPROC]' in line and
                  'MetaData::SatUsage' in line):
            line = fp.readline()
            if not line:
                logging.info('end of file')
                return None, None
    epoch = float(line.split()[1])
    logging.info('--------------------')
    logging.info(f'Found epoch {epoch}')

    data = {}
    while True:
        last_pos = fp.tell()
        line = fp.readline()
        if not line:
            logging.error('Unexpected end of file')
            return None, None

        # Data lines look like L1C [x]..., L8X [x]..., etc
        if not line.startswith('L'):
            break
        if not line[1].isnumeric():
            break
        if line[4] != '[' and line[6] != ']':
            break
        if 'MetaData::SatUsage' in line:
            fp.seek(last_pos)
            break
        #print(line)
        freq = line.split()[0]
        #space = line.find(' ')
        sys_line = line[4:].rstrip()
        d_sys = {}
        for sp in sys_line.split():
            d_sys[sp[1]] = sp[3:]
        #data[freq] = line[4:].rstrip()
        data[freq] = d_sys
        logging.info(freq, data[freq])
        #data = data + line.rstrip() + ' '
    #print(f'{data=}')
    return epoch, data

def combine_sats_in_epoch(d):
    d_all_sys = {}
    for k_band, d_band in d.items():
        logging.info(f'   sys {k_band}: {d_band}')
        for k_sys, d_sats in d_band.items():
            if k_sys in d_all_sys.keys():
                current_str = list(d_all_sys[k_sys])
                for n, c in enumerate(d_sats):
                    if c == '+':
                        current_str[n] = '+'
                d_all_sys[k_sys] = ''.join(current_str)
            else:
                d_all_sys[k_sys] = d_sats
    return d_all_sys


def num_sats_used_in_epoch(d):
    d_all_sys = combine_sats_in_epoch(d)

    num_sat_changes = 0
    for k_sys, d_sats in d_all_sys.items():
        n = d_sats.count('+')
        num_sat_changes += n
        logging.info(f'   {n=}, {num_sat_changes=}')
    return num_sat_changes


def diff_sats_by_epoch(a, b):
    a_all_sys = combine_sats_in_epoch(a)
    b_all_sys = combine_sats_in_epoch(b)
    sat_change_list = diff_sys(deepcopy(a_all_sys), deepcopy(b_all_sys))
    return sat_change_list


def num_sats_used_in_band(d):
    num_sat_changes = 0
    for d_sys, d_sats in d.items():
        logging.info(f'      {d_sys}: {d_sats}')
        n = d_sats.count('+')
        num_sat_changes += n
        logging.info(f'      {n=}, {num_sat_changes=}')
    return num_sat_changes


def diff_letters(a,b):
    return sum ( a[i] != b[i] for i in range(len(a)) )


def diff_sys(a, b):
    logging.info(f"{a=}")
    logging.info(f"{b=}")
    sat_change_list = []
    for a_sys in list(a):
        logging.info(f"{a_sys=}")
        a_sats = a[a_sys]
        for b_sys in list(b):
            logging.info(f"{b_sys=}")
            b_sats = b[b_sys]
            if a_sys == b_sys:
                logging.info(f'   Found sys {a_sys} in both')
                logging.info(f'      {a_sats}')
                logging.info(f'      {b_sats}')
                #n = diff_letters(a_sats, b_sats)
                #if n > 0:
                for n, la in enumerate(a_sats):
                    if a_sats[n] != b_sats[n]:
                        n10 = int(n /10)
                        sat_str = a_sys + str(n-n10)
                        sat_change_list.append(sat_str)
                        logging.info(n, n10, sat_str)
                        
                del a[a_sys]
                del b[b_sys]
                break
    logging.info(f'   {sat_change_list=}')

    logging.info(f'   leftover sys in a: {a}')
    for a_sys in list(a):
        a_sats = a[a_sys]
        for n, la in enumerate(a_sats):
            if a_sats[n] == '+':
                n10 = int(n /11)
                sat_str = a_sys + str(n-n10)
                sat_change_list.append(sat_str)
    logging.info(f'      {sat_change_list=}')

    logging.info(f'   leftover sys in b: {b}')
    for b_sys in list(b):
        b_sats = b[b_sys]
        for n, lb in enumerate(b_sats):
            if b_sats[n] == '+':
                n10 = int(n /10)
                sat_str = a_sys + str(n-n10)
                sat_change_list.append(sat_str)
    logging.info(f'      {sat_change_list=}')

    return sat_change_list


#def diff_sys(a, b):
#    num_sat_changes = 0
#    for a_sys in list(a):
#        a_sats = a[a_sys]
#        for b_sys in list(b):
#            b_sats = b[b_sys]
#            if a_sys == b_sys:
#                logging.info(f'   Found sys {a_sys} in both')
#                logging.info(f'      {a_sats}')
#                logging.info(f'      {b_sats}')
#                n = diff_letters(a_sats, b_sats)
#                num_sat_changes += n
#                logging.info(f'      {n=}, {num_sat_changes=}')
#                del a[a_sys]
#                del b[b_sys]
#                break
#
#    logging.info(f'   leftover sys in a:')
#    n = num_sats_used_in_band(a)
#    num_sat_changes += n
#    logging.info(f'      {n=}, {num_sat_changes=}')
#
#    logging.info(f'   leftover sys in b:')
#    n = num_sats_used_in_band(b)
#    num_sat_changes += n
#    logging.info(f'      {n=}, {num_sat_changes=}')
#
#    return num_sat_changes


def diff_sats_by_sys_and_band(a, b):
    a = deepcopy(a)
    b = deepcopy(b)
    num_changes = 0
    for a_band in list(a):
        a_sys = a[a_band]
        for b_band in list(b):
            b_sys = b[b_band]
            if a_band == b_band:
                logging.info(f'Found band {a_band} in both')
                logging.info(f'   {a_sys}')
                logging.info(f'   {b_sys}')
                sat_change_list = diff_sys(a_sys, b_sys)
                num_changes += len(sat_change_list)
                logging.info(f'   {num_changes=}')
                del a[a_band]
                del b[b_band]
                break

    for a_band, a_sys in a.items():
        logging.info(f'   leftover band {a_band} in a')
        n = num_sats_used_in_band(a_sys)
        num_changes += n
        logging.info(f'   {n=}, {num_changes=}')

    for b_band, b_sys in b.items():
        logging.info(f'   leftover band {b_band} in b')
        n = num_sats_used_in_band(b_sys)
        num_changes += n
        logging.info(f'   {n=}, {num_changes=}')
    #print(f'{a=}')
    #print(f'{b=}')
    return num_changes


def calc_changes_and_plot(filename, one_sec, plot=True):
    epoch = 1
    sat_changes = []
    readA = True
    epochA_sats = None
                #while lineB := fileB.readline():

    fp = open(filename)

    epoch_prev, epoch_sats_prev = read_next_epoch(fp, one_sec)
    while epoch:
        epoch, epoch_sats = read_next_epoch(fp, one_sec)
        if epoch is None:
            break

        #n_diff_sys_band = diff_sats_by_sys_and_band(epoch_sats_prev, epoch_sats)
        sat_change_list_by_epoch = diff_sats_by_epoch(epoch_sats_prev, epoch_sats)
        #print('epoch_sats:')
        #print(json.dumps(epoch_sats, indent=4))
        #print(epoch, n)
        n_sats =  num_sats_used_in_epoch(epoch_sats)
        #print(f'{n_sats=}')

        sat_changes.append({'Time': epoch,
                            'n_sats_used': n_sats,
                            'n_sat_changes': len(sat_change_list_by_epoch), 
                            #'sat_changes_by_sys_band': n_diff_sys_band, 
                            'change_list': sat_change_list_by_epoch, 
                            })

        epoch_sats_prev = epoch_sats
        #if epoch > 323839:
        #    sys.exit()
        #input()
    fp.close()

    df = pd.DataFrame(sat_changes)
    #print(df)

    fname_out = f'{filename[:-4]}_sat_changes.csv'
    print(f'Saving {fname_out}')
    df.to_csv(fname_out, index=False)

    if plot:
        df['dT'] = df.Time.diff()
        pd.options.display.min_rows=30

        fig, ax = plt.subplots(3, 1,
                               figsize=(20, 7),
                               height_ratios=[3, 1, 1],
                               layout="constrained",
                               )
        ax[0].plot(df.Time, df.n_sat_changes, '.', markersize=4)
        ax[1].plot(df.Time, df.n_sats_used, '.', markersize=4)
        ax[2].plot(df.Time, df.dT, '.', markersize=4)

        ax[0].set_title(filename)
        ax[0].set_ylabel('Num Sats Changed')
        ax[1].set_ylabel('Num Sats Used')
        ax[2].set_ylabel('dT (s)')
        ax[2].set_xlabel('Time (s)')

        if df.dT.max() < 1.0:
            ax[2].set_ylim([0, 1])

        png_fname = f'{filename[:-4]}_sat_changes.png'
        print(f'Saving {png_fname}')
        plt.savefig(png_fname)
        plt.show()

        plt.close(fig)

    return df


def main():
    parser = argparse.ArgumentParser(
        description='Apply an elevation mask with t0x2t0x.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        '-f', "--filename",
        default='TitanPlayer_log.txt',
        help="File to process"
    )
    #parser.add_argument(
    #    "-n", "--dry_run",
    #    action='store_true',
    #    help="Don't actually run t0x2t0x, just print the commands."
    #)
    parser.add_argument(
        "-1", "--one_sec",
        action='store_true',
        help="only process precise solutions that are on one second boundaries."
    )
    parser.add_argument(
        "--plot",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Create plots"
    )
    parser.add_argument(
        "--verbose", "-v",
        action='count',
        default=0,
        help="use multiple -v for more detailed messages."
    )

    logger_format = (
            #"[%(asctime)s][%(levelname)-8s]"
            "[%(levelname)-8s]"
            "[%(filename)11s:%(lineno)-5s] "
            "%(funcName)-10s: %(message)s"
            )
    args = parser.parse_args()
    logging.basicConfig(
        level=logging.WARNING - (10 * args.verbose),
        format=logger_format,
        stream=sys.stdout
    )
    logging.debug(args)
    calc_changes_and_plot(args.filename, args.one_sec, args.plot)


if __name__ == '__main__':
    main()
