KeyError:“state_dict”中的'unexpected key“module.encoder.embedding.weight”

时间:2017-05-28 18:55:06

标签: pytorch

尝试加载已保存的模型时出现以下错误。

KeyError: 'unexpected key "module.encoder.embedding.weight" in state_dict'

这是我用来加载已保存模型的函数。

def load_model_states(model, tag):
    """Load a previously saved model states."""
    filename = os.path.join(args.save_path, tag)
    with open(filename, 'rb') as f:
        model.load_state_dict(torch.load(f))

该模型是一个序列到序列的网络,其init函数(构造函数)如下所示。

def __init__(self, dictionary, embedding_index, max_sent_length, args):
    """"Constructor of the class."""
    super(Sequence2Sequence, self).__init__()
    self.dictionary = dictionary
    self.embedding_index = embedding_index
    self.config = args
    self.encoder = Encoder(len(self.dictionary), self.config)
    self.decoder = AttentionDecoder(len(self.dictionary), max_sent_length, self.config)
    self.criterion = nn.NLLLoss()  # Negative log-likelihood loss

    # Initializing the weight parameters for the embedding layer in the encoder.
    self.encoder.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize)

当我打印模型(序列到序列网络)时,我得到以下内容。

Sequence2Sequence (
  (encoder): Encoder (
    (drop): Dropout (p = 0.25)
    (embedding): Embedding(43723, 300)
    (rnn): LSTM(300, 300, batch_first=True, dropout=0.25)
  )
  (decoder): AttentionDecoder (
    (embedding): Embedding(43723, 300)
    (attn): Linear (600 -> 12)
    (attn_combine): Linear (600 -> 300)
    (drop): Dropout (p = 0.25)
    (out): Linear (300 -> 43723)
    (rnn): LSTM(300, 300, batch_first=True, dropout=0.25)
  )
  (criterion): NLLLoss (
  )
)

因此,module.encoder.embedding是嵌入层,module.encoder.embedding.weight表示关联的权重矩阵。那么,为什么它说 - unexpected key "module.encoder.embedding.weight" in state_dict

1 个答案:

答案 0 :(得分:8)

我解决了这个问题。实际上我是使用nn.DataParallel保存模型,它将模型存储在模块中,然后我试图在没有DataParallel的情况下加载模型。所以,我需要暂时在我的网络中添加nn.DataParallel以进行加载,或者我可以加载权重文件,创建一个没有模块前缀的新的有序dict,然后加载它。

第二种解决方法如下所示。

# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

参考:https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686

相关问题