我正在尝试将随机的numpy数组转换为tf.records。但它似乎没有正确执行。标签转换很好,但是图像转换不会返回原始图像。
最后,它会打印False,True,而它应该是True,True。我想知道为什么会这样?这是张量流还是我错过了什么?
将tensorflow导入为tf 将numpy导入为np
def wrap_bytes(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def wrap_int64(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def convert(images, labels, out_path):
num_images = len(labels)
with tf.python_io.TFRecordWriter(out_path) as writer:
for i in range(num_images):
label_ = labels[i]
# the same problem persists whether or not we flatten
image = images[i].flatten()
image_bytes = image.tostring()
features = \
{
'image': wrap_bytes(tf.compat.as_bytes(image_bytes)),
'label': wrap_int64(label_)
}
feature = tf.train.Features(feature=features)
example = tf.train.Example(features=feature)
serialized = example.SerializeToString()
writer.write(serialized)
def parse(serialized):
features = \
{
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
}
parsed_example = \
tf.parse_single_example(
serialized=serialized,
features=features)
image_raw = parsed_example['image']
label_raw = parsed_example['label']
image_ = tf.decode_raw(image_raw, tf.int32)
image_reshaped = tf.reshape(image_, (5, 5))
return image_reshaped, label_raw
def input_fn(filenames, batch_size):
dataset = tf.data.TFRecordDataset(filenames=filenames)
dataset = dataset.map(parse)
# dataset = dataset.repeat(1)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
batch_images_tf, batch_labels_tf = iterator.get_next()
return batch_images_tf, batch_labels_tf
n = 10
num_classes = 15
batch_size = 2
out_path = 'bug.tfrecords'
labels = np.random.randint(0, num_classes, n)
image_shape = (5, 5)
images_ = np.int32(np.random.randint(0, 255, 5*5*n).reshape(n, 5, 5))
convert(images_, labels, out_path)
batch_images_tf, batch_labels_tf = input_fn(out_path, batch_size)
sess = tf.Session()
batch_labels_np = sess.run(batch_labels_tf)
batch_images_np = sess.run(batch_images_tf)
# checking whether the converted data is the same as the original
print(np.array_equal(batch_images_np, images_[0:batch_size]))
print(np.array_equal(batch_labels_np, labels[0:batch_size]))