Pytorch:我什么时候以及为什么应该使用缓冲区?

时间:2021-04-12 12:09:00

标签: pytorch lstm recurrent-neural-network

我正在使用缓冲区来传递 LSTM 网络的隐藏状态。

def __init__(self, model, hidden_state1=None, ...somethine else...):
    self.register_buffer('hidden_state1', hidden_state1)
    self.hidden_state1 = hidden_state1
    ....#other codes

为了避免错误:

RuntimeError: Trying to backward through the graph a second time, 
but the buffers have already been freed. 
Specify retain_graph=True when calling backward the first time.

我使用 .clone().detach() 来分离缓冲区。

无论如何我都需要手动分离它们,我还需要在 Pytorch 中使用缓冲区而不是普通参数吗?

带有“requires_grad=False”的普通参数是否足以替代缓冲区的使用?

(其实我也不知道这样传递隐藏状态是不是好方法)

0 个答案:

没有答案