当我使用collect_nd时,梯度全为零

时间:2018-07-06 02:47:27

标签: tensorflow

我正在用rnn执行分类任务,我需要提取每个序列的最后输出,当我使用gather_nd时,我发现所有变量的梯度都为零。但是当我更改为使用gather时,渐变看起来很正常。

那么有人知道原因吗?谢谢

P.S。我的tensorflow版本是1.5。

下面是我的代码

gather_nd版本

  batch_range = tf.range(self.batch_size)
  indices = tf.stack([batch_range, self.seqLen], axis=1)
  self.last_output = tf.gather_nd(self.output, indices)

收集版本

  idx = tf.range(self.batch_size)*tf.shape(self.output)[1] + (self.seqLen - 1)
  self.last_output = tf.gather(tf.reshape(self.output, [-1, self.rnn_size]), idx)

0 个答案:

没有答案