#!/usr/bin/env python
"""
Usage: ./find_ip.py config.xml

Check config.xml for all units. Make sure the serial # matches the config,
and if not scan the network for the IP.

"""

import xmltodict
import RXTools
import sys
import glob
import multiprocessing
import requests
from contextlib import contextmanager
import signal
import socket

def raise_error(signum, frame):
    """This handler will raise an error inside gethostbyname"""
    raise OSError

@contextmanager
def set_signal(signum, handler):
    """Temporarily set signal"""
    old_handler = signal.getsignal(signum)
    signal.signal(signum, handler)
    try:
        yield
    finally:
        signal.signal(signum, old_handler)

@contextmanager
def set_alarm(time):
    """Temporarily set alarm"""
    signal.setitimer(signal.ITIMER_REAL, time)
    try:
        yield
    finally:
        signal.setitimer(signal.ITIMER_REAL, 0) # Disable alarm

@contextmanager
def raise_on_timeout(time):
    """This context manager will raise an OSError unless
    The with scope is exited in time."""
    with set_signal(signal.SIGALRM, raise_error):
        with set_alarm(time):
            yield

def get_upnp( ip ):
    try:
        r = requests.get( "http://%s/upnp.xml"%ip, timeout=2 )
        d = xmltodict.parse(r.text)
        return (ip, d['root']['device']['friendlyName'])
    except:
        return None

def scan_IPs(all_units, all_n):
    print("Scanning IPs")
    all_ip = ['10.1.149.%d'%n for n in range(2,256)]
    all_ip.extend(['10.1.150.%d'%n for n in range(2,256)])
    all_ip.extend(['10.1.151.%d'%n for n in range(2,256)])
    pool = multiprocessing.Pool(50)
    scan_data = pool.map( get_upnp, all_ip )
    found_n = set()
    for n,serial_num,config_ip in all_units:
        for scan_result in scan_data:
            if scan_result is None:
                continue
            if scan_result[1].find(serial_num) >= 0:
                print(n,serial_num,scan_result)
                if scan_result[0] != config_ip:
                    print(" ******* mismatch IP")
                found_n.add(n)
    print("Not found:",all_n.difference(found_n))

def check_IPs(all_units, all_n):
    all_ok = True
    print("Checking IPs")
    for n,serial_num,config_ip in all_units:
        result = get_upnp(config_ip)
        if result is None:
            print("Couldn't find",serial_num,config_ip)
            all_ok = False
            break
        if result[1].find(serial_num) < 0:
            print("Couldn't find",serial_num,config_ip)
            all_ok = False
            break
        try:
            with raise_on_timeout(1): # Timeout = 1s
                print(n,config_ip,socket.gethostbyaddr(config_ip)[0])
        except OSError:
            print("Could not gethostbyname in time")
    if all_ok:
        print("Config file looks OK")
    return all_ok

def main():
    if len(sys.argv) < 2:
        print(__doc__)
        sys.exit(0)
    config_name = sys.argv[1]
    d=xmltodict.parse(open(config_name).read())

    all_units = []
    all_n = set()
    print("Getting serial numbers")
    for rx in d['data']['devices']['rx']:
        n = int(rx['n'])
        serial_num = rx['serial']
        all_units.append((n,serial_num,rx['ip']))
        all_n.add(n)

    if not check_IPs( all_units, all_n ):
        scan_IPs( all_units, all_n)

if __name__ == "__main__":
    main()
