无法使用tf.Data加载图像

时间:2019-03-29 03:29:59

标签: python tensorflow

我尝试使用tf.Data加载图像,但出现错误。这是我的代码:

import cv2
import tensorflow as tf

# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
  image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
  image_decoded = tf.expand_dims(image_decoded, dim=0)
  return image_decoded, label

# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
  image_decoded.set_shape([None, None, None])
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  image_resized = tf.expand_dims(image_resized, dim=0)
  return image_resized, label

filenames = ["data/img.jpeg", "data/img.jpeg"]
labels = [0, 37]

dataset = tf.data.Dataset.from_tensor_slices((tf.constant(filenames), tf.constant(labels)))
dataset = (dataset.map(
    lambda filename, label: tuple(tf.py_func(
        _read_py_function, [filename, label], [tf.uint8, label.dtype]))))
dataset = dataset.map(_resize_function)
dataset = dataset.batch(2)
dataset = dataset.prefetch(2)

# iterator = dataset.make_initializable_iterator()
configProt = tf.ConfigProto()
configProt.gpu_options.allow_growth = True
configProt.allow_soft_placement = True
sess = tf.Session(config = configProt)

iterator = dataset.make_one_shot_iterator()
# next_element = iterator.get_next()
images, labels = iterator.get_next()

print(sess.run(labels))

但是,我得到的是

tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported object type Tensor
         [[Node: PyFunc = PyFunc[Tin=[DT_STRING, DT_INT32], Tout=[DT_UINT8, DT_INT32], token="pyfunc_0"](arg0, arg1)]]
         [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,1,28,28,?], <unknown>], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

我无法使用tf1.8运行它。有什么问题?

1 个答案:

答案 0 :(得分:0)

我通过删除_read_py_function中的tf.xxx解决了该问题。

相关问题