如何在张量流张量中切片数组?

时间:2018-07-20 08:34:50

标签: python tensorflow keras slice

import tensorflow as tf
import numpy as np
import keras

a = np.array([0.1,0.2,0.3,0.4,0.5])
b = np.array([1,2,3,4,5])

def helper(m):
    x = m[0]
    y = m[1]
    idx = tf.where(x>0.3)
    y = tf.gather(y, idx)
    return y
def helper_output_shape(input_shape):
    return (2,2,1)
x1 = keras.layers.Input(shape=(1,))
x2 = keras.layers.Input(shape=(1,))
r = keras.layers.Lambda(helper, output_shape=helper_output_shape)([x1,x2])

model = keras.Model([x1,x2],r)
model.compile(optimizer='sgd',loss='mean_squared_error')
model.predict([a,b])

如代码所示,有两个数组:分数和目标。我想根据分数对目标数组进行切片。 例如(如上面的代码所示):

a = np.array([0.1,0.2,0.3,0.4,0.5]) # scores array
b = np.array([1,2,3,4,5])           # target array

# I want to keep the value that has a higher score than 0.3 in b, and remove the rest. So the result should be array([4,5])

但出现此错误。 错误:

ValueError: could not broadcast input array from shape (2,2,1) into shape (5,2,1)

我试图在Lambda中使用output_shape解决输入和输出之间形状不匹配的问题,但无法正常工作。如何解决这个问题? 谢谢!

1 个答案:

答案 0 :(得分:0)

就此特定错误而言,您可以像这样返回。

def helper(m):
    x = m[0]
    y = m[1]
    idx = tf.where(x>0.3)
    y = tf.gather(y, idx)
    print( y.get_shape())  #This prints (?, 2, 1)
    return tf.constant([0,0,0,1,1],tf.float32)   #This can be coded generically.

此示例返回值解决了形状问题。我们匹配它期望的(5,2,1)中的5。 5是输入中的元素数。