为什么tf.Print()不起作用?

时间:2018-07-19 12:34:19

标签: python python-3.x tensorflow

我有以下代码段:

import tensorflow as tf
import numpy as np

    # batch x time x events x dim
batch = 2
time = 3
events = 4
tensor = np.random.rand(batch, time, events)

tensor[0][0][2] = 0
tensor[0][0][3] = 0

tensor[0][1][3] = 0

tensor[0][2][1] = 0
tensor[0][2][2] = 0
tensor[0][2][3] = 0

tensor[1][0][3] = 0

def cum_sum(prev, cur):
    non_zeros = tf.equal(cur, 0.)
    tf.Print(non_zeros, [non_zeros], "message ")
    tf.Print(cur, [cur])
    return cur

elems = tf.constant([1,2,3],dtype=tf.int64)
#alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))
cum_sum_ = tf.scan(cum_sum, tensor)

s = tf.Session()

s.run(cum_sum_)

我在传递给tf.Print的函数中有两个tf.scan语句,但是运行累积总和时,没有得到任何打印语句。我在做错什么吗?

2 个答案:

答案 0 :(得分:2)

tf.Print不能那样工作。打印节点必须位于图中才能突出。我强烈建议您查看this教程,以了解如何使用它。

如果您有任何问题,请随时提问。

答案 1 :(得分:0)

尽管,@ Ignacio Peletier的答案是完全有帮助的,它取决于外部站点。我发现没有人在这里提到这一点很奇怪。无论如何,为了遵守规则,我还在这里提供了答案(无需访问外部链接),并提供了一些额外的信息:

要让tf.Print实际打印某些内容,它应该属于图形。为此,您只需重复使用tf.Print中返回的Tensor,并将其传递给下一个操作。 将其传递到下一个操作对于实际显示消息至关重要。

因此,使用您的示例可以将其重写:

def cum_sum(prev, cur):
    non_zeros = tf.equal(cur, 0.)
    non_zeros = tf.Print(non_zeros, [non_zeros], "message ")
    cur = tf.Print(cur, [cur])
    return cur

,它将打印cur,但不会打印non_zeros,因为此节点是悬空的。另外,我不确定是否可以重写您的代码,以便显示non_zeros,因为在您定义它们之后,它实际上并没有在您的代码中使用(因此,tensorflow在非急切模式下只会忽略它)。 / p>

通常对于(非悬挂)节点,我们在代码中的某个地方将其命名为result

result = tf.some_op_here()
# the following is actually displaying the contents of result
result = tf.Print(result, [result], 'some arbitrary message here to be displayed before the actual message')
tf.another_op_here_using_result(result)

将打印result的(可能是串联的)内容。为了控制显示的信息量,您还可以使用参数summarize=x,其中x是显示的参数数量。