以相同顺序混排两个张量

时间:2019-06-13 07:58:04

标签: tensorflow

如上所述。我尝试了这些都无济于事:

tf.random.shuffle( (a,b) )
tf.random.shuffle( zip(a,b) )

我以前将它们串联起来并进行改组,然后取消串联/解压缩。但是现在我处于(a)是4D秩张量而(b)是1D的情况下,所以无法串联。

我还尝试将种子参数赋予shuffle方法,以便它重现相同的改组,并且我使用了两次=>失败。还尝试用随机改组的数字对自己进行改组,但是TF在花哨索引上不如numpy灵活,并且东西==>失败了。

我现在正在做的是,将所有内容转换回numpy,然后使用sklearn中的shuffle,然后通过重铸返回张量。这真是愚蠢的方式。这应该发生在图形内部。

1 个答案:

答案 0 :(得分:1)

您可以只对索引进行混洗,然后使用tf.gather()来提取与那些经混洗的索引相对应的值:

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, (None, 1, 1, 1))
y = tf.placeholder(tf.int32, (None))

indices = tf.range(start=0, limit=tf.shape(x)[0], dtype=tf.int32)
shuffled_indices = tf.random.shuffle(indices)

shuffled_x = tf.gather(x, shuffled_indices)
shuffled_y = tf.gather(y, shuffled_indices)

确保在同一会话运行中计算shuffled_xshuffled_y。否则,它们可能会获得不同的索引顺序。

# Testing
x_data = np.concatenate([np.zeros((1, 1, 1, 1)),
                         np.ones((1, 1, 1, 1)),
                         2*np.ones((1, 1, 1, 1))]).astype('float32')
y_data = np.arange(4, 7, 1)

print('Before shuffling:')
print('x:')
print(x_data.squeeze())
print('y:')
print(y_data)

with tf.Session() as sess:
  x_res, y_res = sess.run([shuffled_x, shuffled_y],
                          feed_dict={x: x_data, y: y_data})
  print('After shuffling:')
  print('x:')
  print(x_res.squeeze())
  print('y:')
  print(y_res)
Before shuffling:
x:
[0. 1. 2.]
y:
[4 5 6]
After shuffling:
x:
[1. 2. 0.]
y:
[5 6 4]