import os
import argparse
import sys
import numpy as np
from pylab import unique, subplots, show, figure, subplot, plot, title, yticks, grid, close, arange, isclose, diff, insert, ones, nan, median, xticks
from sympy import true
from mutils import *
import glob
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.ticker import FormatStrFormatter
from pyproj import Geod, Transformer
import matplotlib.pyplot as plt
import pandas as pd 
import zipfile
import folium
from xml.dom import minidom
import contextily as ctx
import geopandas as gpd
from shapely.geometry import LineString, Point
import getTestPlanTimes as plan



def plt_cno_data(path_name, antenna_num = 0, test_sat_type=-1, png=False, return_figs=False, dir_name='.',plt_rel_time=False,start_time=None,end_time=None):


  if path_name.endswith('.T02') or path_name.endswith('T04'):
      # Process a single T02 or T04 file
      title_txt = os.path.basename(path_name)
  else:
      # Process all of the T04 files in a directory, ignoring files with "spoofed" in their name
      title_txt = path_name
      all_files = glob.glob(os.path.join(path_name, '*.T04'))
      path_name = [f for f in all_files if 'spoofed' not in os.path.basename(f)]
      if not path_name:
          raise FileNotFoundError("No valid .T04 files found in the directory.")
      elif len(path_name) == 1:
          path_name = path_name[0]  # Use the single file directly if only one valid file is foun
      print(path_name)

  d = vd2cls(path_name, rec='-d27')

  if start_time is not None:
      d=d[d.TIME>=start_time]
  if end_time is not None:
      d=d[d.TIME<=end_time]


  print('\n Loaded Rec27\n')

  # Select data for this antenna
  d = d[d.ANTENNA == antenna_num]

  # Which satellite/constellation combinations
  sv_clist = unique(d.SV + 1000*d.SAT_TYPE)
  sv_list = [(int(x) % 1000, int(x/1000)) for x in sv_clist]
  #           \           /  \         /
  #            -----------    ---------
  #                |              |
  #              sv_id         sat_type
  # Get unique satellite types

    
  # Createa  figure for each sat_type, with subplots for the number of signals (RT Band and Track type combination)
  import scipy.io

  # Load the .mat file containing freq_track_keys and values
  mat_file_path = '~/gpstools/pythonTools/freq_track_map.mat'  # Update with the actual path
  # Expand the user path to an absolute path
  mat_file_path = os.path.expanduser(mat_file_path)

  # Check if the .mat file exists
  if not os.path.exists(mat_file_path):
    raise FileNotFoundError(f"The specified .mat file does not exist: {mat_file_path}")
  mat_data = scipy.io.loadmat(mat_file_path)

  # Extract freq_track_keys and values as lists
  freq_track_keys = mat_data.get('freq_track_keys', []).flatten().tolist()
  freq_track_values = mat_data.get('freq_track_values', []).flatten().tolist()


  unique_sat_types = unique(d.SAT_TYPE)

  cno_figs = []
  for sat_type in unique_sat_types:
    if (sat_type!=test_sat_type and test_sat_type != -1):
        continue
    
    # Loop through all freq_track_keys that start with sat_type
    for key, value in zip(freq_track_keys, freq_track_values):
      if key[0][0] != sat_type:
        continue
  
      # Create a new figure and axis for the current satellite type if not already created
      fig,ax = subplots(1, 1, figsize=(10, 6))
      legend_entries = []
    

      # Get   all SVs associa ed with this sat_type
      sv = [sv_id for sv_id, st in sv_list if st == sat_type]
      for sv_id in sv:
        # Filter data for this satellite vehicle (SV)
        indices = np.where( (d.SV == sv_id) & (d.SAT_TYPE == sat_type) & (d.FREQ == key[0][1]) & (d.TRACK == key[0][2]) )

        if indices[0].size == 0:
          continue

        if fig not in cno_figs:
          ax.set_title(f'{value[0]}')
          cno_figs.append((fig, value[0]))
        # Plot data for these indices on the current subplot
        if plt_rel_time:
            time_from_start_min = (d.TIME[indices] - d.TIME[indices][0]) / 60.0
            ax.plot(time_from_start_min, d.CNO[indices], 'o', markersize=1.5)
            ax.set_xlabel('Time from start [min]')
        else:
            ax.plot(d.TIME[indices], d.CNO[indices], 'o', markersize=1.5)
            ax.set_xlabel('GPS TOW [sec]')
        legend_entries.append('SV' + str(sv_id))
        
        ax.set_ylabel('C/N0 [dB-Hz]')
        ax.grid(True)
        ax.legend(legend_entries, loc='upper left', bbox_to_anchor=(1.01, 1), borderaxespad=0.)
        

  if (return_figs):
    return cno_figs
  
