加载数据时删除/跳过记录

时间:2016-06-16 23:59:24

标签: tensorflow

我在训练集中发现了一些错误的数据(错误标记的示例),虽然我已经修复了源代码,但我还是想继续尝试使用相同的数据集,所以我需要跳过这些记录。

我使用TFRecordReader并加载parse_single_example& shuffle_batch。我可以在某处提供过滤器吗?

1 个答案:

答案 0 :(得分:4)

使用tf.train.shuffle_batch()enqueue_many=Truedocs中对如何执行此操作进行简短介绍。如果您可以确定示例是否使用图形操作进行了错误标记,那么您可以像这样过滤结果(改编自another SO answer):

X, y = tf.parse_single_example(...)
is_correctly_labelled = correctly_labelled(X, y)
X = tf.expand_dims(X, 0)
y = tf.expand_dims(y, 0)
empty = tf.constant([], tf.int32)
X, y = tf.cond(is_correctly_labelled,
               lambda: [X, y],
               lambda: [tf.gather(X, empty), tf.gather(y, empty)])
Xs, ys = tf.train.shuffle_batch(
    [X, y], batch_size, capacity, min_after_dequeue,
    enqueue_many=True)

tf.gather只是一种获得零大小切片的方法。在numpy中它只是X[[], ...]

相关问题