Tensorflow变量初始化

时间:2017-01-21 07:26:10

标签: python tensorflow

rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
state = rnn_cell.zero_state(batch_size, tf.float32)
init = tf.global_variables_initializer()
sess = tf.Session()
for i in range(len(x_data)):
    x = process_x(x_data[i])[:std_size]
    y = word[i][:std_size]
    x_split = tf.split(0, time_step_size, x)
    outputs, state = tf.nn.rnn(rnn_cell, x_split, state)

    prediction = tf.reshape(tf.concat(1, outputs), [-1, rnn_size])
    real = tf.reshape(y, [-1])
    ratio = tf.ones([time_step_size * batch_size])

    loss = tf.nn.seq2seq.sequence_loss_by_example([prediction], [real], [ratio])
    cost = tf.reduce_mean(loss)/batch_size
    train = tf.train.AdamOptimizer(0.01).minimize(cost)

    tf.global_variables_initializer().run(session=sess)
    step = 0
    print state
    while step < 1000:
        sess.run(train)
        step+=1
    result = sess.run(tf.arg_max(prediction, 1))
    print result, [t for t in result] == y
    tf.get_variable_scope().reuse_variables()

如果源代码如上所述,则在for循环的每个步骤中初始化rnn_cell和state?
如果我想在其他训练案例中使用状态,那么我必须重复使用它。所以rnn_cell和state应该首先初始化,而不是在那之后 我无法想象这段代码是如何工作的。

1 个答案:

答案 0 :(得分:0)

我认为问题是您必须将计算图部分与会话运行部分分开。你现在正在做的不是张量流通常如何工作。也许试试这个:

rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
state = rnn_cell.zero_state(batch_size, tf.float32)
x_split = tf.split(0, time_step_size, x)
outputs, state = tf.nn.rnn(rnn_cell, x_split, state)

prediction = tf.reshape(tf.concat(1, outputs), [-1, rnn_size])
real = tf.reshape(y, [-1])
ratio = tf.ones([time_step_size * batch_size])

loss = tf.nn.seq2seq.sequence_loss_by_example([prediction], [real], [ratio])
cost = tf.reduce_mean(loss)/batch_size
train = tf.train.AdamOptimizer(0.01).minimize(cost)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

for i in range(len(x_data)):
    x = process_x(x_data[i])[:std_size]
    y = word[i][:std_size]

step = 0
while step < 1000:
    sess.run(train, feed_dict={x_split:x, real:y})
    step+=1
    result = sess.run(tf.arg_max(prediction, 1))
    print result, [t for t in result] == y

您的代码可能存在一些设计问题,但重点是将图形设计与“培训”分开。