在张量流中初始化变量,变量范围和import_graph_def

时间:2018-02-10 00:08:29

标签: python variables tensorflow machine-learning initialization

在尝试使用import_graph_def进行图形手术时,我有许多关于张量流行为的相关问题。 this post

在上图中,我用粗体红色箭头表示2个不同的图形手术。在左侧,有2个图形,g1和g2,手术包括用图形g1替换图形g2中的节点 - 以及它下面的所有内容。如何做到这一点在a project can depend on both Jackson 1.x and 2.x, without conflicts中有所解释。右侧的手术涉及替换属于相同图形的节点,我无法弄清楚如何执行,或者即使它完全可能。我最终得到了这个最小的例子

with tf.Graph().as_default() as g1:
    with tf.variable_scope('foo', reuse=tf.AUTO_REUSE):
        x = tf.placeholder(dtype=tf.float64, shape=[2], name='x')
        c = tf.get_variable('c', initializer=tf.cast(1.0, tf.float64))
        y = tf.identity(2*x, 'y')

        z = tf.identity(3*x*c, 'z')

        g1_def = g1.as_graph_def()
        z1, = tf.import_graph_def(g1_def, input_map={'foo/x:0' : y}, return_elements=["foo/z:0"],
                              name='z1')
        init_op = tf.global_variables_initializer()
        print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='foo'))


with tf.Session(graph=g1) as sess:
    sess.run(init_op)
    print(sess.run(z, feed_dict={'foo/x:0' : np.array([1.0, 2.0])}) )
    print(sess.run(tf.report_uninitialized_variables()))
    # z1 = sess.run(z1, feed_dict={'foo/x:0' : np.array([1.0, 2.0])})

此代码按原样运行。 3张印刷品分别产量:

[<tf.Variable 'foo/c:0' shape=() dtype=float64_ref>]
[ 3.  6.]
[]

特别是,最后一次打印通知没有未初始化的变量。但是,取消注释最后一行会产生错误

FailedPreconditionError (see above for traceback): Attempting to use uninitialized value foo/z1/foo/c

请注意,如果我从上面的c定义中删除z,那么也可以。但是,我想了解这个错误。首先,为什么变量报告为foo/z1/foo/c?为什么范围foo出现两次?当我打印未初始化的变量时,为什么没有报告?为什么我在GLOBAL_VARIABLES范围内打印foo集合时仅报告foo / c?

PS:我想有一种更简单的方法可以提出问题,即什么是

的张量流模拟?
theano.clone(some_tensor, replace={input_var : replace_var})

2 个答案:

答案 0 :(得分:2)

  

首先,为什么变量报告为foo/z1/foo/c?   为什么范围foo出现两次?

在您致电tf.import_graph_def(...)后,图表重复了。第一个图表在foo得分中定义。第二个子图已在范围foo/z1下导入(因为name='z1'加上foo保留在上面的范围内)。因此,图表g1现在包含以下张量:

foo/x
foo/y
foo/c
...
foo/z1/foo/x
foo/z1/foo/y
foo/z1/foo/c
...

第一个foo/c已初始化,但第二个foo/z1/foo/c未初始化(见下文)。

  

为什么在打印未初始化的变量时没有报告?为什么在foo/c范围内打印GLOBAL_VARIABLES集合时仅报告foo

由于report_uninitialized_variables()默认情况下会扫描LOCAL_VARIABLESGLOBAL_VARIABLES,因此这基本上是同一个问题。

可能是一个错误:GLOBAL_VARIABLES调用后tf.import_graph_def集合未更新。我说可能是因为GLOBAL_VARIABLES被设计为一个简单的便利集合。 Tensorflow试图让它保持日期,但可能并不能保证它始终具有所有变量。 tf.add_to_collection存在的事实公开支持这一想法 - 如果他们想要,可以为任何集合添加任何值。底线:此行为在将来的版本中可能会或可能不会更改,但从1.5开始,客户端负责在图表导入后更新全局变量。

  

特别是,最后一次打印通知没有未初始化的变量。但是,取消注释最后一行会产生错误

要修复此错误,您只需运行z1子图的初始值设定项。像这样:

# note that it's defined before `g1.as_graph_def()` to be a part of graph def
init_op = tf.global_variables_initializer()

g1_def = g1.as_graph_def()
z1, = tf.import_graph_def(g1_def, input_map={'foo/x:0': y}, return_elements=["foo/z:0"],
                          name='z1')

# find the init op
z1_init_op = tf.get_default_graph().get_operation_by_name('foo/z1/foo/init')

...

sess.run(z1_init_op)

瞧!你有重复的图表,就像你想要的那样。

答案 1 :(得分:1)

我遇到了类似的问题,但是仅仅运行init操作是行不通的。

我通过手动运行导入图的全局变量的所有“分配”操作来解决此问题。

在我的场景中,我想使用两个不同的输入张量来运行输入为“ patch:0”的编码操作“ z”。

    with tf.Session(graph=tf.get_default_graph()).as_default() as sess:

        g = tf.Graph()       
        saved_model = predictor.from_saved_model(args.export_dir, graph=g)
        variables = g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)]

        fetch_ops = ['z:0','init']
        fetch_ops.extend([v.name.strip(":0") + "/Assign" for v in variables)

        image_graph = tf.graph_util.import_graph_def(
            g.as_graph_def(),
            input_map={'patch:0': image},
            return_elements=fetch_ops,
            name='image')

        warped_graph = tf.graph_util.import_graph_def(
            g.as_graph_def(),
            input_map={'patch:0': warped_image},
            return_elements=fetch_ops,
            name='warp')

        loss = tf.reduce_sum(tf.math.squared_difference(image_graph[0], warped_graph[0]))

        optimizer =  tf.train.GradientDescentOptimizer(learning_rate=0.0001)
        compute_gradients = optimizer.compute_gradients(
            loss,
            var_list=[dest_control_point_locations])

        apply_gradients = optimizer.apply_gradients(compute_gradients, global_step=step)

        sess.run(image_graph[1:])
        sess.run(warped_graph[1:])
        sess.run(tf.global_variables_initializer())

        gradients = sess.run(compute_gradients)

在提取操作并通过使用feed_dict给我的张量馈送运行它时,gradient_computation不起作用,这就是为什么我使用tf.graph_util.import_graph_def(...)

希望这可能会帮助遇到相同问题的任何人。

相关问题