def plt_tracking(path_name, dir_name='.', track=False, raim=False, antenna_num=0, title_txt=None, start_time=None, end_time=None, box=None, test_segments=None):
    if not box:
        if start_time is not None and end_time is not None:
            os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_tracking.py ' + path_name + ' -r --sub_chan_raim --png -s ' + str(int(start_time)) + ' -e ' + str(int(end_time)   )     )
        else:
            os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_tracking.py ' + path_name + ' -r --sub_chan_raim --png' )
        os.makedirs(dir_name + '/RAIM', exist_ok=True)
        for file in glob.glob('*.png'):
            dst = os.path.join(dir_name + '/RAIM', file)
            os.rename(file, dst)

        if start_time is not None and end_time is not None:
            os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_tracking.py ' + path_name + ' -t --png -s ' + str(int(start_time)) + ' -e ' + str(int(end_time)) )
        else:
            os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_tracking.py ' + path_name + ' -t --png' )
        os.makedirs(dir_name + '/Tracking', exist_ok=True)
        for file in glob.glob('*.png'):
            dst = os.path.join(dir_name + '/Tracking', file)
            os.rename(file, dst)
        
    else:
        d, k = vd2arr(path_name, rec='-d35:19')
        
        from collections import Counter
        
        sat_types = unique( d[:,k.SAT_TYPE] )

        raims = ( ('MM',k.dSV_FLAGS_FAULT_MM_RAIM),
                ('WLS',k.dSV_FLAGS_FAULT_WLS_RAIM),
                ('KF',k.dSV_FLAGS_FAULT_KF_RAIM),
                ('SBAS',k.dSV_FLAGS_FAULT_SBAS_MSG),
                ('CNAV',k.dSV_FLAGS_FAULT_CNAV),
                ('FFT',k.dSV_FLAGS_FAULT_FFT_CORR),
                ('RTX',k.dSV_FLAGS_FAULT_RTX_POS),
                ('SYS',k.dSV_FLAGS_FAULT_SYS_RAIM),
                ('EPH',k.dSV_FLAGS_FAULT_EPH),
                ('NMA',k.dSV_FLAGS_FAULT_NMA),
                ('CNO',k.dSV_FLAGS_FAULT_CNO_DETECT_SPOOF),
                ('DUAL_ANT',k.dSV_FLAGS_FAULT_DUAL_ANT_SPOOF)
                )
        raims_cnt = Counter()

        all_results = []
        for loop, sat_type in enumerate(sat_types):
            fig = figure()
            if loop == 0:
                ax = subplot(111)
            else:
                ax = subplot(111, sharex=ax)
            tot_len = 0
            i_st = find(d[:,k.SAT_TYPE] == sat_type)
            sv_list = unique(d[i_st,k.SV])
            
            # plot RAIM info
            for n, (label, flag) in enumerate(raims):
                i = find( (d[:,k.SAT_TYPE]==sat_type) & ((d[:,k.SV_FLAGS].astype(int)&flag)!=0) )
                sv_y = d[i,k.SV].copy()
                for sv_pos, sv in enumerate(sv_list):
                    sv_y[isclose(sv_y,sv)] = sv_pos
                if len(i) == 0:
                    label='_nolegend_'
                plot( d[i,k.TIME], sv_y+(n+1)/(len(raims)+2), 'o', markersize=2, label=label)
                xticks(rotation=45)
                raims_cnt[label] += len(i)
                tot_len += len(i)
                if len(i) > 0:
                    raims_cnt[label] += len(i)

            # plot tracking info
            for sv_pos, sv in enumerate(sv_list):
                i = find( (d[:,k.SV]==sv) & (d[:,k.SAT_TYPE]==sat_type) )
                dt = abs(diff(d[i,k.TIME]))
                igap = find( dt > median(dt)+.001 )+1
                if len(igap) > 0:
                    sv_xgap = insert(d[i,k.TIME], igap, nan )
                    sv_ygap = insert(sv_pos*ones(i.shape), igap, nan )
                else:
                    sv_xgap = d[i,k.TIME]
                    sv_ygap = sv_pos*ones(i.shape)
                plot( sv_xgap, sv_ygap, 'k.-', markersize=1 )
            
            # Get satellite type name for title
            systemLUT = { 0:'GPS', 1:'SBAS', 2:'GLONASS', 3:'Galileo', 4:'QZSS', 9:'NavIC', 10:'BeiDou'}
            sat_type_name = systemLUT.get(sat_type, f'Unknown_{sat_type}')
            title(f'RAIM Flags - {sat_type_name}')
            yticks( arange(len(sv_list)), sv_list.astype(int) )
            grid(True)
            
            # Add test segment boxes
            if test_segments is not None:
                xlim_min, xlim_max = ax.get_xlim()
                ymax = len(sv_list) + 0.5
                ymin = -0.5
                
                for segment in test_segments:
                    start_sec = segment.get('start_seconds')
                    end_sec = segment.get('end_seconds')
                    test_id = segment.get('test_id', '')
                    
                    if start_sec is not None and end_sec is not None:
                        if start_sec >= xlim_min and start_sec <= xlim_max:
                            y_range = ymax - ymin
                            box_top = ymax - 0.02 * y_range
                            
                            ax.hlines(y=box_top, xmin=start_sec, xmax=end_sec, color='red', linestyle='-', linewidth=2, alpha=0.7)
                            ax.vlines(x=start_sec, ymin=ymin, ymax=box_top, color='red', linestyle='--', linewidth=0.5, alpha=0.7)
                            ax.vlines(x=end_sec, ymin=ymin, ymax=box_top, color='red', linestyle='--', linewidth=0.5, alpha=0.7)
                            
                            mid_time = (start_sec + end_sec) / 2
                            label_y = box_top + 0.01 * y_range
                            ax.text(mid_time, label_y, test_id, ha='center', va='bottom', fontsize=8, color='red', rotation=45)
                            
                            current_ylim = ax.get_ylim()
                            new_ymax = max(current_ylim[1], label_y + 0.15 * y_range)
                            ax.set_ylim(current_ylim[0], new_ymax)
            
            if tot_len > 0:
                from matplotlib.patches import Rectangle
                handles, labels = ax.get_legend_handles_labels()
                ax.legend(handles, labels, bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.)
            
            all_results.append( (fig, sat_type) )

        # Save figures
        for fig, sat_type in all_results:
            os.makedirs(dir_name + '/RAIM', exist_ok=True)
            sat_type_name = systemLUT.get(sat_type, f'Unknown_{sat_type}')
            filename = os.path.join(dir_name + '/RAIM', f'raim_{sat_type_name}.png')
            fig.savefig(filename, dpi=300, bbox_inches='tight')
        
        print(raims_cnt)
   
        
def plt_rec35_16(path_name, plt_rel_time=False,start_time=None,end_time=None):
    
    if path_name.endswith('.T02') or path_name.endswith('T04'):
        # Process a single T02 or T04 file
        title_txt = os.path.basename(path_name)
    else:
        # Process all of the T04 files in a directory, ignoring files with "spoofed" in their name
        title_txt = path_name
        all_files = glob.glob(os.path.join(path_name, '*.T04'))
        path_name = [f for f in all_files if 'spoofed' not in os.path.basename(f)]
    if not path_name:
        raise FileNotFoundError("No valid .T04 files found in the directory.")
    elif len(path_name) == 1:
        path_name = path_name[0]  # Use the single file directly if only one valid file is foun
    print(path_name)



    try:
        d16 = vd2cls(path_name, rec='-d35:16')
    except:
        return None

    if start_time is not None:
        d16=d16[d16.TIME>=start_time]
    if end_time is not None:
        d16=d16[d16.TIME<=end_time]
    

    fig16, axs16 = subplots(3, 1, figsize=(10, 10))
    axs16[0].plot(d16.TIME, d16.LAT, 'o', markersize=1.5)
    axs16[0].set_xlabel('GPS TOW [sec]')
    axs16[0].set_ylabel('Latitude [deg]')
    axs16[0].yaxis.set_major_formatter(FormatStrFormatter('%.5f'))
    axs16[0].grid(True)
    axs16[0].set_title('Rec35:16 Position')

    # Plot Longitude
    axs16[1].plot(d16.TIME, d16.LON, 'o', markersize=1.5)
    axs16[1].set_xlabel('GPS TOW [sec]')
    axs16[1].set_ylabel('Longitude [deg]')
    axs16[1].yaxis.set_major_formatter(FormatStrFormatter('%.5f'))
    axs16[1].grid(True)

    # Plot Height
    axs16[2].plot(d16.TIME, d16.HGT, 'o', markersize=1.5)
    axs16[2].set_xlabel('GPS TOW [sec]')
    axs16[2].set_ylabel('Height [m]')
    axs16[2].grid(True)

    fig16.tight_layout()
    return fig16

