张量流而循环变慢

时间:2019-07-15 14:39:02

标签: python tensorflow

问题是一个Tensorflow While循环( tf.while_loop )随着时间的流逝而变慢。该循环应返回一些矩阵。我通过字典提供所有输入。

我知道问题很可能是由于一遍又一遍地添加操作而污染了图形。我是TF初学者,对我来说,导致图形污染的原因并不明显。我们非常感谢您的帮助。

def predict(self, actions, ...):


    feed_dict = {
        self.agent.actions: actions.reshape(-1, self.kwargs["dim_actions"]),
        ...
    }

    states_mu, states_var = self.session.run(self.agent.predict_states(), feed_dict=feed_dict)

    return states_mu, states_var


def predict_states(self):
   ...

    def loop_cond(i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov):
        return i < self.episode_length

    def loop_body(i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov):
        state_mu_i = state_mus[-1][None, :]
        ...
        state_var_tf = state_vars_tf[-1][None, :, :]

        #Some math operations
        ...

        new_state_mu = state_mu_i + delta_mu
        new_state_var = state_var_i + delta_var + inp_out_cov

        new_mu_tf, new_var_tf, inp_tf_cov = some_transform(
            new_state_mu, ....)

        state_mus = tf.concat([state_mus, new_state_mu], 0)
        ...
        state_vars_tf = tf.concat([state_vars_tf, new_var_tf], 0)

        i += 1

        return i, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov

    loop_step = tf.constant(0, tf.int32)
    init_mus_tf, init_vars_tf, inp_tf_cov = some_transform(
        self.state_mu, self.state_var, self.dim_angles)

    loop_vars = [
        loop_step,
        self.state_mu,
        self.state_var,
        init_mus_tf,
        init_vars_tf,
        inp_tf_cov]

    shapes = [loop_step.get_shape(),
              tf.TensorShape([None, self.dim_states]),
              tf.TensorShape([None, self.dim_states, self.dim_states]),
              tf.TensorShape([None, self.dim_states_tf]),
              tf.TensorShape([None, self.dim_states_tf, self.dim_states_tf]),
              inp_tf_cov.get_shape()]

    _, state_mus, state_vars, state_mus_tf, state_vars_tf, inp_tf_cov = tf.while_loop(
        loop_cond,
        loop_body,
        loop_vars=loop_vars,
        shape_invariants=shapes)

    return state_mus_tf[1:], state_vars_tf[1:]

该循环被多次调用。它在运行中会变慢,即在每次迭代后,甚至在重复调用后甚至会变慢。每次运行的迭代速度从上次运行结束的地方开始。 例如,在第一次运行的开始,每个迭代花费1秒,在第一次运行的结束,每个迭代花费3秒。在第二次运行开始时,每次迭代需要3秒,...直到使其无法运行(例如,每次迭代100秒)。

1 个答案:

答案 0 :(得分:0)

该代码似乎大部分都很好,但是在创建类的实例时(或在其他一些初始化步骤中),并且应该将返回值存储在类属性中,您仅应调用predict_states一次。例如:

def __init__(self, ...):
    # ...
    self.states_mu_tf, self.states_var_tf = self.agent.predict_states()

然后,您在predict中使用这些属性:

states_mu, states_var = self.session.run((self.states_mu_tf, self.states_var_tf),
                                         feed_dict=feed_dict)

那样,您将不会在图形中重新创建操作。

相关问题