#!/usr/bin/env python
usage="""\
Simple class to test monitor commands.  Assuming the receiver is
in monitor mode, here are some examples:
  ./RXMonitorCmds.py socket://10.1.123.123:9999 --print_product
  ./RXMonitorCmds.py COM1 --get_hw_info
  ./RXMonitorCmds.py /dev/ttyS0 --prog_img=file.img

If a receiver is not already in monitor mode, you can use
RXTools.sendDColGotoMonitor().
"""
import serial
import struct
import time
import os
import sys
import argparse

class RXMonitor:
    def __init__(self,port,baudrate=9600,verbose=False):
        self.baudrate=baudrate
        self.port = port
        self.num_blocks = None
        self.verbose = verbose
    def connect(self):
        self.s = serial.serial_for_url(self.port,timeout=5,baudrate=self.baudrate)
    def send_enq(self):
        self.s.write(bytearray([5]))
        t1 = time.time()
        self.s.timeout = 1.0
        data = self.s.read()
        if len(data)==1 and data[0] == 6:
            print("ENQ OK after %.3f[s]"%(time.time()-t1))
            return True
        else:
            print("ENQ failed",data)
            return False
    def read_result(self,timeout=5.):
        response = bytearray([])
        self.s.timeout = timeout
        t1 = time.time()
        while True:
            byte = self.s.read()
            if len(byte) == 0:
                break
            response += byte
            if len(response) >= 4 and response[0] == 2:
                expected_len = struct.unpack(">H",response[2:4])[0] + 4 + 3
                if len(response) == expected_len:
                    break
        if len(response) >= 4 and response[0] == 2:
            if expected_len != len(response):
                print("Warning - response length mismatch",
                      expected_len,
                      len(response));
        if self.verbose:
            print('recv:',response)
        return response

    def send_cmd(self,cmd_num,cmd_data,timeout=5.):
        data = struct.pack(">BBH",2,cmd_num,len(cmd_data))
        data += bytearray(cmd_data)
        checksum = 0
        for elem in data[1:]:
            checksum += elem
        checksum &= 0xffff
        data += struct.pack(">HB",checksum,3)
        if self.verbose:
            print("send: %s"%' '.join([hex(x) for x in data]))
        self.s.write(data)
        return self.read_result(timeout)

    def print_product(self):
        data = self.send_cmd( 6, [], timeout=10 )
        data = data[4:] # remove header
        print("Serial:  %s"%data[2:34].decode('ascii',errors='ignore'))
        print("Product: %s"%data[34:74].decode('ascii',errors='ignore'))
        print("Owner1:  %s"%data[74:110].decode('ascii',errors='ignore'))
        print("Owner2:  %s"%data[110:146].decode('ascii',errors='ignore'))
        print("HWVerMajor: %d"%data[146])
        print("HWVerMinor: %d"%data[147])
        print("App version: %d %d"%(data[151],data[152]))
        print("App day/month: %d %d"%(data[153],data[154]))

    def get_hw_info(self,verbose=True):
        data = self.send_cmd( 20, [] )
        print('flash base:  0x%x'%struct.unpack(">L",data[4:8])[0])
        print('flash bytes: 0x%x'%struct.unpack(">L",data[8:12])[0])
        print('HWmaj: %d'%data[12])
        print('HWmin: %d'%data[13])
        print('DSP:   %s'%data[14:30].decode('ascii'))
        self.prog_base = struct.unpack(">L",data[30:34])[0]
        print('prog base:   0x%x'%self.prog_base)
        print('prog offset: 0x%x'%struct.unpack(">L",data[34:38])[0])
        n_erase = struct.unpack(">H",data[38:40])[0]
        self.num_blocks = struct.unpack(">LLLL",data[40:56])[:n_erase]
        self.size_block = struct.unpack(">LLLL",data[56:72])[:n_erase]
        print('# erase: %d'%n_erase)
        for n in range(len(self.num_blocks)):
            print(" num %d size 0x%x"%(self.num_blocks[n],self.size_block[n]))

    def erase_len(self,total_len):
        if self.num_blocks is None:
            self.get_hw_info()
        address = self.prog_base
        # 21 = erase_sectors
        # 22 = erase_status
        t1 = time.time()
        print("Erasing...")
        while total_len > 0:
            data = struct.pack(">L",address)
            # TODO: more than 1 sector at once??
            n_sectors = 1
            is_ok = False
            print("%.1f: Sector 0x%x"%(time.time()-t1,address),end='\r')
            for n in range(5):
                result = self.send_cmd( 21,
                                        bytearray([n_sectors]) + data )
                if result[4] == 0:
                    is_ok = True
                    break
                time.sleep(0.1)
            if not is_ok:
                print("Error erasing 0x%x %d"%(address,result[4]))
                return False
            address += self.size_block[0]
            total_len -= self.size_block[0]
        print("")
        print("Done erasing")
        return True

    def check_erase(self,total_len):
        if self.num_blocks is None:
            self.get_hw_info()
        address = self.prog_base
        while True:
            # monitor.c MAX_DATA_SIZE = 2100 = 0x834
            curr_len = total_len
            if curr_len > 0x800:
                curr_len = 0x800
            data = bytearray([0xff]*curr_len)
            total_len -= len(data)
            result = self.send_cmd( 5,
                                    struct.pack(">LH",address,len(data)) )
            result = result[4:-3]
            if data != result:
                for i in range(len(data)):
                    if data[i] != result[i]:
                        raise RuntimeError("mismatch %d %d %d: 0x%x"%
                                           (i, len(data),
                                            len(result),
                                            address),
                                            data[i:i+5],
                                            result[i:i+5])
                raise RuntimeError("mismatch",address,data,result[4:-3])
            address += len(data)

    def check_img(self,img_filename):
        if self.num_blocks is None:
            self.get_hw_info()
        if img_filename.endswith('.raw'):
            self.prog_base = 0
            print("Raw file - adjust memory to 0 offset")
        file_len = os.path.getsize(img_filename)
        f_in = open(img_filename,'rb')
        address = self.prog_base
        #f_in.read(0x4000); address += 0x4000
        while True:
            # monitor.c MAX_DATA_SIZE = 2100 = 0x834
            data = f_in.read(0x800)
            if not data:
                break
            result = self.send_cmd( 5,
                                    struct.pack(">LH",address,len(data)) )
            result = result[4:-3]
            if data != result:
                for i in range(len(data)):
                    if data[i] != result[i]:
                        raise RuntimeError("mismatch %d %d %d: 0x%x"%
                                           (i, len(data),
                                            len(result),
                                            address),
                                            data[i:i+5],
                                            result[i:i+5])
                raise RuntimeError("mismatch",address,data,result[4:-3])
            address += len(data)

    def prog_img(self,img_filename):
        if self.num_blocks is None:
            self.get_hw_info()
        if img_filename.endswith('.raw'):
            self.prog_base = 0
            print("Raw file - adjust memory to 0 offset")
        file_len = os.path.getsize(img_filename)
        if not self.erase_len(file_len):
            print("Failed erase...")
            return
        f_in = open(img_filename,'rb')
        address = self.prog_base
        t1 = time.time()
        result = bytearray([])
        n_prog = 0
        n_OK = 0
        print("Programming...")
        while True:
            # monitor.c MAX_DATA_SIZE = 2100 = 0x834
            data = f_in.read(0x800)
            if not data:
                break
            print("%.1f: Programming addr 0x%x len %d - percent %.1f (%d %d)"%
                  (time.time()-t1,
                   address,
                   len(data),
                   100.*(address-self.prog_base)/file_len,
                   n_prog,
                   n_OK
                  ), end='\r')
            result += self.send_cmd( 23,
                                     struct.pack(">L",address) + data,
                                     timeout=0 )
            n_prog += 1
            while len(result) >= 8:
                if result[4] != 1:
                    raise RuntimeError("programming problem",result)
                n_OK += 1
                result = result[8:]
            address += len(data)
        while True:
            result += self.read_result(timeout=10.)
            if len(result) == 0:
                break
            while len(result) >= 8:
                if result[4] != 1:
                    raise RuntimeError("programming problem",result)
                n_OK += 1
                result = result[8:]
        print("")
        print("Done with programming %d/%d"%(n_OK,n_prog))

    def boot(self):
        data = self.send_cmd( 4, [] )
        print(data)

if __name__=='__main__':
    parser = argparse.ArgumentParser(description=usage,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('port', help='port to connect, e.g., socket://10.1.x.y:999 or COM1 or /dev/ttyS0')
    parser.add_argument('--print_product', help='print product info',
                        action='store_true')
    parser.add_argument('--get_hw_info', help='get hardware info',
                        action='store_true')
    parser.add_argument('--prog_img', help='program a .img file')
    parser.add_argument('--boot', help='boot to application',
                        action='store_true')
    parser.add_argument('--verbose', help='print raw received data')
    parser.add_argument('--baud', help='for serial ports, what baud rate?',
                        type=int, default=9600)

    args = parser.parse_args()

    rx = RXMonitor( args.port, baudrate=args.baud, verbose=args.verbose )
    rx.connect()
    if rx.send_enq():
        if args.print_product:
            rx.print_product()
        if args.get_hw_info:
            rx.get_hw_info()
        if args.prog_img:
            rx.prog_img(args.prog_img)
        if args.boot:
            rx.boot()
