在张量流中更新张量切片

时间:2019-01-09 10:11:40

标签: python tensorflow

我想用3个维度更新张量的一个切片。在How to do slice assignment in Tensorflow之后,我会做类似的事情

import tensorflow as tf

with tf.Session() as sess:
    init_val = tf.Variable(tf.zeros((2, 3, 3)))
    indices = tf.constant([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1]])
    update = tf.scatter_nd_add(init_val, indices, tf.ones(4))

    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(update))

这可行,但是由于我的实际问题更加复杂,因此我想通过定义切片的开始和大小自动生成索引集,例如是否使用tf.slice(...)。你有什么想法?预先感谢!

我正在使用TensorFlow 1.12,它是当前最新版本。

1 个答案:

答案 0 :(得分:1)

tf.strided_slice支持传递一个var参数来指示切片所引用的变量,因此当您传递它时,它将返回一个可分配的对象(我不确定为什么他们不这样做取决于输入的类型,但无论如何)。您可以执行以下操作:

import tensorflow as tf
import numpy as np

var = tf.Variable(np.ones((3, 4), dtype=np.float32))
s = tf.strided_slice(var, [0, 2], [2, 3], var=var, name='var_slice')
s2 = s.assign([[2], [3]])
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(s2))

输出:

[[1. 1. 2. 1.]
 [1. 1. 3. 1.]
 [1. 1. 1. 1.]]

请注意,在tf.strided_slice中,您给出了开始和结束索引(不包括end),而在tf.slice中,您给出了开始和大小。另外,按照当前的代码,您必须在slice或assign操作中提供一个名称值(我认为这应该是一个错误,并且会发生,因为该API的该部分几乎全部在内部使用)。

相关问题