使用tf.gather_nd获取沿轴的所有可能排列

时间:2017-11-02 02:17:42

标签: tensorflow

我试图从Tensor沿特定轴提取所有可能的排列。我的输入是[B, S, L]张量(B批S长度为L的向量),我想提取这些向量中的所有可能的排列(S!置换),即[B, S!, S, L] Tensor作为输出。 这就是我现在尝试的,但我正在努力获得正确的输出形状。我认为我的错误可能是我创建了一个batch_range,但我也应该创建一个permutation_range。

import tensorflow as tf
import numpy as np
from itertools import permutations

S = 3
B = 5
L = 10

input = tf.constant(np.random.randn(B, S, L))

perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])

batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
indicies = tf.concat([batch_range, perms], axis=3)

permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) # 
# I get a [ B, P, S, S, L] instead of the desired [B, P, S, L]

我发布了一个可能的'解决方案'就在下面,但我认为这个问题仍有问题。我对它进行了测试,如果B> 1,它的进展并不顺利。

1 个答案:

答案 0 :(得分:0)

我刚刚找到答案,如果您认为我错了或者有更简单的方法可以解答,请纠正我:

import tensorflow as tf
import numpy as np
from itertools import permutations

S = 3
B = 5
L = 10

input = tf.constant(np.random.randn(B, S, L))

perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])

batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
perm_range = tf.tile(tf.reshape(tf.range(length_perm, dtype=tf.int32), shape=[1, length_perm, 1, 1]), [B, 1, S, 1])
indicies = tf.concat([batch_range, perm_range, perms], axis=3)

permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) # 
print permutations
相关问题