TensorFlow:指定图层激活的存储

时间:2017-01-03 19:51:24

标签: python tensorflow

动机:很容易想象一个可能会占用大量图层以淹没GPU内存容量的场景。

此问题的潜在解决方案是利用层之间的任何向后可计算性。即图层j的输出是下一个图层输出的函数:

enter image description here

在这里,您可以简单地将最终的这种图层激活存储在内存中,并计算“涓流”后的激活情况。在backprop期间的时尚。

问题:我对TensorFlow中的手动内存管理不熟悉,并且无法找到有关如何指定TensorFlow应保留在内存中的变量的信息。可以丢弃。

实现这一点的另一方面是指定自定义梯度计算;我认为这已经足够记录并且应该是可以实现的。但是,如果有人有任何与此相关的警告,我很感激听到他们。

1 个答案:

答案 0 :(得分:4)

据我了解,您希望通过丢弃中间结果并稍后重新计算来节省内存。我可以看到两种可能的方法。一种是重新连接图形以进行重新计算,另一种是使用持久性张量来获得中间结果并手动控制它们的删除。

对于第一种方法,请考虑以下计算及其梯度图。

您可以使用contrib.graph_editor修改图表,如下所示。

请注意,现在可以选择一个只需要足够内存进行2次激活的执行顺序。但是,TensorFlow通常不会选择此执行顺序,而是在开始时计算两个a2张量,因此需要足够的内存来存储4个峰值激活。 (有关极端示例,请参阅caterpillar graph。)

解决方案是添加控制依赖项以强制执行特定的执行顺序。

这会强制在a2之后计算第二个b3节点。由于TensorFlow在不再需要张量时立即释放内存,因此该图中的所有执行顺序都需要足够的峰值内存来存储2次激活,而不是3次。

这是实现上述示例的notebook

如果进行a2-> a3的计算是可逆的,您可以按如下方式重新连接图表

第二种方法是使用持久性张量。在运行调用完成后,您可以告诉TensorFlow保留某些张量。与变量不同,可以删除这些对象以释放内存。您有更多.run次呼叫的缺点,每次呼叫都会产生额外的200次延迟,但它可能比所有图形重新布线更容易。我没有探索过这条路线,但这里有example使用持久性张量(使用delete_session_tensor命令完成删除)