张量流中的索引慢于收集

时间:2017-09-05 06:23:20

标签: tensorflow tensorflow-serving tensorflow-gpu tensor tensorflow-xla

我正在尝试索引张量以从1d张量获得切片或单个元素。我发现使用numpy索引[:]slice vs tf.gather的方式(差不多30-40%)时,性能存在显着差异。

另外我观察到tf.gather在标量上使用时会产生很大的开销(在未堆叠的张量上循环)而不是张量。这是一个已知的问题吗?

示例代码(效率低下):

for node_idxs in graph.nodes():
    node_indice_list = tf.unstack(node_idxs)
    result = []
    for nodeid in node_indices_list:
        x = tf.gather(..., nodeid)
        y = tf.gather(..., nodeid)
        result.append(tf.mul(x,y))
return tf.stack(result)

而不是 示例代码(高效):

for node_idxs in graph.nodes():
    x = tf.gather(..., node_idxs)
    y = tf.gather(..., node_idxs)
return tf.mul(x, y)

据我所知,第一个低效的实现正在做更多的卸载,堆叠然后循环以及更多聚集操作的工作,但是当我运行的节点的顺序是几百个节点时,我没想到100x减速(正在拆卸和聚集在单个标量上的开销很慢,在第一种情况下,我有更多的聚集操作,每个操作单个元素而不是张量的偏移)。是否有更快的索引方式,我尝试了numpy和slice,结果比收集慢。

1 个答案:

答案 0 :(得分:0)

首先,代码并没有真正比较 gather 与 Numpy 索引 - 它比较了矢量化索引(tf.gather)与循环索引(Python“for”循环)。循环很慢也就不足为奇了。

请注意,在 Tensorflow 中无论如何都限制了类似 Numpy 的索引 tensor[idxs]

<块引用>

仅整数、切片 (:)、省略号 (...)、tf.newaxis (None) 和 标量 tf.int32/tf.int64 张量是有效索引

因此将 tf.gather 用于一般应用。