import sys

# To install:
#  get anaconda
#  pip install visvis (under anaconda)
#  (conda's "vtk" may conflict if installed)
# For headless/scripting operation:
#  yum install xorg-x11-server-Xvfb

if len(sys.argv) != 4:
    print("""
 Used for making a 3D plot of T04 diagnostic FFT data.
 Basic usage:
    viewdat -d35:24 -mb -oFFT.txt files_for_2015_1_2.T04
    xvfb-run --server-args="-screen 0 1200x900x24" python plot_3d_fft.py 2015 1 2
  Arguments = year month day
  For larger images we may want to increase the font sizes.
 Environment:
  Tested on donald with Anaconda in path (e.g.,:
   export PATH=/opt/anaconda/bin:${PATH}
  )
 Note: edit file and enable 'use_pickle' for faster graphics iteration.
""")
    sys.exit(1)

year = int(sys.argv[1])
month = int(sys.argv[2])
day = int(sys.argv[3])

import pickle, time
import visvis as vv
from numpy import *
from mutils import *
import os.path
import Tkinter

# config settings:
use_pickle = False # Save data in binary format for quick retest?
Interval = 120. # Average FFT over this many seconds

fig_width = Tkinter.Tk().winfo_screenwidth()
fig_height = Tkinter.Tk().winfo_screenheight()

# Data columns in FFT.txt file:
dT = 0
dFreq = 1
dBinWidth = 2
dFFT1 = 4

def freq_to_band_str( freq_Hz ):
    if abs(freq_Hz - 1590e6) <= 6e6: return "L1"
    if abs(freq_Hz - 1240e6) <= 6e6: return "L2"
    if abs(freq_Hz - 1190e6) <= 6e6: return "L5"
    if abs(freq_Hz - 1290e6) <= 6e6: return "E6"
    if abs(freq_Hz - 1572e6) <= 6e6: return "B1"
    if abs(freq_Hz - 2504e6) <= 6e6: return "S1"
    return "Freq%d" % int(freq_Hz*1e-6)

t1 = time.time()
if 'all_d' in globals():
    print('using data already in memory')
elif use_pickle and os.path.isfile('test_all_d.p'):
    print('loading pickle file')
    all_d = pickle.load( open( 'test_all_d.p', 'rb' ) )
else:
    secs_in_day = 24.*60.*60.
    day_offset = {}
    all_d = {}
    print('Reading data...')
    for l in open('FFT.txt'):
        d = array(l.split()).astype(float)
        centerFreq = d[dFreq]
        if not centerFreq in day_offset:
            day_offset[centerFreq] = 0.
        tcurr = d[dT] + day_offset[centerFreq]
        if centerFreq in all_d and tcurr + secs_in_day*.5 < all_d[centerFreq].tstart:
            day_offset[centerFreq] += secs_in_day
            tcurr = d[dT] + day_offset[centerFreq]
            print('freq %.1f day roll' % (centerFreq*1e-6))
        if not centerFreq in all_d:
            all_d[centerFreq] = dotdict({})
            all_d[centerFreq].tstart = floor(tcurr/Interval) * Interval
            all_d[centerFreq].tend = all_d[centerFreq].tstart + Interval
            all_d[centerFreq].binWidth = d[dBinWidth]*1e-3
            all_d[centerFreq].count = 1
            all_d[centerFreq].myavg = copy(d[dFFT1:])
            all_d[centerFreq].TimeDec = []
            all_d[centerFreq].FFTMean = []
            all_d[centerFreq].AvgCnt = []
        elif tcurr >= all_d[centerFreq].tend:
            next_tstart = floor(tcurr/Interval) * Interval

            # Add current data
            tmid = (all_d[centerFreq].tstart + all_d[centerFreq].tend)/2.
            all_d[centerFreq].TimeDec.append( tmid )
            all_d[centerFreq].AvgCnt.append( all_d[centerFreq].count )
            # Points with only 1 average tend to have problems, so ignore them:
            if all_d[centerFreq].count > 1:
                all_d[centerFreq].FFTMean.append(copy(all_d[centerFreq].myavg)/all_d[centerFreq].count)
            else:
                tmp = empty(all_d[centerFreq].myavg.shape)
                tmp.fill(NaN)
                all_d[centerFreq].FFTMean.append(tmp)

            if len(all_d[centerFreq].FFTMean) > 1.5*secs_in_day/Interval:
                print('probable error - is data set really longer than 1.5 days?')
                break

            # Fill in gaps with NaNs
            for tstart in r_[all_d[centerFreq].tend:next_tstart:Interval]:
                tmid = tstart + Interval/2.
                all_d[centerFreq].TimeDec.append( tmid )
                all_d[centerFreq].AvgCnt.append( 0 )
                tmp = empty(all_d[centerFreq].myavg.shape)
                tmp.fill(NaN)
                all_d[centerFreq].FFTMean.append(tmp)
            if next_tstart >= all_d[centerFreq].tend+Interval:
                print('freq %.1fMHz data gap %.1f-%.1f' % \
                    (centerFreq*1e-6,all_d[centerFreq].tend,tcurr))

            # start next range
            all_d[centerFreq].tstart = next_tstart
            all_d[centerFreq].tend = all_d[centerFreq].tstart + Interval
            all_d[centerFreq].myavg = copy(d[dFFT1:])
            all_d[centerFreq].count = 1
        elif tcurr >= all_d[centerFreq].tstart and tcurr < all_d[centerFreq].tend:
            all_d[centerFreq].count += 1
            all_d[centerFreq].myavg += d[dFFT1:]
    if use_pickle:
        pickle.dump( all_d, open( 'test_all_d.p', 'wb' ) )
