#!/usr/bin/env python3
"""
GNSS Receiver TTFF Test Script

This script automates Time To First Fix (TTFF) testing for a GNSS receiver using DCOL operations.

Features:
- Command-line argument parsing for receiver/server IPs, ports, credentials, and iteration count
- Disables tracking and clears GPS data via HTTP requests (this causes a reboot)
- Pings the receiver until it is back online
- Enables tracking and via an HTTP request and start a timer
- Sends ephemeris/almanac data using DCOL protocol
- Monitors receiver status and pings for availability
- Measures TTFF by polling for autonomous fix and stops the time and records it
- Repeats for a specified number of iterations (default 10)
- Prints summary TTFF statistics for all iterations

Usage:
    python3 test_dcol_import.py [--receiver-ip ...] [--server-ip ...] [--username ...] [--password ...] [--iterations N] [...]
"""

from concurrent.futures import wait
from pathlib import Path
import dcol_client
import requests
import time
import socket
import xml.etree.ElementTree as ET
from datetime import datetime, timezone
from requests.auth import HTTPBasicAuth
import argparse
import random
import sqlite3
import os
from collections import Counter
from pathlib import Path

TTFF_DB = Path.home() / 'A9Test' / 'ttff_results.db'

# GNSS constellation system ID mapping
GNSS_SYS_MAP = {
    "0": "GPS",
    "2": "GLONASS",
    "3": "Galileo",
    "5": "BeiDou"
}

VALID_FIX_TYPES = {"PosAutonString", "Autonomous"}

TTFF_SCHEMA = [
    ("Year", "INTEGER"),
    ("Month", "INTEGER"),
    ("Day", "INTEGER"),
    ("Hour", "INTEGER"),
    ("Minute", "INTEGER"),
    ("Seconds", "INTEGER"),
    ("Aided", "BOOLEAN"),
    ("TTFF", "REAL"),
    ("NumGPS", "INTEGER"),
    ("NumGalileo", "INTEGER"),
    ("NumGLONASS", "INTEGER"),
    ("NumBeiDou", "INTEGER"),
    ("EnableGPS", "BOOLEAN"),
    ("EnableGalileo", "BOOLEAN"),
    ("EnableGLONASS", "BOOLEAN"),
    ("EnableBeiDou", "BOOLEAN")
]


def create_ttff_schema(db_path):
    """Create a fresh TTFF results table when the database file is new."""
    columns_definition = ',\n            '.join(
        f"{name} {col_type}" for name, col_type in TTFF_SCHEMA
    )

    try:
        db = sqlite3.connect(db_path, timeout=10)
        try:
            cursor = db.cursor()
            cursor.execute(
                f"""
                    CREATE TABLE results (
                        {columns_definition}
                    )
                """
            )
            db.commit()
            print(f"Created new database: {db_path}")
        finally:
            db.close()
    except sqlite3.Error as e:
        print(f"Error creating database schema: {e}")


def configure_wal(db_path, autocheckpoint_pages=1000):
    """Ensure the database is operating in WAL mode with sensible defaults."""
    try:
        with sqlite3.connect(db_path, timeout=10) as db:
            cursor = db.cursor()

            # Switch to WAL mode and keep checkpoints manageable
            cursor.execute("PRAGMA journal_mode=WAL")
            cursor.execute(f"PRAGMA wal_autocheckpoint={int(autocheckpoint_pages)}")

            # NORMAL reduces fsync frequency but keeps durability acceptable
            cursor.execute("PRAGMA synchronous=NORMAL")
            db.commit()
    except sqlite3.Error as e:
        print(f"Error configuring WAL mode: {e}")


def optimize_database(db_path):
    """Optimize the database for performance and integrity."""
    try:
        db = sqlite3.connect(db_path, timeout=10)
        cursor = db.cursor()

        # Check database integrity
        integrity_result = cursor.execute('PRAGMA integrity_check').fetchone()
        if integrity_result[0] != 'ok':
            print(f"Database integrity issue in {db_path}: {integrity_result[0]}")

        # Force a final checkpoint so the WAL is flushed and truncated
        checkpoint_stats = db.execute('PRAGMA wal_checkpoint(TRUNCATE)').fetchone()
        if checkpoint_stats and checkpoint_stats[0] != 0:
            print(f"WAL checkpoint reported busy writers for {db_path}: {checkpoint_stats}")

        # Analyze query planner statistics
        db.execute('ANALYZE')
        db.commit()

        # Optimize the database (reclaims space and defragments)
        db.execute('VACUUM')

        # Update internal statistics
        db.execute('PRAGMA optimize')
        db.commit()
        print(f"Database {db_path} optimized successfully")
        
    except sqlite3.Error as e:
        print(f"Error optimizing database {db_path}: {e}")
    finally:
        if 'db' in locals():
            db.close()


