#!/usr/bin/env python
"""
Get NoPi results and automatically look for changes. The current method
of detecting changes is pretty basic. It segments the points into two
groups using Jenks natural breaks. If the two groups are significantly
different, then something has changed.

Examples:
1) Get data from last 2 weeks and create plots of anything that seems
   to have changed within the last few days:
 ./nopi_plot_change.py Alloy2_BD992_Zero

2) Plot specified combo over the last 2 weeks, even if no changes were detected:
 ./nopi_plot_change.py Alloy2_BD992_Zero,BD940_BD970_Zero --combo "GPS L2E,carr"

3) Same as #1, but email results:
 ./nopi_plot_change.py Alloy2_BD992_Zero --email will_lentz@trimble.com
"""

import argparse
import datetime
from functools import partial
from multiprocessing import Pool
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.mime.image import MIMEImage
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from jenkspy import JenksNaturalBreaks
from tqdm import tqdm
import get_nopi
import socket
import getpass

# Only plot if change detected within last N days
RECENT_DAYS=3

def mail_on_empty(baseline,to_address):
    """baseline = NoPi baseline name, like Alloy2_BD992_Zero
    to_address = email address to send to
    Send an HTML email if we couldn't get any data.
    """
    msg = MIMEMultipart()
    msg['Subject'] = '{} has no data!'.format(baseline)
    msg['From'] = 'nopi_plot_change@trimble.com'
    msg['To'] = to_address

    txt = 'Missing data on baseline '+baseline
    msg.attach(MIMEText(txt, 'html'))

    sender = smtplib.SMTP('smtp.trimble.com')
    sender.send_message(msg)
    sender.quit()
    print("Sent email with error")

def get_host_and_script_info():
    # Get the current host name
    host_name = socket.gethostname()
    # Get the full script name with path
    script_path = os.path.abspath(__file__)
    return host_name, script_path

def mail_on_change(baseline,to_address,all_changes):
    """baseline = NoPi baseline name, like Alloy2_BD992_Zero
    to_address = email address to send to
    all_changes = list of OneChange
    Send an HTML email if there were changes.
    """
    if len(all_changes) == 0:
        return
    msg = MIMEMultipart()
    msg['Subject'] = '{} changed recently'.format(baseline)
    msg['From'] = 'nopi_plot_change@trimble.com'
    msg['To'] = to_address

    host, script = get_host_and_script_info()
    txt = f"<br/>Host: {host}, Script: {script}, User: {getpass.getuser()}"
    for info in all_changes:
        with open(info.png_filename, 'rb') as f_in:
            img = MIMEImage(f_in.read())
        os.remove(info.png_filename)
        img.add_header('Content-ID', '<{}>'.format(info.png_filename))
        msg.attach(img)
        txt += '<br/>Recent change on {} {}'.format(info.combo,info.meas)
        txt += '<br/>%s'%info.link
        txt += '<br/><img src="cid:%s"/>'%info.png_filename
    msg.attach(MIMEText(txt, 'html'))

    sender = smtplib.SMTP('smtp.trimble.com')
    sender.send_message(msg)
    sender.quit()
    print("Sent email with changes")

def detect_changes(signal, meas_type, scale):
    """Auto-detect points where 'signal' has changed.
    meas_type = range/carr/SNR/dopp
    scale = threshold scale (bigger == harder to detect changes)
    Returns list of indexes into signal[] that look a bit different.
    """
    if len(signal) < 5 or np.ptp(signal)==0:
        return []

    # segment 'signal' into 2 groups
    jnb = JenksNaturalBreaks(2)
    jnb.fit(signal)

    # Are the 2 groups noticeably different?
    thresh = {'range':0.2*scale, 'carr':10*scale, 'SNR': 2*scale, 'dopp':0.1*scale}
    ratio1 = jnb.breaks_[1]/jnb.breaks_[0]
    ratio2 = np.median(jnb.groups_[1])/np.median(jnb.groups_[0])
    abs_diff = np.abs(np.median(jnb.groups_[1]) - np.median(jnb.groups_[0]))

    # The following thresholds are kinda arbitrary, so feel free to adjust
    if ratio1 >= 1.2 and ratio2 >= 1.7 and abs_diff > thresh[meas_type]:
        return np.where(jnb.labels_)[0]
    return []

