

from ProcessResults import parse_sampleConfig, cdf_vals, get_rounded_cdf_percents
from mutils import *
from mutils.PosPacConst import *
import matplotlib.pyplot as plt
import json
import sys
import mutils.PosTypeConst as PosTypeConst

##Headless = False
##if(Headless == True):
##  from matplotlib import use
  # Allow running headless from the command line
 ## use("agg")

if __name__ == "__main__":
  # Assume 10Hz data - can be overriden via the config JSON
  dDataRate = 10

  if(len(sys.argv) == 2):
    with open(sys.argv[1],'r') as f: 
      configData = json.load(f)
  else:
    print("Usage:\n  python getStats.py config.json")
    sys.exit()

  dataOut = {}
  yieldData = {}
  dataSummary = []

  prefix = configData['prefix']

  for dataset in configData['dataSets']:
    print(dataset)

    # Get the data segments by loading the config file
    config = parse_sampleConfig(dataset['config'])

    dataSummary.append(config.desc)

    # Load the truth
    pp=doload(dataset['truth'])

    for DUT in dataset['DUTs']:
      thisRX = DUT['label']
      if('dataRate' in DUT):
        dataRate = float(DUT['dataRate'])
      else:
        dataRate = dDataRate

      # Which position record do we want to parse
      if('rec' in DUT):
        rec = DUT['rec']
      else:
        rec = '-d35:2'

      if(not thisRX in dataOut):
        dataOut[thisRX] = {}
        dataOut[thisRX]['data'] = {}
        dataOut[thisRX]['dataRate'] = dataRate
      
      if(not thisRX in yieldData):
        yieldData[thisRX] = {}

      d = vd2cls(DUT['file']+ '/*.T04',rec)

      # Match the truth and DUT data time tags, throw away any epochs that
      # don't match
      _,i_pp,i_d = intersect( pp[:,dPP_TIME], d.TIME )
      if len(i_pp) != len(i_d):
        print("Position Error")

      ref_pos = pp[i_pp,dPP_LAT:dPP_velU+1]
      d = d[i_d,:]

      spans = {}
      envTypes = []
      for desc,start_stops in config.span:

        # We want per environment data, don't process the "All" data
        if(desc == 'All'):
          continue
        # Both Freeway and Freeways are used - combine them
        # similiar issue with downtown San Jose
        # ...
        if(desc == 'Freeways'):
          desc = 'Freeway'
        elif(desc == 'DowntownSJ'):
          desc = 'DowntownSanJose'
        elif(desc == 'Suburban-WillowGlen'):
          desc = 'Urban'
        elif(desc == 'Suburban-Sunnyvale'):
          desc = 'Urban'
      
        if(desc not in envTypes):
          envTypes.append(desc)
        if not desc in spans:
          spans[desc] = zeros(len(d),dtype=bool)

        if not desc in yieldData[thisRX]:
          yieldData[thisRX][desc] = {}
          yieldData[thisRX][desc]['expected']  = 0
          yieldData[thisRX][desc]['actual']    = 0
          yieldData[thisRX][desc]['Stinger']   = 0
          yieldData[thisRX][desc]['Precision'] = 0

        for start,stop in start_stops:
          i_d = find( (d.TIME>=start) & (d.TIME<=stop) )
          spans[desc][i_d] = True
          yieldData[thisRX][desc]['expected'] += (stop - start)*dataRate + 1
          yieldData[thisRX][desc]['actual']   += len(i_d)

          stingerEpochs = find(   (d[i_d].FIXTYPE <= PosTypeConst.dRec29_fixType_SBAS)
                                | (d[i_d].FIXTYPE >= PosTypeConst.dRec29_fixType_QZSS_SLAS)
                                | (d[i_d].FIXTYPE == PosTypeConst.dRec29_fixType_GVBS)
                                | (   (d[i_d].FIXTYPE >= PosTypeConst.dRec29_fixType_KF_auto) 
                                    & (d[i_d].FIXTYPE <= PosTypeConst.dRec29_fixType_KF_SBASplus) ) ) 
          yieldData[thisRX][desc]['Stinger']   += len(stingerEpochs)
          # For the regression system everything else is the precision engine (aka Astra/Titan). More
          # generally this may not be the case, e.g. OmniSTAR
          yieldData[thisRX][desc]['Precision'] += len(i_d) - len(stingerEpochs)

      for env in envTypes:
        i = find(spans[env])
        print(thisRX,env,len(i))

        (dE, dN, dU) = llh2enu( d[i,d.k.LAT:d.k.LAT+3],ref_pos[i,:3] )

        err_2d = sqrt( dE**2 + dN**2 )
        if(not env in dataOut[thisRX]['data']):
          dataOut[thisRX]['data'][env] = err_2d
        else:
          dataOut[thisRX]['data'][env] = np.concatenate((dataOut[thisRX]['data'][env],err_2d))
      
  # We have all the data - now form the plots

  envs = []
  engineList = []
  print(yieldData)

  for engine in dataOut:
    engineList.append(engine)
    envList= list(dataOut[engine]['data'].keys())

    for thisEnv in envList:
      envs.append(thisEnv)

  envs = list(set(envs))


  fid = open(prefix + '-stats.csv','w')

  for thisEnv in envs:
    i = 0
    print(thisEnv)

    fid.write(thisEnv + ',')
    for j in range(len(cdf_vals)):
      fid.write(str(cdf_vals[j]) + ',')
    fid.write('expected,actual,yield,Stinger Yield,Astra/Titan Yield,hours\n')

    fig, ax = plt.subplots(1, 1)

    for engine in dataOut:
      dataRate = dataOut[engine]['dataRate']
      cx2,cy2 = docdf( dataOut[engine]['data'][thisEnv] )
      plt.plot(cx2, cy2,label=engineList[i] + ' (' + "{0:.1f}".format(len(cx2)/dataRate) + ' Secs)')
      stats = get_rounded_cdf_percents( cx2, cy2 )

      fid.write(engineList[i] + ',')
      for j in range(len(stats)):
        print(engineList[i],cdf_vals[j],stats[j],len(cx2))
        fid.write(str(stats[j]) + ',')
      
      fid.write(str(yieldData[engine][thisEnv]['expected']) + ',')
      fid.write(str(yieldData[engine][thisEnv]['actual']) + ',')
      fid.write("%.3f," % (yieldData[engine][thisEnv]['actual'] / yieldData[engine][thisEnv]['expected']))
      fid.write("%.3f," % (yieldData[engine][thisEnv]['Stinger'] / yieldData[engine][thisEnv]['actual']))
      fid.write("%.3f," % (yieldData[engine][thisEnv]['Precision'] / yieldData[engine][thisEnv]['actual']))
      fid.write("%.3f\n" % (yieldData[engine][thisEnv]['actual'] / (3600 * dataRate)))
      i += 1

    plt.grid(True)
    plt.legend(fontsize=8)
    if(thisEnv == 'Trimble'):
      # Make the environment a bit more obvious
      plt.title('Trimble Campus')
    else:
      plt.title(thisEnv)
    xRange = ax.get_xlim()
    plt.xlim(0,xRange[1])
    plt.ylim(0,1)
    plt.xlabel('Error [m]')
    plt.ylabel('2D CDF')
    plt.tight_layout()
    plt.savefig(prefix + '_CDF_' + thisEnv + '.png',format='png',dpi=600)

    if(xRange[1] > 1.0):
      plt.xlim(0,1)
      plt.tight_layout()
      plt.savefig(prefix + '_CDF_' + thisEnv + '-Max1m.png',format='png',dpi=600)
    
    if(xRange[1] > 0.20):
      plt.xlim(0,0.2)
      plt.tight_layout()
      plt.savefig(prefix + '_CDF_' + thisEnv + '-Max20cm.png',format='png',dpi=600)

    plt.show()
    plt.close()
    fid.write("\n\n")

  fid.close()

  print("The following data sets were combined")
  for dataSet in sort(dataSummary):
    print(dataSet)


