使用TF Dataset和Eager创建状态计数器

时间:2018-09-26 13:18:19

标签: python tensorflow tensorflow-datasets

我正在尝试在Tensorflow数据集管道中添加累加器。基本上,我有这个:

  def _filter_bcc_labels(self, labels, labels_table, bcc_count):
        bg_counter = tf.zeros(shape=(), dtype=tf.int32)

        def _add_to_counter():
            tf.add(bg_counter, 1)
            # Here the bg_counter is always equal to 0
            tf.Print(bg_counter, [bg_counter])
            return tf.constant(True)

        return tf.cond(tf.greater_equal(bg_counter, tf.constant(bcc_count, dtype=tf.int32)),
                                        true_fn=lambda: tf.constant(False),
                                        false_fn=_add_to_counter)


ds = ds.filter(lambda file, position, img, lbls: self._filter_bcc_labels(lbls, {"BCC": 0, "BACKGROUND": 1}, 10))

我的目标是在达到bg_counter tf.cond时增加false_fn,但是我的变量始终为0,但实际上从未增加。 有人可以向我解释发生了什么事?

请记住,我正在使用TF eager,不能使用ds.make_initializable_iterator(),然后输入我的bg_counter初始值。 谢谢

3 个答案:

答案 0 :(得分:1)

您可能想将计数器包装在一个类中,因为当超出范围时,Eager中的变量将被删除。

代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
tf.enable_eager_execution()
import tensorflow.contrib.eager as tfe

dataset = tf.data.Dataset.from_tensor_slices(([1,2,3,4,5], [-1,-2,-3,-4,-5]))

class My(object):
    def __init__(self):
        self.x = tf.get_variable("mycounter", initializer=lambda: tf.zeros(shape=[], dtype=tf.float32), dtype=tf.float32
                                 , trainable=False) 

v = My()
print(v.x)
tf.assign(v.x,tf.add(v.x,1.0))
print(v.x)

def map_fn(x,v):
    tf.cond(tf.greater_equal(v.x, tf.constant(5.0))
           ,lambda: tf.constant(0.0)
           ,lambda: tf.assign(v.x,tf.add(v.x,1.0))
           )
    return x

dataset = dataset.map(lambda x,y: map_fn(x,v)).batch(1)

for batch in tfe.Iterator(dataset):
    print("{} | {}".format(batch, v.x))

日志:

<tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=0.0>    
<tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=1.0>    
[1] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=2.0>
[2] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=3.0>
[3] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=4.0>
[4] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=5.0>    
[5] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=5.0>

工作示例: https://www.kaggle.com/mpekalski/tfe-conditional-stateful-counter

答案 1 :(得分:0)

我认为您要执行的操作需要使用assign_add()方法而不是add方法。请注意,该参数必须是变量。

对于tf.cond,请不要在急切的范围之外使用它。 Here是同一讨论。

答案 2 :(得分:0)

由于@MPękalski为我指明了正确的方向,因此我实际上找到了问题的答案。 现在的代码如下所示:

def _filter_bcc_labels(self, bg_counter, labels, labels_table, bcc_count):
        bg_counter = tf.zeros(shape=(), dtype=tf.int32)

        def _add_to_counter():
            nonlocal bg_counter
            bg_counter.assign_add(1)
            # Prints the counter value
            tf.Print(bg_counter, [bg_counter])
            return tf.constant(True)

        return tf.cond(tf.greater_equal(bg_counter, tf.constant(bcc_count, dtype=tf.int32)),
                                        true_fn=lambda: tf.constant(False),
                                        false_fn=_add_to_counter)


bg_counter = tf.get_variable("bg_counter_" + step, initializer=lambda: tf.zeros(shape=[], dtype=tf.int32), dtype=tf.int32, trainable=False)
ds = ds.filter(lambda file, position, img, lbls: self._filter_bcc_labels(bg_counter, lbls, {"BCC": 0, "BACKGROUND": 1}, 10))

请记住,如果对数据集进行两次迭代,则此解决方案将不起作用,因为在这种情况下不会重新初始化计数器。而且,如果您将bg_counter = tf.get_variable("bg_counter_" + step, initializer=lambda: tf.zeros(shape=[], dtype=tf.int32), dtype=tf.int32, trainable=False)移到ds.filter内,那么您会因为急切的模式而得到'Tensor' object has no attribute 'assign_add'

如果您确实想以正确的方式进行操作,则在数据集管道之外的批次中进行迭代时必须创建一个计数器。

相关问题