如何使用自定义梯度下降来加快训练速度?

时间:2019-07-02 09:38:05

标签: tensorflow2.0 tf.keras

我正在尝试实现从tf.keras.Sequential继承的模型类,并且该类具有自定义的梯度下降功能(使用梯度检查点)以优化内存消耗。

我的课程中有一个函数,当调用该函数时,可以为给定的数据点/批处理计算渐变并更新模型中所有图层的权重。由于我在训练循环中反复调用了此函数,因此我不想使用model.fit()的{​​{1}}函数。

我面临的主要问题是每个纪元需要很长时间才能完成(超过tf.keras),我怀疑这是因为我在调用{{1}之前将每个批处理移至GPU }函数,导致延迟。

是否可以重载我的model.fit()类中的某些函数,以便使用model.update_weights来调用我的自定义model函数?

或者,是否有一种方法可以将数据集更有效地获取到GPU? (我目前正在使用model.fit(),它实际上是update_weights。)

0 个答案:

没有答案
相关问题