我们可以将非张量传递给tf.py_func输入吗?

时间:2017-03-19 15:19:02

标签: tensorflow

def np_function( np_array1, float_value):
    np_array2 = ...
return np_array2 

#tensorflow customised op
def tf_function( tf_tensor_in_gpu, float_value):
return  \
    tf.py_func(np_function,[tf_tensor_in_gpu, float_value],[tf.float32])

我想使用" tf.py_func"从我的函数中创建自定义张量流op。如何将非张量输入(例如" float_value"在上面的代码中)传递给我的函数。我的代码是否正确?当我调用session.run时,它在运行时给出错误。

1 个答案:

答案 0 :(得分:1)

我通过以下方式解决了这个问题:

def np_function_generator(float_value):
    def np_function(np_array1):
         np_array2 = ...
         ... you can use float_value here ...  
         return np_array2
    return np_function


#tensorflow customised op
def tf_function( tf_tensor_in_gpu, float_value):
       np_function = np_function_generator(float_value):
       return  \
       tf.py_func(np_function,[tf_tensor_in_gpu],[tf.float32])
相关问题