使用带有tf.data的PIL打开图像

时间:2018-12-11 08:21:57

标签: python tensorflow

我目前正在尝试使用tf.data加载VOC2012数据集以进行语义分段。 VOC2012中的标签使用颜色图,如果我使用PIL库,它将自动转换。当我调用tf.read_file时不是这种情况。

from PIL import Image

train_data = tf.data.Dataset.from_tensor_slices((img_filename_list, lbl_filename_list))

def preprocessing(img_filename, lbl_filename):
    # Load image
    train_img = tf.read_file(img_path + img_filename)
    train_img = tf.image.decode_jpeg(train_img, channels=3)
    train_img = train_img / 255.0  # Normalize

    return train_img, lbl_filename

train_data = train_data.map(preprocessing).shuffle(100).repeat().batch(2)
iterator = train_data.make_initializable_iterator()
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(train_data)

with tf.Session() as sess:
    sess.run(training_init_op)
    while True:
        train_images, lbl_filename = sess.run(next_element)

这是我现在正在做的事情,尽管理想情况下,我希望预处理功能返回使用PIL加载的标签图像,以便创建一个热向量。

def preprocessing(img_filename, lbl_filename):
    ...# Load train images
    train_lbl = Image.open(lbl_path + lbl_filename)
    ...# Do some other stuff
    return train_img, train_lbl

这将导致错误

AttributeError: 'Tensor' object has no attribute 'read'

对此有什么解决办法吗?

1 个答案:

答案 0 :(得分:1)

如@GPhilo所建议,使用tf.py_func将解决此问题。 这是我的解决方案代码

df <- structure(list(X = 1:10, ID = c(20L, 20L, 20L, 55L, 55L, 45L, 
45L, 45L, 45L, 45L), fruit = c("Orange", "Apple", "Pear", "Apple", 
"Blueberries", "Apple", "Banana", "Banana", "Strawberry", "Pear"
 )), class = "data.frame", row.names = c(NA, -10L))