def plt_pos(path_name,plt_rel_time=False,start_time=None,end_time=None):
    
    if path_name.endswith('.T02') or path_name.endswith('T04'):
        # Process a single T02 or T04 file
        title_txt = os.path.basename(path_name)
    else:
        # Process all of the T04 files in a directory, ignoring files with "spoofed" in their name
        title_txt = path_name
        all_files = glob.glob(os.path.join(path_name, '*.T04'))
        path_name = [f for f in all_files if 'spoofed' not in os.path.basename(f)]
    if not path_name:
        raise FileNotFoundError("No valid .T04 files found in the directory.")
    elif len(path_name) == 1:
        path_name = path_name[0]  # Use the single file directly if only one valid file is foun
    print(path_name)

    d = vd2cls(path_name, rec='-d35:2') #Position

    if start_time is not None:
        d=d[d.TIME>=start_time]
    if end_time is not None:
        d=d[d.TIME<=end_time] 

    
    if not plt_rel_time:
        fig, axs = subplots(3, 1, figsize=(10, 12))
        # Plot Latitude
        axs[0].plot(d.TIME, d.LAT, 'o', markersize=1.5)
        axs[0].set_xlabel('GPS TOW [sec]')
        axs[0].set_ylabel('Latitude [deg]')
        axs[0].yaxis.set_major_formatter(FormatStrFormatter('%.5f'))
        axs[0].xaxis.set_major_formatter(FormatStrFormatter('%.0f'))  # No scientific notation
        axs[0].grid(True)
        axs[0].set_title('Rec35:2 Position')
        for label in axs[0].get_xticklabels():
            label.set_rotation(45)

        # Plot Longitude
        axs[1].plot(d.TIME, d.LON, 'o', markersize=1.5)
        axs[1].set_xlabel('GPS TOW [sec]')
        axs[1].set_ylabel('Longitude [deg]')
        axs[1].yaxis.set_major_formatter(FormatStrFormatter('%.5f'))
        axs[1].xaxis.set_major_formatter(FormatStrFormatter('%.0f'))  # No scientific notation
        axs[1].grid(True)
        for label in axs[1].get_xticklabels():
            label.set_rotation(45)

        # Plot Height
        axs[2].plot(d.TIME, d.HGT, 'o', markersize=1.5)
        axs[2].set_xlabel('GPS TOW [sec]')
        axs[2].set_ylabel('Height [m]')
        axs[2].xaxis.set_major_formatter(FormatStrFormatter('%.0f'))  # No scientific notation
        axs[2].grid(True)
        for label in axs[2].get_xticklabels():
            label.set_rotation(45)




            
    else:
        # Additional figure: error from median of first minute, x axis is time from start
        start_time = d.TIME[0]
        time_from_start = d.TIME - start_time
        first_minute_mask = (d.TIME - start_time) <= 60
        # Compute median LLH for the first minute
        lat_med = np.median(d.LAT[first_minute_mask])
        lon_med = np.median(d.LON[first_minute_mask])
        hgt_med = np.median(d.HGT[first_minute_mask])


        # Compute ECEF XYZ error from median of first minute

        # WGS84 ellipsoid
        transformer = Transformer.from_crs("epsg:4326", "epsg:4978", always_xy=True)

        # Convert all LLH to ECEF XYZ
        X, Y, Z = transformer.transform(d.LON, d.LAT, d.HGT)
        X_med, Y_med, Z_med = transformer.transform(lon_med, lat_med, hgt_med)

        x_err = X - X_med
        y_err = Y - Y_med
        z_err = Z - Z_med

        # Convert time from seconds to minutes
        time_from_start_min = time_from_start / 60.0

        axs[0].plot(time_from_start_min, x_err, 'o', markersize=1.5)
        axs[0].set_xlabel('Time from start [min]')
        axs[0].set_ylabel('X error [m]')
        axs[0].grid(True)

        axs[1].plot(time_from_start_min, y_err, 'o', markersize=1.5)
        axs[1].set_xlabel('Time from start [min]')
        axs[1].set_ylabel('Y error [m]')
        axs[1].grid(True)

        axs[2].plot(time_from_start_min, z_err, 'o', markersize=1.5)
        axs[2].set_xlabel('Time from start [min]')
        axs[2].set_ylabel('Z error [m]')
        axs[2].grid(True)

        fig.tight_layout()

    

    return fig

def plot_nma(path_name,plt_rel_time=False,start_time=None,end_time=None,box=None, test_segments=None):

    # Loads the T04 file and plots the NMA status (RTX-NMA and OSNMA) for the specified 
