
from ProcessResults import parse_sampleConfig, cdf_vals, get_rounded_cdf_percents
from mutils import *
from PIL import Image
from mutils.PosPacConst import *
import mutils.PosTypeConst as PosTypeConst
import matplotlib.pyplot as plt
import json
import sys
import os
import datetime
import io

##Headless = False
##if(Headless == True):
##  from matplotlib import use
  # Allow running headless from the command line
 ## use("agg")
# Save a plot as a PNG file. Matplotlib isn't very good at 
# compressing PNG image files. Therefore, use PIL functions
# to compress a memory version of the file. The compressed 
# images are close to half the size of the Matplotlib generated
# version
def savePlot(directory,name):
  # Save the data as a PNG file
  if not os.path.exists(directory):
    os.makedirs(directory)

  # Save the png to memory
  ram = io.BytesIO()
  plt.savefig(ram,format='png',dpi=600)
  ram.seek(0)
  # Compress the png data before saving to the file system, the
  # matplotlib png's are not well compressed
  im = Image.open(ram)
  im2 = im.convert('RGB').convert('P', palette=Image.ADAPTIVE)
  im2.save(directory + "/" + name,format='PNG')


# 10Hz data
dDataRate = 10

if(len(sys.argv) == 2):
  with open(sys.argv[1],'r') as f: 
    configData = json.load(f)
else:
  print("Usage:\n  python getStats.py config.json")
  sys.exit()

RX = configData['RX']

dataOut = {}
yieldData = {}
dataSummary = []

dataRoot = "/net/fermion/mnt/data_drive/SpirentTest/DataDir/"

# We are going to process the latest data set from each sample set
# defined by the config JSON file. The test system uses a symlink, so 
# we just need to look at the "latest-" link. However, if we expand
# the link we can get the ID of each run processed, this will be useful 
# to archive. If there's a problem we may want to look at the 
# individual data sets.
processedDataSet = []

for dataset in configData['dataSets']:
  # Get the data segments by loading the config file. After each RF test a symlink is 
  # generated to the latest data set.
  XMLConfig = dataRoot + 'latest-RX' + str(RX[0]['ID']) + '-' + dataset + '/' + dataset + '.xml'

  # Now break out the run ID from the full path (expand the symlink)
  fullPath= os.path.realpath(XMLConfig)
  RunID = (fullPath.split('/')[-2]).split('-')[1]
  processedDataSet.append(RunID + ' ' + dataset + '.xml')

  # Parse the XML config file
  config = parse_sampleConfig(XMLConfig)

  dataSummary.append(config.desc)

  # Load the truth
  pp=doload(config.truth_file)

  for DUT in RX:
    thisRX = DUT['label']
    # Which position record do we want to parse
    if('rec' in DUT):
      rec = DUT['rec']
    else:
      rec = '-d35:2'

    if(not thisRX in dataOut):
      dataOut[thisRX] = {}
    
    if(not thisRX in yieldData):
      yieldData[thisRX] = {}

    dataPath = dataRoot + 'latest-RX' + str(DUT['ID']) + '-' + dataset 
    print(dataPath)
    try:
      d = vd2cls(dataPath + '/*.T04',rec)
    except:
      print("Problem with ",dataPath)
      continue

    # Match the truth and DUT data time tags, throw away any epochs that
    # don't match
    _,i_pp,i_d = intersect( pp[:,dPP_TIME], d.TIME )
    if len(i_pp) != len(i_d):
      print("Position Error")

    ref_pos = pp[i_pp,dPP_LAT:dPP_velU+1]
    d = d[i_d,:]

    spans = {}
    envTypes = []
    for desc,start_stops in config.span:

      # We want per environment data, don't process the "All" data
      if(desc == 'All'):
        continue
      # Both Freeway and Freeways are used - combine them
      # similiar issue with downtown San Jose
      # ...
      if(desc == 'Freeways'):
        desc = 'Freeway'
      elif(desc == 'DowntownSJ'):
        desc = 'DowntownSanJose'
      elif(desc == 'Suburban-WillowGlen'):
        desc = 'Urban'
      elif(desc == 'Suburban-Sunnyvale'):
        desc = 'Urban'
    
      if(desc not in envTypes):
        envTypes.append(desc)
      if not desc in spans:
        spans[desc] = zeros(len(d),dtype=bool)

      if not desc in yieldData[thisRX]:
        yieldData[thisRX][desc] = {}
        yieldData[thisRX][desc]['expected']  = 0
        yieldData[thisRX][desc]['actual']    = 0
        yieldData[thisRX][desc]['Stinger']   = 0
        yieldData[thisRX][desc]['Precision'] = 0

      for start,stop in start_stops:
        i_d = find( (d.TIME>=start) & (d.TIME<=stop) )
        spans[desc][i_d] = True
        yieldData[thisRX][desc]['expected'] += (stop - start)*dDataRate + 1
        yieldData[thisRX][desc]['actual']   += len(i_d)

        stingerEpochs = find(   (d[i_d].FIXTYPE <= PosTypeConst.dRec29_fixType_SBAS)
                              | (d[i_d].FIXTYPE >= PosTypeConst.dRec29_fixType_QZSS_SLAS)
                              | (d[i_d].FIXTYPE == PosTypeConst.dRec29_fixType_GVBS)
                              | (   (d[i_d].FIXTYPE >= PosTypeConst.dRec29_fixType_KF_auto) 
                                  & (d[i_d].FIXTYPE <= PosTypeConst.dRec29_fixType_KF_SBASplus) ) ) 

        yieldData[thisRX][desc]['Stinger']   += len(stingerEpochs)
        # For the regression system everything else is the precision engine (aka Astra/Titan). More
        # generally this may not be the case, e.g. OmniSTAR
        yieldData[thisRX][desc]['Precision'] += len(i_d) - len(stingerEpochs)

    for env in envTypes:
      i = find(spans[env])
      print(thisRX,env,len(i))

      (dE, dN, dU) = llh2enu( d[i,d.k.LAT:d.k.LAT+3],ref_pos[i,:3] )

      err_2d = sqrt( dE**2 + dN**2 )
      err_3d = sqrt( dE**2 + dN**2 + dU**2 )
      if(not env in dataOut[thisRX]):
        dataOut[thisRX][env] = {'2d':err_2d,'3d':err_3d}
      else:
        dataOut[thisRX][env]['2d'] = np.concatenate((dataOut[thisRX][env]['2d'],err_2d))
        dataOut[thisRX][env]['3d'] = np.concatenate((dataOut[thisRX][env]['3d'],err_3d))
    
