import numpy as np

# TensorFlow has a relatively small # of basic operations:
# https://www.tensorflow.org/versions/master/api_docs/python/math_ops.html#math
# This file maps from TensorFlow operations to NumPy operation.
# See tfconvert_ops.h

def myop_Identity(a):
    return a

def myop_Add(a,b):
    return a+b

def myop_BiasAdd(a,b):
    return a+b

def myop_MatMul(a,b,transpose_a=False,transpose_b=False):
    return np.dot(a.T if transpose_a else a, b.T if transpose_b else b)

def myop_Softmax(a):
    e = np.exp(a)
    return np.divide(e, np.sum(e, axis=-1, keepdims=True))

def myop_Tanh(a):
    return np.tanh(a)

def myop_Reshape(a, shape):
    return a.reshape(shape)

def myop_ZerosLike(a):
    return np.zeros(a.shape, dtype=a.dtype)

def myop_Sum(a, reduction_indices, keep_dims=False):
    return np.sum(a, axis=tuple(reduction_indices), keepdims=keep_dims)

def myop_ExpandDims(a, dim):
    shape = list(a.shape)
    if dim >= 0:
        shape.insert(dim, 1)
    else:
        shape.insert(len(shape) + dim + 1, 1)
    return a.reshape(*shape)

def myop_Tile(a, n):
    return np.tile(a, n)

def myop_Mul(a, b):
    return np.multiply(a, b)

def myop_Concat(dim, *inputs):
    return np.concatenate(inputs, axis=dim)

def myop_Transpose(a, perm=None):
    return np.transpose(a, axes=perm)

def myop_Unpack(a,num=None,axis=0):
    """Given array of shape (A,B,C), return list (len A) of arrays: [ a0:(B,C), a1:(B,C), ..., aA-1:(B,C)]
    'axis' is the array to unpack, e.g., axis=0 -> A, axis=1 -> B, axis=2 ->C.
    'num' is the length of the given axis.  If None, then infer length.
    """
    return_list = []
    if num is None:
        num = a.shape[axis]

    if axis != 0:
        raise Exception('write code to handle axis!=0... use slice()')

    for i in xrange(num):
        return_list.append( a[i,:,:] )
    return return_list

def myop_Slice(a, begin, size):
    slice_list = []
    for i,(b,s) in enumerate(zip(begin,size)):
        if s==-1:
            s=a.shape[i]
        slice_list.append( slice(b,b+s) )
    return a[slice_list]

def myop_Minimum(a, b):
    return np.minimum(a, b)

def myop_Maximum(a, b):
    return np.maximum(a, b)

def myop_Sub(a, b):
    return np.subtract(a, b)
