如何使用tf.data API读取(解码)tfrecords

时间:2018-08-30 14:46:03

标签: tensorflow

我有一个自定义数据集,然后将其存储为tfrecord,

# toy example data
label = np.asarray([[1,2,3],
                    [4,5,6]]).reshape(2, 3, -1)

sample = np.stack((label + 200).reshape(2, 3, -1))

def bytes_feature(values):
    """Returns a TF-Feature of bytes.
    Args:
    values: A string.
    Returns:
    A TF-Feature.
    """
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def labeled_image_to_tfexample(sample_binary_string, label_binary_string):
    return tf.train.Example(features=tf.train.Features(feature={
      'sample/image': bytes_feature(sample_binary_string),
      'sample/label': bytes_feature(label_binary_string)
    }))


def _write_to_tf_record():
    with tf.Graph().as_default():
        image_placeholder = tf.placeholder(dtype=tf.uint16)
        encoded_image = tf.image.encode_png(image_placeholder)

        label_placeholder = tf.placeholder(dtype=tf.uint16)
        encoded_label = tf.image.encode_png(image_placeholder)

        with tf.python_io.TFRecordWriter("./toy.tfrecord") as writer:
            with tf.Session() as sess:
                feed_dict = {image_placeholder: sample,
                             label_placeholder: label}

                # Encode image and label as binary strings to be written to tf_record
                image_string, label_string = sess.run(fetches=(encoded_image, encoded_label),
                                                      feed_dict=feed_dict)

                # Define structure of what is going to be written
                file_structure = labeled_image_to_tfexample(image_string, label_string)

                writer.write(file_structure.SerializeToString())
                return

但是我看不懂它。首先,我尝试过(基于http://www.machinelearninguru.com/deep_learning/tensorflow/basics/tfrecord/tfrecord.htmlhttps://medium.com/coinmonks/storage-efficient-tfrecord-for-images-6dc322b81db4https://medium.com/mostly-ai/tensorflow-records-what-they-are-and-how-to-use-them-c46bc4bbb564

def read_tfrecord_low_level():
    data_path = "./toy.tfrecord"
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
    reader = tf.TFRecordReader()
    _, raw_records = reader.read(filename_queue)

    decode_protocol = {
        'sample/image': tf.FixedLenFeature((), tf.int64),
        'sample/label': tf.FixedLenFeature((), tf.int64)
    }
    enc_example = tf.parse_single_example(raw_records, features=decode_protocol)
    recovered_image = enc_example["sample/image"]
    recovered_label = enc_example["sample/label"]

    return recovered_image, recovered_label

我还尝试了变体方法,例如,将enc_example转换并解码,例如在Unable to read from Tensorflow tfrecord file中。但是,当我尝试评估它们时,我的python会话只会冻结并且不提供任何输出或回溯。

然后我尝试使用热切的执行来查看发生了什么,但是显然它仅与tf.data API兼容。但是据我了解,对tf.data API的转换是在整个数据集上进行的。 https://www.tensorflow.org/api_guides/python/reading_data提到必须编写一个解码函数,但是没有给出如何执行该操作的示例。我发现的所有教程都是针对TFRecordReader制作的(对我不起作用)。

任何帮助(指出我在做什么错/解释正在发生的事情/有关如何使用tf.data API解码tfrecord的指示)都受到高度赞赏。

根据https://www.youtube.com/watch?v=4oNdaQk0Qv4https://www.youtube.com/watch?v=uIcqeP7MFH0 tf.data是创建输入管道的最佳方法,所以我对学习这种方法非常感兴趣。

谢谢!

1 个答案:

答案 0 :(得分:2)

我不确定为什么存储编码的png会导致评估不起作用,但是这是解决此问题的一种可能方法。既然您提到要使用tf.data创建输入管道的方式,我将在玩具示例中展示如何使用它:

label = np.asarray([[1,2,3],
                [4,5,6]]).reshape(2, 3, -1)

sample = np.stack((label + 200).reshape(2, 3, -1))

首先,必须将数据保存到TFRecord文件。与您所做的不同之处在于,该图像未编码为png。

def _bytes_feature(value):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

writer = tf.python_io.TFRecordWriter("toy.tfrecord")

example = tf.train.Example(features=tf.train.Features(feature={
            'label_raw': _bytes_feature(tf.compat.as_bytes(label.tostring())),
             'sample_raw': _bytes_feature(tf.compat.as_bytes(sample.tostring()))}))

writer.write(example.SerializeToString())

writer.close()

上面的代码中发生的事情是将数组变成字符串(一维对象),然后存储为字节特征。

然后,使用tf.data.TFRecordDatasettf.data.Iterator类读回数据:

filename = 'toy.tfrecord'

# Create a placeholder that will contain the name of the TFRecord file to use
data_path = tf.placeholder(dtype=tf.string, name="tfrecord_file")

# Create the dataset from the TFRecord file
dataset = tf.data.TFRecordDataset(data_path)

# Use the map function to read every sample from the TFRecord file (_read_from_tfrecord is shown below)
dataset = dataset.map(_read_from_tfrecord)

# Create an iterator object that enables you to access all the samples in the dataset
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
label_tf, sample_tf = iterator.get_next()

# Similarly to tf.Variables, the iterators have to be initialised
iterator_init = iterator.make_initializer(dataset, name="dataset_init")

with tf.Session() as sess:
    # Initialise the iterator passing the name of the TFRecord file to the placeholder
    sess.run(iterator_init, feed_dict={data_path: filename})

    # Obtain the images and labels back
    read_label, read_sample = sess.run([label_tf, sample_tf])

函数_read_from_tfrecord()是:

def _read_from_tfrecord(example_proto):
        feature = {
            'label_raw': tf.FixedLenFeature([], tf.string),
            'sample_raw': tf.FixedLenFeature([], tf.string)
        }

    features = tf.parse_example([example_proto], features=feature)

    # Since the arrays were stored as strings, they are now 1d 
    label_1d = tf.decode_raw(features['label_raw'], tf.int64)
    sample_1d = tf.decode_raw(features['sample_raw'], tf.int64)

    # In order to make the arrays in their original shape, they have to be reshaped.
    label_restored = tf.reshape(label_1d, tf.stack([2, 3, -1]))
    sample_restored = tf.reshape(sample_1d, tf.stack([2, 3, -1]))

    return label_restored, sample_restored

除了对形状[2, 3, -1]进行硬编码之外,您还可以将其也存储到TFRecord文件中,但是为了简单起见,我没有这样做。

我用一个可行的例子做了一点gist

希望这会有所帮助!

相关问题