Keras:屏蔽非RNN的零填充输入

时间:2017-02-17 16:55:35

标签: python keras

当我需要屏蔽Keras中非RNN的输入时,我的一位同事指出了使用sample_weight代替遮蔽层的非常酷的选项。

就我而言,输入中有62列,第63列是响应。 62列中超过97%的非零条目包含在前30列中。我试图让这个工作,所以我想在训练中将最后32列加权为0,基本上创造了一个'穷人的面具'。

这是一个使用MLP的8级分类任务。响应变量已使用Keras中的to_categorical()函数进行转换。

以下是实施:

model = Sequential()
model.add(Dense(100, input_dim=X.shape[1], init='uniform', activation='relu'))
model.add(Dense(8, init='uniform', activation='sigmoid'))
hist = model.fit(X, y, 
                 validation_data=(X_test, ytest), 
                 nb_epoch=epochs_, 
                 batch_size=batch_size_, 
                 callbacks=callbacks_list, 
                 sample_weight = np.array([X.shape[1]-32, 30])) 

我收到了这个错误:

in standardize_weights
assert y.shape[:sample_weight.ndim] == sample_weight.shape

如何修复我的sample_weight以'掩盖'输入的前32列?

1 个答案:

答案 0 :(得分:2)

样品重量不是这样的:

  

sample_weight:与x长度相同的可选数组,包含适用于每个样本的模型损失的权重。对于时态数据,您可以传递形状为(samples, sequence_length)的2D数组,以对每个样本的每个时间步应用不同的权重。在这种情况下,您应确保在sample_weight_mode="temporal"中指定compile()source

换句话说,此设置对训练数据的样本赋予不同的权重,而不是每个样本的特征。这仅用于训练步骤。 如果您不希望图层使用这些功能,我认为您应该使用遮罩。或者只是从数据集中删除它们?或者,如果它不是太复杂,让网络自己学习哪些有用的功能。

这有帮助吗?