将张量值按阈值映射到1或0

时间:2019-05-29 16:15:46

标签: python tensorflow

我有一个t=[[0.2],[0.8],[0.4]]之类的张量。

,并希望将每个value>0.8映射到1,如果在该阈值以下,则映射到0

我用tf.where进行了尝试,但是它不起作用。

tf.where(t>0.8,tf.ones(t),tf.zeros_like(t))

代码:

self.temp_sim = tf.where(self.predictions>0.8,tf.ones(self.predictions),tf.zeros_like(self.predictions))

错误:

  File "C:\Users\\Anaconda3\lib\site-packages\tensorflow\python\ops\array_ops.py", line 1700, in ones
    output = fill(shape, constant(one, dtype=dtype), name=name)
  File "C:\Users\\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 3468, in fill
    "Fill", dims=dims, value=value, name=name)
  File "C:\Users\\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 609, in _apply_op_helper
    param_name=input_name)
  File "C:\Users\\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 60, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'dims' has DataType float32 not in list of allowed values: int32, int64

将其格式化为代码无效!

0 个答案:

没有答案
相关问题