函数

时间:2018-09-22 18:52:54

标签: python tensorflow

我想做这样的事情:

def f():
    place = tf.placeholder(tf.int32)
    return 2 * place

y = f()

with tf.Session() as sess:
    a = sess.run(y, feed_dict={ place: 5 })

当然,在外面看不到占位符。

NameError                                 Traceback (most recent call last)
<ipython-input-88-8b3c17d16dce> in <module>()
      1 with tf.Session() as sess:
----> 2     a = sess.run(y, feed_dict={ place: 5 })

NameError: name 'place' is not defined

而且,我可以通过以下方式解决此问题:

def f():
    global place
    place = tf.placeholder(tf.int32)
    return 2 * place

但是,有人对此有更好的解决方案吗?如上例所示,当调用函数并将其返回值作为运算符传递给运行函数时,如何在函数内部创建占位符并向其外部馈送值。

2 个答案:

答案 0 :(得分:1)

这纯粹是Python,但您可以简单地将占位符与由此创建的图形op一起返回。

赞:

return place*2, place

然后像这样使用f():

y, place = f()

我认为在Tensorflow中的一个班级工作是一个好主意,并且能够轻松地在任何地方访问这些占位符。

编辑:当然,总是可以给占位符起一个名字,然后在需要输入时从图形中获取它。

答案 1 :(得分:1)

您可以按其名称访问占位符:

import tensorflow as tf

def f():
    place = tf.placeholder(tf.int32,name='place')
    return 2 * place

y = f()

with tf.Session() as sess:
    place = tf.get_default_graph().get_tensor_by_name('place:0')
    a = sess.run(y, feed_dict={ place: 5 })

说明

每当您定义一个占位符(或任何其他TensorFlow张量或操作)时,它就会被添加到计算图中,该图是位于后台并管理所有计算的对象。每个占位符都有一个默认名称,但是您也可以为其选择一个名称。在此示例中,我选择了名称place

现在,对于高级用例,您可能有多个计算图,但是始终有一个是默认图。为了获得默认值,我使用了tf.get_default_graph()。然后,为了获得对占位符的引用,我使用了get_tensor_by_name('place:0')。 (我使用的名称是'place:0'而不是'place',因为定义占位符时,实际上会创建一个tf.Tensor可以供稿,并且创建的也是执行供稿的操作。该操作将具有名称'place',而实际张量将具有名称'place:0'。)

相关问题