为TensorFlow矩阵指数实现自定义梯度

时间:2020-03-10 16:36:17

标签: python tensorflow gradient

我正在尝试使用指数矩阵的系数在TensorFlow中构建自定义层,以便该层具有基本矩阵M1和M2以及拟合系数a和b,因此该层对具有矩阵的输入向量起作用exp(a M1 + b M2)。

对此梯度没有封闭形式的解决方案,而且TensorFlow无论如何都不能采用矩阵指数的梯度,因此我需要在我的图层类中实现针对a和b的自定义梯度。这是该层的代码:

class Matrix(layers.Layer):
    """Class for the linear transformation layer in the network"""

    def __init__(self, dim=1):

        # This won't work for arbitrary dimensionality
        dims = [1, 2, 3]
        assert dim in dims, "Dimensionality {} is not 1, 2, or 3".format(dim)
        super(Matrix, self).__init__()

        A_init = tf.random_normal_initializer()
        self.dim = dim

        # define the basis matrices to generate the Lie transform
        if self.dim == 1:
            M1 = np.array([[0., -1.],
                           [1., 0.]]) # basis 1
            M2 = np.array([[1., 0.],
                           [0.,-1.]]) # basis 2

            self.basis_matrices=[M1, M2]

            self.A = tf.Variable(initial_value=A_init(shape=(2,),
                                                  dtype='float32'),
                             trainable=True)

            # dims 2 and 3 are still pending        

    def compute_matrix_exp(self, zed):

        if self.dim == 1:
            exp_arg = self.A[0]*self.basis_matrices[0] + self.A[1]*self.basis_matrices[1]
            M = tf.linalg.expm(exp_arg)

        return tf.matmul(M, zed)

    def compute_matrix_exp_grad(self, zed):

        if self.dim == 1:
            exp_arg = self.A[0]*self.basis_matrices[0] + self.A[1]*self.basis_matrices[1]
            M = tf.linalg.expm(exp_arg)

            dM_dA0 = tf.matmul(self.basis_matrices[0], M) #approximate
            dM_dA1 = tf.matmul(self.basis_matrices[1], M) #approximate

        return [tf.matmul(dM_dA0, zed), tf.matmul(dM_dA1, zed)]

    @tf.custom_gradient
    def call(self, zed):

        def grad(zed):

            grad = self.compute_matrix_exp_grad(zed)

        return self.compute_matrix_exp(zed), grad

我是TensorFlow的新手,所以我不确定如何最好地实现自定义渐变。非常感谢您的帮助。

编辑:我添加了尝试删除自定义渐变的尝试,发现它没有训练a和b变量,甚至也没有调用{{1 }},让我感到困惑。

0 个答案:

没有答案