def get_all_data(terrasat, baseline, end_date, ndays):
    """terrasat = get Terrasat data?
    baseline = NoPi baseline name, like Alloy2_BD992_Zero
    end_date = date() -> last date
    ndays = # of days to look backward from end_date
    Gets NoPi data and returns [[date(),combo="GPS L1CA",meas="SNR"],...]
    """
    all_days = []
    curr_date = end_date - datetime.timedelta(days=ndays)
    while curr_date <= end_date:
        all_days.append(curr_date)
        curr_date += datetime.timedelta(days=1)
    with Pool(5) as pool:
        if terrasat:
            baseline_name,baseline_id = baseline.split('/')
            func = partial( get_nopi.get_one_data_terrasat, baseline_name, baseline_id )
        else:
            func = partial( get_nopi.get_one_data, baseline )
        all_d = list(tqdm(pool.imap(func, all_days), total=len(all_days)))
    while len(all_d) > 0 and len(all_d[-1].df) == 0:
        all_d.pop() # get rid of empty data at the end
    return all_d

class OneChange:
    """Output data from find_all_changes()"""
    def __init__(self, date : datetime.date, combo:str, meas:str, png_filename:str, link:str ):
        self.date = date
        self.combo = combo
        self.meas = meas
        self.png_filename = png_filename
        self.link = link

class FindConfig:
    """Config for find_all_changes()"""
    def __init__(self, baseline:str, email:str, show_plots:bool, sat_type:str, mult:float ):
        self.baseline = baseline
        self.email = email
        self.show_plots = show_plots
        self.thresh_sat_type = sat_type
        self.thresh_mult = mult

    def get_sat_thresh_mult( self, combo:str ):
        '''Does combo (e.g., "GPS L1CA") satellite system match thresh_sat_type?
        If so, return the increased detection threshold multiplier.
        '''
        if combo.startswith(self.thresh_sat_type):
            return self.thresh_mult
        return 1.0

def changed_recently( x_date_high, recent_dates ):
    """Is there variation in the results for the latest few days?
    Input: x_date_high = array of date() in high group
    recent_dates = list of recent dates
    """
    change_mask = np.zeros(len(recent_dates))
    for curr_date in x_date_high[x_date_high >= recent_dates[0]]:
        idx = recent_dates.index(curr_date)
        if idx >= 0:
            change_mask[idx] = 1
    return len(np.unique(change_mask)) != 1

def find_all_changes(config, all_d, all_combos):
    """config = FindConfig
    all_d = list of get_nopi.get_one_data() data
    show_plots = show plots to screen? (otherwise save to png)
    all_combos = pandas dataframe of unique signals/meas (e.g., "GPS L1CA"+"SNR")

    Try to automatically detect and plot changes in given data "vld_mav".
    Returns list of OneChange()
    """
    all_changes = []
    recent_dates = [x.date for x in all_d[-RECENT_DAYS:]]
    for _,combo in all_combos.iterrows():
        x_date = []
        signal = []
        for data in all_d:
            if len(data.df) == 0:
                continue
            curr_max = data.df[(data.df.combo==combo.combo)
                               &(data.df.meas==combo.meas)
                               &(data.df.vld_epochs>100)].max()
            if np.isnan(curr_max.vld_mav) or curr_max.vld_epochs<1000:
                continue
            x_date.append( data.date )
            signal.append( curr_max.vld_mav )
        x_date = np.array(x_date)
        signal = np.array(signal)
        if len(signal) == 0:
            continue

        result = detect_changes(signal,
                                combo.meas,
                                config.get_sat_thresh_mult(combo.combo))

        if len(result) == 0 and len(all_combos)>1:
            continue

        print(" %d high elements in %s %s."%(len(result),combo.combo,combo.meas))
        changed = changed_recently( x_date[result], recent_dates )

        if changed:
            print("  ******* Changed recently")

        if changed or len(all_combos)==1:
            fig,ax=plt.subplots(1,1)
            ax.plot(x_date,signal,'.-')
            ax.plot(x_date[result],signal[result],'x',color='r')
            ax.set_title('%s: %s %s'%(config.baseline, combo.combo,combo.meas))
            ax.grid()
            fig.autofmt_xdate()
            png_filename = '%s_%s.png'%(combo.combo.replace(' ','_'),combo.meas)
            link_latest = ''
            if len(result) > 0:
                link_latest = [data.link for data in all_d if data.date==x_date[result[-1]]][-1]
            all_changes.append( OneChange(all_d[-1].date,combo.combo, combo.meas,
                                          png_filename, link_latest) )
            if not config.show_plots:
                fig.savefig(png_filename)
                plt.close(fig)
    return all_changes

