从Tensorflow中的张量中随机选择元素

时间:2018-03-08 16:08:14

标签: tensorflow

给定形状为Nx2的张量,如何从这个张量中选择类似于k的{​​{1}}元素(概率相等)?另一点需要注意的是np.random.choice的值在执行期间会动态变化。这意味着我正在处理一个动态大小的张量。

2 个答案:

答案 0 :(得分:0)

您可以将np.random.choice包裹为tf.py_func。例如,见answer。在您的情况下,您需要展平张量,因此它是一个长度为2*N的数组:

import numpy as np
import tensorflow as tf

a = tf.placeholder(tf.float32, shape=[None, 2]) 
size = tf.placeholder(tf.int32)
y = tf.py_func(lambda x, s: np.random.choice(x.reshape(-1),s), [a, size], tf.float32)
with tf.Session() as sess:
    print(sess.run(y, {a: np.random.rand(4,2), size:5}))

答案 1 :(得分:0)

我遇到了类似的问题,我想从点云中对点进行子采样以实现 PointNet。我的输入维度是 [None, 2048, 3],我正在使用以下自定义层子采样到 [None, 1024, 3]

class SubSample(Layer):
  def __init__(self,num_samples):
    super(SubSample, self).__init__()
    self.num_samples=num_samples

  def build(self, input_shape):
    self.shape = input_shape #[None,2048,3]

  def call(self, inputs, training=None):
    k = tf.random.uniform([self.shape[1],]) #[2048,]
    bl = tf.argsort(k)<self.num_samples #[2048,]
    res = tf.boolean_mask(inputs, bl, axis=1) #[None,1024,3]
    # Reshape needed so that channel shape is passed when `run_eagerly=False`, otherwise it returns `None`
    return tf.reshape(res,(-1,self.num_samples,self.shape[-1])) #[None,1024,3]

SubSample(1024)(tf.random.uniform((64,2048,3))).shape

>>> TensorShape([64, 1024, 3])

据我所知,这适用于 TensorFlow 2.5.0

请注意,这不是手头问题的直接答案,而是我偶然发现这个问题时正在寻找的答案。

相关问题