TensorFlow train.string_input_producer多次读取CSV文件

时间:2017-06-11 16:22:31

标签: csv tensorflow

我正在使用标准输入管道来读取TensorFlow中的CSV文件。我有一个输入文件,我想用迷你批次读取。我担心的是,即使我相信我已正确设置了时代数,该文件仍被多次读取。我认为这可能是由于train.string_input_producer()函数的行为造成的。

with self.graph.as_default():
        epochs = np.floor(fileSize / batchSize) + 1
        self.fileNameQ = tf.train.string_input_producer(fileNameList, num_epochs = epochs)
        self.batchInput, self.label = self.inputPipeline(batchSize, dim)

此处,fileSize由以下内容确定:

fileSize = sum(1 for line in open(file.name))

我已经定义了我的输入管道功能如下:

def readFromCsv(self, dim):
    reader = tf.TextLineReader()
    _, csvLine = reader.read(self.fileNameQ)
    recordDefaults = [["\0"] for cl in range(dim + 3)]
    recordStr = tf.decode_csv(csvLine, record_defaults=recordDefaults)
    self.label = tf.stack(recordStr[0:3])
    self.features = tf.stack(recordStr[3:dim + 3])
    return (self.features, self.label)

def inputPipeline(self, batchSize, dim): 
    minAfterDequeue = 10000
    capacity = minAfterDequeue + 3 * batchSize
    example, label = self.readFromCsv(dim)
    exampleBatchStr, labelBatch = tf.train.batch([example, label], batch_size=batchSize, capacity=capacity)
    exampleBatch = tf.string_to_number(exampleBatchStr)
    return (tf.transpose(exampleBatch), tf.transpose(labelBatch))

然后我执行培训并在每次处理1,000条记录时打印出一个点。

def train(self, batchSize, dim):
    with self.sess:
        self.sess.run(tf.local_variables_initializer())
        # Start populating the filename queue.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        #Training iterations
        self.iterationInput = 0
        while self.iterationInput < self.iterations:
            #Train with each vector one by one
            self.iterationInput += 1
            print("iteration " + str(self.iterationInput) + " for window size " + str(dim))
            try:
                loopCount = 0
                while not coord.should_stop():
                #Fill in input data.
                    self.sess.run([self.batchInput, self.label])
                    self.sess.run(self.trainingOp)
                    #For every 1,000 samples, print a dot.
                    if loopCount % 1000 == 0:
                        sys.stdout.flush()
                        sys.stdout.write('.')
                    loopCount += 1
            except tf.errors.OutOfRangeError:
                print("Done training -- epoch limit reached")
                coord.request_stop()

        # When done, join the threads
        coord.join(threads)

运行程序时打印出的点数明显多于文件中的记录数(/ 1000)。我知道train.string_input_producer()初始化了四个线程。我担心每个线程都在读取文件一次,导致运行时在硬件中没有足够的并行化时增加。由于运行时已经很长,我不想再增加它。有什么办法可以阻止文件被读取四次吗?

0 个答案:

没有答案