import sys
import os
import time
import datetime
import signal
import threading
import json
import RXTools
import argparse
import logging
import xmltodict
from logging.handlers import RotatingFileHandler
from flask import Flask, render_template, render_template_string, Markup, Response, request, jsonify, send_from_directory

# To run from a fresh Anaconda 3.X python install:
# conda install -c conda-forge xmltodict
# conda install -c conda-forge flask-caching
# conda install -c conda-forge flask-compress

# Makes sure that the static data doesn't get reloaded - e.g. the graph
# PNG was reloading and flickering on one browser before this was added
try:
  from flask_caching import Cache
  cacheConfig = {
    "DEBUG": True,          # some Flask specific configs
    "CACHE_TYPE": "simple", # Flask-Caching related configs
    "CACHE_DEFAULT_TIMEOUT": 300
  }
  cacheSupported = True
except:
  cacheSupported = False


# Flash compress isn't loaded on all systems, only install when it is
# available. It will gzip data over HTTP
try:
  # This is not installed as standard by the anaconda python
  # installation. To install:
  from flask_compress import Compress
  CompressSupported = True
except:
  CompressSupported = False

try:
  from urllib.request import urlopen
except ImportError:
  from urllib2 import urlopen

# Sleep time in seconds between system data refresh
sysRefreshTime = 0.5

# Make sure we don't send stale data
sysDataTimeout = 10

# When no requests have occurred for the following number of
# seconds stop requesting data from the receiver(s)
stationTimeout = 8

# Port at which the service will run - don't run a 
# second instance of this without changing this!!
webPort = 81
enableServerLogs = True

# Use the current time to initialize a few variables
now = datetime.datetime.now()
# First dimension for the band. The second dimension is
# for the FFT tyoe (raw, average, max)
secondsSysRequest = []
for i in range(3):
  secondsSysRequest.append([now,now,now])

ThreadsActive = True

app = Flask(__name__)
if(CompressSupported == True):
  # Print to the terminal as the log hasn't been setup yet
  print("HTTP Compression Installed")
  Compress(app)

if(cacheSupported):
  print("HTTP Cache Installed")
  cache = Cache(app,config=cacheConfig)

# This intercepts the web requests, it allows us to filter and
# provides a central place for logging
@app.before_request
def limit_remote_addr():
  if(request.remote_addr == '10.1.187.20'):
    # Block Trimble's scanner which messes up the logs
    return "", 403
  else:
    if(not "sysData.json" in request.url):
      # Don't clog the log with the regular web requests
      if(enableServerLogs == True):
        app.logger.info("%s %s" % (request.remote_addr, request.url))

# If we get a 404 (page not found) log an error
@app.errorhandler(404)
def handle404(error):
  if(enableServerLogs == True):
    app.logger.error("%s %s" % (request.remote_addr, request.url))
  return "", 404

@app.route("/")
def main():
  return index()

@app.route('/favicon.ico')
def favicon():
  return send_from_directory(os.path.join(app.root_path, 'static'),
                             'favicon.ico', mimetype='image/vnd.microsoft.icon')

@app.route("/stations.json")
def stations():
  filteredStations = []
  for i in range(len(stations)):
    if(len(groups) == 0):
      filteredStations.append(stations[i])
    elif( int(stations[i]['group']) == groupID):
      filteredStations.append(stations[i])
  jsonData = jsonify(filteredStations)
  return jsonData

@app.route("/index.html")
def index():
  webStr  = ("{% extends 'index.html' %}")
  return render_template_string(Markup(webStr))


def signal_handler(signal, frame):
  global ThreadsActive
  print('You pressed Ctrl+C!')
  ThreadsActive = False

  time.sleep(5)
  for i in range(len(threadsSYSdata)):
    threadsSYSdata[i].join()

  print("Exiting")
  sys.exit(0)

def workerSys(num):
  while(ThreadsActive == True):
    resp = []
    global secondsSysRequest

    now = datetime.datetime.now()
    delta = []
    for i in range(3): # Band
      data = []
      for k in range(3): # FFT Type
        data.append((now - secondsSysRequest[i][k]).total_seconds())
      delta.append(data)
   
    try:
      for band in range(3): # L1, L2, and L5
        for fftType in range(3): # raw, average, max
          if( delta[band][fftType] < stationTimeout):

            if(band==0):
              bandStr = 'L1'
            elif(band==1):
              bandStr = 'L2'
            else:
              bandStr = 'L5'
           
            if(fftType == 0):   # Raw
              url = '&filterMode=NoFilter'
            elif(fftType == 1): # Filterd
              url = '&filterMode=Filter'
            else:               # Max hold
              url = '&filterMode=MaxHold'

            try:
              resp = RXTools.SendHttpGet( stations[num].get("addr"),
                                   '/xml/dynamic/rfSpectrumAnalyzer.xml?rfBand=' + bandStr + url,
                                   stations[num].get("user"),
                                   stations[num].get("pw"),
                                   verbose=False)
            except:
              app.logger.info("Exception Task(SYS) %d:%s" % (num, stations[num].get("long")))
            
            if(resp):
              sysData = xmltodict.parse(resp)
              
              # Store the dict() for later use (we'll convert to JSON and
              # output)
              jsonSYSData[num][band][fftType] = sysData
              jsonSYSDataTime[num][band][fftType] = datetime.datetime.now()

      time.sleep(sysRefreshTime)
    except:
      # Try again in a second
      time.sleep(sysRefreshTime)
  return

