仅使用tf.data API的高性能输入管道的最佳做法(无feed_dict)

时间:2018-09-20 04:19:59

标签: python-3.x tensorflow

官方TensorFlow Performance Guide声明如下:

  

使用feed_dict喂养数据时,   灵活性,通常feed_dict不提供可扩展性   解。如果仅使用单个GPU,则两者之间的差异   tf.data API和feed_dict的性能可能微不足道。我们的   建议避免对所有琐碎的东西使用feed_dict   例子。尤其要避免在输入较大的情况下使用feed_dict。

但是,完全避免使用 feed_dict 似乎是不可能的。考虑使用以下训练,验证和测试数据集的设置。

ds = tf.data.Dataset
n_files = 1000 # total number of tfrecord files
split = int(.67 * n_files)
files = ds.zip((ds.range(n_files),ds.list_files("train/part-r-*")))
train_files = files.filter(lambda a, b: a < split).map(lambda a,b: b)
validation_files = files.filter(lambda a, b: a >= split).map(lambda a,b: b)
test_files = ds.list_files("test/part-r-*")

解析数据集的常用方法如下所示:

def setup_dataset(self, file_ds, mode="train"):

   data = file_ds.apply(tf.contrib.data.parallel_interleave(
       tf.data.TFRecordDataset,
       cycle_length=4,
       sloppy=True,
       buffer_output_elements=self.batch_size * 8,
       prefetch_input_elements=self.batch_size * 8
   ))

   if mode == "train":
       data = data.map(self.train_data_parser)
   else:
       data = data.map(self.test_data_parser)

   return data

然后,您可以使用session.run()Iterator.from_structure()创建可重用的迭代器,而不是通过Iterator.from_string_handle()中的 feed_dict 来馈送各个功能。我将以前者为例,但无论哪种方式,您都会遇到相同的问题。

train = self.setup_dataset(train_files)
self.ops["template_iterator"] = tf.data.Iterator.from_structure(train.output_types, train.output_shapes)
self.ops["next_batch"] = self.ops["template_iterator"].get_next(name="next_batch")
self.ops["train_init"] = self.ops["template_iterator"].make_initializer(train)

validation = self.setup_dataset(validation_files)
self.ops["validation_init"] = self.ops["template_iterator"].make_initializer(validation)

这一切都很好,但是我应该如何处理测试数据集?测试数据集将不包含标签要素,因此不符合与训练和验证数据集相同的 output_types output_shapes

理想情况下,我想从SavedModel还原并初始化测试数据集,而不是通过API提供模型。

在推理过程中合并测试数据集所缺少的诀窍是什么?

1 个答案:

答案 0 :(得分:0)

我为这样的训练和推理设置了数据集和迭代器:

# Train dataset
images_train = tf.placeholder(tf.float32, train_images.shape)
labels_train = tf.placeholder(tf.float32, train_masks.shape)
dataset_train = tf.data.Dataset.from_tensor_slices({"images": images_train, "masks": labels_train})
dataset_train = dataset_train.batch(MINIBATCH)
dataset_train = dataset_train.map(lambda x: map_helper(x, augmentation), num_parallel_calls=8)
dataset_train = dataset_train.shuffle(buffer_size=10000)

iterator_train = tf.data.Iterator.from_structure(dataset_train.output_types, dataset_train.output_shapes)
training_init_op = iterator_train.make_initializer(dataset_train)
batch_train = iterator_train.get_next()

# Inference dataset
images_infer = tf.placeholder(tf.float32, shape=[None] + list(valid_images.shape[1:]))
labels_infer = tf.placeholder(tf.float32, shape=[None] + list(valid_masks.shape[1:]))
dataset_infer = tf.data.Dataset.from_tensor_slices({"images": images_infer, "masks": labels_infer})
dataset_infer = dataset_infer.batch(MINIBATCH)

iterator_infer = tf.data.Iterator.from_structure(dataset_infer.output_types, dataset_infer.output_shapes)
infer_init_op = iterator_infer.make_initializer(dataset_infer)
batch_infer = iterator_infer.get_next()

培训

使用training_init_op初始化迭代器进行训练

sess.run(training_init_op, feed_dict={images_train: train_images, labels_train: train_masks})

验证

使用infer_init_op初始化推理迭代器进行验证

sess.run(infer_init_op, feed_dict={images_infer: images_val, labels_infer: masks_val})

测试

使用infer_init_op初始化推理迭代器以进行测试。这有点骇人听闻,但是我创建了一个带有零的数组,标签将移至其中,并使用与验证相同的迭代器

sess.run(infer_init_op, feed_dict={images_infer: images_test, labels_infer: np.zeros(images_test.shape)})

或者,您可以为训练/验证/测试创建3个不同的数据集/迭代器