Tensorflow自定义渐变中的名称解析

时间:2019-05-03 10:36:38

标签: tensorflow gradient new-operator tensorflow-gradient

我正在使用各自的渐变在Tensorflow中创建几个自定义操作。一切都可以单独很好地工作,但是当我的两个包用相同的名称定义两个不同的操作(不同的输入)时,我面临一个问题。

为了简化我的问题,假设在两个软件包中定义了一个matmul操作。可以很容易地在以下代码中使用它:

import tensorflow as tf
my_ops_a = tf.load_op_library('libpackage_a.so')
my_ops_b = tf.load_op_library('libpackage_b.so')

x, y = tf.random.uniform(10,10), tf.random.uniform(10,10)
my_ops_a.matmul(x, y)
my_ops_b.matmul(x, y)

其梯度可以通过以下方式通知Tensorflow:

from tensorflow.python.framework import ops as tf_ops

@tf_ops.RegisterGradient("Matmul")
def _mat_mul_grad(op, grad):
    return my_ops_a.mat_mul_grad(grad, op.inputs[0], op.inputs[1])

@tf_ops.RegisterGradient("Matmul")
def _mat_mul_grad(op, grad):
    return my_ops_b.mat_mul_grad(grad, op.inputs[0], op.inputs[1])

但是,@tf_ops.RegisterGradient无法识别我所指的matmul

实际上,当我尝试运行通知的代码时,出现以下错误:

KeyError: "Registering two gradient with name 'Matmul'! (Previous registration was in <module> ...)

如何通知Tensorflow我所指的是特定软件包的操作?

谢谢。

0 个答案:

没有答案