import matplotlib
# Allow running headless from the command line
matplotlib.use("agg")
import numpy as np
import pandas as pd
import seaborn as sns
import datetime

from pylab import *
import os
import json
import io
from PIL import Image


def savePlot(directory,name):
  # Save the data as a PNG file
  if not os.path.exists(directory):
    os.makedirs(directory)

  # Save the png to memory
  ram = io.BytesIO()
  savefig(ram,format='png',dpi=600)
  ram.seek(0)
  # Compress the png data before saving to the file system, the
  # matplotlib png's are not well compressed
  im = Image.open(ram)
  im2 = im.convert('RGB').convert('P', palette=Image.ADAPTIVE)
  im2.save(directory + "/" + name,format='PNG')


def createSlipPlot(minSV,maxSV,directory,elevMask,system,freq,signal,sigStr,plotBase,plotRateAsSize):
  dataMin = 1e10
  dataMax = 0
  frame   = None
  gotData = False

  if(elevMask == 30):
    nameBase = "slips30."
  else:
    nameBase = "slips."

  nameBase += str(system) + '.'
  nameBase += str(freq) + '.'
  nameBase += str(signal) + '.'

  # First we'll scan all the data, we'll use this to set the dynamic range of the circle diameters.
  # We could store the values, but the files are pretty small, so we'll just re-read the data later
  # before we plot it
  for i in range(minSV,maxSV+1):
    SVtype = '0'
    if(system == 10): # BDS
      if( (i <= 5) or (i >= 59) ):
        SVtype = '1'

    filename = os.path.join(directory,nameBase + SVtype +  ".SV" + str(i) + ".txt")
    if(os.path.isfile(filename) and (os.path.getsize(filename) > 0)):
      slipData = np.loadtxt(filename,dtype='int',usecols=(0,1,2,3,4,5,6,7))
      if(slipData.ndim == 2):

        DOY = np.array([(datetime.datetime(*x.astype(int)) - datetime.datetime(2018,1,1)).days for x in slipData[:,0:3]])
        # Prevent divide by zero - old script bug wrote 0 into the num
        # of epochs!
        index = np.where(slipData[:,7] > 0)[0]
        
        if(len(index) > 0):
          gotData = True
          rate = slipData[index,6].astype(float) / slipData[index,7].astype(float)
          SV = np.ones(len(DOY[index])) * i
          data = {'Satellite ID': SV.astype(int), 'rate':rate, 'Num Epochs':slipData[index,7], 'Time [Day since start of 2018]':DOY[index]}
          newframe = pd.DataFrame(data)
          if(frame is None):
            frame = newframe
          else:
            frame = frame.append(newframe)
        
          # Now compute the min/max data rate, we want to ignore any rate
          # that is 0 as we'll be forming log10() of the rate (log10(0) == -Inf)
          index = np.where((slipData[:,6] > 0) & (slipData[:,7] > 0))[0]
          if(len(index) > 0):
            tmpRate= slipData[index,6].astype(float) / slipData[index,7].astype(float)
            
            thisMin = np.min(tmpRate)
            if(thisMin < dataMin):
              dataMin = thisMin
            
            thisMax = np.max(tmpRate)
            if(thisMax > dataMax):
              dataMax = thisMax


  if(gotData):
    # ToDo - improve this - trying to get some distance between the no slip
    # case (which is really -Inf in a log scale).
    ZeroOffset = 0.01
    dataRange = np.log10(dataMax) - np.log10(dataMin) + ZeroOffset

    # Deal with the log10(0) cases
    rateTmp = np.array(frame['rate'])
    index = np.where(rateTmp == 0)[0]
    rateTmp[index] = ZeroOffset # Should be a log10()

    # Deal with the rest of the data
    newIndex = list(set(np.arange(len(rateTmp))).difference(index))
    rateTmp[newIndex] = np.log10(rateTmp[newIndex])
    rateTmp[newIndex] -= np.log10(dataMin)
    rateTmp[newIndex] /= dataRange

    # Now adjust all the data.
    scaleFactor = 1.0

    # if the data is in the range 0 - 1.0 the legend is rendered in the
    # range 0-1.2. By scaling by 1.2 we use the maximum range of the scale,
    # we'll adjust the axis labels and remove this scaling as well as
    # converting to a linear PPM - this is a hack and may change with
    # different versions of seaborn
    scaleRate = scaleFactor/np.max(rateTmp)
    rateTmp *= scaleRate
    # Update the data frame
    frame['rate'] = rateTmp

    # Perform similar scaling for the number of epochs. We'll reverse the
    # scaling after plotting and creating the initial legend
    epochsMin = np.min(frame['Num Epochs'])
    frame['Num Epochs'] -= epochsMin
    epochsMax = np.max(frame['Num Epochs'])
    frame['Num Epochs'] /= epochsMax
    #frame['Num Epochs'] *= 120000
    frame['Num Epochs'] *= scaleFactor


    print("rate Data",np.min(frame['rate']),np.max(frame['rate']))
    print("Epoch Data",np.min(frame['Num Epochs']),np.max(frame['Num Epochs']))

    #sns.set(style="whitegrid")
    sns.set(style="darkgrid")
    alpha = 0.1
    palette="rainbow"

    if(plotRateAsSize == True):
      # Rate == Size
      # Num Epochs == Hue
      plot = sns.relplot(x='Time [Day since start of 2018]' , 
                         y='Satellite ID',
                         hue='Num Epochs',
                         size="rate",
                         sizes=(5, 200), 
                         legend="brief",
                         alpha=alpha, 
                         data=frame, 
                         palette=palette)
    else:
      # Rate == Hue
      # Num Epochs == Size
      plot = sns.relplot(x='Time [Day since start of 2018]' , 
                         y='Satellite ID',
                         size='Num Epochs',
                         hue="rate",
                         sizes=(5, 200), 
                         legend="brief",
                         alpha=alpha, 
                         data=frame, 
                         palette=palette)

    # Now adjust the legend to remove the scaling, we
    # also convert the rate log to a linear scale
    state = 0
    leg = plot._legend
    for t in leg.texts:
      if(plotRateAsSize == True):
        if(state == 1):
          PPM  = dataRange * float(t.get_text()) / scaleRate
          PPM += np.log10(dataMin)
          PPM  = (10**PPM) * 1000000
          num = int(PPM)
          t.set_text(f"{num:,d}")
        elif(t.get_text() == 'rate'):
          # We're going to convert from a log10() rate to PPM
          t.set_text('PPM')
          state = 1
        elif(t.get_text() != 'Num Epochs'):
          # Add commas on the 1,000 bounds
          num  = float(t.get_text())
          num /= scaleFactor
          num *= epochsMax
          num += epochsMin
          num_i = int(num)
          t.set_text(f"{num_i:,d}")
      else:
        if(state == 1):
          num  = float(t.get_text())
          num /= scaleFactor
          num *= epochsMax
          num += epochsMin
          num_i = int(num)
          t.set_text(f"{num_i:,d}")
        elif(t.get_text() == 'Num Epochs'):
          state = 1
        elif(t.get_text() == 'rate'):
          t.set_text('PPM')
        elif(t.get_text() != 'rate'):
          PPM  = dataRange * float(t.get_text()) / scaleRate
          PPM += np.log10(dataMin)
          PPM  = (10**PPM) * 1000000
          num = int(PPM)
          t.set_text(f"{num:,d}")
    


    # Tweak the location of the legend - reduce the white space
    #leg.set_bbox_to_anchor([0.9, 0.8])


    # Now make the markers transparent for the num epochs. This
    # is a hack as I can't find out how to control the number of
    # markers. Currently it seems to be 4. We change the transparency
    # to match the data transparency.
    num = 0
    for lh in leg.legendHandles: 
      if(num < 5):
        lh.set_alpha(0.5)
      num += 1

    plot.fig.suptitle(sigStr + ' - Reported Cycle Slip Rate - Elev. Mask = ' + str(elevMask))
    #title(sigStr + ' - Reported Cycle Slip Rate - Elev. Mask = ' + str(elevMask))
    #tight_layout()
    savePlot('History',plotBase + '-' + nameBase + "png")
    #show()
  close()


