pymc3中的自定义多变量Dirichlet先验

时间:2018-05-17 06:52:22

标签: pymc3

是否可以在pymc3中创建自定义多变量分布?在下文中,我尝试创建Dirichlet分布的线性变换。关于此的所有变体都返回了许多错误,可能与theano数据类型有关?任何帮助将不胜感激。

import numpy as np
import pymc3 as pymc
import theano.tensor as tt


# data
n = 5
prior_params = np.ones(n - 1) / (n - 1)
mx = np.array([[0.25 , 0.5  , 0.75 , 1.   ],    
              [0.25 , 0.333, 0.25 , 0.   ],
              [0.25 , 0.167, 0.   , 0.   ],
              [0.25 , 0.   , 0.   , 0.   ]])
# Note that the matrix mx takes the unit simplex into the unit simplex.

# custom log-liklihood
def generate_function(mx, prior_params):
    def log_trunc_dir(x):
        return pymc.Dirichlet.dist(a=prior_params).logp(mx.dot(x.T)).eval()
    return log_trunc_dir

#model
with pymc.Model() as simple_model:
    x = pymc.Dirichlet('x', a=np.ones(n - 1))
    q = pymc.DensityDist('q', generate_function(mx, prior_params), observed={'x': x})

1 个答案:

答案 0 :(得分:1)

感谢PyMC3开发社区的大力帮助,我可以发布 以下是在PyMC3中之前定制的Dirichlet的工作示例。

import pymc3 as pm
import numpy as np
import scipy.special as special
import theano.tensor as tt
import matplotlib.pyplot as plt


n = 4

with pm.Model() as model:
    prior = np.ones(n) / n

    def dirich_logpdf(value=prior):
        return -n * special.gammaln(1/n) + (-1 + 1/n) * tt.log(value).sum()

    stick = pm.distributions.transforms.StickBreaking()
    probs = pm.DensityDist('probs', dirich_logpdf, shape=n, 
                testval=np.array(prior), transform=stick)
    data = np.array([5, 7, 1, 0])
    sfs_obs = pm.Multinomial('sfs_obs', n=np.sum(data), p=probs, observed=data)

with model:
    step = pm.Metropolis()
    trace = pm.sample(100000, tune=10000, step=step)

print('MLE = ', data / np.sum(data))
print(pm.summary(trace))

pm.traceplot(trace, [probs])
plt.show()
相关问题