# antenna. The script operates headless and saves the plot as a PNG file. The PNG
# will use the T04 filename as the prefix *without the path* and the antenna number
# as a suffix.
#
# Limitations:
#  - only OSNMA (Galileo E1), RTX-NMA (GPS L1 C/A) and RTX-NMA (BeiDou B1I) are supported
# 
# filename = file name of the T04 file including the path
# antNum = antenna number (0 or 1)
#

    try:
        (data,key)=vd2arr(path_name,rec='-d35:19')
        print('Data loaded from: ', path_name)
    except:
        print('Error loading data from: ', path_name)
        sys.exit(1)

    if start_time is not None:
        data=data[data[:,key.TIME]>=start_time]
    if end_time is not None:
        data=data[data[:,key.TIME]<=end_time]

    # Select the requested antenna and drop everything except GPS/Galileo/BeiDou data.
    i = find(  (data[:,key.SAT_TYPE] == 0) | (data[:,key.SAT_TYPE] == 3) | (data[:,key.SAT_TYPE] == 10) ) 

    # Convert the selected data to a pandas DataFrame
    df = pd.DataFrame({
        'TIME': data[i, key.TIME],
        'SAT_TYPE': data[i, key.SAT_TYPE].astype(int),  
        'FREQ': data[i, key.FREQ].astype(int),  
        'TRACK': data[i, key.TRACK].astype(int),  
        'NMA_FLAGS': data[i, key.NMA_FLAGS].astype(int)
    })

    # Create masks for each satellite system

    # GPS L1 C/A
    gps_mask = (df['SAT_TYPE'] == 0) & (df['FREQ'] == 0) & (df['TRACK'] == 0)

    # Galileo E1 BOC or MBOC
    galileo_mask = (  (df['SAT_TYPE'] == 3) 
                    & (df['FREQ'] == 0) 
                    & ((df['TRACK'] == 20) | (df['TRACK'] == 23)))
    
    # BeiDou B1I
    bds_mask = (df['SAT_TYPE'] == 10) & (df['FREQ'] == 6) & (df['TRACK'] == 26)

    # Create pass/fail columns using vectorized operations
    
    # RTX-NMA - bits 0-2
    df['GPS_PASS'] = ((df['NMA_FLAGS'] & 0b111) == 1) & gps_mask
    df['GPS_FAIL'] = ((df['NMA_FLAGS'] & 0b111) == 2) & gps_mask
    df['BDS_PASS'] = ((df['NMA_FLAGS'] & 0b111) == 1) & bds_mask
    df['BDS_FAIL'] = ((df['NMA_FLAGS'] & 0b111) == 2) & bds_mask

    # OSNMA - bits 3-5
    df['GALILEO_PASS'] = ((df['NMA_FLAGS'] & 0b111000) ==  8) & galileo_mask
    df['GALILEO_FAIL'] = ((df['NMA_FLAGS'] & 0b111000) == 16) & galileo_mask

    # Group by time and count pass/fail flags
    results = df.groupby('TIME').agg({
        'GPS_PASS': 'sum',
        'GPS_FAIL': 'sum',
        'GALILEO_PASS': 'sum',
        'GALILEO_FAIL': 'sum',
        'BDS_PASS': 'sum',
        'BDS_FAIL': 'sum'
    }).reset_index()

    # Ensure all times are sorted
    results = results.sort_values('TIME')

    fig = figure()
    ax = fig.add_subplot(111)

    # Create a dictionary to hold the pass/fail column names and labels for each system
    systems = {
        'GPS': {'pass_col': 'GPS_PASS', 'fail_col': 'GPS_FAIL', 'pass_label': 'RTX-NMA(GPS) Pass', 'fail_label': 'RTX-NMA(GPS) Fail'},
        'BDS': {'pass_col': 'BDS_PASS', 'fail_col': 'BDS_FAIL', 'pass_label': 'RTX-NMA(BDS) Pass', 'fail_label': 'RTX-NMA(BDS) Fail'},
        'Galileo': {'pass_col': 'GALILEO_PASS', 'fail_col': 'GALILEO_FAIL', 'pass_label': 'OSNMA(Galileo) Pass', 'fail_label': 'OSNMA(Galileo) Fail'}
    }

    # Plot all data
    # Calculate time in minutes since start
    if not plt_rel_time:
        x_vals = results['TIME']
        x_label = 'GPS TOW [sec]'
    else:
        x_vals = (results['TIME'] - results['TIME'].iloc[0]) / 60.0
        x_label = 'Time from start [min]'

    for system, config in systems.items():
        passMean = results[config['pass_col']].mean()
        failMean = results[config['fail_col']].mean()
        ax.plot(x_vals, results[config['pass_col']], '.-', label=config['pass_label'] + f' $\\mu$={passMean:5.2f}')
        ax.plot(x_vals, results[config['fail_col']], '.-', label=config['fail_label'] + f' $\\mu$={failMean:5.2f}')
        
        # Print stats directly from DataFrame
        print(f"{system} Pass: {passMean:5.2f}")
        print(f"{system} Fail: {failMean:5.2f}")

    ax.grid(True, linestyle='--', alpha=0.7)
    ax.legend(fontsize=6, loc='best')
    ax.set_xlabel(x_label)
    ax.set_ylabel('Number of NMA pass/fail per epoch')
    _, name = os.path.split(path_name)
    ax.set_title('NMA Status: ' + name)
    ax.set_xticks(ax.get_xticks())
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
    
    # Add test segment boxes
    if test_segments is not None and box:
        xlim_min, xlim_max = ax.get_xlim()
        ymax = ax.get_ylim()[1]
        ymin = ax.get_ylim()[0]
        
        for segment in test_segments:
            start_sec = segment.get('start_seconds')
            end_sec = segment.get('end_seconds')
            test_id = segment.get('test_id', '')
            
            if start_sec is not None and end_sec is not None:
                if start_sec >= xlim_min and start_sec <= xlim_max:
                    y_range = ymax - ymin
                    box_top = ymax - 0.02 * y_range
                    
                    ax.hlines(y=box_top, xmin=start_sec, xmax=end_sec, color='red', linestyle='-', linewidth=2, alpha=0.7)
                    ax.vlines(x=start_sec, ymin=ymin, ymax=box_top, color='red', linestyle='--', linewidth=0.5, alpha=0.7)
                    ax.vlines(x=end_sec, ymin=ymin, ymax=box_top, color='red', linestyle='--', linewidth=0.5, alpha=0.7)
                    
                    mid_time = (start_sec + end_sec) / 2
                    label_y = box_top + 0.01 * y_range
                    ax.text(mid_time, label_y, test_id, ha='center', va='bottom', fontsize=8, color='red', rotation=45)
                    
                    current_ylim = ax.get_ylim()
                    new_ymax = max(current_ylim[1], label_y + 0.15 * y_range)
                    ax.set_ylim(current_ylim[0], new_ymax)
    
    fig.tight_layout()
    # Return the figure for saving later
    return fig

