# -*- coding: utf-8 -*-
"""
Created on Mon Sep 16 23:52:12 2024

@author: JZhao
"""

import re
import glob
import matplotlib.pyplot as plt
import numpy as np
from numpy.lib import recfunctions as rfn

from mutils import *
import allantools
from bitarray import bitarray

import argparse
import os


def beidou_parity_check(word_in):
    """
    Implements BCH(15,11,1) decoding from ICD v1.0 (December 2013).
    Returns the 4-bit syndrome (0 means no error).
    """
    d = 0
    for i in range(15):
        bit = (word_in >> (14 - i)) & 0x1
        tmp = (d >> 3) & 0x1
        d = (((d ^ tmp) << 1) & 0xF) | (bit ^ tmp)
    return d

def compare_gps_l1ca_dbits_entries(arr, arr_ref):
    """
    Compare two arrays by TOW and PRN.
    For each entry in arr:
        - If a matching (TOW, PRN) exists in arr_ref and all values are equal, result=10
        - If a matching (TOW, PRN) exists but values differ, result=11
        - If (TOW, PRN) only in arr, result=1
        - If (TOW, PRN) only in arr_ref, result=2 (not shown in output, but can be added)
    Returns: result array, same length as arr
    """
    tow_col = arr.k['TOW']
    prn_col = arr.k['PRN']
    flags_col = arr.k['FLAGS']

    # Build lookup for reference array
    ref_dict = {}
    for row in arr_ref:
        key = (((row[tow_col] // 6) * 6), row[prn_col])
        ref_dict[key] = row
        #print("Key", key)

    result = np.ones(len(arr), dtype=int)  # default: only in arr
    class Stats:
        def __init__(self):
            self.same_entry_count = 0
            self.crc_inverted_count = 0
            self.parity_errors = 0
            self.unexplained_bad_data = 0
            self.total_usable_data = 0

    stats = Stats()

    for i, row in enumerate(arr):
        #print("Row: ", row)
        key = (((row[tow_col] // 6) * 6), row[prn_col])

        #Account for the potential of an extra flag byte (check if last bit is set)
        if row[8] & 1:
            N1 = 10
        else:
            N1 = 9
        
        if key in ref_dict:
            stats.total_usable_data += 1
            # Compare all values
            if np.array_equal(row[N1:], ref_dict[key][N1:]):
                stats.same_entry_count += 1
                result[i] = 0 #6  # same entry
            else:
                result[i] = 6  # different values
        else:
            result[i] = 0  # only in arr

        # If the data bits are different, check if the CRC is inverted
        if (result[i] == 6):
            bits = bitarray()
            bits.frombytes(row[N1:48].astype(np.uint8).tobytes())
            ref_bits = bitarray()
            ref_bits.frombytes(ref_dict[key][N1:48].astype(np.uint8).tobytes())


            diff = False
            for j in range(10):
        
                if bits[j*30:j*30+24] != ref_bits[j*30:j*30+24]:
                    print("Different bits at segment", j)
                    stats.unexplained_bad_data += 1
                    diff = True
                    break
                if (bits[j*30+24:(j+1)*30] != ref_bits[j*30+24:(j+1)*30] 
                    and bits[j*30+24:(j+1)*30] != (ref_bits[j*30+24:(j+1)*30]^bitarray('111111'))):
                    print("Different CRC at segment", j)
                    stats.parity_errors += 1
                    diff = True
                    break

            if not diff:
                #print("No difference at segment", i)
                stats.crc_inverted_count += 1
                result[i] = 0 #7  # CRC inverted


        if (result[i] == 6):
            print("Key", key, "Result", result[i], " for TOW:", row[tow_col], " PRN:", row[prn_col], " Data Source:", ref_dict[key][flags_col] & 0b11)
            print("ZCOUNT: ", int((bits[18:26] + bits[30:42]).to01(), 2), "SUBF: ", int(ref_bits[15:18].to01(), 2), "PAGE: ", int(bits[43:50].to01(), 2))
            print("REF: ", int((bits[18:26] + bits[30:42]).to01(), 2), "SUBF: ", int(ref_bits[15:18].to01(), 2), "PAGE: ", int(ref_bits[43:50].to01(), 2))
            print(' '.join(f"{b:02X}" for b in row[N1:48].astype(np.uint8).tobytes()))
            print(' '.join(f"{b:02X}" for b in ref_dict[key][N1:48].astype(np.uint8).tobytes()))
            
        
    print("Comparison Stats (Percent of GPS L1CA Data):")
    total = stats.total_usable_data
    print("  Unexplained bad data: {:.2f}%".format(100 * stats.unexplained_bad_data / total))
    print("  Same entries: {:.2f}%".format(100 * stats.same_entry_count / total))
    print("  Entry ignored due to parity error: {:.2f}%".format(100 * stats.parity_errors / total))
    print("  Entry ok bc CRC inverted: {:.2f}%".format(100 * stats.crc_inverted_count / total))

    return result

def compare_bds_b1_dbits_entries(arr, arr_ref):
    """
    Compare two arrays by TOW and PRN.
    For each entry in arr:
      - If a matching (TOW, PRN) exists in arr_ref and all values are equal, result=0
      - If a matching (TOW, PRN) exists but values differ, result=6
      - If (TOW, PRN) only in arr, result=0 (ignored)
      - If (TOW, PRN) only in arr_ref, result=0 (not shown in output, but can be added)
    Returns: result array, same length as arr
    """
    tow_col = arr.k['TOW']
    prn_col = arr.k['PRN']
    flags_col = arr.k['FLAGS']

    # Build lookup for reference array
    ref_dict = {}
    for row in arr_ref:
        key = (row[tow_col], row[prn_col])
        ref_dict[key] = row
        #print("Key", key)

    result = np.ones(len(arr), dtype=int)  # default: only in arr
    
    # Statistics to print at the end to understand what data was rejected
    class Stats:
        def __init__(self):
            self.data_source_count = 0
            self.same_entry_count = 0
            self.page_number_count = 0
            self.b1_count = 0
            self.crc_inverted_count = 0
            self.parity_errors = 0
            self.d1_d2_mismatch_count = 0
            self.unexplained_bad_data = 0

    stats = Stats()

    # Iterate through each row in arr and compare with arr_ref
    for i, row in enumerate(arr):
        key = (row[tow_col], row[prn_col])
        
        
        if key in ref_dict: # Matching (TOW, PRN) found in reference data
            
            if (row[flags_col] & 0b11) != 0: #Only process B1I data, ignore others
                result[i] = 0  # Ignore because not B1I data 
                continue
            else:
                stats.b1_count += 1

            #Account for the potential of an extra flag byte (check if last bit is set)
            if row[8] & 1:
                N1 = 10
            else:
                N1 = 9

            # Compare all values
            if np.array_equal(row[N1:], ref_dict[key][N1:]):
                stats.same_entry_count += 1
                result[i] = 0 #6  # same entry
                continue
            else:
                if row[flags_col] & 0b11 == ref_dict[key][flags_col] & 0b11: #Make sure data is from the data source. i.e. B1I)
                    result[i] = 0 #1  
                else:
                    stats.data_source_count += 1
                    result[i] = 0  # Ignore because ref data is not B1I data
                    continue
        else:
            result[i] = 0  # only in arr
            continue



        # If the data bits are different, check if the CRC is inverted
        if (result[i] == 6):
            bits = bitarray()
            bits.frombytes(row[N1:48].astype(np.uint8).tobytes())
            ref_bits = bitarray()
            ref_bits.frombytes(ref_dict[key][N1:48].astype(np.uint8).tobytes())

            #Data pages (trying bot D1 and D2) both don't match so ignore this
            if bits[43:50] != ref_bits[43:50] and bits[43:47] != ref_bits[43:47]: 
                stats.page_number_count += 1
                result[i] = 0 #Different page number, ignore
                continue

            diff = False

     

            #Go through each data word
            for j in range(10):

                subf = int(bits[15:18].to01(),2)
                ref_subf = int(ref_bits[15:18].to01(),2)

                if j==0: #Word 1 has 6 parity bits
                    #First check that it passes parity

                    #Parity check the second 1/2 of the first word. 
                    # The first 1/2 of this word doesn't have any parity.
                    # Parity check the second half (15 bits) of the first word.
                    # bits[j*30+15:j*30+30] are the 15 bits to check
                    word_bits = bits[j*30+15:j*30+30]
                    word_int = int(word_bits.to01(), 2)
                    if not beidou_parity_check(word_int):
                        stats.parity_errors += 1
                        break
                        # You may want to count or handle parity errors here
                    #Check if the 24 information bits are the same
                    if bits[j*30:j*30+24] != ref_bits[j*30:j*30+24]:            
                        # Print which bits are different
                        diff_bits = [k for k in range(j*30, j*30+24) if bits[k] != ref_bits[k]]

                        stats.unexplained_bad_data += 1
                        print("Different bits at segment", j)
                        print("Different bit positions:", diff_bits) 
                        diff = True
                        break

                    #Check if the 6 parity bits are different and NOT inverted
                    if (bits[j*30+24:(j+1)*30] != ref_bits[j*30+24:(j+1)*30] 
                        and bits[j*30+24:(j+1)*30] != (ref_bits[j*30+24:(j+1)*30]^bitarray('111111'))):
                        print("Different CRC at segment", j)
                        # Print which bits are different
                        diff_bits = [k for k in range(j*30+24, (j+1)*30) if bits[k] != ref_bits[k]]
                        print("Different bit positions (CRC):", diff_bits)
                        diff = True
                        stats.unexplained_bad_data += 1
                        break
                else: #Words 2-10 have 8 parity bits
                    #Special cases for Subframe 5, Pages 8,9, and 10 but those are ignored here for simplicity
                    #First check parity
                    # Parity check the first 1/2 word (bits j*30 to j*30+15)
                    first_half_bits = bits[j*30:j*30+15]
                    first_half_int = int(first_half_bits.to01(), 2)
                    if beidou_parity_check(first_half_int):
                        stats.parity_errors += 1
                        break

                    # Parity check the second 1/2 word (bits j*30+15 to (j+1)*30)
                    second_half_bits = bits[j*30+15:(j+1)*30]
                    second_half_int = int(second_half_bits.to01(), 2)
                    if beidou_parity_check(second_half_int):
                        stats.parity_errors += 1
                        break

                    #Check if the 22 information bits are the same
                    if bits[j*30:j*30+22] != ref_bits[j*30:j*30+22]:
                        print("Different bits at segment", j)
                        diff_bits = [k for k in range(j*30, j*30+22) if bits[k] != ref_bits[k]]
                        print("Different bit positions:", diff_bits)
                        diff = True
                        stats.unexplained_bad_data += 1
                        break
                    #Check if the 8 parity bits are different and NOT inverted
                    if (bits[j*30+22:(j+1)*30] != ref_bits[j*30+22:(j+1)*30] 
                        and bits[j*30+22:(j+1)*30] != (ref_bits[j*30+22:(j+1)*30]^bitarray('11111111'))):

                        print("Different CRC at segment", j)
                        diff_bits = [k for k in range(j*30+22, (j+1)*30) if bits[k] != ref_bits[k]]
                        print("Different bit positions (CRC):", diff_bits)                       
                        diff = True
                        stats.unexplained_bad_data += 1
                        break

            if not diff:
                #print("No difference at segment", i)
                result[i] = 0 #7  # CRC inverted
                stats.crc_inverted_count += 1


        if (result[i] == 6):
            print("Key", key, "Result", result[i], " for TOW:", row[tow_col], " PRN:", row[prn_col], " Data Source:", ref_dict[key][flags_col] & 0b11)
            print("ZCOUNT: ", int((bits[18:26] + bits[30:42]).to01(), 2), "SUBF: ", int(ref_bits[15:18].to01(), 2), "PAGE: ", int(bits[43:50].to01(), 2), "or PAGE: ", int(bits[43:47].to01(), 2))
            print("REF: ", int((bits[18:26] + bits[30:42]).to01(), 2), "SUBF: ", int(ref_bits[15:18].to01(), 2), "PAGE: ", int(ref_bits[43:50].to01(), 2), "or PAGE: ", int(ref_bits[43:47].to01(), 2))
            print(' '.join(f"{b:02X}" for b in row[N1:48].astype(np.uint8).tobytes()))
            print(' '.join(f"{b:02X}" for b in ref_dict[key][N1:48].astype(np.uint8).tobytes()))

            #break
        
        
    print("Comparison Stats (Percent of B1 Data):")
    total = stats.b1_count
    if total == 0:
        print("No B1 data in test file")
    else:
        print("  Unexplained bad data: {:.2f}%".format(100 * stats.unexplained_bad_data / total))
        print("  Same entries: {:.2f}%".format(100 * stats.same_entry_count / total))
        print("  Entry ignored due to parity error: {:.2f}%".format(100 * stats.parity_errors / total))
        print("  Different entries ignored due to different data source: {:.2f}%".format(100 * stats.data_source_count / total))
        print("  Entry ok bc CRC inverted: {:.2f}%".format(100 * stats.crc_inverted_count / total))
    return result

def compare_gal_e1_dbits_entries(arr, arr_ref):
    """
    Compare two arrays by TOW and PRN.
    For each entry in arr:
      - If a matching (TOW, PRN) exists in arr_ref and all values are equal, result=0
      - If a matching (TOW, PRN) exists but values differ, result=6
      - If (TOW, PRN) only in arr, result=0 (ignored)
      - If (TOW, PRN) only in arr_ref, result=0 (not shown in output, but can be added)
    Returns: result array, same length as arr
    """

    N = 30
    tow_col = arr.k['TOW']
    prn_col = arr.k['PRN']
    flags_col = arr.k['FLAGS']
    

    # Build lookup for reference array
    ref_dict = {}
    for row in arr_ref:
        key = (row[tow_col], row[prn_col])
        ref_dict[key] = row
        #print("Key", key)

    result = np.ones(len(arr), dtype=int)  # default: only in arr
    
    # Statistics to print at the end to understand what data was rejected
    class Stats:
        def __init__(self):
            self.data_source_count = 0
            self.same_entry_count = 0
            self.e1_count = 0
            self.crc_inverted_count = 0
            self.parity_errors = 0
            self.unexplained_bad_data = 0
            self.spare_bits_mismatch = 0

    stats = Stats()

    # Iterate through each row in arr and compare with arr_ref
    for i, row in enumerate(arr):

        key = (row[tow_col], row[prn_col])
        
        
        if key in ref_dict: # Matching (TOW, PRN) found in reference data
            
            if (row[flags_col] & 0b11) != 0: #Only process E1-B data, ignore others
                result[i] = 0  # Ignore because not E1-B data
                continue
            else:
                stats.e1_count += 1

            #Account for the potential of an extra flag byte (check if last bit is set)
            if row[8] & 1:
                N1 = 10
            else:
                N1 = 9
            
            # Compare all values
            if np.array_equal(row[N1:], ref_dict[key][N1:]):
                stats.same_entry_count += 1
                result[i] = 0 #6  # same entry
                continue
            
            else:
                if row[flags_col] & 0b11 == ref_dict[key][flags_col] & 0b11: #Make sure data is from the data source. i.e. E1-B
                    #Make sure CRC passed for both entries
                    if row[arr.k['CRC_PASSED']] == 0 and ref_dict[key][arr.k['CRC_PASSED']] == 0:
                        stats.parity_errors += 1
                        result[i] = 0 #Ignore because CRC failed for one of the entries
                        continue
                    result[i] = 6 #1  # different values, but present in both (can use another code if needed) 
                else:
                    stats.data_source_count += 1
                    result[i] = 0  # Ignore because ref data is not E1-B data
                    continue
        else:
            result[i] = 0  # only in arr
            continue


        diff = False
        # If the data bits are different, check if the CRC is inverted
        if (result[i] == 6):

            #Convert just the subframe data bytes to bits 
            bits = bitarray()
            bits.frombytes(row[N1:N].astype(np.uint8).tobytes())
            ref_bits = bitarray()
            ref_bits.frombytes(ref_dict[key][N1:N].astype(np.uint8).tobytes())

            #Check for CRC inverted
            # Check for CRC inverted (24 bits)
            if (bits[82:106] != ref_bits[82:106] 
                and bits[82:106] == (ref_bits[82:106] ^ bitarray('1' * 24))):
                diff = True
                stats.crc_inverted_count += 1

            
            #Check and ignore spare bits
            if bits[80:82] != ref_bits[80:82]:
                diff = True
                stats.spare_bits_mismatch += 1
 
            
            if bits[10:80] != ref_bits[10:80] and diff:
                diff = True
                print("Different bits in data field")
                stats.unexplained_bad_data += 1
            else:
                result[i] = 0 #7  # CRC inverted
                stats.crc_inverted_count += 1

        if (result[i] == 6):
            print("Key", key, "Result", result[i], " for TOW:", row[tow_col], " PRN:", row[prn_col], " Data Source:", ref_dict[key][flags_col] & 0b11)
            print("TEST PAGE: ", int(bits[2:8].to01(), 2))
            print("REF PAGE: ", int(ref_bits[2:8].to01(), 2))
            print(' '.join(f"{b:02X}" for b in row[N1:N].astype(np.uint8).tobytes()))
            print(' '.join(f"{b:02X}" for b in ref_dict[key][N1:N].astype(np.uint8).tobytes()))



             #break
        
        
    print("Comparison Stats (Percent of GAL I/NAV E1-B Data):")
    total = stats.e1_count
    if total == 0:
        print("No E-1B data in test file")
    else:
        print("  Unexplained bad data: {:.2f}%".format(100 * stats.unexplained_bad_data / total))
        print("  Same entries: {:.2f}%".format(100 * stats.same_entry_count / total))
        print("  Entry ignored due to parity error: {:.2f}%".format(100 * stats.parity_errors / total))
        print("  Entry ignored due to spare bits mismatch: {:.2f}%".format(100 * stats.spare_bits_mismatch / total))
        print("  Different entries ignored due to different data source: {:.2f}%".format(100 * stats.data_source_count / total))
        print("  Entry ok bc CRC inverted: {:.2f}%".format(100 * stats.crc_inverted_count / total))
    return result

def generate_dbits_reference(ref_list, output_file, sat_type='GPS'):
    all_dbits = []
    for filepath_ref in ref_list:
        print("file :", filepath_ref)
        cls_gps_dbits_ref = get_gps_dbits(filepath_ref, sat_type=sat_type)
        all_dbits.append(cls_gps_dbits_ref)

    prn_col = all_dbits[0].k['PRN']
    tow_col = all_dbits[0].k['TOW']

    combined_dbits = np.concatenate(all_dbits)

    keys = np.array(list(zip(combined_dbits[:, prn_col], combined_dbits[:, tow_col])))
    _, idx = np.unique(keys, axis=0, return_index=True)
    
    union_dbits = combined_dbits[idx]

    #for row in union_dbits:
        # Access fields by index or by name if structured
        #print("ROW : ", row)
        #tow = row[union_dbits[tow_col]]
        #prn = row[union_dbits[prn_col]]
        # Do something with row
        #print(f"TOW: {tow}, PRN: {prn}")

    return union_dbits

    # Sort by TOW and PRN
    #sorted_indices = np.lexsort((combined_dbits.PRN, combined_dbits.TOW))
    #sorted_dbits = combined_dbits[sorted_indices]
    
    # Save to output file
    #np.savetxt(output_file, sorted_dbits, fmt='%d ' + ' '.join(['%02X'] * (len(sorted_dbits[0]) - 2)))
    #print(f"Reference databits saved to {output_file}")

def dbits_arr_to_cls(dbits):
    k=dotdict({})
    k.SUB = 0 
    k.TOW = 1
    k.WN = 2
    k.PRN = 3
    k.CHAN = 4
    k.FLAGS = 5
    k.NAV_TYPE = 6
    k.SOURCE = 7
    k.CRC_PASSED = 8
    k.DATA_NUM = 9
    return VdCls(dbits, keys=k, flags=[]) 
            
def get_lines_from_t0x(filename, record_type):
    cmd = 'viewdat' + ' -d' + record_type + ' -mb ' + filename.replace("\\", "\\\\")    
    
    print("CMD: " + cmd) 
    print("Loading record from %s is done!" % (filename)) 
    return cmd2arr_raw(cmd)    

def get_gps_dbits(path_name, sat_type='GPS'):
    filename = path_name

    # if type(filename) is str:
    #     filename = os.path.expanduser(filename)
    #     filename = os.path.abspath(filename)
    #     files = sorted(glob.glob(filename))
    #     print("filename is str")
    # else:
    #     files = sorted([os.path.abspath(x) for x in filename])
        
    # for n,f in enumerate(files):
    #     print("File %d : %s" % (n, f))
    if sat_type == 'BDS':
        lines = get_lines_from_t0x( path_name, '26:11') #BDS
        N = 38
    elif sat_type == 'GAL':
        lines = get_lines_from_t0x( path_name, '26:3') #GAL I/NAV 
        N = 30
    else:
        lines = get_lines_from_t0x( path_name, '26:6') #GPS
        N=38

    gps_dbits = []
    for line in lines:
        fields = line.split()
        conv_fields = [int(fields[idx], 10 if idx < 10 else 16) for idx in range(10+N)]
        conv_fields.append(0)  # add dummy column for result
        gps_dbits.append(conv_fields)

    return dbits_arr_to_cls(gps_dbits)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process file paths for GPS data bits.")
    parser.add_argument('filepath', type=str, help='Path to the main test file')
    parser.add_argument('filepath_ref_list', nargs='+', help='List of reference file paths')
    parser.add_argument('-sat_type', type=str, default='GPS', help='Satellite type')

    args = parser.parse_args()

    args.sat_type = args.sat_type.upper()

    filepath = args.filepath
    filepath_ref_list = args.filepath_ref_list
    
    print("File list: ",  filepath_ref_list)

    # Load the databits in test file
    cls_gps_dbits = get_gps_dbits(filepath,args.sat_type)

    combined_dbits = generate_dbits_reference(filepath_ref_list, "gps_dbits_reference.txt",args.sat_type)

    cls_gps_dbits_ref = dbits_arr_to_cls(combined_dbits)

    if args.sat_type == 'BDS':
        result = compare_bds_b1_dbits_entries(cls_gps_dbits, cls_gps_dbits_ref)
    elif args.sat_type == 'GAL':
        result = compare_gal_e1_dbits_entries(cls_gps_dbits, cls_gps_dbits_ref)
    else: #GPS default 
        result = compare_gps_l1ca_dbits_entries(cls_gps_dbits, cls_gps_dbits_ref)

    cls_gps_dbits[:, -1] = result

    d = cls_gps_dbits
    k = cls_gps_dbits.k
    x1 = np.linspace(0, 1, len(cls_gps_dbits[:,0]))
    sv_list = np.unique(cls_gps_dbits.PRN)
    fig = plt.figure()
    ax = fig.add_subplot(111)

    for sv_pos,sv in enumerate(sv_list):
        d = cls_gps_dbits[cls_gps_dbits.PRN==sv]
        
        # databits from E1B
        d = d[d.SOURCE==0]
        t = d.TOW
        
        prn = d.PRN
        result = (d[:, -1])
        
   
        ax.plot(t, d.PRN + (result+0)/10, 'x', markersize=3, color='red')
        ax.plot(t, d.PRN, 'o', markersize=3)
    
        ax.set_title(f"Nav message comparison to reference ({args.sat_type})")
        ax.legend(["Mismatched data", "Sat data received"], loc="best", bbox_to_anchor=(1, 1), borderaxespad=0.)
        plt.tight_layout(rect=[0, 0, 0.85, 1])  # Make room for legend

        # Get the last element (filename) from the input filepath
        input_filename = os.path.splitext(os.path.basename(filepath))[0]
        second_input_filename = os.path.splitext(os.path.basename(filepath_ref_list[-1]))[0]
        plt.savefig(f"{args.sat_type}_{input_filename}_compare_dbits_{second_input_filename}.png", dpi=300, bbox_inches='tight')

