加载csv文件时Tensorflow输入管道错误

时间:2017-04-25 21:55:34

标签: python csv tensorflow newline

我正在尝试使用输入管道加载csv文件。由于以下错误,我在线跟踪了一些文档,但未能复制它们。

InvalidArgumentError (see above for traceback): Expect 6 fields but have 751 in record 0
     [[Node: DecodeCSV_1 = DecodeCSV[OUT_TYPE=[DT_STRING, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_STRING], field_delim=",", _device="/job:localhost/replica:0/task:0/cpu:0"](ReaderReadV2_1:1, DecodeCSV_1/record_defaults_0, DecodeCSV_1/record_defaults_1, DecodeCSV_1/record_defaults_2, DecodeCSV_1/record_defaults_3, DecodeCSV_1/record_defaults_4, DecodeCSV_1/record_defaults_5)]]

好像我遇到了换行符分隔符问题。我将不胜感激任何反馈。请参阅以下步骤以复制问题。

我使用链接https://vincentarelbundock.github.io/Rdatasets/csv/datasets/iris.csv将iris数据集下载到我的本地,并删除了标题

以下CSV格式:

"1",5.1,3.5,1.4,0.2,"setosa"
"2",4.9,3,1.4,0.2,"setosa"
"3",4.7,3.2,1.3,0.2,"setosa"

我的代码如下:

import tensorflow as tf

def read_my_file_format(filename_queue):
    reader = tf.TextLineReader(skip_header_lines=0)
    key, value = reader.read(filename_queue)
    record_defaults = [[""], [0.0], [0.0], [0.0], [0.0], [""]]
    index, slength, swidth, plength, pwidth, species = tf.decode_csv(value, record_defaults=record_defaults, field_delim=',')
    features = tf.stack([slength, swidth, plength, pwidth])
    return features, [species]

def input_pipeline(filepaths, batch_size, num_epochs=None):
    filename_queue = tf.train.string_input_producer(filepaths, num_epochs=num_epochs, shuffle=True)
    features, label = read_my_file_format(filename_queue)
    min_after_dequeue = 10000
    capacity = min_after_dequeue + 3 * batch_size
    example_batch, label_batch = tf.train.shuffle_batch([features, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
    return example_batch, label_batch

example_batch, label_batch = input_pipeline(filepaths=["/Users/iiskin/Downloads/iris.csv"],batch_size=10,num_epochs=10)

with tf.Session() as sess:
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    try:
        while not coord.should_stop():
            features, label = sess.run([example_batch, label_batch])
            print features
    except tf.errors.OutOfRangeError:
        print('Done -- epoch limit reached')
    finally:
        coord.request_stop()

coord.join(threads)

我使用的是tensorflow版本:1.0.0-rc2

0 个答案:

没有答案
相关问题