从TensorFlow中截断的正态分布中采样

时间:2018-02-17 16:35:28

标签: python tensorflow scipy

我正在将一堆计算从Numpy / Scipy移植到TensorFlow,我需要从truncated normal distribution生成样本,相当于scipy.stats.truncnorm.rvs()

我认为生成这些样本的两种标准方法是通过拒绝采样或将截断的均匀分布样本馈送到逆正态累积分布函数,但前者似乎很难在静态内实现计算图(我们无法知道在生成样本之前要运行多少个拒绝循环)并且我认为标准TensorFlow库中没有反向正态累积分布函数。

我意识到TensorFlow中有一个名为truncated_normal()的函数,但它只是在两个标准偏差处剪辑;你不能指定下限和上限。

有什么建议吗?

1 个答案:

答案 0 :(得分:2)

更简单的方法可能是使用tf.py_func在TensorFlow图中嵌入您习惯使用的函数。如果您需要它,您可以像往常一样使用numpy.random模块设置种子:

import numpy as np
import tensorflow as tf
from scipy.stats import truncnorm

np.random.seed(seed=123)

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
size = tf.placeholder(tf.int32)

f = tf.py_func(lambda x, y, s: truncnorm.rvs(x, y, size=s),
               inp=[a, b, size],
                Tout=tf.float64)


with tf.Session() as sess:
    print(sess.run(f, {a:0, b:1, size:10}))

将打印:

[0.63638154 0.24732635 0.19533476 0.49072188 0.66066675 0.37031253
 0.9732229  0.62423404 0.42385328 0.34206036]