#!/usr/bin/env python3

########################################################################
# 
# okws.py
#
#
# This library supports generation of option keys using OKWS. Either SPOK
# or Arrowhead timed options Subscription keys may be created.
#
# OKWS Security:
#   OKWS access is enabled through standard AppId/secret client_credentials.
#   These credentials are not stored in CVS for security reasons. They are
#   stored in the files specified as follows below.
#       STAGING_CREDENTIALS_JSON
#       PRODUCTION_CREDENTIALS_JSON
#
# See the comments in the Okws() methods below for more info about operation.
#   __init__()         - Constructor
#   getAuthorization() - Internal method to fetch the access token.
#   getSpok()          - Fetch a SPOK key.
#   getSub()           - Fetch an Arrowhead Subscription key.
#
# Copyright Trimble Inc 2022-2023
########################################################################



import datetime, os, sys
import http.client
import json

authConfigFiles = type('', (object,), {
  'spokStaging':    'secrets/okws-staging.json',
  'spokProduction': 'secrets/okws-production.json',
  'smokStaging':    'secrets/okws-smok-staging.json'
} )



class Okws():
  SPOK_STAGING    = 1
  SPOK_PRODUCTION = 2
  SMOK_STAGING    = 3
  #SMOK_PRODUCTION = 4

  # -------------------------------------------------- __init__()
  ''' Constructor for Okws(). Loads the OKWS key/secret credentials, endpoint,
      and sub-URL from either the staging or production JSON credentials file.
  '''
  def __init__(self, keyType):

    self.keyType = keyType

    self.loadCredentials()

    self.debug = False

  # -------------------------------------------------- loadCredentials()
  def loadCredentials(self):
    global authConfigFiles

    creds = None

    # Load OKWS Production or Staging JSON credentials files.
    if self.keyType == Okws.SPOK_STAGING:
      with open( authConfigFiles.spokStaging, 'r' ) as fp:
        creds = json.load( fp )
    elif self.keyType == Okws.SPOK_PRODUCTION:
      with open( authConfigFiles.spokProduction, 'r' ) as fp:
        creds = json.load( fp )
    elif self.keyType == Okws.SMOK_STAGING:
      with open( authConfigFiles.smokStaging, 'r' ) as fp:
        creds = json.load( fp )
    else:
      print("ERROR: Unhandled keyType in loadCredentials()")
      return

    self.tokenEndpoint    = creds["tokenEndpoint"]

    self.spokEndpoint     = creds["spokEndpoint"]
    self.spokEndpointPath = creds["spokEndpointPath"]

    # Currently, only the Staging endpoint generates Arrowhead subsscriptions.
    self.subsEndpoint     = None
    self.subsEndpointPath = None
    if self.keyType == Okws.SPOK_STAGING:
      self.subsEndpoint     = creds["subsEndpoint"]
      self.subsEndpointPath = creds["subsEndpointPath"]

    self.bearerToken      = creds["bearer"]
    self.scope            = creds["scope"]
    self.debug            = False

  # -------------------------------------------------- debugLog()
  def debugLog(self, fcn, log):
    if not self.debug:
      return
    print("OKWS: %s() %s" % ( fcn, log ))

  # -------------------------------------------------- getAuthorization()
  ''' Note: Internal-only method not called by the using application.
      This method is called before every transaction to obtain the bearer
      token enabling OKWS API access.
  '''
  def getAuthorization(self):
    fcnName = 'getAuthorization'
    self.debugLog( fcnName, 'connect' )

    utc = datetime.datetime.utcnow()

    conn = http.client.HTTPSConnection( self.tokenEndpoint )
    payload = 'grant_type=client_credentials&scope=%s' % ( self.scope )
    headers = {
      'Content-Type': 'application/x-www-form-urlencoded',
      'Authorization': 'Basic %s' % ( self.bearerToken )
    }

    self.debugLog( fcnName,
                   'POST /oauth/token?scope=%s payload=%s headers=%s' % (
                       self.scope, payload, str(headers) ) )
    conn.request( "POST", "/oauth/token?scope=%s" % ( self.scope ),
                  payload, headers )
    res = conn.getresponse()
    if not None == res:
      data = res.read()
      dataDict = json.loads( data.decode("utf-8") )

      gotToken   = 'access_token' in dataDict.keys()
      gotExpires = 'expires_in' in dataDict.keys()
      success    = gotToken and gotExpires

      dashes = '-'*10
      errorHRule = "%s ERROR %s ERROR %s" % ( dashes, dashes, dashes )
      if not success:
        self.debugLog( fcnName, errorHRule )
      self.debugLog( fcnName, data.decode("utf-8") )
      if not success:
        self.debugLog( fcnName, errorHRule )

      if gotToken and gotExpires:
        token = dataDict[ 'access_token' ]

        utc2 = datetime.datetime.utcnow()
        self.debugLog( fcnName,
                       'elapsed: %.1f sec.' % (utc2 - utc).total_seconds() )
        return token
      else:
        self.debugLog( fcnName, errorHRule )
        return None

    else:
      return None

  # -------------------------------------------------- getSpok()
  ''' getSpok():
        Given the user's email address, serial number, and a legacy format
        permanent options key, this method forms a POST request and fetches
        the key from the OKWS SPOK endpoint.

        The optional receiverTypes=[] field by default allows the key to be
        used with any receiver model. If a lost of model numbers is specified,
        SPOK key operation will be limited to those receivers.

        The newSerial, hwIdToken, and newSerialType are used only when the
        receiver S/N is being set or changed.
  '''
  def getSpok(self, email, serialNumber, legacyOptionKey,
              receiverTypes=[],
              newSerial=None, hwIdToken=None, newSerialType = "0"):
    fcnName = 'getSpok'
    utc = datetime.datetime.utcnow()

    try:
      #
      # Get the access token.
      #
      self.debugLog( fcnName, 'enter' )
      # Get the access token.
      accessToken = self.getAuthorization()
      if not None == accessToken:
        self.debugLog( fcnName, 'getAuthorization() - OK' )
      else:
        self.debugLog( fcnName, 'getAuthorization() - FAILED' )
        return None
  
      #
      # Initialize the payload fields fields for the POST request.
      #
      payloadDict = {
        "userTID": email,
        "userName": "",
        "userEmail": email,
        "serialNumber": serialNumber,
        "receiverTypes": receiverTypes,
        "optionKey": legacyOptionKey
      }

      #
      # Handle the special case of setting or changing the serial number.
      #
      if not None == newSerial:
        self.debugLog( fcnName,
                       "toOKWS: newSerialType(okws): %s" % ( newSerialType ) )
        payloadDict["NT"] = newSerialType ;
        payloadDict["newSerialNumber"] = newSerial ;
        if "0" == newSerialType:
          payloadDict["FS"] = "0"
        else:
          payloadDict["FS"] = "1"

      if not None == hwIdToken:
        fields = hwIdToken.split( '_' )
        self.debugLog( fcnName, 'toOKWS: hwIdToken: %s' % ( newSerialType ) )
        payloadDict["hardwareID"] = fields[0]
        if len(fields) > 1:
          payloadDict["pcbSerialNumber"] = fields[1]
        if len(fields) > 2:
          payloadDict["permanentHardwareID"] = fields[2]

      payload = json.dumps( payloadDict )

      #
      # Define header fields.
      #
      headers = {
        'Content-Type': 'application/json',
        'Authorization': 'Bearer %s' % ( accessToken )
      }

      self.debugLog( fcnName, 'headers: %s' % ( str(headers) ) )
      self.debugLog( fcnName, 'payload: %s' % ( payload ) )

      self.debugLog(fcnName, 'POST request path: %s' % (self.spokEndpointPath))
      #
      # Perform the POST request and get the response.
      #
      self.debugLog( fcnName, 'connect endpoint: %s' % (self.spokEndpoint) )
      conn = http.client.HTTPSConnection( self.spokEndpoint )

      self.debugLog(fcnName, 'POST request path: %s' % (self.spokEndpointPath))
      conn.request( "POST",
                    self.spokEndpointPath,
                    payload,
                    headers )

      self.debugLog( fcnName, 'get response' )
      res = conn.getresponse()
      data = res.read()

      #
      # Decode the response.
      #
      d = json.loads( data.decode("utf-8") )
      # Example return from OKWS:
      #   {'creationDateTime': '2023-02-07T07:31:49Z',
      #    'poBlob': 'sXtCduk1WkOiChN6HgjV.JUppikFLojPqqfdFuIeJY2S8oG4h4Ecp2P5m6j/piXwQxcux8diNHY9eWO2n4Xwnbf/SQHn1NmPCke9ufYBLCOuBYGuBcl9gCzJL3W7PPg19GuiBNq3k.8CG6P4H0YjJwAA0I.YCAo0cH/mrXe9eZLGZ2Et9426WPNW01kswG8SDN8oEXJ7wKNe6TSzczWMQQhWvRscvsky/OZUdhwxGAcTCAv3MRarVzbRuYbzqsu2GfW40XOE6IY5ec',
      #    'receiverTypes': [],
      #    'serialNumber': '6046F00001'}
      if 'poBlob' in d.keys():
        self.debugLog( fcnName, 'response: %s' % ( data.decode('utf-8') ) )
        self.debugLog( fcnName, 'poBlob: %s' % ( d["poBlob"] ) )

        utc2 = datetime.datetime.utcnow()
        self.debugLog( fcnName,
                       'elapsed: %.1f sec.' % (utc2 - utc).total_seconds() )

        return d["poBlob"]
      else:
        dashes = '-'*10
        errorHRule = "%s ERROR %s ERROR %s" % ( dashes, dashes, dashes )
        if not 'poBlob' in d.keys():
          self.debugLog( fcnName, errorHRule )
        self.debugLog( fcnName, 'OKWS Response: %s' % ( str(d) ) )

        if 'errorCode' in d.keys():
          print("OKWS-errorCode: %s" % ( d['errorCode'] ))
        if 'errorDetails' in d.keys():
          print("OKWS-errorDetails: %s" % ( d['errorDetails'] ))

        if not 'poBlob' in d.keys():
          self.debugLog( fcnName, errorHRule )

        return None

    except Exception as e:
      print("ERROR: OKWS-Exception: %s" % ( str(e) ))
      return None

  # -------------------------------------------------- getSub()
  ''' getSub():
        - fetch a bearer token
        - form a request and submit
        - decode the response from the OKWS Higgs Subscription endpoint
        - return the "blob" or None on error
  '''
  def getSub( self, email, serialNumber, legacyOptionKey,
              utcExpire, receiverTypes ):
    try:
      fcnName = 'getSub'
      self.debugLog( fcnName, 'enter' )

      #
      # Get the access token.
      #
      accessToken = self.getAuthorization()
      if not None == accessToken:
        self.debugLog( fcnName, 'getAuthorization() - OK' )
      else:
        self.debugLog( fcnName, 'getAuthorization() - FAILED' )
        return None
  
      #
      # Initialize the payload fields fields for the POST request.
      #
      expirationStr = utcExpire.strftime( "%Y-%m-%dT%H:%M:%SZ" )
      payloadDict = {
        "userTID": email,
        "userName": "",
        "userEmail": email,
        "expirationDateTime": expirationStr,
        "serialNumber": serialNumber,
        "receiverTypes": receiverTypes,
        "licenseType": 'TL',
        "optionKey": legacyOptionKey
      }

      payload = json.dumps( payloadDict )

      #
      # Define header fields.
      #
      headers = {
        'Content-Type': 'application/json',
        'Authorization': 'Bearer %s' % ( accessToken )
      }

      self.debugLog( fcnName, 'headers: %s' % ( str(headers) ) )
      self.debugLog( fcnName, 'payload: %s' % ( payload ) )

      #
      # Perform the POST request and get the response.
      #
      self.debugLog( fcnName, 'connect endpoint: %s' % ( self.subsEndpoint ) )
      conn = http.client.HTTPSConnection( self.subsEndpoint )

      self.debugLog(fcnName, 'POST request path: %s' % (self.subsEndpointPath))
      conn.request( "POST",
                    self.subsEndpointPath,
                    payload,
                    headers )

      self.debugLog( fcnName, 'get response' )
      res = conn.getresponse()
      data = res.read()

      #
      # Decode the response.
      #
      d = json.loads( data.decode("utf-8") )

      # Example return from OKWS:
      #   { "creationDateTime": "2023-09-06T23:36:21Z",
      #      "expirationDateTime": "2025-04-21T00:00:00Z",
      #      "licenseBlob": "6Vv6Qv.Hn00krvQanFYU5uLz8eLcxdlkjCk.IvWqwiqHUCjnuNtlIwkWQ/Q9dFKco4O72eFKkTqPxm1ZNJiUi/nCErDRNAsu/ozeOGkFGF69jsqZVTWziMWc8awX96khZie59rKJN5DF1EBGXph/ghGYkFxvaatf8LeutXvh3C3ph4E946adGUMoOsQn/gkpI8GPp3UcpMabNS3jpvdPe0Jy5ThkXLi3FfimBVX67.J3TqASS4z1PQbmNjYXFPNDg8iAcfyHxc4ar13rDIh0qHTLuONp",
      #      "licenseType": "TL",
      #      "receiverTypes": [],
      #      "serialNumber": "2222222222"
      #   }

      if 'licenseBlob' in d.keys():
        self.debugLog( fcnName, 'response: %s' % ( data.decode('utf-8') ) )
        self.debugLog( fcnName, 'licenseBlob: %s' % ( d["licenseBlob"] ) )
        return d["licenseBlob"]
      else:
        dashes = '-'*10
        errorHRule = "%s ERROR %s ERROR %s" % ( dashes, dashes, dashes )
        if not 'licenseBlob' in d.keys():
          self.debugLog( fcnName, errorHRule )
        self.debugLog( fcnName, 'OKWS Response: %s' % ( str(d) ) )

        if 'errorCode' in d.keys():
          print("OKWS-errorCode: %s" % ( d['errorCode'] ))
        if 'errorDetails' in d.keys():
          print("OKWS-errorDetails: %s" % ( d['errorDetails'] ))

        if not 'licenseBlob' in d.keys():
          self.debugLog( fcnName, errorHRule )

        return None

    except Exception as e:
      print("ERROR: OKWS-Exception: %s" % ( str(e) ))
      return None

  # -------------------------------------------------- getSmok()
  def getSmok(self, email, serialNumber,
              firmwareWarranty ="",        # Format: YYYY-MM-DD
              downgradeLimit   ="",          # Format: 566
              securityToken    ="",
              unlock           = False,
              unbrick          = False):

    fcnName = 'getSmok'
    self.debugLog( fcnName, 'enter' )

    utc = datetime.datetime.utcnow()

    try:
      #
      # Get the access token.
      #
      self.debugLog( fcnName, 'enter' )
      # Get the access token.
      accessToken = self.getAuthorization()
      if not None == accessToken:
        self.debugLog( fcnName, 'getAuthorization() - OK' )
      else:
        self.debugLog( fcnName, 'getAuthorization() - FAILED' )
        return None
  
      #
      # Initialize the payload fields fields for the POST request.
      #
      payloadDict = {
        "userTID":   email,
        "userName":  "",
        "userEmail": email,

        "serialNumber":        serialNumber,
        "firmwareWarranty":    firmwareWarranty,
        "downgradeLimit":      downgradeLimit,
        "unlockHiddenCommand": "true" if unlock else "",
        "unbrickReceiver":     "true" if unbrick else "",
        "securityToken":       securityToken
      }

      #
      # Handle the special case of setting or changing the serial number.
      #
      payload = json.dumps( payloadDict )

      #
      # Define header fields.
      #
      headers = {
        'Content-Type': 'application/json',
        'Authorization': 'Bearer %s' % ( accessToken )
      }

      self.debugLog( fcnName, 'headers: %s' % ( str(headers) ) )
      self.debugLog( fcnName, 'payload: %s' % ( payload ) )

      self.debugLog(fcnName, 'POST request path: %s' % (self.spokEndpointPath))
      #
      # Perform the POST request and get the response.
      #
      self.debugLog( fcnName, 'connect endpoint: %s' % (self.spokEndpoint) )
      conn = http.client.HTTPSConnection( self.spokEndpoint )

      self.debugLog(fcnName, 'POST request path: %s' % (self.spokEndpointPath))
      conn.request( "POST",
                    self.spokEndpointPath,
                    payload,
                    headers )

      self.debugLog( fcnName, 'get response' )
      res = conn.getresponse()
      data = res.read()

      #
      # Decode the response.
      #
      d = json.loads( data.decode("utf-8") )
      # Example return from OKWS:
      # {
      #     "creationTime": "20240807T222310",
      #     "serialNumber": "6000R00000",
      #     "smokBlob": "534D4B30010066B3F3CE0015000A36303030523030303030010307E905020202361AED2E8D241A1B3D27780E2CBD11FD4FF4A4927E6C83207F5F21DC7BA3B32A1B523DF5BE2E473B7417812818CDA350F44A0AB2AC09199C4DF7C5848FC76796F5"
      # }
      if 'smokBlob' in d.keys():
        self.debugLog( fcnName, 'response: %s' % ( data.decode('utf-8') ) )
        self.debugLog( fcnName, 'smokBlob: %s' % ( d["smokBlob"] ) )

        utc2 = datetime.datetime.utcnow()
        self.debugLog( fcnName,
                       'elapsed: %.1f sec.' % (utc2 - utc).total_seconds() )

        return d["smokBlob"]
      else:
        dashes = '-'*10
        errorHRule = "%s ERROR %s ERROR %s" % ( dashes, dashes, dashes )
        if not 'smokBlob' in d.keys():
          self.debugLog( fcnName, errorHRule )
        self.debugLog( fcnName, 'OKWS Response: %s' % ( str(d) ) )

        if 'errorCode' in d.keys():
          print("OKWS-errorCode: %s" % ( d['errorCode'] ))
        if 'errorDetails' in d.keys():
          print("OKWS-errorDetails: %s" % ( d['errorDetails'] ))

        if not 'smokBlob' in d.keys():
          self.debugLog( fcnName, errorHRule )

        return None

    except Exception as e:
      print("ERROR: OKWS-Exception: %s" % ( str(e) ))
      return None
