使用tensorflow时在例外中包含tfrecord名称

时间:2018-08-06 21:07:07

标签: python tensorflow tensorflow-datasets

我正在尝试通过tensorflow数据集管道内置一些调试代码。基本上,如果对某个文件的tfrecord解析失败,我希望能够找出是哪个文件。我的梦想是在parsing_function中运行多个断言,如果断言,它们会提供文件名。

我的管道如下所示:

tf.data.Dataset.from_tensor_slices(file_list)
        .apply(tf.contrib.data.parallel_interleave(lambda f: tf.data.TFRecordDataset(f), cycle_length=4))
        .map(parse_func, num_parallel_calls=params.num_cores)
        .map(_func_for_other_stuff)

理想情况下,我会在parallel_interleave步骤中传递文件名,但是如果我有匿名函数返回文件名tfrecordataset元组,则会得到:

TypeError: `map_func` must return a `Dataset` object.

我也曾尝试像this问题一样在文件本身中包含文件名,但由于文件名的长度可变,因此在这里遇到了问题。

1 个答案:

答案 0 :(得分:1)

传递给tf.contrib.data.parallel_interleave()的函数的返回值必须为tf.data.Dataset。因此,您可以通过使用tf.data.Dataset.zip()如下将文件名张量附加到TFRecordDataset的每个元素来解决此问题:

def read_records_func(filename):
  records = tf.data.TFRecordDataset(filename)

  # Create a dataset from the filename tensor and repeat it indefinitely.
  filename_as_dataset = tf.data.Dataset.from_tensors(filename).repeat(None)

  return tf.data.Dataset.zip((filename_as_dataset, records))

dataset = (tf.data.Dataset.from_tensor_slices(file_list)
           .apply(tf.contrib.data.parallel_interleave(read_records_func, cycle_length=4))
           .map(parse_func, num_parallel_calls=params.num_cores)
           .map(_func_for_other_stuff))