针对多次运行优化Tensorflow

时间:2016-06-27 16:54:38

标签: tensorflow

我需要多次执行语句sess.run()。 我在代码的开头创建了sess一次。但是,每个sess.run()语句在我的CPU机器上花费大约0.5-0.8秒。有什么办法可以优化吗?由于Tensorflow会进行延迟加载,有什么方法可以让它不能做到这一点,并使其更快?

我正在使用图像分类示例中的Inception模型。

def load_network():
    with gfile.FastGFile('model.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        data = f.read()
        graph_def.ParseFromString(data)
        png_data = tf.placeholder(tf.string, shape=[])
        decoded_png = tf.image.decode_png(png_data, channels=3)
        _ = tf.import_graph_def(graph_def, name=input_map={'DecodeJpeg': decoded_png})
        return png_data

def get_pool3(sess, png_data, imgBuffer):
    pool3 = sess.graph.get_tensor_by_name('pool_3:0')
    pool3Vector = sess.run(pool3, {png_data: imgBuffer.getvalue()})
    return pool3Vector

def main():
    sess = getTensorSession()
    png_data = load_network()

    # The below line needs to be called multiple times, which is what takes
    # nearly 0.5-0.8 seconds.
    # imgBuffer contains the stored value of the image.
    pool3 = get_pool3(sess, png_data, imgBuffer)

1 个答案:

答案 0 :(得分:2)

Tensorflow懒惰地运行操作---在调用sess.run()之前实际上没有计算任何内容。当您调用sess.run()时,Tensorflow会执行计算图中的所有操作。因此,如果sess.run()花费0.5-0.8秒,那么你的计算本身可能需要0.5-0.8秒。

(sess.run()有一些开销,但它不应该接近半秒左右。)

希望有所帮助!

添加了:

以下是一些可以加快计算速度的方法:

  • 使用Tensorflow的分析工具来查看计算的哪些部分花费时间。它们尚未记录,但您可以在此github问题中找到有关它们的一些信息:https://github.com/tensorflow/tensorflow/issues/1824

  • 让您的计算更便宜 - 降低模型的复杂性,使用更小的图像等。

  • 在GPU而不是CPU上运行计算。