加速Python / NumPy中的多项随机样本

时间:2016-02-01 15:01:02

标签: python numpy optimization scipy vectorization

我通过一组概率probs从多项分布生成绘制向量,其中每个绘制是probs中所选项的索引:

import numpy as np
def sample_mult(K, probs):
    result = np.zeros(num_draws, dtype=np.int32)
    for n in xrange(K):
        draws = np.random.multinomial(1, probs)
        result[n] = np.where(draws == 1)[0][0]
    return result

这可以加快吗?一遍又一遍地调用np.random.multinomial似乎效率低下(np.where也可能很慢。)

timeitThe slowest run took 6.72 times longer than the fastest. This could mean that an intermediate result is being cached 100000 loops, best of 3: 18.9 µs per loop

2 个答案:

答案 0 :(得分:6)

您可以使用size选项与np.random.multinomial一起使用随机样本行而不是仅使用默认size=1的一行输出,然后使用.argmax(1)来模拟np.where()[0][0] {3}}行为。

因此,我们将有一个矢量化解决方案,如此 -

result = (np.random.multinomial(1,probs,size=K)==1).argmax(1)

答案 1 :(得分:0)

“ =”的p =参数可以做到这一点(并避免使用argmax):

result = np.random.choice(len(probs), K, p=probs)