#!/usr/bin/env python3

# Summary:
#   This script performs DHCP testing and also supports creating a PNG plot of
#   a log file created as a result of a single invocation.
#
# All parameters other than --file and --png are used to configure a data
# collection session. Data collection is terminated using Ctrl-C. A CSV
# colleced by this script may be plotting using a command of this form.
#
#   dhcpTest.py --file CSV_FILE_NAME --png PNG_FILE_NAME

import datetime, os, re, sys, time
import random
from lxml import objectify
import traceback
import matplotlib.dates
import matplotlib.pyplot as plt

from rxclasses import Receiver
from rxclasses import Log

# -------------------------------------------------- ping()
PING_TIMEOUT        =  600.0
DHCPV6_ADDR_TIMEOUT = 3600.0
DHCPV6_PING_TIMEOUT =  600.0

# -------------------------------------------------- class Timeout()
class Timeout():
  def __init__(self, interval):
    self.interval = interval
    self.startTime = datetime.datetime.utcnow()

  def reached(self):
    return (datetime.datetime.utcnow() - self.startTime).total_seconds() > self.interval

# -------------------------------------------------- ping()
def ping( host, timeoutSeconds ):
  if os.name == 'nt':
    cmd = "ping -n 1 -w %d %s > nul 2>nul" % ( timeoutSeconds, host )
  else:
    cmd = "ping -c 1 -W %d %s > /dev/null 2>/dev/null" % ( timeoutSeconds, host )
  result = os.system( cmd )
  return 0 == int(result)

# -------------------------------------------------- sendAlert()
def sendAlert( email, subj, body ):
  os.system( 'sendmesg.py -e "%s" -s "%s" -b "%s"' % ( email, subj, body ) )

# -------------------------------------------------- getDhcpv6Address()
def getDhcpv6Address( r ):
  xml = rxIpInUse.getXml( 'ipv6.xml' )
  addr = None
  if not None == xml:
    x = objectify.fromstring( xml )
    try:
      if not '::' == x.DHCPv6NAAddress:
        addr = x.DHCPv6NAAddress.text
        sys.stderr.write('  Got DHCPv6 address: %s\n' % ( addr ))
    except:
      addr = None
  return addr


# -------------------------------------------------- printWithDate()
def printWithDate( s ):
  utc = datetime.datetime.utcnow()
  print( "%s %s" % ( utc.strftime("%Y/%m/%d %H:%M:%S"), s ) )

# -------------------------------------------------- getDateTime()
def getDateTime( m ):
  ''' Extract datetime from the first 7 fields of the reg-exp match.
  '''
  dt = datetime.datetime( int(m.group(1)),  int(m.group(2)),  int(m.group(3)), 
                                    int(m.group(4)),  int(m.group(5)),  int(m.group(6)) )
  dt += datetime.timedelta( microseconds = int(m.group(7)) )
  return dt

# -------------------------------------------------- genPngPlot()
# 2021/03/12 00:45:58.453,1,10.1.107.57,8.3,10.1.107.57,OK,24.2,[2001:559:7f1:1:821f:12ff:fed3:91d6],OK,26.6,2001:559:7f1:1:665f:48cf:ed91:6977,OK,29.8
# Fields:
#    0 DateTime
#    1 Test number
#    2 IP address used to trigger shutdown
#    3 Time delay until shutdown (5 missed pings)
#    4 IPv4 address
#    5 OK or FAIL
#    6 Shutdown to IPv4 ping delay
#    7 IPv6(SLAAC) address
#    8 OK or FAIL
#    9 Shutdown to IPv6(SLAAC) ping delay
#   10 IPv6(DHCPv6) address
#   11 OK or FAIL
#   12 Shutdown to IPv6(DHCPv6) ping delay

