import PospacAnalysis as pa 
import matplotlib.pyplot as plt
import numpy as np
import mutils as m
from mutils.PosPacConst import *
import mutils.PosTypeConst as PosTypeConst
import math
from statistics import median
import ITRF as itrf 
from datetime import date
import glob, os
import pickle
import pandas as pd
from copy import copy
import dataframe_image as dfi
import plotly.graph_objects as go
import shutil



"""
This script processes a single dataset and plots comparisons between sources of truth files and cleaned/raw replay data

"""

#CONSTANTS
pdopThresh = 5

#Get the truth file with user inputs as to which truth source we want to run and whether we want to do adjustments
workingDir = pa.getDirFromUser()


#Get the config file and output the filepath for the user to select 
#Use this most of the time:
XMLConfig,truthName = pa.changeXMLFile(workingDir)
pp,RunConfig = pa.generateTruth(XMLConfig)
plotDirName = truthName.split('/')[-1].split('.')[0] #Get just the name of the truth file without the path and .txt extension
plotDirPath = workingDir + '/Plots/' + plotDirName
print(plotDirPath)
plotDirPathOG = copy(plotDirPath)

#Ask user whether they want ITRF translation
itrfAdjust = input("ITRF 2000 to 2014 Adjust? (y/n)?") #Do this for truth data on the server

#Filter the replay data by fix type and PDOP to get clean measurements
ref_pos,meas = pa.formatMeas(workingDir + '/ReplayData',pp) #Format the data into k,d concatenated into "meas," as defined in mutils
filterIdx = m.find( (meas.FIXTYPE == PosTypeConst.dRec29_fixType_RTK_fix) & (meas.PDOP <= pdopThresh))
meas = meas[filterIdx,:]
ref_pos = ref_pos[filterIdx,:]
#Do adjustments 
if itrfAdjust == 'y': 
    plotDirPath += '_itrf'
    #Clean the data by translating ITRF, removing overpass data and doing a median ENU correction
    #Get epoch
    yr,month,day = pa.getDateFromDir(workingDir)
    epoch = itrf.dateAsYear(date(yr,month,day))
    ref_pos[:,dPP_LAT:dPP_HGT+1] = pa.itrf2000_itrf2014(ref_pos,epoch) #Translate truth into ITRF2014 frame

# #Get the clear data
meas_clear, meas_overpasses, ref_clear, ref_overpasses = pa.getFreewayData(RunConfig,ref_pos,meas)

#Ask user whether they want to first do an along/cross track adjustment or a median adjustment
alongCrossAdjust, medianAdjust = 'n', 'n'
adjustment = input("Along cross first (ac), median first (m), or no adjust (n)? ")

