如何使用scatter_update更新张量流对角线权重矩阵

时间:2019-06-09 12:25:21

标签: tensorflow

我尝试使用scatter_update更新tensorflow对角线权重矩阵,但到目前为止没有任何运气。它要么提示形状不匹配,要么仅沿第一行更新。这是非常奇怪的API行为。有人可以帮我吗?谢谢

Example:
dia_mx = tf.Variable(initial_value=np.array([[1.,0.,0.],
                                             [0.,1.,0.],
                                             [0.,0.,1.]]))
new_diagonal_values = np.array([2., 3., 4.])
tf.scatter_update(dia_mx, [[0,0],[1,1],[2,2]], new_diagonal_values)

Get error:
InvalidArgumentError: shape of indices ([3,2]) is not compatible with the shape of updates ([3]) [Op:ResourceScatterUpdate]

Expect new diagonal matrix:
dia_mx = [[2.,0.,0.],
          [0.,3.,0.],
          [0.,0.,4.]]

1 个答案:

答案 0 :(得分:0)

要使用张量更新特定索引,请使用 tf.scatter_nd_update()

import tensorflow as tf
import numpy as np

dia_mx = tf.Variable(initial_value=np.array([[1.,0.,0.],
                                             [0.,1.,0.],
                                             [0.,0.,1.]]))
updates = [tf.constant(2.), tf.constant(3.), tf.constant(4.)]
indices = tf.constant([[0, 0], [1, 1], [2, 2]])
update_tensor = tf.scatter_nd_update(dia_mx, indices, updates)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(update_tensor.eval())
# [[2. 0. 0.]
#  [0. 3. 0.]
#  [0. 0. 4.]]

tf.scatter_update() 沿张量的第一个维度应用更新。在这种特殊情况下,这意味着立即将更新应用于矩阵的整个行:

dia_mx = tf.Variable(initial_value=np.array([[1.,0.,0.],
                                             [0.,1.,0.],
                                             [0.,0.,1.]]), dtype=tf.float32)
updates = tf.constant([[2., 0., 0.], [0., 3., 0.], [0., 0., 4.]], dtype=tf.float32)
indices = tf.constant([0, 1, 2])
update_tensor = tf.scatter_update(dia_mx, indices, updates)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(update_tensor.eval())
# [[2. 0. 0.]
#  [0. 3. 0.]
#  [0. 0. 4.]]