def genPngPlot( csvName, pngName ):
  ''' Read a CSV file generated by this script and plot the results.
  '''

  x         = []
  tShutdown = []
  t4        = []
  t6slaac   = []
  t6dhcp    = []

  # 2021/03/12 00:45:58.453
  isDate = re.compile( r"^(\d{4})/(\d{2})/(\d{2}) (\d{2}):(\d{2}):(\d{2})\.(\d{3})" )

  with open( csvName, "r" ) as f:
    haveIp6Slaac  = False
    haveIp6DHCPv6 = False

    for line in f.readlines():
      line = line.rstrip()
      fields = line.split( ',' )

      # Check for IPv4 only.
      nFields = len(fields)

      if 7 == nFields or 10 == nFields or 13 == nFields:
        # 2021/03/07 22:03:35.502,TestNum,ShutdownIp,ShutdownTime
        #                        ,IPv4Address,IPv4Result,IPv4Time
        #                        ,IPv6Address,IPv6Result,IPv6Time
        #                        ,DHCPv6Address,DHCPv6Result,DHCPv6Time
        skip = False
        if 'IPv6Result' in line:
          haveIp6Slaac  = True
          skip = True
        if 'DHCPv6Result' in line:
          haveIp6DHCPv6 = True
          skip = True

        if skip:
          continue

        m = isDate.match( fields[0] )
        pos = 0
        if m:
          try:
            dt = getDateTime( m )

            # Perform conversions first. This will fail for the first line containing header text.
            pos = 1
            shutdownDelay = float( fields[3] )
            pos = 2
            t4delay       = float( fields[6] )

            t6delay1 = None
            t6delay2 = None

            if nFields == 10:
              pos = 3
              t6delay1 = float( fields[9] )

              if nFields == 13:
                pos = 4
                t6delay2 = float( fields[12] )

            pos = 5
            x.append( dt )

            pos = 6
            tShutdown.append( shutdownDelay )
            pos = 7
            t4.append( t4delay )
            pos = 8
            if not None == t6delay1:
              t6slaac.append( t6delay1 )
            pos = 9
            if not None == t6delay2:
              t6dhcp.append( t6delay2 )
          except:
            print("Ignoring(%d): %s" % ( pos, line ))

  if len(x) < 1:
    print("Error: No DHCP test data found in file")
    return

  nPlots = 2
  y = []
  label = []
  style = []
  color = []
  y.append( tShutdown )
  label.append( "Shutdown (sec.)" )
  style.append( 'k-' )

  y.append( t4 )
  label.append( "IPv4 (sec.)" )
  style.append( 'b-' )

  if len(t6slaac) > 0:
    nPlots += 1
    y.append( t6slaac )
    if haveIp6Slaac:
      label.append( "IPv6(SLAAC) (sec.)" )
    elif haveIp6DHCPv6:
      label.append( "IPv6(DHCPv6) (sec.)" )

    style.append( 'g-' )

    if len(t6dhcp) > 0:
      nPlots += 1
      y.append( t6dhcp )
      label.append( "IPv6(DHCPv6) (sec.)" )
      style.append( 'm-' )
    
  fig, ax = plt.subplots(nPlots, 1, sharex=True, figsize=(7,10))

  # Convert to matplotlib dates.
  x = matplotlib.dates.date2num( x )

  for i in range(nPlots):
    ax[i].plot_date( x, y[i], style[i], xdate=True, linewidth=0.5 )
    ax[i].set_ylabel( label[i] )

    # Set the title.
    if 0 == i:
      ax[i].set_title( "DHCP Delays - %d Restarts (%s)" % ( len(x), csvName ) )

  plt.xticks(rotation=-30)

  plt.tight_layout()

  plt.savefig( pngName, dpi=150 )


# -------------------------------------------------- main()
Usage="""dhcpTest.py [-h] --ip4 IP4 ...
  Perform DHCP testing. Triggers a restart, waits for 5 missed ping
  responses, wait until all addresses respond to ping, then log
  times, IP addresses, and delays to a CSV file.
  IPv4 address is required. All optional parameters apply except
  --file or --png, which are only for plotting.

usage: dhcpTest.py --file CSV_FILE --png PNG_FILE_NAME
  Generate PNG plot from CSV file generated by DHCP testing.
"""

