从TFRecordDataset获取数据集为numpy数组

时间:2018-02-19 17:39:25

标签: python numpy tensorflow tensorflow-datasets

我使用新的 tf.data API 为CIFAR10数据集创建迭代器。我正在从两个 .tfrecord 文件中读取数据。一个保存训练数据(train.tfrecords),另一个保存测试数据(test.tfrecords)。这很好用。但是,在某些时候,我需要将数据集(训练数据和测试数据)作为 numpy数组

是否可以从tf.data.TFRecordDataset对象中检索数据集为numpy数组?

1 个答案:

答案 0 :(得分:4)

您可以使用tf.data.Dataset.batch()转换和tf.contrib.data.get_single_element()来执行此操作。 作为复习,<script src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.1/jquery.min.js"></script> <div id="content"></div> <button class="check">Check</button> 将最多dataset.batch(n) n个连续元素,并通过连接每个组件将它们转换为一个元素。这要求所有元素每个组件具有固定的形状。如果dataset大于n中的元素数量(或dataset未精确划分元素数量),则最后一批可以更小。因此,您可以为n选择较大的值并执行以下操作:

n
相关问题