# We have all the data - now form the plots

envs = []
engineList = []
print(yieldData)

for engine in dataOut:
  engineList.append(engine)
  envList= list(dataOut[engine].keys())

  for thisEnv in envList:
    envs.append(thisEnv)

envs = list(set(envs))

if not os.path.exists("Trend"):
  os.makedirs("Trend")

now = datetime.datetime.now()
YYYYMMDD = str(now.year) + '-' + str(now.month).zfill(2) + '-' + str(now.day).zfill(2)
if not os.path.exists(YYYYMMDD):
  os.makedirs(YYYYMMDD)

fid = open(YYYYMMDD + '/' + YYYYMMDD + '-stats.csv','w')

labelList = []
for DUT in RX:
  labelList.append(DUT['label'])


# 0 -- 2D
# 1 -- 3D
for posType in range(2):
  for thisEnv in envs:
    i = 0
    print(thisEnv)

    if(posType == 0):
      typeStr = '2D'
    else:
      typeStr = '3D'
    
    fid.write(typeStr + '\n')
    fid.write(thisEnv + ',')

    for j in range(len(cdf_vals)):
      fid.write(str(cdf_vals[j]) + ',')
    fid.write('expected,actual,yield,Stinger Yield,Astra/Titan Yield,hours\n')

    fig, ax = plt.subplots(1, 1)

    for engine in dataOut:
      if(posType == 0):
        cx2,cy2 = docdf( dataOut[engine][thisEnv]['2d'] )
      else:
        cx2,cy2 = docdf( dataOut[engine][thisEnv]['3d'] )

      plt.plot(cx2, cy2,label=engineList[i] + ' (' + "{0:.1f}".format(len(cx2)/dDataRate) + ' Secs)')
      stats = get_rounded_cdf_percents( cx2, cy2 )

      fid.write(engineList[i] + ',')

      index = labelList.index(engine)
      RXID=RX[index]['ID']

      stinger = False
      if('rec' in RX[index]):
        if(RX[index]['rec'] == '-d35:16'):
          stinger = True

      trendStr = 'Trend/RX' + str(RXID) + '-' 
      if(stinger == True):
        trendStr += 'Rec35_16-'
      else:
        trendStr += 'Rec35_2-'
      trendStr += thisEnv + '.' + typeStr + '.csv'
      
      # Trend file to append
      trend = open(trendStr,'a')
      for loop in range(2):
        if(loop == 0):
          handle = fid
        else:
          handle = trend
          handle.write("%d,%d,%d," % (now.year,now.month,now.day) )

        for j in range(len(stats)):
          print(engineList[i],cdf_vals[j],stats[j],len(cx2))
          handle.write(str(stats[j]) + ',')
        
        handle .write(str(yieldData[engine][thisEnv]['expected']) + ',')
        handle.write(str(yieldData[engine][thisEnv]['actual']) + ',')
        handle.write("%.3f," % (yieldData[engine][thisEnv]['actual'] / yieldData[engine][thisEnv]['expected']))
        handle.write("%.3f," % (yieldData[engine][thisEnv]['Stinger'] / yieldData[engine][thisEnv]['actual']))
        handle.write("%.3f," % (yieldData[engine][thisEnv]['Precision'] / yieldData[engine][thisEnv]['actual']))
        handle.write("%.3f\n" % (yieldData[engine][thisEnv]['actual'] / (3600 * dDataRate)))
      trend.close()
      i += 1

    plt.grid(True)
    plt.legend(fontsize=8)
    if(thisEnv == 'Trimble'):
      # Make the environment a bit more obvious
      plt.title(typeStr + ': Trimble Campus')
    else:
      plt.title(typeStr + ': ' + thisEnv)
    xRange = ax.get_xlim()
    plt.xlim(0,xRange[1])
    plt.ylim(0,1)
    plt.xlabel('Error [m]')
    plt.ylabel(typeStr + ' CDF')
    plt.tight_layout()
    savePlot(YYYYMMDD,YYYYMMDD + '_CDF_' + thisEnv + '_' + typeStr + '.png')

    if(xRange[1] > 1.0):
      plt.xlim(0,1)
      plt.tight_layout()
      savePlot(YYYYMMDD,YYYYMMDD + '_CDF_' + thisEnv + '_' + typeStr + '-Max1m.png')
    
    if(xRange[1] > 0.20):
      plt.xlim(0,0.2)
      plt.tight_layout()
      savePlot(YYYYMMDD,YYYYMMDD + '_CDF_' + thisEnv + '_' + typeStr + '-Max20cm.png')

    plt.show()
    plt.close()
    fid.write("\n\n")

fid.close()

print("The following data sets were combined")
for dataSet in sort(dataSummary):
  print(dataSet)

fid = open(YYYYMMDD + '/' + YYYYMMDD + '-datasets.txt','w')
for dataSet in processedDataSet:
  fid.write("%s\n" % dataSet)
fid.close()




