"""High‑rate GNSS receiver FFT WebSocket client.

Receives 2048‑point FFT frames (optionally multiple bands) and writes them to
hour‑rotated text files while displaying per‑band RSSI & peak frequency bars.
Designed to stay small & dependency‑free for quick field use.
"""

from __future__ import annotations

import argparse
import asyncio
import datetime as dt
import signal
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List
import math

import websockets

# Copyright Trimble Inc 2024 - 2025

# Starting in 5.70, certain receiver models will support a binary websocket that can
# provide high rate (FFT) data in a network efficient manner. Each 2048 point FFT is
# sent in binary occupying (pre-compressed) 2048 bytes. A little extra data is sent
# in ASCII to fully describe the FFT (around 30 bytes). The websocket protocol allows
# a per message deflate compression that reduces the data size by about 25%, resulting in
# around 1.5KB per FFT. The FFT data can be sent from multiple bands (this example sends
# four bands of data) each at a 5Hz rate.
# 
# This script is a proof of concept showing how to use the binary web socket. All
# data received is logged into a text file, the FFT is converted to ASCII. The file
# is space delimited for each loading into Matlab.
#

FFT_POINTS = 2048
DEFAULT_BANDS = "L1,L2,L5,E6"
MIN_INTERVAL_FASTEST_MS = 200  # Fastest supported interval (ms) by receiver

MIN_RSSI = -45.0
MAX_RSSI = 10.0
RSSI_RANGE = MAX_RSSI - MIN_RSSI  # 55
BAR_WIDTH = 55  # ~1 char / dB

def _clamp(v: float) -> float:
    if math.isnan(v):
        return MIN_RSSI
    return MIN_RSSI if v < MIN_RSSI else MAX_RSSI if v > MAX_RSSI else v

def _color_for(norm: float) -> str:
    """Return an ANSI 24-bit color approximating MATLAB's jet colormap.

    jet goes (for increasing norm 0..1): dark blue -> blue -> cyan -> yellow -> orange -> red.
    We implement the classic 'jet' via piecewise tent functions:
        r = clip(1.5 - |4x - 3|)
        g = clip(1.5 - |4x - 2|)
        b = clip(1.5 - |4x - 1|)
    then scale to 0..255.
    """
    x = 0.0 if norm < 0.0 else 1.0 if norm > 1.0 else norm

    def comp(val: float) -> float:
        return 0.0 if val < 0.0 else 1.0 if val > 1.0 else val

    r = comp(1.5 - abs(4 * x - 3))
    g = comp(1.5 - abs(4 * x - 2))
    b = comp(1.5 - abs(4 * x - 1))

    R = int(r * 255)
    G = int(g * 255)
    B = int(b * 255)
    return f"\033[38;2;{R};{G};{B}m"

def _build_bar(rssi: float, peak_freq: float, agc: int, qual: int) -> str:
    clamped = _clamp(rssi)
    norm = (clamped - MIN_RSSI) / RSSI_RANGE if RSSI_RANGE else 0.0
    if math.isnan(norm) or norm < 0:
        norm = 0.0
    elif norm > 1:
        norm = 1.0
    filled = int(norm * BAR_WIDTH)
    bar = "#" * filled + "-" * (BAR_WIDTH - filled)
    color = _color_for(norm)
    return f"{color}{bar}\033[0m {clamped:5.1f}dBm AGC:{agc}dB Qual:{qual} Peak:{peak_freq:6.1f}MHz"

@dataclass
class RateTracker:
    count: int = 0
    start_ms: int | None = None

    def update(self, timestamp_ms: int, bands_per_message: int) -> float | None:
        if not self.count:
            self.start_ms = timestamp_ms
            self.count = 1
            return None
        self.count += 1
        elapsed = (timestamp_ms - (self.start_ms or timestamp_ms)) / 1000.0
        if elapsed <= 0:
            return None
        return (self.count / elapsed) / float(bands_per_message)

@dataclass
class BandInfo:
    freq: float = 0.0
    rssi: float = MIN_RSSI
    peak_freq: float = 0.0
    sum_db: float = -100.0

class BarDisplay:
    """Manages per‑band state & ANSI bar rendering."""
    def __init__(self, capacity: int):
        self.slots: List[BandInfo] = [BandInfo() for _ in range(capacity)]
        self.first_print = True

    def update(self, center_freq_mhz: float, rssi: float, peak_freq: float, agc: int, qual: int) -> None:
        # Existing slot?
        for slot in self.slots:
            if slot.freq == center_freq_mhz:
                slot.rssi = rssi
                slot.peak_freq = peak_freq
                slot.agc = agc
                slot.qual = qual
                return
        # Empty slot?
        for slot in self.slots:
            if slot.freq == 0.0:
                slot.freq = center_freq_mhz
                slot.rssi = rssi
                slot.peak_freq = peak_freq
                slot.agc = agc
                slot.qual = qual
                return

    def render(self) -> None:
        if not self.first_print:
            print(f"\033[{len(self.slots)}F", end="")  # move cursor up
        else:
            self.first_print = False
        
        for slot in self.slots:        
            if slot.freq == 0.0:
                print("(waiting for band)" + " " * 10 + "\033[K")
            else:
                bar = _build_bar(slot.rssi, slot.peak_freq, slot.agc, slot.qual)
                print(f" Band {slot.freq:.0f} MHz: {bar}\033[K")

