使`while_loop`运行更快

时间:2019-06-13 01:55:53

标签: python tensorflow optimization

由于tf.scan甚至给back_prop=False都给了我oom,所以我尝试while_loop并手工做长得很长的导数。

但是我的while_loop很慢。 是否可以加快while_loop的速度?

结果:

result: 1000000
second per iteration 5.064262390136719e-06

测试:

import tensorflow as tf
import numpy as np
import time


def make_while_loop(counts, dtype=np.int32):
    i = tf.get_variable('i', dtype=np.int32, initializer=dtype(0))
    one = tf.constant(1, dtype=dtype)
    loop_ends = tf.constant(counts, dtype=dtype)
    condition = lambda i: tf.less(i, loop_ends)
    increment = lambda i: tf.add(i, one)
    return tf.while_loop(condition, increment, [i], back_prop=False)


def main():
    count = int(1e6)
    i = make_while_loop(count)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.graph.finalize()
        sess.run(init)
        st = time.time()
        print("result:", sess.run(i)[1])
        en = time.time()
        print("second per iteration", (en - st) / count)


main()

0 个答案:

没有答案