交互式会话eval和tf.random_uniform不会产生预期的输出

时间:2017-09-05 13:47:54

标签: python tensorflow

我对以下代码片段有疑问,obs_patternobs_seqobs_seq_s不会产生预期的行为。我试过TensorFlow 1.2.1。我怀疑有什么不对劲。

import tensorflow as tf

sess = tf.InteractiveSession()
seq_length = 5
num_bits = 4
obs_pattern_shape = [num_bits, seq_length]
obs_pattern = tf.cast(
    tf.random_uniform(obs_pattern_shape, minval=0, maxval=2, seed=1234, dtype=tf.int32), 
    tf.float32)
print(obs_pattern.eval())
seq_length_zeros = tf.zeros([1, seq_length])
obs_seq = tf.concat([obs_pattern, seq_length_zeros], axis=0)
print(obs_seq.eval())
add_vec = tf.one_hot([num_bits], (num_bits + 1), on_value = 1.0, off_value=0.0, axis=0)
obs_seq_s = tf.concat([obs_seq, add_vec], axis=1)
print(obs_seq_s.eval())

sess.close()

obs_pattern

[[ 1.  1.  1.  0.  0.]
[ 0.  0.  0.  1.  0.]
[ 0.  1.  1.  0.  0.]
[ 0.  1.  0.  0.  0.]]

obs_seq

[[ 1.  1.  0.  1.  1.]
[ 0.  1.  1.  0.  1.]
[ 1.  1.  1.  0.  0.]
[ 1.  1.  0.  1.  0.]
[ 0.  0.  0.  0.  0.]]

obs_seq_s

[[ 1.  0.  0.  0.  0.  0.]
[ 0.  0.  1.  0.  0.  0.]
[ 0.  1.  0.  0.  0.  0.]
[ 0.  1.  0.  0.  0.  0.]
[ 0.  0.  0.  0.  0.  1.]]

修改 根据下面的评论我改变了obs_pattern,它的行为和我想的一样

import numpy as np
arr = np.random.randint(2, size=(num_bits, seq_length))
obs_pattern = tf.convert_to_tensor(arr, dtype=tf.float32)

1 个答案:

答案 0 :(得分:0)

预计每次在用obs_pattern形成的张量上运行sess时,它会重新运行一个新的随机统一。 为了获得您期望的行为,生成 ONCE 一个numpy的numpy数组,然后将其提供给相同形状的占位符:obs_pattern。 为了澄清您使用种子,您运行了多次操作,以便获得序列的第一次迭代。

相关问题