print('read data in %.1f secs' % (time.time()-t1))

t1 = time.time()
fig = vv.figure()
fig.position = [0,0,fig_width,fig_height]
for centerFreq,v in all_d.iteritems():
    print('Processing freq',centerFreq)
    t2 = time.time()

    # set up X, Y, Z
    v.FFTMean = array(v.FFTMean)
    v.TimeDec = array(v.TimeDec)
    v.AvgCnt = array(v.AvgCnt)
    rows, cols = v.FFTMean.shape
    centerIdx = round(cols/2.)
    freq = (r_[0:cols] - centerIdx)*v.binWidth + centerFreq * 1e-6
    hrs = v.TimeDec/3600.
    hrs -= floor(hrs[0]/24.)*24.
    X,Y = meshgrid(freq,hrs)
    Z=v.FFTMean

    # plot data
    surf = vv.surf( X, Y, Z )
    surf.colormap = vv.CM_JET
    cbar = vv.colorbar()
    title_text = '%s %d-%.2d-%.2d - %d Av. - Amplitude [dB] versus Time & Frequency' % \
                 (freq_to_band_str(centerFreq), \
                  year, month, day, \
                  Interval)
    ftitle = vv.title(title_text)
    ax = surf.GetAxes()
    ax.axis.showGridX = True
    ax.axis.showGridY = True
    ax.axis.showGridZ = True
    ax.axis.xLabel = 'Frequency [MHz]'
    ax.axis.yLabel = 'Time [Hours into GPS day]'

    # make the plot square
    lims = ax.GetLimits()
    min_lim = min(lims[0].range,lims[1].range,lims[2].range)
    ax.daspect = ( min_lim/lims[0].range, -min_lim/lims[1].range, min_lim/lims[2].range )
    zoom = ax.GetView()['zoom']
    ax.SetView(elevation=45,azimuth=-45,zoom=zoom*.65)

    # make sure the plot is fully generated before saving
    fig.DrawNow()
    vv.screenshot('%s-%d-%.2d-%.2d-Freq3D-%ds.png' \
                  % (freq_to_band_str(centerFreq), \
                     year, month, day, Interval ), bg='w')

    # Clear figure for next plot
    cbar.parent.eventPosition.Unbind() # without this we get a warning when clearing figure
    fig.Clear()

    # Do a 2D plot too
    fig2d = figure(figsize=(12,9)) # 1200x900 final output size
    imshow(flipud(Z.T),aspect='auto',extent=[hrs[0],hrs[-1], freq[0],freq[-1]])
    xlabel('Time [Hours into GPS day]')
    ylabel('Frequency [MHz]')
    title(title_text)
    colorbar()
    savefig('%s-%d-%.2d-%.2d-Freq2D-%ds.png' \
            % (freq_to_band_str(centerFreq), \
               year, month, day, Interval ), \
            dpi=100 ) # 12*100 x 9*100 final pixel size
    close(fig2d)
    print('  time for figure %.2f secs' % (time.time() - t2))

print('total time',(time.time() - t1))