#If we want to adjust the along/cross track error first
if adjustment  == 'ac':
    #plotDirPath = plotDirPathOG
    plotDirPath += '_alongCross'

    #All plots and get the AC median adjustment parameters
    (dE_clear,dN_clear,dU_clear) = m.llh2enu( meas_clear[:,meas_clear.k.LAT:meas_clear.k.LAT+3],ref_clear[:,dPP_LAT:dPP_HGT+1])
    err_2d,err_3d = pa.errorBarChart(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Get the 2D and 3D error and plot <=15cm error
    pa.enuErrorPlots(dE_clear,dN_clear,dU_clear,ref_clear,meas_clear,workingDir,plotDirPath) #Plot the ENU error 
    pa.enuHistograms(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Plot ENU histograms
    (x_v,y_v,z_v) = pa.plotCrossAlongError(dE_clear,dN_clear,dU_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Plot Along/Cross track error
    xv = median(x_v)
    yv = median(y_v)
    zv = median(z_v)
    

    #Do along cross update by replacing the LLH with the adjusted LLH
    ref_clear[:,dPP_LAT:dPP_HGT+1] = pa.alongCrossAdjust(x_v,y_v,z_v,ref_clear)

    #Check if we also want to do a median adjustment
    medianAdjust = input("Median Adjust? (y/n)")

    #Do median adjustment if necessary
    if medianAdjust == 'y':
        adjust = input("e, n, u, or enu? ")
        plotDirPath += '_median'
        plotDirPath += adjust.upper()
        ref_clear[:,dPP_LAT:dPP_HGT+1], enuCorr = pa.medianAdjust(meas_clear[:,meas_clear.k.LAT:meas_clear.k.LAT+3],ref_clear[:,dPP_LAT:dPP_HGT+1],adjust.lower())

    #REDO Plots after AC and median adjustments have been made
    (dE_clear,dN_clear,dU_clear) = m.llh2enu( meas_clear[:,meas_clear.k.LAT:meas_clear.k.LAT+3],ref_clear[:,dPP_LAT:dPP_HGT+1]) #Get ENU error 
    err_2d,err_3d = pa.errorBarChart(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Get the 2D and 3D error and plot <=15cm error
    pa.enuErrorPlots(dE_clear,dN_clear,dU_clear,ref_clear,meas_clear,workingDir,plotDirPath) #Plot the ENU error 
    pa.enuHistograms(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Plot ENU histograms
    (x_v,y_v,z_v) = pa.plotCrossAlongError(dE_clear,dN_clear,dU_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Plot Along/Cross track error
    pa.plotCDF(err_2d,err_3d,workingDir,plotDirPath) #Plot CDF of 2D and 3D error
    heightError = pa.plotHeightError(meas_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Height error

#If we want to start with median adjustment
elif adjustment == 'm':
    plotDirPath += '_median'
    adjust = input("e, n, u, or enu? ")
    plotDirPath += adjust.upper()
    ref_clear[:,dPP_LAT:dPP_HGT+1], enuCorr = pa.medianAdjust(meas_clear[:,meas_clear.k.LAT:meas_clear.k.LAT+3],ref_clear[:,dPP_LAT:dPP_HGT+1],adjust.lower())

    #Plots for the median adjustment
    (dE_clear,dN_clear,dU_clear) = m.llh2enu( meas_clear[:,meas_clear.k.LAT:meas_clear.k.LAT+3],ref_clear[:,dPP_LAT:dPP_HGT+1]) #Get ENU error 
    err_2d,err_3d = pa.errorBarChart(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Get the 2D and 3D error and plot <=15cm error
    (x_v,y_v,z_v) = pa.plotCrossAlongError(dE_clear,dN_clear,dU_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Plot Along/Cross track error
    xv = median(x_v)
    yv = median(y_v)
    zv = median(z_v)
    pa.enuErrorPlots(dE_clear,dN_clear,dU_clear,ref_clear,meas_clear,workingDir,plotDirPath) #Plot the ENU error 
    pa.enuHistograms(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Plot ENU histograms
    pa.plotCDF(err_2d,err_3d,workingDir,plotDirPath) #Plot CDF of 2D and 3D error
    heightError = pa.plotHeightError(meas_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Height error

    #Check if we also want to do an AC adjustment
    alongCrossAdjust = input("Along Cross Adjust? (y/n)")

    #Do along/cross adjustments if necessary
    if alongCrossAdjust == 'y':
        plotDirPath += '_alongCross'

        #Along/Cross Plots
        (dE_clear,dN_clear,dU_clear) = m.llh2enu( meas_clear[:,meas_clear.k.LAT:meas_clear.k.LAT+3],ref_clear[:,dPP_LAT:dPP_HGT+1]) #Get ENU error 
        err_2d,err_3d = pa.errorBarChart(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Get the 2D and 3D error and plot <=15cm error
        (x_v,y_v,z_v) = pa.plotCrossAlongError(dE_clear,dN_clear,dU_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Plot Along/Cross track error
        xv = median(x_v)
        yv = median(y_v)
        zv = median(z_v)

        #REDO Plots after AC and median adjustments have been made
        (dE_clear,dN_clear,dU_clear) = m.llh2enu( meas_clear[:,meas_clear.k.LAT:meas_clear.k.LAT+3],ref_clear[:,dPP_LAT:dPP_HGT+1]) #Get ENU error 
        err_2d,err_3d = pa.errorBarChart(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Get the 2D and 3D error and plot <=15cm error
        pa.enuErrorPlots(dE_clear,dN_clear,dU_clear,ref_clear,meas_clear,workingDir,plotDirPath) #Plot the ENU error 
        pa.enuHistograms(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Plot ENU histograms
        (x_v,y_v,z_v) = pa.plotCrossAlongError(dE_clear,dN_clear,dU_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Plot Along/Cross track error
        pa.plotCDF(err_2d,err_3d,workingDir,plotDirPath) #Plot CDF of 2D and 3D error
        heightError = pa.plotHeightError(meas_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Height error

#No adjustment
else:
    #All plots with no adjustments
    (dE_clear,dN_clear,dU_clear) = m.llh2enu( meas_clear[:,meas_clear.k.LAT:meas_clear.k.LAT+3],ref_clear[:,dPP_LAT:dPP_HGT+1]) #Get ENU error 
    err_2d,err_3d = pa.errorBarChart(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Get the 2D and 3D error and plot <=15cm error
    pa.enuErrorPlots(dE_clear,dN_clear,dU_clear,ref_clear,meas_clear,workingDir,plotDirPath) #Plot the ENU error 
    pa.enuHistograms(dE_clear,dN_clear,dU_clear,workingDir,plotDirPath) #Plot ENU histograms
    (x_v,y_v,z_v) = pa.plotCrossAlongError(dE_clear,dN_clear,dU_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Plot Along/Cross track error
    pa.plotCDF(err_2d,err_3d,workingDir,plotDirPath) #Plot CDF of 2D and 3D error
    heightError = pa.plotHeightError(meas_clear,ref_clear,err_2d,err_3d,workingDir,plotDirPath) #Height error

#Save to CSV for ease of future plots
if medianAdjust == 'y' or adjustment == 'm':
    eastAppend = enuCorr[0,0]
    northAppend = enuCorr[0,1]
    upAppend = enuCorr[0,2]
else:
    eastAppend = float('inf')
    northAppend = float('inf')
    upAppend = float('inf')

if alongCrossAdjust == 'y' or adjustment == 'ac':
    xvAppend = xv
    yvAppend = yv
    zvAppend = zv
else:
    xvAppend = float('inf')
    yvAppend = float('inf')
    zvAppend = float('inf')

#Make dataframe with error and heading
df = pd.DataFrame()
df['Lat'] = meas_clear.k.LAT
df['Lon'] = meas_clear.k.LAT+1
df['Height'] = meas_clear.k.LAT+2

df['dE'] = dE_clear
df['dN'] = dN_clear
df['dU'] = dU_clear
df['Heading'] = ref_clear[:,12]
df.to_csv(os.path.join(plotDirPath,'enu_errors.csv'),index = False)

#Make a dataframe with adjustment parameters
df_adjust = pd.DataFrame()
df_adjust['Median East'] =  [eastAppend]
df_adjust['Median North'] = [northAppend]
df_adjust['Median Up'] = [upAppend]
df_adjust['AlongCross X'] = [xvAppend]
df_adjust['AlongCross Y'] = [yvAppend]
df_adjust['AlongCross Z'] = [zvAppend]
df_adjust.to_csv(os.path.join(plotDirPath,'adjustment_params.csv'),index = False)

#Table view for summary plots
fig = go.Figure(data=[go.Table(
    header=dict(values=list(df_adjust.columns),
                fill_color='powderblue',
                align='left'),
    cells=dict(values=df_adjust.round(6).transpose().values.tolist(),
               fill_color='whitesmoke',
               align='left'))
])
fig.update_layout(title_font_family='Arial Black')
fig.write_image(plotDirPath + "/AdjustParams.png")