def fft_mov(path_name, dir_name='.'):
    d,k=vd2arr(path_name,rec='-d35:261')
    # Check if the required script exists before running

    for s in np.unique(d[:,k.SAT_TYPE]):
        os.system( 'python /home/ltalegh/gpstools/pythonTools/rec35_261_movie.py ' + path_name + ' -s*/' + str(int(s)) + ' -r5000' )
        os.makedirs(dir_name + '/fft_movie', exist_ok=True)
        src = 'sv.mp4'
        dst = os.path.join(dir_name, 'fft_movie', f'{int(s)}.mp4')
        if os.path.exists(src):
            os.rename(src, dst)

def plot_fft(path_name,dir_name='.',start_time=None,end_time=None):
    mit = False
    d = vd2cls(path_name, rec='-d35:25')
    if start_time is not None:
        d=d[d.SEC>=start_time]  
    if end_time is not None:
        d=d[d.SEC<=end_time]

    if 'SAMP_PT' in d.k and np.any((d.SAMP_PT == 1) | (d.SAMP_PT == 2)):
        mit = True
    
    print("MITLIGATION DATA: ", mit)
    
    if not mit:
        #Pl    if not mit:ot full FFT, low and high rate

        #Plot low-rate FFT
        try:
            os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_fft_data.py ' + path_name + ' -s' + ' -a' + ' --pngs')
        except:
            print("Error generating low-rate FFT plots")
            return
        
        os.makedirs(dir_name + '/fft_plots/low_rate', exist_ok=True)
        for file in glob.glob('*.png'):
            dst = os.path.join(dir_name, 'fft_plots/low_rate', file)
            os.rename(file, dst)
        
        
        #Plot high-rate FFT
        os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_fft_data.py ' + path_name + ' -s' + ' -a' + ' -hr' + ' --pngs')
        os.makedirs(dir_name + '/fft_plots/high_rate', exist_ok=True)
        for file in glob.glob('*.png'):
            dst = os.path.join(dir_name, 'fft_plots/high_rate', file)
            os.rename(file, dst)

    elif mit:
        #Plot pre-mitigation
        os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_fft_data.py ' + path_name + ' -s' + ' -a' + ' --pngs' + ' --sample_point 1')
        os.makedirs(dir_name + '/fft_plots/pre_mit/low_rate', exist_ok=True)
        for file in glob.glob('*.png'):
            dst = os.path.join(dir_name, 'fft_plots/pre_mit/low_rate', file)
            os.rename(file, dst)
        
        #High rate FFT
        os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_fft_data.py ' + path_name + ' -s' + ' -a' + ' -hr' + ' --pngs' + ' --sample_point 1')
        os.makedirs(dir_name + '/fft_plots/pre_mit/high_rate', exist_ok=True)
        for file in glob.glob('*.png'):
            dst = os.path.join(dir_name, 'fft_plots/pre_mit/high_rate', file)
            os.rename(file, dst)

        #Plot post-mitigation
        os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_fft_data.py ' + path_name + ' -s' + ' -a' + ' --pngs' + ' --sample_point 2')
        os.makedirs(dir_name + '/fft_plots/post_mit/low_rate', exist_ok=True)
        for file in glob.glob('*.png'):
            dst = os.path.join(dir_name, 'fft_plots/post_mit/low_rate', file)
            os.rename(file, dst)
        
        #High rate FFT
        os.system( 'python /home/ltalegh/gpstools/pythonTools/plt_fft_data.py ' + path_name + ' -s' + ' -a' + ' -hr' + ' --pngs' + ' --sample_point 2')
        os.makedirs(dir_name + '/fft_plots/post_mit/high_rate', exist_ok=True)
        for file in glob.glob('*.png'):
            dst = os.path.join(dir_name, 'fft_plots/post_mit/high_rate', file)
            os.rename(file, dst)

def plot_power_agc(path_name, dir_name='.',start_time=None,end_time=None):
        #Plot AGC

        try:
            if start_time is not None and end_time is not None:
                os.system( 'python ~/gpstools/pythonTools/getPowerAGC.py ' + path_name + ' -s' + str(start_time) + ' -e' + str(end_time))
            elif start_time is None and end_time is not None:
                os.system( 'python ~/gpstools/pythonTools/getPowerAGC.py ' + path_name + ' -e' + str(end_time))
            elif end_time is None and start_time is not None:
                os.system( 'python ~/gpstools/pythonTools/getPowerAGC.py ' + path_name + ' -s' + str(start_time))
            else:
                os.system( 'python ~/gpstools/pythonTools/getPowerAGC.py ' + path_name)
        except:
            print("Error generating power agc plots")
            return

        os.makedirs(dir_name + '/power_agc', exist_ok=True)
        for file in glob.glob('*.png'):
            dst = os.path.join(dir_name, 'power_agc', file)
            os.rename(file, dst)

def plot_kml_map(path_name, dir_name='.', start_time=None, end_time=None):    

    start_time = int(start_time)
    end_time = int(end_time)
    try:
        if start_time is None and end_time is None:
            os.system('~/t01_tools/colossal/t01/build64/t012kml/t012kml ' + path_name + ' outputkmz.kmz')
        else:
            d = vd2cls(path_name, rec='-d35:2') #Position
            weekNum = str(int(d.WEEK[0]))
            if start_time is None and end_time is not None:
                os.system('~/t01_tools/colossal/t01/build64/t012kml/t012kml ' + path_name + ' outputkmz.kmz' + ' -es' + weekNum + ',' + str(end_time))
            elif end_time is None and start_time is not None:
                os.system('~/t01_tools/colossal/t01/build64/t012kml/t012kml ' + path_name + ' outputkmz.kmz' + ' -is' + weekNum + ',' + str(start_time))
            else:
                print('~/t01_tools/colossal/t01/build64/t012kml/t012kml ' + path_name + ' outputkmz.kmz' + ' -is' + weekNum + ',' + str(start_time) + ' -es' + weekNum + ',' + str(end_time))
                os.system('~/t01_tools/colossal/t01/build64/t012kml/t012kml ' + path_name + ' outputkmz.kmz' + ' -is' + weekNum + ',' + str(start_time) + ' -es' + weekNum + ',' + str(end_time))

    except:
        print("Error generating KMZ file")
        return

    try:
        # Unzip KMZ (which is just a zip file)
        with zipfile.ZipFile('outputkmz.kmz', 'r') as kmz:
            # Find the first .kml file in the archive
            for name in kmz.namelist():
                if name.endswith('.kml'):
                    kml_filename = name
                    kmz.extract(name)
                    break
    except:
        print("Error extracting KML from KMZ file")
        return

    if not kml_filename:
        raise FileNotFoundError("No KML file found inside the KMZ archive.")


    # Parse KML file
    doc = minidom.parse(kml_filename)
    placemarks = doc.getElementsByTagName('Placemark')

    lons, lats = [], []
    for placemark in placemarks:
        coords = placemark.getElementsByTagName('coordinates')
        for coord in coords:
            points = coord.firstChild.data.strip().split()
            for point in points:
                lon, lat, *_ = map(float, point.split(','))
                lons.append(lon)
                lats.append(lat)

    # Center map
    center = [sum(lats)/len(lats), sum(lons)/len(lons)]
    m = folium.Map(
        location=center,
        zoom_start=15,
        tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
        attr='Esri'
    )

    # Add points
    folium.PolyLine(list(zip(lats, lons)), color='red').add_to(m)

    m.save('kml_map.html')

    
    os.makedirs(dir_name + '/position/map', exist_ok=True)
    for file in glob.glob('*.html'):
        dst = os.path.join(dir_name, 'position/map', file)
        os.rename(file, dst)
    
    # Remove .kml and .kmz files generated during the process
    for ext in ('*.kml', '*.kmz'):
        for file in glob.glob(ext):
            try:
                os.remove(file)
            except Exception as e:
                print(f"Could not remove {file}: {e}")