# For the websocket client library used in this script, see:
# See https://websockets.readthedocs.io/en/stable/reference/asyncio/client.html
def parse_and_handle(message: bytes, tracker: RateTracker, bands_per_message: int, file_prefix: str, display: BarDisplay) -> None:
    """Parse one combined ASCII+binary FFT message, update UI & append to log."""
    parts = message.split(b',')
    #if len(parts) < 6:
    if len(parts) < 8:
        return
    try:
        timestamp_ms = int(parts[0])
        center_freq_mhz = float(parts[1])
        pre_post = int(parts[2])
        antenna = int(parts[3])
        rssi = float(parts[4])
        agc = int(parts[5])
        qual = int(parts[6])
    except (ValueError, IndexError, TypeError):
        return
    if math.isnan(rssi):
        rssi = MIN_RSSI

    # Find binary start: skip 7 ASCII tokens + 7 commas
    start = sum(len(p) for p in parts[:7]) + 7  # commas
    fft_raw = message[start: start + FFT_POINTS]
    if len(fft_raw) < FFT_POINTS:
        return
    
    fft_values = [(b / 4.0) + 10.0 for b in fft_raw]    # Convert to dB
    #fft_linear = [10 ** (v / 10.0) for v in fft_values] # Convert to linear

    # Compute peak frequency (bin spacing 50MHz span assumed)
    peak_val = max(fft_values)
    peak_index = fft_values.index(peak_val)
    peak_freq = (peak_index - (FFT_POINTS / 2)) * (50.0 / FFT_POINTS) + center_freq_mhz

    if tracker.update(timestamp_ms, bands_per_message) is not None:
        display.update(center_freq_mhz, rssi, peak_freq, agc, qual)
        display.render()

    # Rotate logs hourly (UTC)
    now = dt.datetime.now(dt.timezone.utc)
    path = Path(f"{file_prefix}{now:%Y-%m-%d-%H}.txt")
    with path.open("a", encoding="utf-8") as fh:
        fh.write(f"{timestamp_ms} {center_freq_mhz:.3f} {pre_post} {antenna} {rssi:.1f} {agc} {qual}")
        fh.write("".join(f" {v:.2f}" for v in fft_values))
        fh.write("\n")

async def client(host: str, user: str, password: str, 
                 min_rate_ms: int, bands: str, file_prefix: str, 
                 fftType: int, antNum: int,
                 display: BarDisplay) -> None:
    # Connect the binary websocket and get L1, L2, L5, and E6 band data at 200ms (5Hz).
    # The rate is not guaranteed, but the server will try to send data at this rate and
    # will not exceed it. Change the rate setting (in ms) for slower data rates, e.g. 1000ms
    # for 1Hz. 200ms/5Hz is the fastest rate supported.
    uri  = f"ws://{user}:{password}@{host}/ws/rfBinSpectrumAnalyzer?rfBand={bands}&minRate={min_rate_ms}"
    uri += f"&fftType={fftType}&antNum={antNum}"

    print('Get data from:', uri)
    # Note compression='deflate' is default and triggers support of the 'permessage-deflate'
    # extension to web sockets. I've explicitly included it here for clarity. Enabling
    # this reduces the over-the-wire bandwidth by about 25%. We get the 2048 point FFT
    # in around 1.5kB per epoch.
    async with websockets.connect(uri, compression='deflate') as websocket:
        # Only handle SIGTERM gracefully; let SIGINT raise KeyboardInterrupt immediately.
        loop = asyncio.get_running_loop()
        def _graceful_term() -> None:
            if not websocket.closed:
                asyncio.create_task(websocket.close())
        try:
            loop.add_signal_handler(signal.SIGTERM, _graceful_term)
        except NotImplementedError:
            pass
        tracker = RateTracker()
        bands_per_message = len(bands.split(','))
        async for raw in websocket:
            parse_and_handle(raw, tracker, bands_per_message, file_prefix, display)


def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description='WebSocket FFT Client')
    p.add_argument('--ip', default='10.1.149.249', help='Receiver IP / host')
    p.add_argument('--prefix', default='webSocketFFT-', help='Log file name prefix')
    p.add_argument('--username', default='admin', help='Username')
    p.add_argument('--password', default='password', help='Password')
    p.add_argument('--rate', type=int, default=MIN_INTERVAL_FASTEST_MS, help='Minimum interval (ms) between FFT sets (fastest={MIN_INTERVAL_FASTEST_MS}ms)')
    p.add_argument('--bands', default=DEFAULT_BANDS, help=f'Comma list of RF bands (default {DEFAULT_BANDS})')
    p.add_argument('--fftType', type=int, default=0, help=f'Pre-mitigation (0) or Post-mitigation (1) (default 0)')
    p.add_argument('--antNum', type=int, default=0, help=f'primary antenna (0) or secondary antenna (1)  (default 0)')
    return p


def main(argv: Iterable[str] | None = None) -> int:
    args = build_parser().parse_args(list(argv) if argv is not None else None)
    min_rate_ms = max(args.rate, MIN_INTERVAL_FASTEST_MS)
    display = BarDisplay(args.bands.count(',') + 1)
    while True:
        try:
            asyncio.run(client(args.ip, 
                               args.username, args.password, 
                               min_rate_ms, 
                               args.bands,
                               args.prefix,
                               args.fftType,
                               args.antNum,
                               display))
        except KeyboardInterrupt:
            print("\nInterrupted. Exiting.")
            return 0
        except Exception as exc:  # Broad to keep field tool resilient
            print("WebSocket problem - restarting:", exc)
            time.sleep(1)


if __name__ == '__main__':  # pragma: no cover
    raise SystemExit(main())

