来自Gumbel Softmax分布的Pyro样本值

时间:2020-01-31 02:29:23

标签: python pytorch

sampled_indexes = pyro.sample(f"{address}_{index}", pyro.distributions.RelaxedOneHotCategoricalStraightThrough(1, logits=char_dist), obs=observed[index])

我正在使用VAE来生成名称的seq2seq,但我做得不好,因为我是从无法支持的分类分布中采样的。因此,我正在尝试实现Gumbel Softmax。我有这段代码,其中logits是按字符分类的分布,而index是被观察到的字符,它正在尝试seq2seq生成。我给人的印象是RelaxedOneHotCategoricalStraightThrough可以将类别转换为Gumbel Softmax,然后从中进行采样,但是它只是通过将一个字符设置为1来对类别分布进行采样。任何人都知道如何从A可以在pyro.sample语句中使用Gumbel Softmax吗? PyTorch具有我可以使用的Gumbel Softmax发行版本,但不适用于pyro.sample。有什么建议吗?

0 个答案:

没有答案
相关问题