import os, errno
import time
import numpy as np
import pprint
import itertools
import networkx as nx
from shutil import copyfile
#
# Converts TensorFlow trained models to Numpy/Python and C++
#
#   The main goal is to simplify deploying a trained TensorFlow model.
# A standard TensorFlow deployment on an embedded system involves
# porting the engine to your platform.  The TensorFlow library has lots
# of features not needed when deploying a trained model, so the minimum
# binary executable size is ~13MB.
#
# TensorFlow compiles down to a relatively small number of fundamental
# operations - for example see tfconvert_ops.py
# More details at:
# https://www.tensorflow.org/versions/master/api_docs/python/math_ops.html#math

def make_sure_path_exists(path):
    """Create 'path' if it doesn't exist"""
    try:
        os.makedirs(path)
    except OSError as exception:
        if exception.errno != errno.EEXIST:
            raise

class Model(object):
    def __init__(self, sess, row_major=True):
        """row_major - default Eigen Tensor layout"""
        self.dag = None
        self.ops = {}
        self.inputs = None
        self.outputs = None
        self.variables = {}
        self.placeholders = {}
        self.connections = {}
        self.sess = sess
        self.sorted_list = None
        self.input_descendants = None
        self.sorted_list_precalc = None
        self.sorted_list_calc = None
        self.test_module = None
        self.layout = 'RowMajor' if row_major else 'ColMajor'

        self.include_file = None
        self.func_signature = None

    def build(self, x, y):
        """Build an internal model of a TensorFlow trained model.
        x = TensorFlow input
        y = TensorFlow output"""
        self.inputs = x
        self.outputs = y
        print('parsing model...')
        self.parse_val( y )
        print('building graph nodes...')
        unsorted_dict = {}
        for o in self.ops:
            unsorted_dict[o] = {}
        print('building graph edges...')
        for k,v in self.connections.items():
            for o in v:
                try:
                    unsorted_dict[o][k] = 1
                except KeyError:
                    print('missing:',k,o)
        self.unsorted_dict = unsorted_dict # for debugging
        print('sorting graph...')
        self.dag = nx.DiGraph()
        for k in unsorted_dict:
            self.dag.add_node(k)
        for k,v in unsorted_dict.items():
            for v0 in v:
                self.dag.add_edge(v0,k)
        self.sorted_list = list(nx.topological_sort( self.dag ))
        self.input_descendants = nx.descendants( self.dag, self.inputs.op.name )
        self.sorted_list_precalc = [x for x in self.sorted_list if x not in self.input_descendants]
        self.sorted_list_calc = [x for x in self.sorted_list if x in self.input_descendants]

    def sanitize_var(self,var):
        """Make TensorFlow variable name 'var' usable in Python/C++ scripts"""
        return 'var_' + var.replace(':','__').replace('/','__')

    def get_out_index(self,op_input):
        """Usually a TensorFlow operation has one output.  In the case of many
        outputs we need to find which # the operation input (op_input)
        is connected to.
        """
        for i,op_output in enumerate(op_input.op.outputs):
            if op_input == op_output:
                return i
        raise Exception("Couldn't find matching input")

    def get_extra_args(self,op,do_lower=False):
        """Some TensorFlow operations (op) have attributes.  If so, return
        them in a list.
        TODO - need an operation registry.  Put all Numpy/C++/extra_arg information
        in one place.
        """
        extra_args = []
        if op.type == "MatMul":
            extra_args = [str(op.get_attr('transpose_a')), str(op.get_attr('transpose_b'))]
        elif op.type == "Sum":
            extra_args = [str(op.get_attr('keep_dims'))]
        elif op.type == "Unpack":
            try:
                num = op.get_attr('num')
            except ValueError:
                num=None
            try:
                axis = op.get_attr('axis')
            except ValueError:
                axis=0
            extra_args = [str(num),str(axis)]
        if do_lower:
            return [x.lower() for x in extra_args]
        else:
            return extra_args

    def format_numpy_op(self, op):
        """Format a TensorFlow operation (op) for a Numpy script.
        """
        op_args = []
        for x in op.inputs:
            if len(x.op.outputs) == 1:
                var_name = self.sanitize_var(x.op.name)
                if x.op.name != self.inputs.op.name:
                    var_name = 'self.'+var_name
                op_args.append(var_name)
            else:
                num = self.get_out_index( x )
                op_args.append("%s[%d]" % ('self.'+self.sanitize_var(x.op.name),num))
        extra_args = self.get_extra_args(op)
        return 'myop_' + op.type + '(' + ', '.join(op_args + extra_args) + ')'

    def output_numpy(self,test_module):
        """Output a Numpy trained model to gen/'test_module'.py
        """
        self.test_module = test_module

        thresh = np.get_printoptions()['threshold']
        np.set_printoptions(threshold=np.inf)

        make_sure_path_exists('gen/')
        with open('gen/' + test_module + '.py','w') as f:
            f.write('from tfconvert_ops import *\n')
            f.write('from numpy import *\n')
            f.write('class Model(object):\n')
            f.write('  def __init__(self):\n')
            for k,(v,op) in self.variables.items():
                data=v.eval(session=self.sess)
                f.write("    self.%s = %s\n" % \
                        (self.sanitize_var(op.name), \
                         pprint.pformat(data)))
            for o in self.sorted_list_precalc:
                op = self.ops[o]
                if op.type in ("Variable","VariableV2","Const","Placeholder"):
                    continue
                f.write("    self.%s = %s\n" % \
                        (self.sanitize_var(o), \
                         self.format_numpy_op(op)))
            for o in self.sorted_list_calc:
                f.write("    self.%s = None\n" % self.sanitize_var(o) )

            f.write('  def eval_func(self,%s):\n' % self.sanitize_var(self.inputs.op.name))
            for o in self.sorted_list_calc:
                op = self.ops[o]
                if op.type in ("Variable","VariableV2","Const","Placeholder"):
                    continue
                f.write("    self.%s = %s\n" % \
                        (self.sanitize_var(o), \
                         self.format_numpy_op(op)))
            f.write('    return self.%s\n' % self.sanitize_var(op.name))

        np.set_printoptions(threshold=thresh)

    def output_numpy_validation(self,out_file,test_input,test_result):
        """Output a Numpy test case to gen/'out_file'
        Must be called after output_numpy()
        """
        make_sure_path_exists('gen/')

        # output test data to binary file
        inbin_name = self.test_module + '_in.npy'
        outbin_name = self.test_module + '_out.npy'
        np.save( 'gen/' + inbin_name, test_input )
        np.save( 'gen/' + outbin_name, test_result )

        with open('gen/' + out_file,'w') as f:
            f.write('import numpy as np\n')
            f.write('from time import time\n')
            f.write('import sys\n')
            f.write('test_input = np.load("%s")\n' % inbin_name )
            f.write('desired_output = np.load("%s")\n' % outbin_name )

            f.write('import %s\n' % self.test_module)
            f.write('mval = %s.Model()\n' % self.test_module)
            f.write('n_runs = int(sys.argv[1])\n')
            f.write('t1 = time()\n')
            f.write("for i in range(n_runs):\n")
            f.write("  code_output = mval.eval_func(test_input)\n")
            f.write('t2 = time()\n')
            f.write('print("Run time %.1f us, n_runs=%d" % ((t2-t1)*1e6,n_runs))\n')
            f.write("print('truth',desired_output)\n")
            f.write("print('diff',code_output - desired_output)\n")
            f.write("print('valid=',not np.any( abs(desired_output/code_output-1) > .001 ))\n")

        copyfile('tfconvert_ops.py','gen/tfconvert_ops.py')
        print('Run verification to make sure conversion worked:')
        print('  cd gen/')
        print('  python %s 1' % out_file)

    def format_cpp_var(self, var_name, op):
        """Output a C++ variable declaration.
        var_name = variable name
        op = TensorFlow operation
        """
        n_dims = len(op.outputs[0].get_shape())
        if n_dims == 0:
            n_dims = 1
        if op.type == "Unpack":
            result = "Tensor<float,%d,dMODEL_LAYOUT> %s[%d];\n" % \
                     (n_dims, \
                      var_name, \
                      op.get_attr('num')
                     )
        else:
            result = "Tensor<float,%d,dMODEL_LAYOUT> %s;\n" % \
                     (n_dims, \
                      var_name)
        return result

    def format_cpp_op(self, var_name, op):
        """Output a C++ TensorFlow operation.
        var_name = variable name
        op = TensorFlow operation
        """
        result = ''
        op_args = []
        for x in op.inputs:
            if len(x.op.outputs) == 1:
                op_args.append(self.sanitize_var(x.op.name))
            else:
                num = self.get_out_index( x )
                op_args.append("%s[%d]" % (self.sanitize_var(x.op.name),num))
        extra_args = self.get_extra_args(op,do_lower=True)
        result += 'myop_' + op.type + '(' + ', '.join([var_name] + op_args + extra_args) + ') ;\n'
        return result

    def format_cpp_var_init(self,var_name,data,binary_out=None):
        """Output initialization for C++ TensorFlow variable.
        var_name = variable name
        data = initialization data (Numpy array or a float)
        binary_out = output into a binary file for C++ reading?
        """
        var_name = self.sanitize_var(var_name)
        if len(data.shape) >= 1:
            myshape = data.shape
        else:
            myshape = [1]
            data = np.array([data])

        if binary_out is not None:
            result = "TensorMap< Tensor<float,%d,dMODEL_LAYOUT> > %s( %s, %s ) ;\n" % \
                     (len(myshape),\
                      var_name, \
                      "init_" + var_name, \
                      ', '.join([str(x) for x in myshape]))
            import struct
            with open(binary_out,'wb') as f:
                if self.layout == 'RowMajor':
                    flat_data = data.flatten()
                else:
                    flat_data = data.T.flatten()
                f.write(struct.pack('i',len(flat_data)))
                for x in flat_data:
                    f.write(struct.pack('f',x))
        else:
            result = ' {\n'
            result += '  Eigen::array<ptrdiff_t,%d> dim;\n' % len(myshape)
            for i,x in enumerate(myshape):
                result += '  dim[%d] = %d;\n' % (i,myshape[i])
            result += '  %s.resize( dim );\n' % var_name
            result += ' }\n'
            flat_data = data.flatten()
            for i,x in enumerate(itertools.product(*[range(y) for y in myshape])):
                result += " %s%s = %.9g ;\n" % (var_name, pprint.pformat(x).replace(',)',')'), flat_data[i])
        return result


    def output_cpp(self,file_name,cpp_classname):
        """Output a C++ trained model to gen/'file_name' (and header)
        """
        make_sure_path_exists('gen/')
        self.include_file = file_name + '.h'
        self.func_signature = 'void eval_func(Tensor<float,%d,dMODEL_LAYOUT> &out, const Tensor<float,%d,dMODEL_LAYOUT> &%s)' \
                              % (len(self.outputs.get_shape()), \
                                 len(self.inputs.get_shape()), \
                                 self.sanitize_var(self.inputs.op.name) \
                             )

        with open('gen/' + self.include_file,'w') as f:
            f.write('// GENERATED FILE - DO NOT EDIT\n')
            f.write('#include "Eigen/Core"\n')
            f.write('#include "unsupported/Eigen/CXX11/Tensor"\n')
            f.write('#define dMODEL_LAYOUT %s\n' % self.layout)
            f.write('using namespace Eigen;\n')
            f.write('class %s {\n' % cpp_classname)
            f.write(' public:\n')
            f.write('  %s();\n' % cpp_classname)
            f.write('  '+self.func_signature + ';\n' )
            f.write(' private:\n')
            for o in self.sorted_list:
                op = self.ops[o]
                var_name = self.sanitize_var(o)
                f.write('  '+self.format_cpp_var(var_name,op))
            f.write('};\n')

        with open('gen/' + file_name + '.cpp','w') as f:
            f.write('// GENERATED FILE - DO NOT EDIT\n')
            f.write('#include "stinger_config.h"\n')
            f.write('#if ENABLE_EL_SLOPES\n')
            f.write('#include "%s"\n' % self.include_file)
            f.write('#include "tfconvert_ops.h"\n')
            f.write('%s::%s() {\n' % (cpp_classname,cpp_classname) )
            for k,(v,op) in self.variables.items():
                data=v.eval(session=self.sess)
                f.write( self.format_cpp_var_init( op.name, data ) )
            f.write('}\n')

            f.write(self.func_signature.replace('eval_func','%s::eval_func'%cpp_classname) + ' {\n' )
            for o in self.sorted_list:
                if len(self.dag.edges(o)) == 0 and o != self.outputs.op.name:
                    print('DEBUG: skip',o)
                    continue
                op = self.ops[o]
                if op.type in ("Variable","VariableV2","Const","Placeholder"):
                    continue
                var_name = self.sanitize_var(o)
                f.write(' '+self.format_cpp_op(var_name,op))
            f.write(' out = %s;\n' % self.sanitize_var(o))
            f.write('}\n')
            f.write('#endif // ENABLE_EL_SLOPES\n')

    def output_cpp_validation(self,out_file,cpp_classname,test_input,test_result):
        """Output test for the C++ trained model under gen/
        out_file = test file name (e.g., test.cpp)
        TODO - use a template file (like T04 library)
        """
        make_sure_path_exists('gen/')
        with open('gen/stinger_config.h','w') as f:
            f.write('#define ENABLE_EL_SLOPES 1\n')
        with open('gen/' + out_file,'w') as f:
            f.write('#include <iostream>\n')
            f.write('#include <cmath>\n')
            f.write('#include <limits>\n')
            f.write('#include <stdlib.h>\n')
            f.write('using namespace std ;\n')
            f.write('#include "%s"\n' % self.include_file)
            f.write('#include "tfconvert_test.h"\n')
            f.write('#include <ctime>\n')
            f.write('int main(int argc,char *argv[]) {\n')
            f.write('  int n_runs = 1;\n')
            f.write('  if(argc>=2) n_runs=atoi(argv[1]);\n')
            bin_in_filename = 'gen/' + self.test_module+'_in.bin'
            bin_out_filename = 'gen/' + self.test_module+'_out.bin'
            f.write('  float *init_var_input = load_float("%s");\n' % bin_in_filename)
            f.write('  float *init_var_desired_output = load_float("%s");\n' % bin_out_filename)
            f.write(self.format_cpp_var_init('input',test_input,binary_out=bin_in_filename))
            f.write(self.format_cpp_var_init('desired_output',test_result,binary_out=bin_out_filename))

            f.write('  Tensor<float,2,dMODEL_LAYOUT> var_output;\n')
            f.write('  %s x = %s();\n' % (cpp_classname,cpp_classname))
            f.write('  cout << "n_runs: " << n_runs << endl;\n')
            f.write('  clock_t t_start = clock();\n')
            f.write('  for(int i=0; i < n_runs; i++)\n')
            f.write('    x.eval_func(var_output, var_input);\n')
            f.write('  clock_t t_end = clock();\n')
            f.write('  cout << "CPS: " << CLOCKS_PER_SEC << endl;\n')
            f.write('  cout << "Time: " << (t_end - t_start)*1e6/(double)CLOCKS_PER_SEC << " us" << endl;\n')
            f.write("  cout << \"output: \\n\" << var_output << endl ;\n")
            f.write("  cout << \"desired_output: \\n\" << var_desired_output << endl ;\n")
            #f.write("  cout << \"diff: \" << var_desired_output - var_output << endl ;\n")
            f.write("  bool valid=true;\n")
            f.write("  for(int i=0;i<var_output.dimensions()[0];i++) {\n")
            f.write("    for(int j=0;j<var_output.dimensions()[1];j++) {\n")
            f.write("      float diff = fabs(var_output(i,j) - var_desired_output(i,j));\n")
            f.write("      if( diff > fabs(var_desired_output(i,j))*1e-5 ) {\n")
            f.write("        valid = false;\n")
            f.write('        cout << "err: " << i << ", " << j << " = " << diff << endl;\n')
            f.write("      }\n")
            f.write("    }\n")
            f.write("  }\n")
            f.write("  cout << \"valid: \" << valid << endl ;\n")
            f.write("  return 0 ;\n")
            f.write("}\n")
        print('Do: g++ -g -std=c++11 -O -I. -Igen/ -Wall -W tfconvert_test.cpp %s %s' %
              ('gen/'+out_file, 'gen/'+self.test_module+'.cpp'))
        print('(or for older g++ "-std=gnu++0x" will work with a couple warnings...)')

    def parse_val(self, val):
        """Go through the TensorFlow object recursively and build up a
        list of all variables/operations/etc.
        val = TensorFlow connection
        """
        op = val.op
        if op.name in self.ops:
            # Don't add same op twice and possibly end up in infinite loop...
            return

        self.ops[op.name] = op

        if op.type in ("Variable","VariableV2", "Const"):
            self.variables[val.name] = (val,op)
            return
        elif op.type == "Placeholder":
            self.placeholders[val.name] = (val,op)
            return

        for i in op.inputs:
            self.connections.setdefault(i.op.name,{})[op.name] = 1
            self.parse_val( i )

