我使用数据集API,按如下方式读取数据:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
我现在想要使用flat_map
来过滤掉一些,同时在训练时动态复制一些其他样本(这是导致我的模型的输入函数)。
flat_map
的API需要返回Dataset
个对象,但我不知道如何创建它。这是我想要实现的伪代码实现:
def flat_map_impl(tf_example):
# Pseudo-code:
# if tf_example["a"] == 1:
# return []
# else:
# return [tf_example, tf_example]
dataset.flat_map(flat_map_impl)
如何在flat_map
函数中实现它?
注意:我想通过py_func
可以实现这一点,但我更愿意避免这种情况。
答案 0 :(得分:1)
从tf.data.Dataset
返回时,创建Dataset.flat_map()
的最常用方法可能是使用Dataset.from_tensors()
或Dataset.from_tensor_slices()
。在这种情况下,由于tf_example
是字典,因此最简单的方法是使用Dataset.from_tensors()
和Dataset.repeat(count)
的组合,其中conditional expression计算count
:< / p>
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
def flat_map_impl(tf_example):
count = tf.cond(tf.equal(tf_example["a"], 1)),
lambda: tf.constant(0, dtype=tf.int64),
lambda: tf.constant(2, dtype=tf.int64))
return tf.data.Dataset.from_tensors(tf_example).repeat(count)
dataset = dataset.flat_map(flat_map_impl)