在张量流中,计算张量中大于阈值的元素的最简单方法是什么?

时间:2018-01-17 00:45:15

标签: tensorflow

e.g。给出m x n张量,我试图找到大于阈值的元素。

看来这可以用tf.greater来完成,但似乎我需要构建一个m x n张量的阈值?

有什么好办法吗?

2 个答案:

答案 0 :(得分:1)

看起来你没有长时间搜索:

import tensorflow as tf

x= tf.constant([[0, 1, 2], [3, 4, 5]], dtype=tf.float32)
out=  tf.greater(x, 2.5)
with tf.Session() as sess:
    print(sess.run(out))

给出:

  

[[False False False] [True True True]]

答案 1 :(得分:1)

这是一种计数大于阈值的元素数量的方法:

x = tf.constant([[1,2,3,4],[2,3,4,5],[3,4,5,6]])
threshold = 4
elements_gt = tf.math.greater(x,threshold)
num_elements_gt = tf.math.reduce_sum(tf.cast(elements_gt, tf.int32))
print(num_elements_gt)

计算tf.greater时,可以使用tf.greater_equaltf.lesstf.less_equalelements_gt作为过滤器。