@app.route("/sysData.json")
def sysData():
  global secondsSysRequest

  select = 0 # Assume L1

  band = request.args.get('L2')
  if(band is not None):
    select = 1
  else:
    band = request.args.get('L5')
    if(band is not None):
      select = 2

  fftArg = request.args.get('Type')
  fftType = 0 # Default to raw FFT
  if(fftArg is not None):
    if(fftArg == 'Filtered'):
      fftType = 1
    elif(fftArg == 'Max'):
      fftType = 2

  output = []

  for i in range(len(jsonSYSData)): # Loop for each station
    delta = (datetime.datetime.now() - jsonSYSDataTime[i][select][fftType]).total_seconds()
    if( (len(jsonSYSData[i][select][fftType]) > 0)  and (delta < sysDataTimeout) ):
      jsonSYSData[i][select][fftType]['name'] = stations[i]['long']
      output.append( jsonSYSData[i][select][fftType] )

  # Get the time of the request
  secondsSysRequest[select][fftType] = datetime.datetime.now()  

  jsonData = jsonify(output)
  return jsonData

if __name__ == "__main__":
  ######################################################################
  # Parse arguments
  parser = argparse.ArgumentParser(description='Web Application that plots FFTs from a set of GNSS receivers')
  parser.add_argument('-s','--stations', help='Filename of the station JSON e.g. --stations stations.json')
  parser.add_argument('-p','--port', help='Optional HTTP port, by default 81 e.g. --port 81')
  parser.add_argument('-f','--fftRate', help='Optionally control how frequent in seconds we request FFT data (default 0.5) --sysRate 0.5')
  args = parser.parse_args()
  ######################################################################

  if(args.fftRate):
    sysRefreshTime = float(args.fftRate)

  # Load the list of stations
  if(args.stations):
    with open(args.stations,'r') as f: 
      data = json.load(f)

    stations = []
    groups = []
    gotStations = False
    # New JSON format
    for i in range(len(data)):
      if('groupDef' in data[i]):
        groups = data[i]['groupDef']
      elif('RX' in data[i]):
        stations = data[i]['RX']
        gotStations = True
    
    # Old format
    if(gotStations == False):
      stations = data
  else:
    print('require a station JSON file')
    sys.exit(1)

  longName = []
  for i in range(len(stations)):
    longName.append(stations[i]['long'])

  # By changing to a set and back to a list 
  # we get a unique list
  longUnique = list(set(longName))
  if(len(longUnique) != len(longName)):
    # long name should be unique
    print('Exiting - Long name needs to be unique')
    # When we create a unique list the order will change.
    # Therefore sort both the original and unique lists
    # so they are in the same order and it is easy to
    # find the duplicates
    print('List of input names\n',sorted(longName))
    print('List of unique input names\n',sorted(longUnique))
    duplicates = list(set([x for x in longName if longName.count(x) > 1]))
    print('List of duplicate names\n',sorted(duplicates))
    sys.exit(10)

  now = datetime.datetime.now()
  groupRequestTime = []
  if(len(groups) > 0):
    groupID = int(groups[0]['group'])

    # Find the maximum group ID
    maxID = 0
    for i in range(len(groups)):
      if(int(groups[i]['group']) > maxID):
        maxID = int(groups[i]['group'])

    # Setup the list so we can directly index via the group setting
    for i in range(maxID + 1):
      groupRequestTime.append(now)
  else:
    groupID = None
    groupRequestTime.append(now)

  if(args.port):
    webPort = int(args.port)

  signal.signal(signal.SIGINT, signal_handler)
  print("Setup CTRL-C handler")
  threadsSYSdata = []

  # [station][band][fft type]
  jsonSYSData = []
  jsonSYSDataTime = []
  # Kick off the threads
  for i in range(len(stations)):
    # Holds the XML (converted to JSON) from the receiver
    jsonSYSData.append([ [dict(),dict(),dict()],
                         [dict(),dict(),dict()],
                         [dict(),dict(),dict()] ])
      
    # Holds the time we received the XML from the receiver
    # set to a time in the past
    defTime = datetime.datetime(2010,1,1,0,0,0)
    jsonSYSDataTime.append([ [defTime,defTime,defTime],
                             [defTime,defTime,defTime],
                             [defTime,defTime,defTime]])

    t = threading.Thread(target=workerSys, args=(i,))
    threadsSYSdata.append(t)
    t.start()
 
  if(enableServerLogs == True):
    formatter = logging.Formatter(
          "[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s")
    handler = RotatingFileHandler('webApp.log', maxBytes=(1024*1024), backupCount=128)
    handler.setLevel(logging.INFO)
    handler.setFormatter(formatter)
    app.logger.addHandler(handler)
    app.logger.setLevel(logging.INFO)
    log = logging.getLogger('werkzeug')
    log.setLevel(logging.INFO)
  
  try:
    app.run(host='0.0.0.0',port=webPort,threaded=True,debug=True)
  except:
    sys.exit(1)

