如果未设置tf.stop_gradient会怎样?

时间:2019-05-09 12:06:30

标签: python tensorflow object-detection tensorflow-model-analysis

我正在阅读Tensorflow模型的faster-rcnn代码。我对tf.stop_gradient的使用感到困惑。

考虑以下代码片段:

if self._is_training:
    proposal_boxes = tf.stop_gradient(proposal_boxes)
    if not self._hard_example_miner:
    (groundtruth_boxlists, groundtruth_classes_with_background_list, _,
     groundtruth_weights_list
    ) = self._format_groundtruth_data(true_image_shapes)
    (proposal_boxes, proposal_scores,
     num_proposals) = self._sample_box_classifier_batch(
         proposal_boxes, proposal_scores, num_proposals,
         groundtruth_boxlists, groundtruth_classes_with_background_list,
         groundtruth_weights_list)

更多代码为here。我的问题是:如果没有为tf.stop_gradient设置proposal_boxes,会发生什么?

1 个答案:

答案 0 :(得分:1)

这确实是一个好问题,因为这条简单的行tf.stop_gradient对于训练fast_rcnn模型至关重要。这就是训练期间需要它的原因。

Faster_rcnn模型是两阶段的检测器,损失函数必须满足两个阶段的目标。在fast_rcnn中,rpn损耗和fast_rcnn损耗都需要最小化。

这是论文在3.2节中所说的

  

经过独立培训的RPN和Fast R-CNN都将以不同方式修改其卷积层。因此,我们需要开发一种技术,允许在两个网络之间共享卷积层,而不是学习两个单独的网络。

然后,本文描述了三种训练方案,在原始论文中,他们采用了第一个解决方案-替代训练,即先训练RPN,然后再训练Fast-RCNN。

第二种方案是近似联合训练,易于实施,并且 API采纳了该方案。快速R-CNN接受来自预测边界框的输入坐标(通过rpn),因此快速R-CNN损耗将具有不包含边界框坐标的梯度。但是在此训练方案中,忽略这些梯度,这正是使用tf.stop_gradient的原因。该报告指出,这种培训方案将减少25-50%的培训时间。

第三个方案是非近似联合训练,因此不需要tf.stop_gradient。该论文报告说,拥有一个与盒坐标可区分的RoI池层是一个不小的问题。

但是为什么忽略这些渐变呢?

事实证明,RoI池层是完全可区分的,但是支持方案2的主要原因是方案3。将导致它在训练期间变得不稳定。

API的一位作者给出了很好的答案here

一些further reading关于联合训练。