def analyzeStation(signals,LocalDir,plotBase, plotRateAsSize=True):
  MaxSatType = len(signals['types'])
  for satType in range(MaxSatType):
    thisSat    = signals['types'][satType]
    satTypeStr = thisSat['satTypeStr']
    minSV = int(thisSat['min'])
    maxSV = int(thisSat['max'])
      
    MaxSig = len(thisSat['signals'])
    for sigType in range(MaxSig):
      system = int(thisSat['signals'][sigType]['satType'])
      freq   = int(thisSat['signals'][sigType]['freq'])
      signal = int(thisSat['signals'][sigType]['sig'])
      sigStr = thisSat['signals'][sigType]['sigStr']
    
      # Each BDS signal type is defined twice, once for MEO/IGOS
      # and again for GEO. We do this so we can break out the GEOs,
      # which have a shorter PDI. However, for the historical plots
      # we want to combine all BDS onto the same graph. To prevent 
      # processing each signal type twice, if the orbit is "1" (GEO)
      # don't process this data type. For the orbit type of 0 we'll 
      # adjust the file name to make sure we read in GEO data and
      # process all PRNs.
      if(satTypeStr == 'BDS'):
        if(thisSat['signals'][sigType]['orbit'] == '1'):
          continue
      
        # BDS differentiates GEO and MEO/IGSO - remove this part of the string, 
        # we'll process in one block
        index = sigStr.find(' (')
        sigStr = sigStr[0:index]

      sigStr = satTypeStr + ' - ' + sigStr

      print(system,freq,signal,sigStr)
      for elevType in range(2):
        if(elevType == 0):
          elevMask = 0
        else:
          elevMask = 30
        createSlipPlot(minSV,maxSV,LocalDir,elevMask,system,freq,signal,sigStr,plotBase,plotRateAsSize)


#############
#
# Code Start
#
#############

#elevMask = 0
#system = 9
#freq = 2
#signal = 0
#createSlipPlot('Singapore-Alloy',elevMask,system,freq,signal,'test','test')

if __name__ == "__main__":
  # Load the signal LUT JSON file 
  with open('signals.json', 'r') as f:
    signals = json.load(f)

  with open('receivers.json', 'r') as f:
    receivers = json.load(f)
  MaxRX = len(receivers['receivers'])

  for rx in range(MaxRX):
    LocalDir = receivers['receivers'][rx]['RXStr']
    plotBase = receivers['receivers'][rx]['rxShortName']
    analyzeStation(signals,LocalDir,plotBase,plotRateAsSize=False)

