在TensorFlow中将多个字节读入单个值

时间:2015-12-31 06:40:20

标签: machine-learning computer-vision tensorflow

我试图以与TensorFlow中的cifar10示例中描述的方式类似的方式阅读标签:

 ....
 label_bytes = 2 # it was 1 in the original version
 result.key, value = reader.read(filename_queue)
 record_bytes = tf.decode_raw(value, tf.uint8)
 result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
 ....

问题是,如果label_byte大于1(例如2),result.label似乎成为两个元素的张量(每个元素都是1字节)。我只想将连续label_bytes个字节表示为单个值。我该怎么做?

由于

1 个答案:

答案 0 :(得分:3)

创建第二个解码器,用它解码int16并将第一个元素作为标签

shorts = tf.decode_raw(value, tf.int16)
result.label = tf.cast(shorts[0], tf.int32)

这可能是一个更好的解决方案,但它确实有效。