加载操作后Pytorch GPU内存增加

时间:2019-09-23 10:56:19

标签: pytorch

我有一个pytorch模型,大小为386MB,但是当我加载模型时

state = torch.load(f, flair.device)

我的GPU内存占用了900MB,为什么会发生这种情况,并且有解决方法?

这就是我保存模型的方式

model_state = self._get_state_dict()

# additional fields for model checkpointing
model_state["optimizer_state_dict"] = optimizer_state
model_state["scheduler_state_dict"] = scheduler_state
model_state["epoch"] = epoch
model_state["loss"] = loss

torch.save(model_state, str(model_file), pickle_protocol=4)

1 个答案:

答案 0 :(得分:2)

可能是optimizer_state占用了额外的空间。一些优化器(例如Adam)跟踪每个可训练参数的统计信息,例如一阶和二阶矩。如您所知,此信息会占用空间。

您可以先加载到CPU:

state = torch.load(f, map_location=torch.device('cpu'))