def plot_covariance(path_name, dir_name='.', start_time=None, end_time=None, test_segments = None):
    d = vd2cls(path_name, rec='-d35:2')
    if start_time is not None:
        d=d[d.TIME>=start_time]
    if end_time is not None:
        d=d[d.TIME<=end_time]   
    
    fig, axs = subplots(2, 3, figsize=(16, 8))

    # Plot Covariance 
    axs[0, 0].plot(d.TIME, d.SigN, 'x', markersize=1.5)
    axs[0, 0].set_xlabel('GPS TOW [sec]')
    axs[0, 0].set_ylabel('SigN [m]')
    axs[0, 0].grid(False)
    for label in axs[0, 0].get_xticklabels():
        label.set_rotation(45)

    axs[0, 1].plot(d.TIME, d.SigE, 'x', markersize=1.5)
    axs[0, 1].set_xlabel('GPS TOW [sec]')
    axs[0, 1].set_ylabel('SigE [m]')
    axs[0, 1].grid(False)
    for label in axs[0, 1].get_xticklabels():
        label.set_rotation(45)

    axs[0, 2].plot(d.TIME, d.SigU, 'x', markersize=1.5)
    axs[0, 2].set_xlabel('GPS TOW [sec]')
    axs[0, 2].set_ylabel('SigU [m]')
    axs[0, 2].grid(False)
    for label in axs[0, 2].get_xticklabels():
        label.set_rotation(45)

    axs[1, 0].plot(d.TIME, d.SigEN, 'x', markersize=1.5)
    axs[1, 0].set_xlabel('GPS TOW [sec]')
    axs[1, 0].set_ylabel('SigEN [m^2]')
    axs[1, 0].grid(False)
    for label in axs[1, 0].get_xticklabels():
        label.set_rotation(45)

    axs[1, 1].plot(d.TIME, d.SigEU, 'x', markersize=1.5)
    axs[1, 1].set_xlabel('GPS TOW [sec]')
    axs[1, 1].set_ylabel('SigEU [m^2]')
    axs[1, 1].grid(False)
    for label in axs[1, 1].get_xticklabels():
        label.set_rotation(45)

    axs[1, 2].plot(d.TIME, d.SigNU, 'x', markersize=1.5)
    axs[1, 2].set_xlabel('GPS TOW [sec]')
    axs[1, 2].set_ylabel('SigNU [m^2]')
    axs[1, 2].grid(False)
    ymin_last = axs[1, 2].get_ylim()[0]


    for label in axs[1, 2].get_xticklabels():
        label.set_rotation(45)

    fig.tight_layout()

    # # Draw boxes for each test segment
    # ymax = max(ax.get_ylim()[1] for ax in axs.flatten())
    # ymin = min(ax.get_ylim()[0] for ax in axs.flatten())
    # print(ymax, ymin)


    for ax in axs.flatten():
        xlim_min, xlim_max = ax.get_xlim()

        ymax = ax.get_ylim()[1]
        ymin = ax.get_ylim()[0]
        
        if test_segments is not None:
            for segment in test_segments:
                start_sec = segment.get('start_seconds')
                end_sec = segment.get('end_seconds')
                test_id = segment.get('test_id', '')
                # print(test_id, start_sec, end_sec)
                # break
                
                # Only draw if the segment is within the x-axis bounds
                if start_sec is not None and end_sec is not None:
                    if start_sec >= xlim_min and start_sec <= xlim_max:
                        # Draw horizontal lines to complete the box
                        # Place horizontal lines at the top of the plot area
                        y_range = ymax - ymin
                        box_top = ymax - 0.02 * y_range
                        
                        ax.hlines(y=box_top, xmin=start_sec, xmax=end_sec, color='red', linestyle='-', linewidth=2, alpha=0.7)
                        ax.vlines(x=start_sec, ymin=ymin, ymax=box_top, color='red', linestyle='--', linewidth=0.5, alpha=0.7)
                        ax.vlines(x=end_sec, ymin=ymin, ymax=box_top, color='red', linestyle='--', linewidth=0.5, alpha=0.7)
                        
                        # Add label at the top of the box, within bounds
                        mid_time = (start_sec + end_sec) / 2
                        # Position label above the box with some padding
                        label_y = box_top + 0.01 * y_range
                        ax.text(mid_time, label_y, test_id, ha='center', va='bottom', fontsize=8, color='red', rotation=45)
                        
                        # Increase y-axis upper limit to prevent label cutoff
                        current_ylim = ax.get_ylim()
                        new_ymax = max(current_ylim[1], label_y + 0.15 * y_range)
                        ax.set_ylim(current_ylim[0], new_ymax)

 
    
 


    return fig

