from pylab import *

# Simulate Gauss-Markov process that RTK uses for multipath modelling
def sim_GM( random_noise, corr_noise, corr_time, delta_time, num_epochs ):
    num_epochs = int(num_epochs)

    # first generate Gaussian (uncorrelated / white) noise sequence
    w = random_noise * randn(num_epochs)

    # now compute the driving source for the correlated noise
    c = corr_noise * sqrt(1 - exp(-(2.0 * delta_time) / corr_time)) * randn(num_epochs)

    t = zeros(num_epochs)
    x = zeros(num_epochs)
    exp_factor = exp(-delta_time/corr_time)
    for k in xrange(2,num_epochs):
        t[k] = t[k-1] + delta_time
        x[k] = exp_factor * x[k-1] + c[k]

    # now add in the random noise to the samples
    x = x + w

    return t, x

# statsmodels estimate GM(1) - doesn't estimate added white noise
def sm_estimate_GM( deltaTime, x ):
    #from statsmodels.tsa import arima_model
    #mdl = arima_model.ARMA(x,(1,0))
    #result = mdl.fit(disp=0)
    #print result.summary()
    #print result.params
    #print sqrt(result.sigma2)
    from statsmodels.tsa import ar_model
    mdl = ar_model.AR(x)
    result = mdl.fit(maxlag=1)
    print('Tc',-deltaTime/log(result.params[1]))
    print('corrNse',sqrt(result.sigma2))


# RTK estimate of Gauss-Markov parameters
def RTK_estimate_GM( deltaTime, process, maxCorrTime, version=2 ):
    maxCorrTime = int(maxCorrTime)

    # remove NaNs
    process = process[isfinite(process)]

    ## initialise the output quantities
    corrTime    = 0
    corrNoise   = 0
    uncorNoise  = 0

    numSamples  = len(process)

    ## Compute the autocorrelation out to the maxLag
    maxLag = maxCorrTime;
    # maxLag = numSamples / 2;

    cor = zeros(maxLag);
    corSamples = numSamples - maxLag - 1

    if corSamples < 1:
        raise Exception('too few samples %d' % (corSamples) )

    #if corSamples < maxCorrTime:
    #    # Not many samples left for integration
    #    raise Exception('too few samples #2: %d %d' % (corSamples, maxCorrTime) )

    t  = r_[1:maxLag]
    t2 = t * deltaTime

    # compute the epoch-to-epoch value differences and pad the end
    epochToEpochDiff = diff(process)
    epochToEpochDiffSqr = 0.5 * epochToEpochDiff * epochToEpochDiff

    # version 0 & 1 are identical
    if version == 1:
        cor = correlate( process[:corSamples], process, 'valid' )[:1:-1]
        acf = cor / corSamples
    elif version == 2:
        # assume that data end repeats?
        cor = correlate( process, hstack((process,process[-maxLag:])) )[::-1]
        acf = cor/(len(process)-maxLag)
    elif version == 0:
        # outer loop considers the correlation shifts
        buf1   = process
        buf2 = buf1[:corSamples]
        for lag in xrange(maxLag):
            # extract a portion of data to use in the formation of the acf
            buf3 = buf1[lag:(corSamples + lag)]

            # compute the product of the samples
            # Two samples buf2 and buf3 of size corSamples are multiplied. buf2 contains the first
            # corSamples elements of the input values. buf3 is shifted by lag.

            # finalise the computation of the autocorrelation
            cor[lag] = sum( buf3 * buf2 );

        ## Normallize the autocorrelation function
        #  Divide it by the number of samples
        acf = cor / corSamples;

    ## Compute the autocorrelation value at zero lag
    #  This is equal to the sum of the uncorrelated and correlated
    #  parts of the signal
    acf0  = acf[0];
    varUncor    = sum(epochToEpochDiffSqr) / (numSamples - 1);

    # Sometimes the uncorrelated variance component is so small that it
    # cannot be accurately determined, in which case, just set it to zero.
    if ( acf0 > varUncor ):
        corrNoise   = sqrt( acf0 - varUncor );
        uncorNoise  = sqrt( varUncor );
    else:
        corrNoise   = sqrt( acf0 );
        uncorNoise  = 0.0;

    # get the correlated variance
    corrVar = corrNoise * corrNoise

    #plot(acf/corrVar)
    #plot(acf/acf0)
    #title('acf0 %g corrVar %g' % (acf0, corrVar) )
    #show()

    ## Generate estimates of the correlation time for each lag
    T       = zeros(maxLag);
    weightT = zeros(maxLag);
    if version >= 2:
        r_all   = abs(acf / corrVar) # NOTE: added abs() here
    else:
        r_all   = acf / corrVar
    n_valid = 0
    for lag in xrange(1,maxLag):
        # time delta for this lag
        dt = deltaTime * lag

        # compute the correlation for this lag
        r = r_all[lag]

        # correlation needs to be valid
        if ( r > 0.0 and r < 1.0 and abs(dt-1) > 1e-3 ):
            n_valid += 1
            # estimate the correlation time for each lag and store estimates
            tc              = -dt / log( r );
            T[lag]          = tc;

            # compute the variance of the correlation time values
            varT            = (tc * tc * tc * tc)  \
                              / ( (r * (dt - 1)) * (r * (dt - 1)) * maxLag);

            # formulate a weight for each correlation time observation
            weightT[lag]    = 1.0 / varT;

    ## Get the weighted mean correlation time
    sum_weightT = sum( weightT )
    if abs(sum_weightT) < 1e-9:
        raise Exception('weightT is all zero')
    meanT = sum( T * weightT ) / sum( weightT )
    corrTime = meanT

    process_time = deltaTime*len(process)
    if corrTime > process_time:
        if process_time < 250.0:
            raise Exception('Warning: corrTime %.1f > data time %.1f' % (corrTime, process_time))
        else:
            corrTime = process_time

    if n_valid < maxLag*.1:
        raise Exception('Warning: estimate using only a small sample of points: %d/%d' % (n_valid,maxLag))
    return corrTime, corrNoise, uncorNoise
