使用tensorflow一次切割多个切片

时间:2017-02-02 14:36:42

标签: python python-2.7 tensorflow

我正准备张量流RNN的输入张量 目前我正在做以下

rnn_format = list()
for each in range(batch_size):
    rnn_format.append(tf.slice(input2Dpadded,[each,0],[max_steps,10]))
lstm_input = tf.stack(rnn_format)

是否可以立即执行此操作,无需循环,具有一些张量流函数?

2 个答案:

答案 0 :(得分:2)

正如Peter Hawkins所建议的那样,你可以使用gather_nd和适当的指数来实现目标。

您可以在调用gather_nd之前完成内部维度的均匀裁剪。

示例:

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

# integer image simply because it is more readable to me
im0 = np.random.randint(10, size=(20,20))
im = tf.constant(im0)

max_steps = 3
batch_size = 10

# create the appropriate indices here
indices = (np.arange(max_steps) +
    np.arange(batch_size)[:,np.newaxis])[...,np.newaxis]
# crop then call gather_nd
res = tf.gather_nd(im[:,:10], indices).eval()

# check that the resulting tensors are equal to what you had previously
for each in range(batch_size):
  assert(np.all(tf.slice(im, [each,0],[max_steps,10]).eval() == res[each]))

修改

如果你的切片索引是张量,你只需在创建indices时用tensorflow的操作替换numpy的操作:

# indices stored in a 1D array
my_indices = tf.constant([1, 8, 3, 0, 0])
indices = (np.arange(max_steps) +
    my_indices[:,tf.newaxis])[...,tf.newaxis]

进一步评论:

  • indices是在添加过程中利用broadcasting创建的:数组实际上是平铺的,以便它们的尺寸匹配。 numpy和tensorflow以类似的方式支持广播。
  • 省略号...是标准numpy slicing notation的一部分,它基本上填充了其他切片索引留下的所有剩余维度。因此[..., newaxis]基本上等同于expand_dims(·, -1)

答案 1 :(得分:1)

尝试tf.splittf.split_v。见这里:

https://www.tensorflow.org/api_docs/python/tf/split

这有帮助吗?