def plot_signals_used(path_name, dir_name='.', start_time=None, end_time=None):
    d = vd2cls(path_name, rec='-d35:2')
    if start_time is not None:
        d=d[d.TIME>=start_time]
    if end_time is not None:
        d=d[d.TIME<=end_time] 

    fig, axs = subplots(2, 3, figsize=(15, 10))

    # Plot Number of GPS L1 C/A signals used
    axs[0, 0].plot(d.TIME, d.GPSUsd, 'o', markersize=1.5)
    axs[0, 0].set_xlabel('GPS TOW [sec]')
    axs[0, 0].set_ylabel('Number of GPS L1 C/A signals used')
    axs[0, 0].grid(True)
    for label in axs[0, 0].get_xticklabels():
        label.set_rotation(45)  
    
    axs[0, 1].plot(d.TIME, d.GLOUsd, 'o', markersize=1.5)
    axs[0, 1].set_xlabel('GPS TOW [sec]')
    axs[0, 1].set_ylabel('Number of GLONASS L1 signals used')
    axs[0, 1].grid(True)
    for label in axs[0, 1].get_xticklabels():
        label.set_rotation(45)  

    axs[0, 2].plot(d.TIME, d.GALUsd, 'o', markersize=1.5)  
    axs[0, 2].set_xlabel('GPS TOW [sec]')
    axs[0, 2].set_ylabel('Number of Galileo E1 signals used')
    axs[0, 2].grid(True)
    for label in axs[0, 2].get_xticklabels():
        label.set_rotation(45)

    axs[1, 0].plot(d.TIME, d.BDSUsd, 'o', markersize=1.5)
    axs[1, 0].set_xlabel('GPS TOW [sec]')
    axs[1, 0].set_ylabel('Number of BeiDou B1I signals used')
    axs[1, 0].grid(True)
    for label in axs[1, 0].get_xticklabels():
        label.set_rotation(45)

    axs[1, 1].plot(d.TIME, d.QZSSUsd, 'o', markersize=1.5)      
    axs[1, 1].set_xlabel('GPS TOW [sec]')
    axs[1, 1].set_ylabel('Number of QZSS L1 C/A signals used')
    axs[1, 1].grid(True)
    for label in axs[1, 1].get_xticklabels():
        label.set_rotation(45)

    axs[1, 2].plot(d.TIME, d.IRNSSUsd, 'o', markersize=1.5)      
    axs[1, 2].set_xlabel('GPS TOW [sec]')
    axs[1, 2].set_ylabel('Number of NavIC L5 signals used')
    axs[1, 2].grid(True)
    for label in axs[1, 2].get_xticklabels():
        label.set_rotation(45)

    fig.tight_layout()
    os.makedirs(dir_name + '/signals_used', exist_ok=True)
    filename = (dir_name + '/signals_used/' + 'signals_used.png')
    fig.savefig(filename, dpi=300) 

   


