为什么在Tensorflow上的PTB教程中运行epoch时构造了feed_dict?

时间:2018-03-21 12:07:42

标签: python tensorflow

Q1:我正在关注Recurrent Neural Networks上的this tutorial,我想知道你为什么需要在代码的以下部分创建feed_dict

def run_epoch(session, model, eval_op=None, verbose=False):

  state = session.run(model.initial_state)

  fetches = {
      "cost": model.cost,
      "final_state": model.final_state,
  }
  if eval_op is not None:
    fetches["eval_op"] = eval_op

  for step in range(model.input.epoch_size):
    feed_dict = {}
    for i, (c, h) in enumerate(model.initial_state):
      feed_dict[c] = state[i].c
      feed_dict[h] = state[i].h

    vals = session.run(fetches, feed_dict)

我测试了,似乎如果删除这部分代码,代码也会运行:

def run_epoch(session, model, eval_op=None, verbose=False):

  fetches = {
      "cost": model.cost,
      "final_state": model.final_state,
  }
  if eval_op is not None:
    fetches["eval_op"] = eval_op

  for step in range(model.input.epoch_size):
    vals = session.run(fetches)

所以我的问题是,为什么在提供新批量数据后需要将初始状态重置为零?

Q2:此外,据我所知,使用feed_dict被认为是缓慢的。这就是为什么建议使用tf.data API提供数据的原因。在这种情况下使用feed_dict也是一个问题吗?如果是这样,如何避免在此示例中使用feed_dict

UPD:非常感谢@jdehesa的详细回复。它帮助很大!在我结束这个问题并接受你的回答之前,你能澄清一下你提到回答Q1的一点。

我现在看到了feed_dict的目的。但是,我不确定这是教程中实现的内容。从你说的话:

  

在每个纪元的开头,代码首先采用默认的"零状态"然后进入一个循环,其中当前状态为初始状态,运行模型并将输出状态设置为下一次迭代的新当前状态。

我刚看了一下本教程的the source code,我没有看到输出状态被设置为下一次迭代的新当前状态。它是隐含地在某处完成的还是我错过了什么?

我也许在理论方面也缺少一些东西。为了确保我理解正确,这里有一个简单的例子。假设输入数据是一个存储0到120整数值的数组。我们将批量大小设置为5,一批中的数据点数为24,以及时间步长数展开的RNN为10。在这种情况下,您只能在020的时间点使用数据点。然后分两步处理数据(model.input.epoch_size = 2)。当您遍历model.input.epoch_size

state = session.run(model.initial_state)
# ...
for step in range(model.input.epoch_size):
  feed_dict = {}
  for i, (c, h) in enumerate(model.initial_state):
    feed_dict[c] = state[i].c
    feed_dict[h] = state[i].h

  vals = session.run(fetches, feed_dict)

您提供了一批这样的数据:

> Iteration (step) 1:
x:
 [[  0   1   2   3   4   5   6   7   8   9]
 [ 24  25  26  27  28  29  30  31  32  33]
 [ 48  49  50  51  52  53  54  55  56  57]
 [ 72  73  74  75  76  77  78  79  80  81]
 [ 96  97  98  99 100 101 102 103 104 105]]
y:
 [[  1   2   3   4   5   6   7   8   9  10]
 [ 25  26  27  28  29  30  31  32  33  34]
 [ 49  50  51  52  53  54  55  56  57  58]
 [ 73  74  75  76  77  78  79  80  81  82]
 [ 97  98  99 100 101 102 103 104 105 106]]

> Iteration (step) 2:
x:
 [[ 10  11  12  13  14  15  16  17  18  19]
 [ 34  35  36  37  38  39  40  41  42  43]
 [ 58  59  60  61  62  63  64  65  66  67]
 [ 82  83  84  85  86  87  88  89  90  91]
 [106 107 108 109 110 111 112 113 114 115]]
