#!/usr/bin/env python3
"""
TTFF Data Visualization Script

This script reads TTFF (Time To First Fix) test results from ttff_results.db,
filters and analyzes data based on aiding state and enabled GNSS systems,
and generates three comprehensive visualization plots:

1. Raw TTFF values by configuration (box plots)
2. Satellite counts by system and configuration (grouped bar charts)
3. Daily summary TTFF statistics over time (line plots)

All plots are saved as PNG files at 600 DPI resolution.
"""

import pandas as pd
import sqlite3
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Configuration
DB_PATH = Path.home() / 'A9Test' / 'ttff_results.db'
OUTPUT_DIR = Path.home() / 'A9Test'
DPI = 600

def load_database():
    """
    Load TTFF results from the database into a pandas DataFrame.
    Uses read-only connection to avoid locking issues.
    
    Returns:
        pd.DataFrame: DataFrame containing all TTFF test results
    """
    print("Loading data from database...")
    db = sqlite3.connect(f'file:{DB_PATH}?mode=ro', uri=True)
    
    query = '''
    SELECT *
    FROM results
    '''
    
    df = pd.read_sql_query(query, db)
    db.close()
    
    print(f"Loaded {len(df)} records from database")
    return df


def create_configuration_label(row):
    """
    Create a human-readable label for a test configuration.
    
    Args:
        row: DataFrame row with Enable* columns
        
    Returns:
        str: Label like "GPS+Galileo" or "GPS only"
    """
    systems = []
    if row['EnableGPS']:
        systems.append('GPS')
    if row['EnableGalileo']:
        systems.append('Galileo')
    if row['EnableGLONASS']:
        systems.append('GLONASS')
    if row['EnableBeiDou']:
        systems.append('BeiDou')
    
    if not systems:
        return "No Systems"
    elif len(systems) == 1:
        return f"{systems[0]} only"
    else:
        return '+'.join(systems)


def create_filter_groups(df):
    """
    Create filter groups based on Aided state and enabled systems.
    
    Args:
        df: DataFrame with TTFF results
        
    Returns:
        dict: Dictionary mapping group labels to filtered DataFrames
    """
    print("\nCreating filter groups...")
    
    # Add configuration label column
    df['Config'] = df.apply(create_configuration_label, axis=1)
    
    # Add aided label
    df['AidedLabel'] = df['Aided'].apply(lambda x: 'Aided' if x else 'No Aiding')
    
    # Create combined group label
    df['Group'] = df['Config'] + ' (' + df['AidedLabel'] + ')'
    
    # Get unique groups and filter out empty ones
    groups = {}
    for group_name in df['Group'].unique():
        group_df = df[df['Group'] == group_name]
        if len(group_df) > 0:
            groups[group_name] = group_df
            print(f"  {group_name}: {len(group_df)} records")
    
    return groups, df


def plot_raw_ttff(df, output_path):
    """
    Plot raw TTFF values as box plots for each configuration.
    
    Args:
        df: DataFrame with TTFF results and Group labels
        output_path: Path to save the plot
    """
    print("\nGenerating Plot 1: Raw TTFF box plots...")
    
    # Sort groups for consistent ordering
    groups = sorted(df['Group'].unique())
    
    # Create figure with larger size for readability
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Prepare data for box plot
    data_to_plot = [df[df['Group'] == group]['TTFF'].values for group in groups]
    
    # Create box plot
    bp = ax.boxplot(data_to_plot, tick_labels=groups, patch_artist=True)
    
    # Color boxes by aided/no-aiding
    for i, group in enumerate(groups):
        if '(Aided)' in group:
            bp['boxes'][i].set_facecolor('lightblue')
        else:
            bp['boxes'][i].set_facecolor('lightcoral')
    
    # Formatting
    ax.set_xlabel('Configuration', fontsize=12, fontweight='bold')
    ax.set_ylabel('TTFF (seconds)', fontsize=12, fontweight='bold')
    ax.set_title('Time To First Fix by Configuration', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Rotate x-axis labels for readability
    plt.xticks(rotation=45, ha='right')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='lightblue', label='With Aiding'),
        Patch(facecolor='lightcoral', label='Without Aiding')
    ]
    ax.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
    print(f"Saved: {output_path}")
    plt.close()


