sess.run动态增加内存使用量

时间:2018-12-05 14:35:59

标签: python tensorflow memory

我试图通过以下代码来训练模型。

sess.run([train_op, model.global_step, model.loss, model.prediction], feed_dict)

但是,我发现运行“ model.prediction”时内存使用量动态增加。

在迭代过程中,我从不保留“ sess.run()”的结果。

“ model.prediction”为

@property
def prediction(self):
    return [tf.argmax(self.logits_b, 1),
            tf.argmax(self.logits_m, 1),
            tf.argmax(self.logits_s, 1),
            tf.argmax(self.logits_d, 1)]

我不知道为什么会这样。 请帮助我。

1 个答案:

答案 0 :(得分:1)

每次使用属性prediction时,都会在图形中创建新操作。您应该只创建一次操作,然后将其返回到属性中:

def create_model(self):
    # Hypothetical function that creates the model, only called once
    # ...
    self._prediction = (tf.argmax(self.logits_b, 1),
                        tf.argmax(self.logits_m, 1),
                        tf.argmax(self.logits_s, 1),
                        tf.argmax(self.logits_d, 1))
    # ...

@property
def prediction(self):
    return self._prediction
相关问题