#!/usr/bin/env python

usage="""\
To use, enable FE_SPA_ENABLE (and optionally FE_SPA_MUTEX_ENABLE) in firmware.

Examples:
 ./plot_spa.py file.T04 3200 3200.2 # show task switch data from time 3200-3200.2
 ./plot_spa.py file.T04 3200 3200.2 -l # show task switch+mutex data from time 3200-3200.2
"""

from pylab import *
import mutils as m
import argparse

mode_long_mutexes=-1  # auto mode - only mutexes with any delays > 5us
mode_all_mutexes=-2   # auto mode - all mutexes
mode_one_task_mutexes=-3 # Only show mutexes for given task
mode_no_mutexes=-4   # no mutex data

# Plot task IDs (default) or priority
mode_task_id = 0
mode_priority = 1

def main():
    parser = argparse.ArgumentParser(description=usage,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("filename", help="File with rec35:271 (and 274?) data")
    parser.add_argument("start_secs", help="Start time [GPS secs]",type=float)
    parser.add_argument("end_secs", help="End time [GPS secs]",type=float)
    parser.add_argument("-l", "--long_mutex", help="Show only mutexes with >5us delays.",
                        action="store_true")
    parser.add_argument("-a", "--all_mutex", help="Show all mutexes.",
                        action="store_true")
    parser.add_argument("-m", "--mutex", help="Show only given mutex ID.", type=int )
    parser.add_argument("-t", "--task_mutex", help="Show only mutexes for given task ID.", type=int )
    parser.add_argument("-p", "--priority", help="plot priority instead of the default task ID.", action="store_true")
    args = parser.parse_args()

    filename = args.filename
    t0 = args.start_secs
    t1 = args.end_secs
    mutex_ID = mode_no_mutexes
    task_ID = -1
    if args.long_mutex:
        mutex_ID = mode_long_mutexes
    elif args.all_mutex:
        mutex_ID = mode_all_mutexes
    elif args.mutex:
        mutex_ID = args.mutex
    elif args.task_mutex:
        mutex_ID = mode_one_task_mutexes
        task_ID = args.task_mutex

    if args.priority:
      data_type = mode_priority
    else:
      data_type = mode_task_id

    tsk,mut = load_spa( filename, t0, t1, mutex_ID )

    close('all')
    do_plots( tsk, mut, mutex_ID, task_ID, data_type )
    show()
    return tsk, mut

def load_spa(filename, t0, t1, mutex_ID):
    """Load data from T04 file.
       filename = T04 filename with SPA data
       t0 = start GPS seconds
       t1 = end GPS seconds. Make t1-t0 very short, or it takes a long time.
       mutex_ID = mode_long_mutexes
                  mode_all_mutexes
                  mode_no_mutexes
               or positive integer -> single mutex ID to look at
    """
    IDs,names,cores = m.parse_load_diags_task_names(filename)
    IDs = array(IDs)
    mut = None
    if mutex_ID != mode_no_mutexes:
        mut=m.vd2cls(filename,'-d35:274 -s%d -e%d'%((int(t0-1),int(t1+2))))
        mut=mut[(mut.gps_secs>=t0)&(mut.gps_secs<=t1)]
    tsk=m.vd2cls(filename,'-d35:271 -s%d -e%d'%((int(t0-1),int(t1+2))))
    tsk=tsk[(tsk.MICROSECONDS>=t0*1e6)&(tsk.MICROSECONDS<=t1*1e6)]
    return tsk,mut

def do_plots( tsk, mut, mutex_ID, task_ID, data_type ):
    """Plot data from load_spa()
       tsk = task switch data
       mut = mutex lock/unlock data
       mutex_ID = mode_*_mutexes
               or positive integer -> single mutex ID to look at
       task_ID = if > 0, only show mutexes for this task
       data_type = plot task IDs or task priority
    """

    # The default plot marker/line cycle is pretty short( <= 10 ).
    # Extend it out so we get unique plot lines/markers for > 10 lines.
    prop_len = len(plt.rcParams['axes.prop_cycle'])
    if prop_len <= 10:
        new_cyc = (plt.rcParams['axes.prop_cycle']*2
                   + cycler(marker=(['1']*prop_len + ['2']*prop_len)))
        plt.rcParams['axes.prop_cycle'] = new_cyc

    # Plot task switches
    fig1,ax1=subplots(1,1)
    for curr_core in unique(tsk.CORE_NUM):
        tska = tsk[tsk.CORE_NUM==curr_core]
        x = roll(repeat(tska.MICROSECONDS*1e-6,3),-2)
        if(data_type == mode_task_id):
          y = repeat(tska.TASK_ID,3)
        else:
          y = repeat(tska.PRI,3)
        x[r_[2:len(x):3]] = NaN
        x[-2:] = NaN
        ax1.plot(x, y )
    ax1.grid()
    ax1.set_xlabel('GPS secs')
    ax1.set_title('Task switches')
    if(data_type == mode_task_id):
      ax1.set_ylabel('Task ID')
    else:
      ax1.set_ylabel('Task Priority')

    if mutex_ID == mode_no_mutexes:
        return

    # Plot mutex locks/unlocks
    fig2,ax2=subplots(1,1)
    ax1.get_shared_x_axes().join(ax1, ax2)
    ax1.get_shared_y_axes().join(ax1, ax2)
    if task_ID > 0:
        i=m.find((mut.acq==0)&(mut.task_ID==task_ID))
        mut_list = unique(mut.mutex_ID[i])
        ax2.set_ylabel('Mutex ID')
        for n,curr_mut_ID in enumerate(mut_list):
            muta = mut[(mut.mutex_ID==curr_mut_ID)]
            muta = muta[argsort(muta.task_ID+muta.gps_secs*1e-7),:]
            i = m.find(muta.acq==0)
            x = insert( muta.gps_secs, i, NaN )
            y = insert( muta.task_ID+(n-len(mut_list)/2)/(2*len(mut_list)), i, NaN )
            if n < 20:
                ax2.plot(x, y, label='%d'%curr_mut_ID )
            else:
                ax2.plot(x, y )

        if len(mut_list) >= 20:
            print("Truncated mutex legend box - too many labels.")
        legend(bbox_to_anchor=(1.01, 1.0), loc='upper left')
        ax2.grid(True)
        ax2.set_title('All mutexes ever locked by task %d'%task_ID)
        ax2.grid(True)
        ax2.set_xlabel('GPS secs')
        ax2.set_ylabel('Mutex task ID')
        tight_layout()

        fig2,ax2=subplots(1,1)
        ax1.get_shared_x_axes().join(ax1, ax2)
        tsk0 = tsk[tsk.TASK_ID==task_ID]
        ax2.plot( tsk0.MICROSECONDS*1e-6, tsk0.PRI )
        ax2.set_xlabel("GPS secs")
        ax2.set_ylabel("Task %d priority"%task_ID)
        ax2.grid()

    elif mutex_ID > 0:
        for curr_core in unique(mut.core_num):
            muta=mut[(mut.core_num==curr_core)
                     &(mut.mutex_ID==mutex_ID)]
            muta = muta[argsort(muta.task_ID+muta.gps_secs*1e-7),:]
            i = m.find(muta.acq==0)
            x = insert( muta.gps_secs, i, NaN )
            y = insert( muta.task_ID, i, NaN )
            ax2.plot(x, y )
        ax2.set_title('Mutex locks for mutex ID %d'%mutex_ID)
        ax2.grid(True)
        ax2.set_xlabel('GPS secs')
        ax2.set_ylabel('Mutex task ID')
    else:
        if mutex_ID==mode_long_mutexes:
            i=m.find(mut.acq==1)
        else:
            i=m.find(mut.acq==0)
        mut_list = unique(mut.mutex_ID[i])
        ax2.set_ylabel('Mutex task ID')
        for curr_core in unique(mut.core_num):
            if curr_core > 0:
                fig2,ax2=subplots(1,1)
                ax1.get_shared_x_axes().join(ax1, ax2)
                ax1.get_shared_y_axes().join(ax1, ax2)

            mutN=mut[(mut.core_num==curr_core)]
            for n,curr_mut_ID in enumerate(mut_list):
                muta = mutN[(mutN.mutex_ID==curr_mut_ID)]
                muta = muta[argsort(muta.task_ID+muta.gps_secs*1e-7),:]
                i = m.find(muta.acq==0)
                x = insert( muta.gps_secs, i, NaN )
                y = insert( muta.task_ID+(n-len(mut_list)/2)/(2*len(mut_list)), i, NaN )
                if n < 20:
                    ax2.plot(x, y, label='%d'%curr_mut_ID )
                else:
                    ax2.plot(x, y )

            if len(mut_list) >= 20:
                print("Truncated mutex legend box - too many labels.")
            legend(bbox_to_anchor=(1.01, 1.0), loc='upper left')
            ax2.grid(True)
            if mutex_ID==mode_long_mutexes:
                ax2.set_title('Core %d: Mutex with any wait > 5us'%curr_core)
            else:
                ax2.set_title('Core %d: Mutex'%curr_core)
            ax2.grid(True)
            ax2.set_xlabel('GPS secs')
            ax2.set_ylabel('Mutex task ID')
            tight_layout()


if __name__ == '__main__':
    tsk, mut = main()
