高斯过程作为CNN的最后一层

时间:2018-06-20 09:22:05

标签: python keras deep-learning

(免责声明:我只是最近才开始机器学习,如果能帮助我找出错误,我将不胜感激) 在使用CNN提取特征以进行图像识别之后,我想使用高斯层。我同样使用keras-gp(https://github.com/alshedivat/keras-gp)。我在最后一个时期收到此错误:

WARNING:tensorflow:Variable *= will be deprecated. Use variable.assign_mul if you want assignment to the variable value or 'x = x * y' if you want a new python Tensor object.
Epoch 1/50
2018-06-20 14:44:47.166910: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
 9/10 [==========================>...] - ETA: 0s - loss: 0.0000e+00 - acc: 0.5289 - gp_1_mse: 0.0000e+00 - gp_1_nlml: 0.0000e+00 - mse: 0.0000e+00 - nlml: 0.0000e+00Traceback (most recent call last):
  File "train_network.py", line 102, in <module>
    H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS), validation_data= (testX, testY), steps_per_epoch=len(trainX) // BS, epochs=EPOCHS)
  File "/Users/s/PycharmProjects/Ropar_proj1/venv/lib/python2.7/site-packages/Keras-2.1.3-py2.7.egg/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/Users/s/PycharmProjects/Ropar_proj1/venv/lib/python2.7/site-packages/Keras-2.1.3-py2.7.egg/keras/engine/training.py", line 2203, in fit_generator
    verbose=0)
TypeError: evaluate() got an unexpected keyword argument 'sample_weight' 

同一代码的相关部分是:

MSGP = GP (
        gp_hypers,
        batch_size = 25 ,
        nb_train_samples= 250)
gp = MSGP(model)
model = Model(inputs = input1, outputs = [gp])
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
loss = [gen_gp_loss(gp) for gp in model.output_layers]
model.compile(loss=loss, optimizer=opt, metrics=["accuracy"], 
           sample_weight_mode= None)
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,horizontal_flip=True, fill_mode="nearest")
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS), 
validation_data= (testX, testY), steps_per_epoch=len(trainX) // BS, epochs=EPOCHS)

0 个答案:

没有答案