如何检测是否在 pytorch 中触发了回调?

时间:2021-02-09 14:36:38

标签: python nlp pytorch bert-language-model

我正在微调 BERT 模型。首先,我想冻结层并进行一些训练。当某个回调被触发时(比如 ReduceLROnPlateau),我想解冻图层。我该怎么做?

1 个答案:

答案 0 :(得分:1)

恐怕 PyTorch 中的学习率调度程序不提供挂钩。看一下 ReduceLROnPlateau here 的实现,调度器被触发时会重置两个属性(i.e. 当它识别到一个高原并降低学习率时):

    if self.num_bad_epochs > self.patience:
        self._reduce_lr(epoch)
        self.cooldown_counter = self.cooldown
        self.num_bad_epochs = 0

基于此,您可以包装调度程序步骤调用,并通过检查 _reduce_lrscheduler.cooldown_counter == scheduler.cooldown 是否为真来确定 scheduler.num_bad_epochs == 0 是否被触发。

相关问题