import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import mutils as m
from mutils.PosPacConst import *
import mutils.PosTypeConst as PosTypeConst

"""
This code contains the functions for SummaryPlotScript (for now, just uses CDF comparison)
"""

def cdfComparisonPlot(df,plotLabel,plotType,title):

    """
    Inputs:
        - df: Dataframe containing the dE,dE,dU errors
        - plotLabel: name of the comparison (i.e. what to call it in the legend)
        - plotType: '2D' or '3D'
        - title: plot title 
    """
    
    #Prase out the data from the DataFrame
    dE = df['dE']
    dN = df['dN']
    dU = df['dU']

    #Compute the 2D and 3D error
    err2d = np.sqrt( dE**2 + dN**2 )
    err3d = np.sqrt( dE**2 + dN**2 + dU**2)

    #Get the indices where 2D and 3D error are < 20cm
    i2 = m.find(err2d <= 0.2)
    i3 = m.find(err3d <= 0.2)

    #Plot 2D and 3D error cdf 
    if plotType == '2D':
        cx2,cy2 = m.docdf(err2d[i2])
        plt.plot(cx2,cy2,label = plotLabel)
        plt.ylabel('2D CDF')
    else:
        cx3,cy3 = m.docdf(err3d[i3])
        plt.plot(cx3,cy3,label = plotLabel)
        plt.ylabel('3D CDF')
    
    #Format and save the plot
    plt.xlabel('Error [m]')   
    plt.grid(True)
    plt.title(title)
    plt.tight_layout()
    plt.legend(loc = "lower right")




def barPlotComparison(df,filename,plotType,count):
    #This code is part of the commented out bar plotting section of SummaryPlotScript.py
    #This takes the ENU error and plots a bar chart comparison between datasets 
    

    #Offset indicates whether to 
    if plotType == '2D':
        offset = 0
    else:
        offset = 1

    # #Define 2D and 3D error based on dENU input
    dE = df['dE']
    dN = df['dN']
    dU = df['dU']

    err_2d = np.sqrt( dE**2 + dN**2 )
    err_3d = np.sqrt( dE**2 + dN**2 + dU**2)
    threshold = [0.05, 0.1, 0.15]
    stats2D = []
    stats3D = []
    for thres in threshold:
        i = m.find(err_2d <= thres)
        stats2D.append(len(i))
        i = m.find(err_3d <= thres)
        stats3D.append(len(i))

    #Define length of dataset and the empty array for filling data
    denom = len(dE) #Actual data collected
    data = np.empty((2,3))

    #Adjust the data to show the percentage below each threshold
    data[0,:] = [(100.0*x)/denom for x in stats2D] #2D Error
    data[1,:] = [(100.0*x)/denom for x in stats3D] #3D Error
    data[0,2] -= data[0,1]
    data[0,1] -= data[0,0]
    data[1,2] -= data[1,1]
    data[1,1] -= data[1,0]


    #Bar plot settings
    plt.bar(filename, data[offset,0], label='<=5cm',zorder=3,color = '#1f77b4')
    bottom = data[offset,0]
    plt.bar(filename, data[offset,1], bottom=bottom, label='<=10cm',zorder=3, color = '#ff7f0e')
    bottom += data[offset,1]
    plt.bar(filename, data[offset,2], bottom=bottom, label='<=15cm',zorder=3, color = '#2ca02c')
    bottom +=data[offset,2]
    plt.xticks(rotation=45,ha='right')
    plt.yticks(np.arange(0, 100, 10))
    if offset == 0:
        errorLabel = '2D'
    else:
        errorLabel = '3D'

    plt.ylabel(errorLabel + ' Error <= Threshold [%]')
    #plt.title(label + ' Error')
    
    xmin, xmax, ymin, ymax = plt.axis()
    if(ymax > 100.0):
        plt.ylim([ymin,100.0])
    plt.grid(True,zorder=0)
    plt.tight_layout()
    if count == 1:
        plt.legend(loc='lower left')
    return err_2d, err_3d



    