def ignore_combos(all_combos,ignore_list):
    """all_combos = dataframe with all combo/meas values in data
    ignore_list = list of combos to ignore, e.g. ['Beidou B2,carr']
    Return updated all_combos.
    """
    for ignore_txt in ignore_list:
        ignore_combo,ignore_meas = ignore_txt.split(',')
        all_combos = all_combos[np.invert((all_combos.combo==ignore_combo)
                                          &(all_combos.meas==ignore_meas))]
    return all_combos


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('baseline',
                        help='NoPi baseline name (or comma-separated list)')
    parser.add_argument('--terrasat','-t',
                        help='Instead of normal NoPi, get data from TerrasatCycleSlips',
                        action='store_true', default=False)
    parser.add_argument('--email',
                        help='Email plots to given address instead of showing')
    parser.add_argument('--end_date',
                        type=lambda s: datetime.datetime.strptime(s, '%Y-%m-%d').date(),
                        default=datetime.datetime.now().date(),
                        help='End date in YYYY-MM-DD format')
    parser.add_argument('--ndays','-n',
                        help='# of days in past to look at',
                        type=int,
                        default=14)
    parser.add_argument('--combo','-c',
                        help='only look at this one combo, e.g. "GPS L1CA,carr"')
    parser.add_argument('--ignore','-i',
                        action='append',
                        default=[],
                        help='ignore this combo, e.g. "QZSS SAIF,carr"')
    parser.add_argument('--sat_thresh_scale',
                        help='scale up change-detect thresh for given system, .e.g. require QZSS changes to be 3x bigger "3.0,QZSS"')
    args = parser.parse_args()
    show_plots = True
    if args.email:
        # Allow running headless from the command line
        matplotlib.use("agg")
        show_plots = False

    all_combos = None
    if args.combo:
        combo_txt = args.combo.split(',')
        all_combos = pd.DataFrame({'combo':[combo_txt[0]],
                                   'meas':[combo_txt[1]]})
    scale_sat,scale_mult = '',1.0
    if args.sat_thresh_scale:
        scale_mult = float(args.sat_thresh_scale.split(',')[0])
        scale_sat = args.sat_thresh_scale.split(',')[1]

    for baseline in args.baseline.split(','):
        print("Processing baseline",baseline)
        all_d = get_all_data(args.terrasat, baseline, args.end_date, args.ndays)
        if len(all_d) == 0:
            print("WARNING - couldn't get any data")
            if args.email:
                mail_on_empty(baseline,args.email)
            continue
        if all_combos is None:
            all_combos = get_nopi.get_unique_combo_meas(all_d[-1].df)
            all_combos = ignore_combos(all_combos, args.ignore)
        config = FindConfig(baseline, args.email, show_plots, scale_sat, scale_mult)
        all_changes = find_all_changes(config, all_d, all_combos)
        if args.email:
            mail_on_change(baseline,args.email,all_changes)
    if show_plots:
        print("Showing graphs")
        plt.show()

if __name__ == '__main__':
    main()
