从TensorFlow图中的TFRecords文件中读取顺序数据?

时间:2017-02-19 01:41:38

标签: tensorflow recurrent-neural-network sequential

我正在使用视频数据,但我相信这个问题应该适用于任何顺序数据。我想从TFRecords文件传递我的RNN 10连续示例(视频帧)。当我第一次开始阅读文件时,我需要抓住10个示例,并使用它来创建一个序列示例,然后将其推送到队列中以供RNN准备好时使用。但是,现在我有了10帧,下次我从TFRecords文件中读取时,我只需要取一个例子,然后将其他9个移位。但是当我点击第一个TFRecords文件的末尾时,我需要在第二个TFRecords文件上重启该进程。我的理解是cond op将处理每种条件下所需的操作,即使该条件不是要使用的条件。当使用条件检查是否只读取10个示例或仅1时,这将是一个问题。无论如何解决此问题仍然具有上述期望的结果?

1 个答案:

答案 0 :(得分:1)

您可以使用TensorFlow 1.12中最近添加的Dataset.window()转换来做到这一点:

filenames = tf.data.Dataset.list_files(...)

# Define a function that will be applied to each filename, and return the sequences in that
# file.
def get_examples_from_file(filename):
  # Read and parse the examples from the file using the appropriate logic.
  examples = tf.data.TFRecordDataset(filename).map(...)

  # Selects a sliding window of 10 examples, shifting along 1 example at a time.
  sequences = examples.window(size=10, shift=1, drop_remainder=True)

  # Each element of `sequences` is a nested dataset containing 10 consecutive examples.
  # Use `Dataset.batch()` and get the resulting tensor to convert it to a tensor value
  # (or values, if there are multiple features in an example).
  return sequences.map(
      lambda d: tf.data.experimental.get_single_element(d.batch(10)))

# Alternatively, you can use `filenames.interleave()` to mix together sequences from
# different files.
sequences = filenames.flat_map(get_examples_from_file)