在Tensorflow的数据集API中使用flat_map

时间:2018-05-25 13:54:07

标签: tensorflow tensorflow-datasets

我使用数据集API,按如下方式读取数据:

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

我现在想要使用flat_map来过滤掉一些,同时在训练时动态复制一些其他样本(这是导致我的模型的输入函数)。

flat_map的API需要返回Dataset个对象,但我不知道如何创建它。这是我想要实现的伪代码实现:

def flat_map_impl(tf_example):
    # Pseudo-code:
    # if tf_example["a"] == 1:
    #     return []
    # else:
    #     return [tf_example, tf_example]

dataset.flat_map(flat_map_impl)

如何在flat_map函数中实现它?

注意:我想通过py_func可以实现这一点,但我更愿意避免这种情况。

1 个答案:

答案 0 :(得分:1)

tf.data.Dataset返回时,创建Dataset.flat_map()的最常用方法可能是使用Dataset.from_tensors()Dataset.from_tensor_slices()。在这种情况下,由于tf_example是字典,因此最简单的方法是使用Dataset.from_tensors()Dataset.repeat(count)的组合,其中conditional expression计算count:< / p>

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

def flat_map_impl(tf_example):
  count = tf.cond(tf.equal(tf_example["a"], 1)),
                  lambda: tf.constant(0, dtype=tf.int64),
                  lambda: tf.constant(2, dtype=tf.int64))

  return tf.data.Dataset.from_tensors(tf_example).repeat(count)

dataset = dataset.flat_map(flat_map_impl)