if __name__ == "__main__":
  import argparse

  parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter,
                                    description=None,
                                    usage=Usage )

  parser.add_argument("--user", help="Username")
  parser.add_argument("--password", help="Password")
  parser.add_argument("--rxname", help="Name of receiver (used in the log file name).")
  parser.add_argument("--ip4", help="IPv4 address. (Must be reserved/fixed.)")
  parser.add_argument("--ip6", help="IPv6 SLAAC address. (Prefix assumed not to change.)")
  parser.add_argument("--http_port", help="Port for the HTTP connection")
  parser.add_argument("--dhcpv6", action='store_true', help="Also verify DHCPv6 address is acquired.")
  parser.add_argument("--dhcpv6stoponfailure", action='store_true', help="Stop script if DHCPv6 fails.")
  parser.add_argument("--minsec", help="Mininum sleep time after completion of test cycle.")
  parser.add_argument("--maxsec", help="Maximum sleep time after completion of test cycle.")
  parser.add_argument("--alertemail", help="Send email alerts on startup and on errors. (Requires sendmesg.py script.)")

  parser.add_argument("--file", help="CSV file to be plotted.")
  parser.add_argument("--png", help="PNG plot file output.")

  args = parser.parse_args()

  #
  # Verify arguments are valid. Print usage if not.
  #
  doTest = False
  doPng  = False

  if not None == args.ip4:
    doTest = True

  alertEmail = args.alertemail

  if not None == args.file and not None == args.png:
    doPng = True

  if doPng:
    genPngPlot( args.file, args.png )
    sys.exit(0)

  elif not doTest:
    parser.print_help()
    sys.exit(1)

  #
  #
  #

  ip4host   = None
  ip6host   = None
  http_port = 80

  testDhcpv6 = False
  dhcpv6StopOnFailure = False

  username = "admin"
  password = "tr.imble"

  prefix   = None

  minSec = 120.0
  maxSec = 120.0

  pinghosts = []

  #
  # Assign variables from command line arguments.
  #

  ip4host = args.ip4
  pinghosts.append( ip4host )

  if args.ip6 is not None:
    ip6host = args.ip6
    pinghosts.append( ip6host )
    # Make sure an IPv6 address includes the brackets.
    if ':' in ip6host:
      if not '[' in ip6host:
        ip6host = '[' + ip6host
      if not ']' in ip6host:
        ip6host = ip6host + ']'
  nPingHosts = len(pinghosts)

  if args.http_port is not None:
    http_port = args.http_port

  if args.user is not None:
    username = args.user

  if args.password is not None:
    password = args.password

  if args.minsec is not None:
    minSec = float(args.minsec)

  if args.maxsec is not None:
    maxSec = float(args.maxsec)

  if args.rxname is not None:
    prefix = args.rxname

  if args.dhcpv6:
    testDhcpv6 = True

  if args.dhcpv6stoponfailure:
    dhcpv6StopOnFailure = True

  if not 80 == http_port:
    if not ip4host == None:
      ip4host = ip4host + ':' + str(http_port)
    if not ip6host == None:
      ip6host = ip6host + ':' + str(http_port)


  receivers = []
  if not ip4host == None:
    receivers.append( Receiver( ip4host, username, password ) )
  if not ip6host == None:
    receivers.append( Receiver( ip6host, username, password ) )

  if testDhcpv6:
    receivers.append( None )

  log = Log( prefix, "csv" )

  rxIndex = 0

  nTests = 0
  headerPrinted = False

  if not None == alertEmail:
    msg = "%s starting %s" % ( os.path.basename(__file__), pinghosts[0] )
    sendAlert( alertEmail, msg, msg )

  #
  # main loop
  #
  while True:
    try:
      nTests += 1
  
      ###
      ### Reset receiver.
      ###
      rxIpInUse = receivers[rxIndex]
  
      # If testing IPv6 but no address was obtained, skip to the next index
      # and continue.
      if None == rxIpInUse:
        # Increment the receiver index.
        rxIndex += 1
        if rxIndex >= len(receivers):
          rxIndex = 0
        continue
  
      t0 = datetime.datetime.utcnow()
      receiverShutdown = False
      while not receiverShutdown:
        ### Reset then wait for 5 consecutive failed pings.
        ### Allow up to 60 seconds to shut down, otherwise send reset again.
  
        # Use programmatic interface to reset.
        sys.stderr.write('Reset...\n')
        rxIpInUse.reset()
  
        sys.stderr.write('Wait for restart...\n')
        successes = 0
        failures = 0
        while failures < 5 and successes < 60:
          if ping(rxIpInUse.host, 1):
            successes += 1
            failures = 0
            time.sleep(1.0)
          else:
            failures += 1
  
        if failures >= 5:
          receiverShutdown = True
        else:
          log.log("Error: Receiver did not shut down in 60 seconds. Reset again.")
  
      # Log time to shutdown.
      t1 = datetime.datetime.utcnow()
  
      # Allow 5 minutes for the receiver to respond to a ping.
      retries = 0
      found = True
      sys.stderr.write('Wait for ping...\n')
  
      responseTimes = [None]*len(pinghosts)
      nAddressesResponding = 0
  
      timeout = Timeout( PING_TIMEOUT )
      while nAddressesResponding < len(pinghosts) and not timeout.reached():
        for i in range(nPingHosts):
          #sys.stderr.write('  1\n')
          if None == responseTimes[i]:
            #sys.stderr.write('  2 ping %s\n' % ( pinghosts[i] ))
            if ping( pinghosts[i], 1 ):
              responseTimes[i] = (datetime.datetime.utcnow() - t1).total_seconds()
              sys.stderr.write('  Got response: %s in %.1f sec.\n' % ( pinghosts[i], responseTimes[i] ))
              nAddressesResponding += 1
            else:
              time.sleep(0.5)
  
      # If testing DHCPv6 availability, get address and check for ping.
      dhcpv6Address = None
      dhcpv6Time    = None
      if testDhcpv6:
        log.log( "DHCPv6 test: Get address." )
  
        timeout = Timeout( DHCPV6_ADDR_TIMEOUT )
        iGetDhcpv6 = 0
        while None == dhcpv6Address and not timeout.reached():
          log.log("getDhcpv6Address(1)[%d]: %s" % ( iGetDhcpv6, receivers[iGetDhcpv6].host ))
          dhcpv6Address = getDhcpv6Address( receivers[iGetDhcpv6] )
          iGetDhcpv6 += 1
          if iGetDhcpv6 >= len(pinghosts):
            iGetDhcpv6 = 0
  
          time.sleep(1.0)
  
        # Ping DHCPv6 address if we got it.
        if not None == dhcpv6Address:
          utc = datetime.datetime.utcnow()
          printWithDate( "%s: DHCPv6 test: Wait for ping response... %s" % ( utc.strftime("%Y/%m/%d %H:%M:%S"), dhcpv6Address ) )
          count = 0
          timeout = Timeout( DHCPV6_PING_TIMEOUT )
          while None == dhcpv6Time and not timeout.reached():
            if ping( dhcpv6Address, 1 ):
              dhcpv6Time = (datetime.datetime.utcnow() - t1).total_seconds()
              sys.stderr.write('  Got DHCPv6 ping response in %.1f sec.\n' % (
                               dhcpv6Time ))
            else:
              # Check for a change in the IPv6 address.
              printWithDate("getDhcpv6Address(2)[%d]: %s" % ( iGetDhcpv6, receivers[iGetDhcpv6].host ))
              checkAddress = getDhcpv6Address( receivers[iGetDhcpv6] )
              iGetDhcpv6 += 1
              if iGetDhcpv6 >= len(pinghosts):
                iGetDhcpv6 = 0
                
              # If the address changed, update it for ping tests.
              if not None == checkAddress:
                if not checkAddress == dhcpv6Address:
                  log.log( "WARNING: DHCPV6 address changed: %s to %s" % (
                           dhcpv6Address, checkAddress ) )
                  dhcpv6Address = checkAddress
              time.sleep(1.0)
        else:
          printWithDate( "DHCPv6 test: No address found." )
  
      if not headerPrinted:
        hdrStr = "TestNum,ShutdownIp,ShutdownTime,IPv4Address,IPv4Result,IPv4Time"
        if len(receivers) > 1:
          hdrStr += ",IPv6Address,IPv6Result,IPv6Time"
        if testDhcpv6:
          hdrStr += ",DHCPv6Address,DHCPv6Result,DHCPv6Time"
        log.log(hdrStr)
        headerPrinted = True
  
      #
      # Log results as a single line in the CSV file.
      #
      logStr  = "%d,%s,%.1f" % ( nTests, rxIpInUse.host, (t1 - t0).total_seconds() )
  
      # IPv4 result
      if None == responseTimes[0]:
        logStr += ",%s,FAIL,-1" % ( pinghosts[0] )
      else:
        logStr += ",%s,OK,%.1f" % ( pinghosts[0], responseTimes[0] )
  
      # IPv6 result
      if len(receivers) > 1:
        if None == responseTimes[1]:
          logStr += ",%s,FAIL,-1" % ( pinghosts[1] )
        else:
          logStr += ",%s,OK,%.1f" % ( pinghosts[1], responseTimes[1] )
  
      # DHCPv6 result
      if testDhcpv6:
        if None == dhcpv6Time:
          logStr += ",::,FAIL,-1"
        else:
          logStr += ",%s,OK,%.1f" % ( dhcpv6Address, dhcpv6Time )
    
      log.log( logStr )
  
      if testDhcpv6 and None == dhcpv6Time:
        logStr = "ERROR: DHCPv6 failed"
        if None == dhcpv6Address:
          logStr += " dhcpv6Addr: None"
        else:
          logStr += " dhcpv6Addr: %s" % ( dhcpv6Address )
        log.log( logStr )
  
        if not None == alertEmail:
          sendAlert( "ERROR: DHCPv6 failure on %s" % ( alertEmail, pinghosts[0] ), "no body" )
  
        if dhcpv6StopOnFailure:
          printWithDate("Exit on DHCPv6 failure.")
          sys.exit(1)
  
  
      # Wait before restarting receiver again.
      sleepSec = 120.0
      if not minSec == maxSec:
        sleepSec = random.randrange(minSec, maxSec)
      sys.stderr.write('Sleep for %.0f seconds...\n' % ( sleepSec ) )
      time.sleep( sleepSec ) ;
  
      if testDhcpv6:
        d6Index = len(receivers) - 1
        if None == dhcpv6Address:
          dhcpv6Host = "[%s]" % ( dhcpv6Address )
          if not 80 == http_port:
            dhcpv6Host = dhcpv6Host + ':' + str(http_port)
          receivers[ d6Index ] = Receiver( dhcpv6Host,
                                           username,
                                           password )
        else:
          receivers[ d6Index ] = None
  
      # Increment the receiver index.
      rxIndex += 1
      if rxIndex >= len(receivers):
        rxIndex = 0

    except KeyboardInterrupt:
      print("Terminating on Ctrl-C")
      sys.exit(1)

    except:
      traceback.print_exc()

      if not None == alertEmail:
        print("------------------------- Sending email alert")
        sendAlert( "ERROR: General socket failure on %s" % ( alertEmail, pinghosts[0] ), "no body" )

      sys.exit(1)
