将numpy数组传递给tensorflow队列

时间:2016-08-21 19:47:25

标签: python numpy tensorflow

我有一个NumPy数组,希望使用Queue在TensorFlow的代码中读取它。我想队列返回整个数据洗牌,一些指定的纪元数,然后抛出一个错误。如果我不需要硬编码示例的大小和示例的数量,那将是最好的。 我认为shuffle batch旨在实现这一目的。我尝试过如下使用它:

data = tf.constant(train_np) # train_np is my numpy array of shape (num_examples, example_size)
batch = tf.train.shuffle_batch([data], batch_size=5, capacity=52200, min_after_dequeue=10, num_threads=1, seed=None, enqueue_many=True)

sess.run(tf.initialize_all_variables())
tf.train.start_queue_runners(sess=sess)
batch.eval()

该方法的问题在于它连续读取所有数据,并且我无法指定它在一些时期之后完成。我知道我可以使用RandomShuffleQueue并将数据插入其中几次,但是: a)我不想浪费epoch *内存数据和b)它将允许队列在时代之间进行洗牌。

在没有编写自己的队列的情况下,是否有一种很好的方法可以在Tensorflow中阅读时代中的混洗数据?

1 个答案:

答案 0 :(得分:6)

您可以创建另一个队列,将数据排入num_epoch次,关闭它,然后将其连接到batch。为了节省内存,您可以使此队列变小,并将项目并行排入其中。时代之间会有一些混合。为了完全防止混音,您可以使用num_epochs=1下面的代码并将其称为num_epochs次。

tf.reset_default_graph()
data = np.array([1, 2, 3, 4])
num_epochs = 5
queue1_input = tf.placeholder(tf.int32)
queue1 = tf.FIFOQueue(capacity=10, dtypes=[tf.int32], shapes=[()])

def create_session():
    config = tf.ConfigProto()
    config.operation_timeout_in_ms=20000
    return tf.InteractiveSession(config=config)

enqueue_op = queue1.enqueue_many(queue1_input)
close_op = queue1.close()
dequeue_op = queue1.dequeue()
batch = tf.train.shuffle_batch([dequeue_op], batch_size=4, capacity=5, min_after_dequeue=4)

sess = create_session()

def fill_queue():
    for i in range(num_epochs):
        sess.run(enqueue_op, feed_dict={queue1_input: data})
    sess.run(close_op)

fill_thread = threading.Thread(target=fill_queue, args=())
fill_thread.start()

# read the data from queue shuffled
tf.train.start_queue_runners()
try:
    while True:
        print batch.eval()
except tf.errors.OutOfRangeError:
    print "Done"
如果队列不够大,无法将整个numpy数据集加载到其中,那么BTW,上面的enqueue_many模式将会挂起。您可以通过如下所示以块的形式加载数据来为自己提供更小队列的灵活性。

tf.reset_default_graph()
data = np.array([1, 2, 3, 4])
queue1_capacity = 2
num_epochs = 2
queue1_input = tf.placeholder(tf.int32)
queue1 = tf.FIFOQueue(capacity=queue1_capacity, dtypes=[tf.int32], shapes=[()])

enqueue_op = queue1.enqueue_many(queue1_input)
close_op = queue1.close()
dequeue_op = queue1.dequeue()

def dequeue():
    try:
        while True:
            print sess.run(dequeue_op)
    except:
        return 

def enqueue():
    for i in range(num_epochs):
        start_pos = 0
        while start_pos < len(data):
            end_pos = start_pos+queue1_capacity
            data_chunk = data[start_pos: end_pos]
            sess.run(enqueue_op, feed_dict={queue1_input: data_chunk})
            start_pos += queue1_capacity
    sess.run(close_op)

sess = create_session()

enqueue_thread = threading.Thread(target=enqueue, args=())
enqueue_thread.start()

dequeue_thread = threading.Thread(target=dequeue, args=())
dequeue_thread.start()
相关问题