如何在Tensorflow

时间:2017-05-09 02:26:31

标签: tensorflow

我正在尝试用我的数据表示写入和读取TFRecord,如下所示:

输入的形状为:[100000,600],类型为float。

标签的形状为:[100000,185,17],类型为int。

我的主要问题是如何在阅读过程中处理float类型的输入。我已经创建了TFRecordWriter,如下所示没有错误(尽管我不是100%确信这是正确的)。但是,我不知道在TFRecordReader期间如何解码生成的原始浮动特征(如果它是一个字符串,我会使用tf.decode_raw)

编辑---我已经想出了如何阅读浮动功能。它需要使用tf.VarLenFeature来创建稀疏张量。然后通过其.values操作从中提取浮点张量。我在错误跟踪后输入了下面的工作read_and_decode函数。

def convert_to(input, labels, name):
    num_examples = input.shape[0]
    input_dim1 = input.shape[1]
    labels_dim1 = labels.shape[1]
    labels_dim2 = labels.shape[2]
    filename = os.path.join(DATA_DIR, name + '.tfrecords')
    print('Writing', filename)
    writer = tf.python_io.TFRecordWriter(filename)
    for index in range(num_examples):
        input_raw = input[index]
        labels_raw = labels[index].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'input_d1': _int64_feature(input_dim1),
            'labels_d1': _int64_feature(labels_dim1),
            'labels_d2': _int64_feature(labels_dim2),
            'input_raw': _float_feature(input_raw),
            'labels_raw': _bytes_feature(labels_raw)}))
        writer.write(example.SerializeToString())
    writer.close()

以下是我尝试过的读者代码,当我尝试解码原始输入时,这会给我一个错误:

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
           'input_dim1': tf.FixedLenFeature([], tf.int64),
           'labels_dim1': tf.FixedLenFeature([], tf.int64),
           'labels_dim2': tf.FixedLenFeature([], tf.int64),
           'input_raw': tf.FixedLenFeature([], tf.float32),
           'labels_raw': tf.FixedLenFeature([], tf.string)
        })

   input = features['input_raw']       #CONFUSION HERE
   labels = tf.decode_raw(features['labels_raw'], tf.uint8)

   labels = tf.reshape(labels, [185,17])
   print (labels.shape)                #CORRECTLY GIVES (185, 17)

   input = tf.reshape(input, [600])    #ERROR HERE
   print (input.shape) 

错误跟踪如下:

文件“/users/trabinow/compound_prediction/spectra2smiles/spectra2smiles_refined/spectra2smiles_input.py”,第53行,在read_and_decode中     input = tf.reshape(input,[600])

文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py”,第2630行,重塑     名称=名)

文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”,第763行,在apply_op op_def = op_def)

文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第2397行,在create_op中     set_shapes_for_outputs(RET)

文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第1757行,在set_shapes_for_outputs中     shapes = shape_func(op)

文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第1707行,在call_with_requiring中     return call_cpp_shape_fn(op,require_shape_fn = True)

文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py”,第610行,在call_cpp_shape_fn中     debug_python_shape_fn,require_shape_fn)

文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py”,第675行,在_call_cpp_shape_fn_impl     提出ValueError(err.message)

ValueError:无法使用1个元素重塑一个张量,使用输入形状为'tower_0 / Reshape_1'(op:'Reshape')整形[600](600个元素):[],[1]。

=============================================== =========================

以下是读者的新工作代码:

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
           'input_dim1': tf.FixedLenFeature([], tf.int64),
           'labels_dim1': tf.FixedLenFeature([], tf.int64),
           'labels_dim2': tf.FixedLenFeature([], tf.int64),
           'input_raw': tf.VarLenFeature(tf.float32),
           'labels_raw': tf.FixedLenFeature([], tf.string)
        })

   input = features['input_raw'].values
   labels = tf.decode_raw(features['labels_raw'], tf.uint8)

   labels = tf.reshape(labels, [185,17])
   print (labels.shape)                #CORRECTLY GIVES (185, 17)

   input = tf.reshape(input, [600])    
   print (input.shape)                 #CORRECTLY GIVES (600,)

0 个答案:

没有答案