tensorflow中等效的以下代码是什么?

时间:2017-11-14 01:09:22

标签: for-loop tensorflow

我有以下功能:

import random

lst = []
for i in range(100):
    lst.append(random.randint(1, 10))

print(lst)

buffer = []

# This is the peace of code which I am interested to convert into tensorflow.
for a in lst:
    buffer.append(a)

    if len(buffer) > 5:
        buffer.pop(0)

    if len(buffer) == 5:
        print(buffer)

因此,从代码中,我需要创建一个缓冲区(可能是张量流中的变量)。此缓冲区应保留上一个conv layer提取的要素。在我的情况下,variable将是RNN的输入。

这种方法的优势在于,当我们拥有大型图像时,以及当我们需要为(batch of images) * (sequence length) * (size of 1 image)提供RNN时,需要将大量图像加载到主要图像中记忆。另一方面,根据上面的代码,我们将使用张量流Datasetsinput queue或任何其他替代方案一次提供1张图像。因此,我们将在内存中存储大小为batch_size * sequence_length * feature space的功能。此外,我们可以说:

if len(buffer) == n:
    # empty out the buffer after using its elements
    buffer = [] # Or any other alternative way

我知道我可以提供我的网络batches图片,但我需要根据一些文献完成上述代码。

非常感谢任何帮助!!

1 个答案:

答案 0 :(得分:2)

我尝试使用 tf.FIFOQueue https://www.tensorflow.org/api_docs/python/tf/FIFOQueue)重新生成输出。我已经在下面给出了我的代码以及必要的评论。

BATCH_SIZE = 20

lst = []
for i in range(BATCH_SIZE):
    lst.append(random.randint(1, 10))
print(lst)

curr_data = np.reshape(lst, (BATCH_SIZE, 1)) # reshape the tensor so that [BATCH_SIZE 1]

# queue starts here
queue_input_data = tf.placeholder(tf.int32, shape=[1]) # Placeholder for feed the data

queue = tf.FIFOQueue(capacity=50, dtypes=[tf.int32], shapes=[1]) # Queue define here

enqueue_op = queue.enqueue([queue_input_data])  # enqueue operation
len_op = queue.size()  # chek the queue size

#check the length of the queue and dequeue one if greater than 5
dequeue_one = tf.cond(tf.greater(len_op, 5), lambda: queue.dequeue(), lambda: 0)
#check the length of the queue and dequeue five elemts if equals to 5
dequeue_many = tf.cond(tf.equal(len_op, 5), lambda:queue.dequeue_many(5), lambda: 0)

with tf.Session() as session:
    for i in range(BATCH_SIZE):
        _ = session.run(enqueue_op, feed_dict={queue_input_data: curr_data[i]}) # enqueue one element each ietaration
        len = session.run(len_op)  # check the legth of the queue
        print(len)

        element = session.run(dequeue_one)  # dequeue the first element
        print(element)

但是,以下两个问题与上述代码有关,

  1. 只有出列一个出列许多操作可用,您无法看到队列中的元素(我不认为您需要这是因为你看起来像管道)。

  2. 我认为 tf.cond 是实现条件操作的唯一方法(我找不到任何其他类似于此的合适函数)。但是,由于它类似于 if-then-else语句,它必须在语句为false时定义一个操作(不仅仅是 if语句而没有<强>否则)。由于 Tensorflow 完全是关于构建图形,我认为必须包含两个分支(当条件为真和假时)。

  3. 此外,可以在此处找到Tensorflow 输入管道的详细说明(http://ischlag.github.io/2016/11/07/tensorflow-input-pipeline-for-large-datasets/)。

    希望这有帮助。