def plot_satellite_counts(df, output_path):
    """
    Plot satellite counts by system and configuration.
    
    Args:
        df: DataFrame with TTFF results
        output_path: Path to save the plot
    """
    print("\nGenerating Plot 2: Satellite counts by system...")
    
    # Calculate mean satellite counts for each group
    sat_columns = ['NumGPS', 'NumGalileo', 'NumGLONASS', 'NumBeiDou']
    groups = sorted(df['Group'].unique())
    
    # Prepare data
    sat_data = {col: [] for col in sat_columns}
    for group in groups:
        group_df = df[df['Group'] == group]
        for col in sat_columns:
            sat_data[col].append(group_df[col].mean())
    
    # Create figure
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Set up bar positions with extra spacing between groups
    group_spacing = 1.2  # Add space between groups
    x = np.arange(len(groups)) * group_spacing
    width = 0.2
    
    # Add alternating background colors for each group
    for i in range(len(groups)):
        left_edge = x[i] - group_spacing / 2 if i == 0 else (x[i-1] + x[i]) / 2
        right_edge = (x[i] + x[i+1]) / 2 if i < len(groups) - 1 else x[i] + group_spacing / 2
        
        if i % 2 == 0:  # Even groups get light gray background
            ax.axvspan(left_edge, right_edge, facecolor='lightgray', alpha=0.3, zorder=0)
        else:  # Odd groups get darker gray background
            ax.axvspan(left_edge, right_edge, facecolor='darkgray', alpha=0.3, zorder=0)
    
    # Plot bars for each satellite system
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
    labels = ['GPS', 'Galileo', 'GLONASS', 'BeiDou']
    
    for i, (col, color, label) in enumerate(zip(sat_columns, colors, labels)):
        offset = (i - 1.5) * width
        ax.bar(x + offset, sat_data[col], width, label=label, color=color, 
               alpha=0.8, edgecolor='black', linewidth=0.8)
    
    # Add vertical dashed lines between groups
    for i in range(len(groups) - 1):
        line_x = (x[i] + x[i + 1]) / 2
        ax.axvline(x=line_x, color='gray', linestyle='--', linewidth=1, alpha=0.5)
    
    # Formatting
    ax.set_xlabel('Configuration', fontsize=12, fontweight='bold')
    ax.set_ylabel('Mean Number of Satellites', fontsize=12, fontweight='bold')
    ax.set_title('Satellite Counts by System and Configuration', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(groups, rotation=45, ha='right')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
    print(f"Saved: {output_path}")
    plt.close()


def plot_daily_summary(df, output_path, csv_path):
    """
    Plot daily summary TTFF statistics and save summary data.
    
    Args:
        df: DataFrame with TTFF results
        output_path: Path to save the plot
        csv_path: Path to save the summary CSV
    """
    print("\nGenerating Plot 3: Daily summary statistics...")
    
    # Create date column
    df['Date'] = pd.to_datetime(df[['Year', 'Month', 'Day']])
    
    # Group by date and configuration, calculate mean TTFF
    daily_summary = df.groupby(['Date', 'Group']).agg({
        'TTFF': ['mean', 'std', 'count'],
        'NumGPS': 'mean',
        'NumGalileo': 'mean',
        'NumGLONASS': 'mean',
        'NumBeiDou': 'mean'
    }).reset_index()
    
    # Flatten column names
    daily_summary.columns = ['Date', 'Group', 'TTFF_mean', 'TTFF_std', 'Count',
                              'NumGPS_mean', 'NumGalileo_mean', 'NumGLONASS_mean', 'NumBeiDou_mean']
    
    # Save to CSV
    daily_summary.to_csv(csv_path, index=False)
    print(f"Saved summary data: {csv_path}")
    
    # Create plot
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Get unique groups
    groups = sorted(df['Group'].unique())
    
    # Plot line for each group
    colors = plt.cm.tab10(np.linspace(0, 1, len(groups)))
    
    for i, group in enumerate(groups):
        group_data = daily_summary[daily_summary['Group'] == group]
        if len(group_data) > 0:
            # Use different line styles for aided vs no-aiding
            linestyle = '-' if '(Aided)' in group else '--'
            marker = 'o' if '(Aided)' in group else 's'
            
            ax.plot(group_data['Date'], group_data['TTFF_mean'], 
                   label=group, color=colors[i], linestyle=linestyle, 
                   marker=marker, markersize=4, linewidth=2)
            
            # Add error bars (standard deviation)
            ax.fill_between(group_data['Date'], 
                           group_data['TTFF_mean'] - group_data['TTFF_std'],
                           group_data['TTFF_mean'] + group_data['TTFF_std'],
                           alpha=0.2, color=colors[i])
    
    # Formatting
    ax.set_xlabel('Date', fontsize=12, fontweight='bold')
    ax.set_ylabel('Mean TTFF (seconds)', fontsize=12, fontweight='bold')
    ax.set_title('Daily TTFF Summary by Configuration', fontsize=14, fontweight='bold')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    ax.grid(True, alpha=0.3)
    
    # Rotate date labels
    plt.xticks(rotation=45, ha='right')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
    print(f"Saved: {output_path}")
    plt.close()
    
    # Print summary statistics
    print("\nDaily Summary Statistics:")
    print(f"  Total unique dates: {daily_summary['Date'].nunique()}")
    print(f"  Date range: {daily_summary['Date'].min()} to {daily_summary['Date'].max()}")
    print(f"  Total daily group records: {len(daily_summary)}")


def plot_ttff_scatter_timeline(df, output_path):
    """
    Plot all individual TTFF measurements as scatter points over time.
    Color-coded by configuration with mean TTFF in legend.
    
    Args:
        df: DataFrame with TTFF results and Group labels
        output_path: Path to save the plot
    """
    print("\nGenerating Plot 4: TTFF scatter timeline...")
    
    # Create datetime column
    df['DateTime'] = pd.to_datetime(df[['Year', 'Month', 'Day', 'Hour', 'Minute', 'Seconds']])
    
    # Get unique groups and calculate mean TTFF for each
    groups = sorted(df['Group'].unique())
    group_means = {}
    for group in groups:
        group_means[group] = df[df['Group'] == group]['TTFF'].mean()
    
    # Create figure
    fig, ax = plt.subplots(figsize=(16, 8))
    
    # Generate colors for each group
    cmap = plt.get_cmap('jet')
    colors = cmap(np.linspace(0, 1, len(groups)))   
    
    # Plot scatter points for each group
    for i, group in enumerate(groups):
        group_df = df[df['Group'] == group]
        if len(group_df) > 0:
            # Use different markers for aided vs no-aiding
            marker = '^' if '(Aided)' in group else 's'
            #marker_size = 50 if '(Aided)' in group else 40
            marker_size = 50
            
            # Create label with mean TTFF
            mean_ttff = group_means[group]
            label = f"{group} (μ={mean_ttff:.1f}s)"
            
            ax.scatter(group_df['DateTime'], group_df['TTFF'], 
                      c=[colors[i]], label=label, marker=marker, 
                      s=marker_size, alpha=0.7, edgecolors=colors[i], linewidth=0.5)
    
    # Formatting
    ax.set_xlabel('Date/Time', fontsize=12, fontweight='bold')
    ax.set_ylabel('TTFF (seconds)', fontsize=12, fontweight='bold')
    ax.set_title('TTFF Measurements Over Time (Individual Tests)', fontsize=14, fontweight='bold')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    ax.grid(True, alpha=0.3)
    
    # Rotate date labels
    plt.xticks(rotation=45, ha='right')
    
    # Format x-axis to show dates nicely
    import matplotlib.dates as mdates
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
    fig.autofmt_xdate()
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
    print(f"Saved: {output_path}")
    plt.close()


def main():
    """Main execution function."""
    print("="*70)
    print("TTFF Data Visualization Script")
    print("="*70)
    
    # Load data
    df = load_database()
    
    if len(df) == 0:
        print("\nError: No data found in database!")
        return
    
    # Create filter groups
    groups, df = create_filter_groups(df)
    
    if len(groups) == 0:
        print("\nError: No valid filter groups found!")
        return
    
    # Generate plots
    plot_raw_ttff(df, OUTPUT_DIR / 'ttff_raw.png')
    plot_satellite_counts(df, OUTPUT_DIR / 'ttff_satellite_counts.png')
    plot_daily_summary(df, OUTPUT_DIR / 'ttff_daily_summary.png', 
                      OUTPUT_DIR / 'ttff_daily_summary.csv')
    plot_ttff_scatter_timeline(df, OUTPUT_DIR / 'ttff_scatter_timeline.png')
    
    print("\n" + "="*70)
    print("All plots generated successfully!")
    print("="*70)
    print("\nOutput files:")
    print(f"  - {OUTPUT_DIR / 'ttff_raw.png'}")
    print(f"  - {OUTPUT_DIR / 'ttff_satellite_counts.png'}")
    print(f"  - {OUTPUT_DIR / 'ttff_daily_summary.png'}")
    print(f"  - {OUTPUT_DIR / 'ttff_daily_summary.csv'}")
    print(f"  - {OUTPUT_DIR / 'ttff_scatter_timeline.png'}")


if __name__ == "__main__":
    main()