def store_ttff_result(db_path, end_time, aided, ttff_seconds, sat_counts, enabled_systems):
    """Store a successful TTFF result to the database.
    
    Args:
        db_path: Path to the database file
        end_time: datetime object of when fix was achieved (UTC)
        aided: Boolean indicating if aiding data was used
        ttff_seconds: Time to first fix in seconds
        sat_counts: Dictionary mapping constellation names to satellite counts
        enabled_systems: Dictionary with keys 'gps', 'galileo', 'glonass', 'beidou' (boolean values)
    """
    num_gps = sat_counts.get('GPS', 0)
    num_galileo = sat_counts.get('Galileo', 0)
    num_glonass = sat_counts.get('GLONASS', 0)
    num_beidou = sat_counts.get('BeiDou', 0)
    
    try:
        with sqlite3.connect(db_path, timeout=10) as db:
            cursor = db.cursor()
            cursor.execute(
                """
                INSERT INTO results 
                (Year, Month, Day, Hour, Minute, Seconds, Aided, TTFF, NumGPS, NumGalileo, NumGLONASS, NumBeiDou,
                 EnableGPS, EnableGalileo, EnableGLONASS, EnableBeiDou)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                (
                    end_time.year,
                    end_time.month,
                    end_time.day,
                    end_time.hour,
                    end_time.minute,
                    end_time.second,
                    1 if aided else 0,
                    ttff_seconds,
                    num_gps,
                    num_galileo,
                    num_glonass,
                    num_beidou,
                    1 if enabled_systems.get('gps', False) else 0,
                    1 if enabled_systems.get('galileo', False) else 0,
                    1 if enabled_systems.get('glonass', False) else 0,
                    1 if enabled_systems.get('beidou', False) else 0
                )
            )
            db.commit()
    except sqlite3.Error as e:
        print(f"Warning: Failed to store TTFF result to database: {e}")


def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(
        description='GNSS Receiver TTFF Test - Performs iterations of ephemeris/almanac injection and monitors time to first fix'
    )

    parser.add_argument('--receiver-ip', default='10.1.150.XXX',
                        help='IP address of the GNSS receiver (default: 10.1.150.XXX)')
    parser.add_argument('--receiver-port', type=int, default=28001,
                        help='A9 DCOL port on receiver (default: 28001)')
    parser.add_argument('--server-ip', default='10.1.150.YYY',
                        help='IP address of the aiding data server (default: 10.1.150.YYY)')
    parser.add_argument('--server-port', type=int, default=6001,
                        help='DCOL port on aiding data server (default: 6001)')
    parser.add_argument('--username', default='admin',
                        help='Username for receiver web interface (default: admin)')
    parser.add_argument('--password', default='password',
                        help='Password for receiver web interface (default: password)')
    parser.add_argument('--iterations', type=int, default=30,
                        help='Number of test iterations (default: 30)')
    parser.add_argument('--enable-gps', action='store_true', default=True,
                        help='Enable GPS tracking (default: enabled)')
    parser.add_argument('--disable-gps', action='store_true',
                        help='Disable GPS tracking')
    parser.add_argument('--enable-galileo', action='store_true', default=True,
                        help='Enable Galileo tracking (default: enabled)')
    parser.add_argument('--disable-galileo', action='store_true',
                        help='Disable Galileo tracking')
    parser.add_argument('--enable-glonass', action='store_true',
                        help='Enable GLONASS tracking (not yet supported)')
    parser.add_argument('--enable-beidou', action='store_true', default=True,
                        help='Enable BeiDou tracking (default: enabled)')
    parser.add_argument('--disable-beidou', action='store_true',
                        help='Disable BeiDou tracking')
    return parser.parse_args()

def safe_requests_get(url, auth=None, timeout=10, description=None):
    """Helper for requests.get with error handling and status printout"""
    if description:
        print(description)
    try:
        response = requests.get(url, auth=auth, timeout=timeout)
        print(f"  HTTP {response.status_code}")
        return response
    except Exception as e:
        print(f"  Error: {e}")
        return None

def ping_until_alive(host, timeout=60, interval=1):
    """Ping host until it responds or timeout is reached"""
    start = time.time()
    while time.time() - start < timeout:
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.settimeout(2)
            sock.connect((host, 80))
            sock.close()
            print(f"  Receiver {host} is alive")
            return True
        except (socket.timeout, socket.error):
            time.sleep(interval)
    print(f"  Warning: Receiver {host} did not respond within {timeout}s")
    return False

def build_tracking_url(receiver_ip, enabled_systems):
    """Build tracking configuration URL based on enabled GNSS systems.
    
    Args:
        receiver_ip: IP address of the receiver
        enabled_systems: Dictionary with keys 'gps', 'galileo', 'glonass', 'beidou' (boolean values)
    
    Returns:
        Tuple of (url, description) for the tracking configuration
    """
    base_url = f"http://{receiver_ip}/cgi-bin/trackingPage.xml"
    params = []
    systems_enabled = []
    
    if enabled_systems.get('gps', False):
        params.append("L1CA_enable=on")
        systems_enabled.append("GPS L1C/A")
    
    if enabled_systems.get('galileo', False):
        params.append("E1_enable=on")
        params.append("E1_mode=D_P")
        params.append("E1_mboc=on")
        systems_enabled.append("Galileo E1 (D+P/MBOC)")
    
    if enabled_systems.get('beidou', False):
        params.append("B1_enable=on")
        systems_enabled.append("BeiDou B1I")
    
    # GLONASS support reserved for future use
    # if enabled_systems.get('glonass', False):
    #     params.append("GLONASS_enable=on")  # Placeholder
    #     systems_enabled.append("GLONASS")
    
    if params:
        url = base_url + "?" + "&".join(params)
    else:
        url = base_url  # No systems enabled
    
    description = "Enabling tracking: " + ", ".join(systems_enabled) if systems_enabled else "No systems enabled"
    
    return url, description

def wait_for_autonomous_fix(receiver_ip, auth=None, timeout=600):
    """Poll position XML until fix type is 'Autonomous' or timeout
    
    Returns:
        Tuple of (success: bool, sat_counts: dict) where sat_counts maps constellation name to count
    """
    url = f"http://{receiver_ip}/xml/dynamic/posData.xml"
    start = time.time()

    while time.time() - start < timeout:
        try:
            response = requests.get(url, auth=auth, timeout=5)
            if response.status_code == 200:
                root = ET.fromstring(response.content)
                
                # Extract fix and solution types safely
                fix_elem = root.find('.//position/fixType')
                soln_elem = root.find('.//position/soln')
                
                fix_type = fix_elem.text.strip() if fix_elem is not None else None
                soln_type = soln_elem.text.strip() if soln_elem is not None else None
                
                print(f"  Fix type: {fix_type}, Solution: {soln_type}")
                
                # Check if the fix is valid
                is_valid_fix = (fix_type in VALID_FIX_TYPES) or (soln_type in VALID_FIX_TYPES)
                
                if is_valid_fix:
                    # Get satellites used in the fix
                    used_sats = root.findall('.//position/SvsUsed/sv')
                    print(f"  Total satellites used: {len(used_sats)}")
                    
                    # Map system IDs to constellation names and count them
                    sat_names = [
                        GNSS_SYS_MAP.get(sv.get('sys'), f"Unknown (ID: {sv.get('sys')})") 
                        for sv in used_sats
                    ]
                    
                    sats_used_count = dict(Counter(sat_names))
                    
                    for name, count in sats_used_count.items():
                        print(f"    {name}: {count}")
                    
                    return True, sats_used_count
        except Exception as e:
            print(f"  Error polling position: {e}")
        time.sleep(1)
    
    print(f"  Timeout: Did not achieve Autonomous fix within {timeout}s")
    return False, {}

def ttff_test_loop(with_aiding, auth, receiver_ip, receiver_a9_dcol_port, server_ip, server_54_dcol_port, iteration, enabled_systems):
    print(f"\n{'='*70}")
    label = "With Aiding" if with_aiding else "No Aiding"
    print(f"Iteration {iteration + 1} ({label})")
    print(f"{'='*70}")

    safe_requests_get(
        # With no arguments will disable all tracking
        f"http://{receiver_ip}/cgi-bin/trackingPage.xml",
        auth=auth,
        timeout=10,
        description="\n[0] Disabling tracking of all signals on the receiver..."
    )


    # Each GPS frame is 30s long, to avoid synch issues, add a minimum delay of
    # 5s (to allow CM to flush) and add a random 0-30s before the next iteration
    wait_time = random.randint(0, 30) + 5
    print(f"\n[0.5] Waiting randint(0,30) + 5 = {wait_time}s before next iteration to avoid GPS frame sync issues...")
    time.sleep(wait_time)

    safe_requests_get(
        f"http://{receiver_ip}/cgi-bin/resetPage.xml?doClearGPS=1",
        auth=auth,
        timeout=10,
        description="\n[1] Clearing GNSS data on the receiver..."
    )

    print("\n[2] Waiting 15 seconds for receiver to start the reboot...")
    time.sleep(15)
    
    print("\n[3] Waiting for receiver to come back online...")
    ping_until_alive(receiver_ip, timeout=60)

    # When we get here the receiver is responding to pings - turn on the tracking and start the timer
    tracking_url, tracking_desc = build_tracking_url(receiver_ip, enabled_systems)
    safe_requests_get(
        tracking_url,
        auth=auth,
        timeout=10,
        description=f"\n[4] {tracking_desc}..."
    )

    t_start = datetime.now(timezone.utc)
    print(f"\n[5] Start time: {t_start.isoformat()}")
    if with_aiding:
        print("\n[6] Sending ephemeris/almanac data via DCOL...")
        try:
            # Last known position
            position_params = {
                'reserved': 0,
                'latitude': 37.384634178,
                'longitude': -122.006504915,
                'height': 0.0  # Default height to 0
            }
            result = dcol_client.run_dcol_operations(
                get_server_ip=server_ip,
                get_port=server_54_dcol_port,
                send_server_ip=receiver_ip,
                send_port=receiver_a9_dcol_port,
                do_get=True,
                do_send=True,
                do_settime=True,
                #subtypes=[3,1],
                #subtypes=[1, 11, 3, 7, 12],
                #subtypes=[11, 12], # Galileo
                #subtypes=[21, 22], # BeiDou
                subtypes=[3, 1, 11, 7, 12, 21, 22], # GPS, Galileo, BeiDou
                prn="All",
                debug=True,
                radians=True,
                datfile="Eph.dat",
                position_params=position_params
            )
            print(f"  DCOL Success: {result['success']}")
            print(f"  GET results: {len(result['get_results'])}, SEND results: {len(result['send_results'])}")
        except Exception as e:
            print(f"  Error in DCOL operations: {e}")
    else:
        print(f"\n[6] Skipping DCOL aiding data send...")
    print(f"\n[7] Monitoring for Autonomous fix...")
    fix_achieved, sat_counts = wait_for_autonomous_fix(receiver_ip, auth=auth, timeout=600)
    t_end = datetime.now(timezone.utc)
    ttff_seconds = (t_end - t_start).total_seconds()

    print(f"\n[8] Results for iteration {iteration + 1} ({label}):")
    print(f"  Start: {t_start.isoformat()}")
    print(f"  End:   {t_end.isoformat()}")
    print(f"  TTFF:  {ttff_seconds:.1f} seconds")
    print(f"  Fix achieved: {fix_achieved}")
    if fix_achieved and sat_counts:
        print(f"  Satellites used by constellation:")
        for constellation, count in sat_counts.items():
            print(f"    {constellation}: {count}")


    return {
        'iteration': iteration + 1,
        'start': t_start,
        'end': t_end,
        'ttff_seconds': ttff_seconds,
        'success': fix_achieved,
        'sat_counts': sat_counts,
        'enabled_systems': enabled_systems
    }


def main():
    """Main test loop"""
    args = parse_args()
    RECEIVER_IP = args.receiver_ip
    RECEIVER_A9_DCOL_PORT = args.receiver_port
    SERVER_IP = args.server_ip
    SERVER_54_DCOL_PORT = args.server_port
    USERNAME = args.username
    PASSWORD = args.password
    NUM_ITERATIONS = args.iterations
    auth = HTTPBasicAuth(USERNAME, PASSWORD)
    
    # Determine which GNSS systems are enabled
    enabled_systems = {
        'gps': args.enable_gps and not args.disable_gps,
        'galileo': args.enable_galileo and not args.disable_galileo,
        'glonass': args.enable_glonass,
        'beidou': args.enable_beidou and not args.disable_beidou
    }

    # Print enabled systems
    enabled_list = [name.upper() for name, enabled in enabled_systems.items() if enabled]
    print(f"\nEnabled GNSS systems: {', '.join(enabled_list) if enabled_list else 'None'}\n")

    # Initialize database
    try:
        if not os.path.exists(TTFF_DB):
            create_ttff_schema(TTFF_DB)
        configure_wal(TTFF_DB)
    except Exception as e:
        print(f"Warning: Database initialization failed: {e}")
        print("Continuing test without database storage...")

    # Run both test loops with adjacent iterations
    ttff_no_aiding = []
    ttff_with_aiding = []

    # Turn off the NTP client, we only need to do this once at the start of the test
    safe_requests_get(
        # With no arguments will disable all tracking
        f"http://{RECEIVER_IP}/cgi-bin/ntp.xml",
        auth=auth,
        timeout=10,
        description="\nDisable the NTP client on the receiver..."
    )


    for i in range(NUM_ITERATIONS):
        result = ttff_test_loop(False, auth, RECEIVER_IP, RECEIVER_A9_DCOL_PORT, SERVER_IP, SERVER_54_DCOL_PORT, i, enabled_systems)
        ttff_no_aiding.append(result)
        if result['success']:
            store_ttff_result(TTFF_DB, result['end'], False, result['ttff_seconds'], result['sat_counts'], enabled_systems)
        
        result = ttff_test_loop(True, auth, RECEIVER_IP, RECEIVER_A9_DCOL_PORT, SERVER_IP, SERVER_54_DCOL_PORT, i, enabled_systems)
        ttff_with_aiding.append(result)
        if result['success']:
            store_ttff_result(TTFF_DB, result['end'], True, result['ttff_seconds'], result['sat_counts'], enabled_systems)

    # Print and compare statistics
    def print_ttff_stats(results, label, num_iterations):
        timestamp = datetime.now(timezone.utc).isoformat()
        
        # Collect all output lines
        output_lines = []
        output_lines.append(f"\nTimestamp: {timestamp}")
        output_lines.append(f"\n{'='*70}\nSUMMARY - {label}\n{'='*70}")
        output_lines.append("\nIndividual Results:")
        
        for result in results:
            status = "✓" if result['success'] else "✗"
            output_lines.append(f"  {status} Iteration {result['iteration']:2d}: {result['ttff_seconds']:6.1f}s")
        
        successful_ttffs = [r['ttff_seconds'] for r in results if r['success']]
        if successful_ttffs:
            mean_ttff = sum(successful_ttffs) / len(successful_ttffs)
            min_ttff = min(successful_ttffs)
            max_ttff = max(successful_ttffs)
            output_lines.append(f"\nStatistics (n={len(successful_ttffs)} successful):")
            output_lines.append(f"  Mean TTFF: {mean_ttff:.1f} seconds")
            output_lines.append(f"  Min TTFF:  {min_ttff:.1f} seconds")
            output_lines.append(f"  Max TTFF:  {max_ttff:.1f} seconds")
        else:
            output_lines.append("\nNo successful fixes achieved!")
        
        success_rate = len(successful_ttffs) / num_iterations * 100
        output_lines.append(f"\nSuccess Rate: {success_rate:.0f}% ({len(successful_ttffs)}/{num_iterations})")
        output_lines.append(f"{'='*70}")
        
        # Print to terminal
        for line in output_lines:
            print(line)
        
        # Append to file
        try:
            with open('ttff_summary.txt', 'a') as f:
                f.write('\n'.join(output_lines) + '\n')
        except Exception as e:
            print(f"Warning: Failed to write to ttff_summary.txt: {e}")

    print_ttff_stats(ttff_no_aiding, "Time to First Fix WITHOUT Aiding Data", NUM_ITERATIONS)
    print_ttff_stats(ttff_with_aiding, "Time to First Fix WITH Aiding Data", NUM_ITERATIONS)

    # Compare results
    print(f"\n{'='*70}\nTTFF COMPARISON\n{'='*70}")
    def get_mean(results):
        vals = [r['ttff_seconds'] for r in results if r['success']]
        return sum(vals) / len(vals) if vals else None
    mean_no_aiding = get_mean(ttff_no_aiding)
    mean_with_aiding = get_mean(ttff_with_aiding)
    print(f"Mean TTFF without aiding: {mean_no_aiding:.1f} seconds" if mean_no_aiding is not None else "No successful fixes without aiding.")
    print(f"Mean TTFF with aiding:    {mean_with_aiding:.1f} seconds" if mean_with_aiding is not None else "No successful fixes with aiding.")
    if mean_no_aiding is not None and mean_with_aiding is not None:
        diff = mean_no_aiding - mean_with_aiding
        print(f"TTFF improvement with aiding: {diff:.1f} seconds")

    # Optimize database before exit
    try:
        optimize_database(TTFF_DB)
    except Exception as e:
        print(f"Warning: Database optimization failed: {e}")

if __name__ == "__main__":
    main()
