Tensorflow:如何重复运行tensorflow作业?

时间:2017-05-30 00:31:26

标签: tensorflow

我尝试使用不同的超参数重复运行tensorflow深度学习程序。

for i in range(10):
    my_learner = DQNLearner()
    my_learner.build_network()
    my_learner.run()


class DQNLearner():
    def build_network(self):
        W1 = tf.get_variable(
            "W1",
            shape=[self.input_size, h_size],
            initializer=tf.contrib.layers.xavier_initializer()
        )
        b1 = tf.Variable(tf.random_normal([h_size]))
        L1 = tf.nn.relu(tf.matmul(self._X, W1) + b1)
        L1 = tf.nn.dropout(L1, keep_prob=self.keep_prob)

        W2 = tf.get_variable(
            "W2",
            shape=[h_size, h_size],
            initializer=tf.contrib.layers.xavier_initializer()
        )
        b2 = tf.Variable(tf.random_normal([h_size]))
        L2 = tf.nn.relu(tf.matmul(L1, W2) + b2)
        L2 = tf.nn.dropout(L2, keep_prob=self.keep_prob)
        .
        .
        .
        .
        .

它在第一个循环中运行良好。但在第二个循环中,它出现如下:

ValueError: Variable W1 already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:

我该如何解决?

1 个答案:

答案 0 :(得分:0)

tf.get_variable方法尝试在已调用的情况下为您提供现有变量,否则会创建一个新变量。见the variable guide

您可以通过调用tf.reset_default_graph()来重置两个循环之间的图形。通过这样做,新变量将被添加到一个全新的默认图表中:

for i in range(10):
    tf.reset_default_graph()
    my_learner = DQNLearner()
    my_learner.build_network()
    my_learner.run()

另一种方法是保留变量并使用初始化操作在两个循环之间重新初始化它们。例如,您可以这样做:

tf.reset_default_graph()
my_learner = DQNLearner()
my_learner.build_network()

init_op = tf.global_variables_initializer()
sess = tf.InteractiveSession()
for i in range(10):
    sess.run(init_op)
    my_learner.run()
相关问题