Tensorflow:无效参数:元组组件

时间:2016-08-09 22:23:49

标签: python numpy queue tensorflow labels

运行输入记录功能时出现以下错误。

Invalid argument: Shape mismatch in tuple component 1
Expected [2], got [4]

我的标签是每个记录的(0,1)numpy数组。设定尺寸为2.但它给出了不正确的尺寸误差。 解码后打印张量对象也给出了维度2

 Tensor("DecodeRaw_1:0", shape=(2,), dtype=float32)

输入功能是:

def input():

filename_queue = tf.train.string_input_producer(["path_to_record"])
print filename_queue

label = read_and_decode(f_queue)

min_queue_examples = n
labels_batch = tf.train.shuffle_batch(
    [ label_record],
    batch_size=batch_size,
    num_threads=2,
    capacity=min_queue_examples + 3 * batch_size,
    min_after_dequeue=min_queue_examples)

return labels_batch

读取记录功能

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
    serialized_example,
    features={

        'label': tf.FixedLenFeature([], tf.string),
    })

)

label = tf.decode_raw(features['label'], tf.float32)
label.set_shape([2])
return label

0 个答案:

没有答案
相关问题