if __name__ == '__main__':
    parser = argparse \
        .ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
                        description=__doc__)

    parser.add_argument('path_name',
                        help='name of or path to your T04 file(s)')
    parser.add_argument('-dir_name', 
                        help='Directory name to save figures', default='.')
    parser.add_argument('-raim', action='store_true', default=False, help='Plot RAIM and tracking figures')
    parser.add_argument('-cno', action='store_true', default=False, help='Plot C/N0 figures')
    parser.add_argument('-pos', action='store_true', default=False, help='Plot position figures')
    parser.add_argument('-rec35_16', action='store_true', default=False, help='Plot 35:16 position figures')
    parser.add_argument('-nma', action='store_true', default=False, help='Plot NMA figures')
    parser.add_argument('-all_figs', action='store_true', default=False, help='Plot all figures')
    parser.add_argument('-pdf', action='store_true', default=False, help='Save all PNG figures in dir_name to a single PDF')
    parser.add_argument('-plt_rel_time', action='store_true', default=False, help='Plot absolute time on x-axis instead of time from start')
    parser.add_argument('-fft_mov', action='store_true', default=False, help='Create a moving FFT of the C/N0 data (not implemented yet)')
    parser.add_argument('-fft_plots', action='store_true', default=False, help='Create high rate and low rate FFT plots')
    parser.add_argument('-agc', action='store_true', default=False, help='Plot power and AGC data')
    parser.add_argument('-kml_map', action='store_true', default=False, help='Generate a map of the position data on Google Maps using KML/KMZ file')
    parser.add_argument('-cov', action='store_true', default=False, help='Plot covariance figures')
    parser.add_argument('-signals_used', action='store_true', default=False, help='Plot signals used figures')

    #Optional arguments for start and end time of data
    parser.add_argument('-s', type=float, default=None, help='Start time in GPS TOW')
    parser.add_argument('-e', type=float, default=None, help='End time in GPS TOW')

    #JammerTest
    parser.add_argument('-jammerTest', action='store_true', default=False, help='Whether to plot jammer test segments on the figures (only works for position and covariance currently)')
    parser.add_argument('-day', type=int, default=None, help='Day of the test (15-19) to determine jammer test segment times (only needed if -jammerTest is set)') 
    parser.add_argument('-r', type=str, default=None, help='Vehicle name (Stella or Bleik) to determine jammer test segment times (only needed if -jammerTest is set)')


    args = parser.parse_args()

    if args.all_figs:
        args.raim = True
        args.cno = True
        args.pos = True
        args.rec35_16 = True
        args.nma = True
        args.fft_plots = False #Made false for now for speed
        args.agc = True
        args.cov = False
        args.signals_used = True
       # args.kml_map = True

    
   
    #Figure out default directory name
    if args.dir_name == '.':
        args.dir_name = os.path.splitext(os.path.basename(args.path_name))[0]

    # If start and/or end time are specified, filter the T04 file(s) to a temporary file
    if args.s is not None or args.e is not None:
        if args.s is not None:
            start_time = args.s
        if args.e is not None:
            end_time = args.e
    
    # Update dir_name to include start and/or end seconds if specified
    if args.s is not None or args.e is not None:
        start_str = f"{int(args.s)}" if args.s is not None else ""
        end_str = f"{int(args.e)}" if args.e is not None else ""
        # Always add an underscore before the time info
        if start_str and end_str:
            args.dir_name += f"_{start_str}_{end_str}"
        elif start_str:
            args.dir_name += f"_{start_str}"
        elif end_str:
            args.dir_name += f"_{end_str}"

    print( args.dir_name )

    if args.signals_used:
        plot_signals_used(path_name=args.path_name, dir_name=args.dir_name, start_time=args.s, end_time=args.e)


    if args.jammerTest:

        # args.areaPM = 3
        weekday = args.day

        box = True

        #Mapping
        # Map the date to the vehicle location in the morning and afternoon
        location_vehicle = { 15:(2,2),
                            16:(3,3),
                            17:(2,1),
                            18:(1,3),
                            19:(1,1)}
        
        # Bleik location is always (1,1) (in the hall)
        location_Bleik = { 15:(1,1),
                        16:(1,1),
                        17:(1,1),
                        18:(1,1),
                        19:(1,1)}

        #Options currently are Stella or Bleik (Alloy and Silk)
        if args.r.lower() == 'stella':
            args.areaAM = location_vehicle.get(weekday, (1,1))[0]
            args.areaPM = location_vehicle.get(weekday, (1,1))[1]
        else:
            args.areaAM = location_Bleik.get(weekday, (1,1))[0]
            args.areaPM = location_Bleik.get(weekday, (1,1))[1]

        testsAM = plan.get_test_plan_times(2025, 9, int(weekday), location=f'Test Area {args.areaAM}')
        testsPM = plan.get_test_plan_times(2025, 9, int(weekday), location=f'Test Area {args.areaPM}')

        test_segments = []
        for i, tests in enumerate((testsAM, testsPM)):
            PM = False
            AM = True
            for test in tests:
                test_id = test.get("test_id", "")
                if(i==0):
                    # AM only up to lunch
                    if(test_id == "0.1.1"):
                        AM = False
                    
                    if(AM == False):
                        continue
                else:
                    # Only after lunch for PM
                    if(test_id == "0.1.1"):
                        PM = True
                    
                    if((PM == False)):
                        continue

                if(test_id.startswith("0.")):
                    # Skip briefings/lunch
                    continue

                start_gps = test.get("start_time_gps") or {}
                end_gps = test.get("end_time_gps") or {}
                try:
                    start_week = int(start_gps.get("week"))
                    end_week = int(end_gps.get("week"))
                    start_seconds = float(start_gps.get("seconds"))
                    end_seconds = float(end_gps.get("seconds"))
                except (TypeError, ValueError):
                    continue

                if end_seconds <= start_seconds:
                    continue

                test_segments.append(
                    {
                        "test_id": test.get("test_id", ""),
                        "start_week": start_week,
                        "end_week": end_week,
                        "start_seconds": start_seconds,
                        "end_seconds": end_seconds,
                        "AM_or_PM": "PM" if i == 1 else "AM",
                    }
                )
    else:       
        box = False
        test_segments = None

    #Position
    if args.pos:
        pos_fig = plt_pos(path_name=args.path_name, plt_rel_time=args.plt_rel_time, start_time=args.s, end_time=args.e)
        os.makedirs(args.dir_name + '/position/', exist_ok=True)
        filename = (args.dir_name + '/position/' + 'llh.png')
        pos_fig.savefig(filename, dpi=300)

    
    if args.rec35_16:
        fig16= plt_rec35_16(path_name=args.path_name, plt_rel_time=args.plt_rel_time, start_time=args.s, end_time=args.e)
        if fig16 is not None:
            os.makedirs(args.dir_name + '/position/', exist_ok=True)
            filename_16 = (args.dir_name + '/position/' + 'llh_rec35_16.png')
            fig16.savefig(filename_16, dpi=300, bbox_inches='tight')
        else:
            print('No 35:16 data found, continuing without it')
    
    
    #NMA
    if args.nma:
        nma_fig = plot_nma(path_name=args.path_name, plt_rel_time=args.plt_rel_time, start_time=args.s, end_time=args.e, box=box, test_segments=test_segments)
        os.makedirs(args.dir_name + '/NMA', exist_ok=True)
        filename = (args.dir_name + '/NMA/' + 'nma.png')
        nma_fig.savefig(filename, dpi=300)
       #Do RAIM and track figures 
    
    #RAIM 
    if args.raim:
        plt_tracking(path_name=args.path_name, dir_name=args.dir_name, antenna_num=0, start_time=args.s, end_time=args.e, box=box,test_segments=test_segments)

   
    if args.cno:
        #Get cno data
        cno_figs = plt_cno_data(path_name=args.path_name, return_figs=True, plt_rel_time=args.plt_rel_time,start_time=args.s, end_time=args.e)

        #Save CN0 Figs
        for fig, value in cno_figs:
            os.makedirs(args.dir_name + '/cno_data', exist_ok=True)
            filename = (args.dir_name + '/cno_data/' + 'cno-%s.png' % value.replace(' ', '-').replace('(', '').replace(')', '').replace('/', '-'))
            fig.savefig(filename, dpi=300)
   
    if args.kml_map:
        plot_kml_map(path_name=args.path_name, dir_name=args.dir_name, start_time=args.s, end_time=args.e)

    if args.agc:
        plot_power_agc(path_name=args.path_name, dir_name=args.dir_name, start_time=args.s, end_time=args.e)

    if args.fft_plots:
        plot_fft(path_name=args.path_name, dir_name=args.dir_name, start_time=args.s, end_time=args.e)
        
    if args.fft_mov:
        fft_mov(path_name=args.path_name, dir_name=args.dir_name)

    if args.cov:
        cov_fig = plot_covariance(path_name=args.path_name, dir_name=args.dir_name, start_time=args.s, end_time=args.e, test_segments=test_segments)
        os.makedirs(args.dir_name + '/covariance', exist_ok=True)
        filename = (args.dir_name + '/covariance/' + 'covariance.png')
        cov_fig.savefig(filename, dpi=300)

    if args.pdf:
        pdf_path = os.path.join(args.dir_name, f"{os.path.basename(args.dir_name)}.pdf")
        with PdfPages(pdf_path) as pdf:
            for root, dirs, files in os.walk(args.dir_name):
                for file in files:
                    if file.endswith('.png'):
                        img_path = os.path.join(root, file)
                        fig = plt.figure(figsize=(12, 8), dpi=300)
                        img = plt.imread(img_path)
                        plt.imshow(img)
                        plt.axis('off')
                        pdf.savefig(fig)
                        plt.close(fig)
        print(f"Saved all PNG figures to {pdf_path}")


    

