如何使用元组

时间:2016-12-28 21:06:46

标签: python-2.7 tensorflow neural-network recurrent-neural-network

我最近将我的tesnorflow从Rev8升级到Rev12。在Rev8中,rnn_cell.LSTMCell中的默认“state_is_tuple”标志设置为False,因此我使用列表初始化了LSTM Cell,请参阅下面的代码。

#model definition  
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)


#init_state place holder and feed_dict
def add_placeholders(self):
     self.init_state = tf.placeholder("float", [None, self.cell_size])

def get_feed_dict(self, data, label):
    feed_dict = {self.input_data: data,
             self.input_label: reg_label,
             self.init_state: np.zeros((self.config.batch_size, self.cell_size))}
    return feed_dict

在Rev12中,默认的“state_is_tuple”标志设置为True,为了使我的旧代码工作,我必须明确地将标志变为False。但是,现在我收到了tensorflow的警告说:

  

“使用连接状态较慢,很快就会被弃用。   使用state_is_tuple = True“

我尝试通过将self.init_state的占位符定义更改为以下内容来使用元组初始化LSTM单元格:

self.init_state = tf.placeholder("float", (None, self.cell_size))

但现在我收到一条错误消息:

  

“'Tensor'对象不可迭代”

有谁知道如何使这项工作?

1 个答案:

答案 0 :(得分:1)

现在使用cell.zero_state向LSTM提供“零状态”要简单得多。您不需要明确地将初始状态定义为占位符。将其定义为张量,并在需要时提供它。这是它的工作原理,

lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
self.initial_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)

如果你想提供一些其他值作为初始状态,比如说next_state = states[-1],在你的会话中计算并在feed_dict中传递它 -

feed_dict[self.initial_state] = next_state

在您的问题中,lstm_cell.zero_state()应该足够了。

无关,但请记住,您可以在Feed字典中传递张量和占位符!这就是self.initial_state在上面示例中的工作方式。有关工作示例,请查看PTB Tutorial