使用tf.map_fn将多个图像作为张量读取

时间:2018-02-23 17:01:02

标签: python tensorflow

我正在使用Tensorflow的数据集API阅读各种图像(数据和标签)。由于数据集队列在CPU上,因此复制数据非常昂贵。但是,我似乎无法找到避免这种情况的方法。

问题:我可以按统一顺序高效加载各种图像(例如h,w,c)吗?

假设我想在单个单通道图像中阅读,我可以按如下方式进行:

image = tf.image.decode_png(tf.read_file(file_name), channels=1)  # h,w,c

或者对于多通道RGB:

image = tf.image.decode_png(tf.read_file(file_name), channels=3)  # h,w,c

这为我提供了一个高度 - 宽度 - 通道排序,便于数据增强和预处理功能,如tf.image.per_image_standardization

但是,如果我加载多个图像并希望将它们堆叠在一起(例如,具有多个RGB输入的CNN或多标签语义分段问题),我似乎总是要复制数据。以下是使用tf.stack中的副本的一种方式:

images = []
for image_id in range(0, images):
    file = file_names[image_id]
    images.append(tf.image.decode_png(tf.read_file(file), channels=1)[:, :, 0])
images = tf.stack(images, axis=2)  # Packs as h,w,c

另一种方法是使用tf.map_fn,它看起来就像是为了这个目的。然而,它“叠加”在错误的维度,所以我仍然需要一个昂贵的转置:

map = tf.map_fn(lambda f: tf.image.decode_png(tf.read_file(f), channels=1)[:, :, 0],
                file_names, back_prop=False, dtype=tf.uint8)
images = tf.transpose(map, [1, 2, 0])  # from c,h,w to h,w,c

是否可以避免tf.stacktf.transpose

1 个答案:

答案 0 :(得分:0)

一般来说,删除副本非常困难,因为张量通常是不可变的。只要操作系统想要输出内容,它就会分配新内存并写入内存。

可以想象将map_fn实现更改为沿不同维度堆栈张量。不幸的是,它是使用TensorArray构建的,它不支持此功能。

有一点需要注意的是,CHW通常更适合GPU,因为它们更喜欢内部尺寸。大多数TF操作都支持这种布局。

如果您有冒险精神,可以尝试通过XLA运行此部分。因为XLA获得了图表的全局视图,所以它可以潜在地优化其中的一些操作。它正在大力发展,可能会或可能不会使您的用例受益。

您还可以查看使用图像的官方张量流模型(例如https://github.com/tensorflow/models/tree/master/official/resnet)以获得最佳实践。

相关问题