y:
 [[ 11  12  13  14  15  16  17  18  19  20]
 [ 35  36  37  38  39  40  41  42  43  44]
 [ 59  60  61  62  63  64  65  66  67  68]
 [ 83  84  85  86  87  88  89  90  91  92]
 [107 108 109 110 111 112 113 114 115 116]]

在每次迭代中,您构造一个新的feed_dict,其初始状态为周期单位为零。因此,您假设在每个步骤中从头开始处理序列。这是对的吗?

1 个答案:

答案 0 :(得分:1)

  • Q1。 feed_dict用于设置周期性单位的初始状态。默认情况下,每次调用run周期性单位时,处理初始“零”状态的数据。但是,如果您的序列很长,您可能需要将它们分成几个步骤。重要的是,在每个步骤之后,保存循环单位的最终状态并输入下一步的初始状态,否则就好像下一步再次是序列的开始(特别是,如果你的输出只是处理整个序列后网络的最终输出,就像在最后一步之前丢弃所有数据一样)。在每个纪元的开始,代码首先采用默认的“零状态”,然后继续进行循环,其中当前状态作为初始状态给出,模型运行并且输出状态被设置为下一个的新当前状态迭代。

  • Q2。声称“feed_dict缓慢”可能会有些误导,被视为一般性的真相(我不是责怪你说的,我见过它也很多次)。 feed_dict的问题在于它的功能是将非TensorFlow数据(通常是NumPy数据)带入TensorFlow世界。这并不是很糟糕,只是需要一些额外的时间来移动数据,这在涉及大量数据时尤其值得注意。例如,如果要通过feed_dict输入一批图像,则需要从磁盘加载它们,解码它们,将其转换为大的NumPy数组并将其传递到feed_dict,然后TensorFlow将将所有数据复制到会话中(GPU内存或其他);所以你会在内存和额外的内存交换中存储两份数据。 tf.data有帮助,因为它可以在TensorFlow中执行所有操作(这也减少了Python / C行程的数量,有时通常更方便)。在你的情况下,通过feed_dict提供的是经常性单位的初始状态。除非你有几个非常大的重复层,否则我认为性能影响可能相当小。但 可能是为了在这种情况下避免feed_dict,你需要有一组TensorFlow变量保持当前状态,设置循环单位以使用它们的输出初始状态(initial_state参数tf.nn.dynamic_rnn)并使用其最终状态更新变量值;然后在每个新批次上,您必须再次将变量重新初始化为“零”状态。但是,我会确保在沿着该路线行进之前这将带来显着的好处(例如,使用和不使用feed_dict测量运行时间,即使结果是错误的)。

编辑:

作为更新的说明,我在这里复制了代码的相关行:

state = session.run(model.initial_state)

fetches = {
    "cost": model.cost,
    "final_state": model.final_state,
}
if eval_op is not None:
  fetches["eval_op"] = eval_op

for step in range(model.input.epoch_size):
  feed_dict = {}
  for i, (c, h) in enumerate(model.initial_state):
    feed_dict[c] = state[i].c
    feed_dict[h] = state[i].h

  vals = session.run(fetches, feed_dict)
  cost = vals["cost"]
  state = vals["final_state"]

  costs += cost
  iters += model.input.num_steps

在一个纪元的开头,statemodel.initial_state的值,除非给出feed_dict替换其值,否则它将是默认的“零”初始状态值。 fetches是一个稍后传递给session.run的字典,因此它返回另一个字典,其中({等式}}将保存最终状态值。然后,在每个步骤中,创建"final_state",用feed_dict中的数据替换initial_state张量值,并使用state调用runfeed_dict中检索张量的值,然后fetches保留vals调用的输出。第run行替换了state = vals["final_state"]的内容,它是我们当前的状态值,具有上次运行的输出状态;所以在下一次迭代state将保留前一个最后一个状态的值,因此网络将继续“仿佛”一次性给出整个序列。在下一次调用feed_dict时,run_epoch将再次初始化为默认值state,并且该过程将再次从“零”开始。

相关问题