在keras中实现切片层

时间:2020-02-11 11:47:21

标签: arrays tensorflow keras

(免责声明:我已将我的问题简化为要点,我想做的工作稍微复杂一些,但我在这里描述了核心问题。)

我正在尝试使用keras建立网络,以学习约5×5矩阵的属性。

输入数据采用1000 x 5 x 5 numpy数组的形式,其中每个5 x 5子数组代表一个矩阵。

我想让网络做的是使用矩阵中每一行的属性,所以我想将每5 x 5数组拆分为单独的1 x 5数组,然后将这5个数组中的每一个传递给下一个网络的一部分。

这是我到目前为止所拥有的:

input_mat = keras.Input(shape=(5,5), name='Input')

part_list = list()   
for i in range(5):
    part_list.append(keras.layers.Lambda(lambda x: x[i,:])(input_mat)) 

dense_list = list()
for i in range(5):
    dense_list.append( keras.layers.Dense(10, activation='selu', 
                                          use_bias=True)(part_list[i]) )


conc = keras.layers.Concatenate(axis=-1, name='Concatenate')(dense_list)
dense_out = keras.layers.Dense(1, name='D_out', activation='sigmoid')(conc)


model = keras.Model(inputs= input_mat, outputs=dense_out)
model.compile(optimizer='adam', loss='mean_squared_error')

我的问题是,这似乎不能很好地训练,并且查看模型摘要,我不确定网络是否按我的意愿分配输入:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input (InputLayer)              (None, 5, 5)         0                                            
__________________________________________________________________________________________________
lambda_5 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_6 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_7 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_8 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_9 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
dense (Dense)                   (5, 10)              60          lambda_5[0][0]                   
__________________________________________________________________________________________________
dense_1 (Dense)                 (5, 10)              60          lambda_6[0][0]                   
__________________________________________________________________________________________________
dense_2 (Dense)                 (5, 10)              60          lambda_7[0][0]                   
__________________________________________________________________________________________________
dense_3 (Dense)                 (5, 10)              60          lambda_8[0][0]                   
__________________________________________________________________________________________________
dense_4 (Dense)                 (5, 10)              60          lambda_9[0][0]                   
__________________________________________________________________________________________________
Concatenate (Concatenate)       (5, 50)              0           dense[0][0]                      
                                                                 dense_1[0][0]                    
                                                                 dense_2[0][0]                    
                                                                 dense_3[0][0]                    
                                                                 dense_4[0][0]                    
__________________________________________________________________________________________________
D_out (Dense)                   (5, 1)               51          Concatenate[0][0]                
==================================================================================================
Total params: 351
Trainable params: 351
Non-trainable params: 0

Lambda层的输入和输出节点对我来说似乎是错误的,尽管我恐怕仍在努力理解这一概念。

2 个答案:

答案 0 :(得分:1)

在行

part_list.append(keras.layers.Lambda(lambda x: x[i,:])(input_mat)) 

您基本上正在拍摄1000张图像中的前5张,这不是您想要的。

要实现所需的功能,请尝试张量流的unstack操作:

part_list = tf.unstack(input_mat, axis=1)

这应该给您一个包含5个元素的列表,每个元素的形状为[1000, 5]

答案 1 :(得分:0)

要避免使用 Lambda。

代替子类层:

class Slice(keras.layers.Layer):
    def __init__(self, begin, size,**kwargs):
        super(Slice, self).__init__(**kwargs)
        self.begin = begin
        self.size = size
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'begin': self.begin,
            'size': self.size,
        })
        return config
    def call(self, inputs):
        return tf.slice(inputs, self.begin, self.size)
相关问题