Keras的CNN回归任务中奇怪的均方根误差行为

时间:2017-11-01 13:44:56

标签: keras regression

我正在使用类似于alexnet的CNN进行与图像相关的回归任务。我为损失函数定义了一个rmse。然而,在第一个时代的训练期间,损失带来了巨大的价值。但是在第二个时代之后,它降到了一个有意义的价值。这是:

  

1/51 [..............................] - ETA:847s - 损失:104.1821 -   acc:0.2500 - root_mean_squared_error:104.1821 2/51   [> .............................] - ETA:470s - 损失:5277326.0910 -   acc:0.5938 - root_mean_squared_error:5277326.0910 3/51   [> .............................] - ETA:345s - 损失:3518246.7337 -   acc:0.5000 - root_mean_squared_error:3518246.7337 4/51   [=> ............................] - ETA:281s - 损失:2640801.3379 -   acc:0.6094 - root_mean_squared_error:2640801.3379 5/51   [=> ............................] - ETA:241s - 损失:2112661.3062 -   acc:0.5000 - root_mean_squared_error:2112661.3062 6/51   [==> ...........................] - ETA:214s - 损失:1760566.4758 -   acc:0.4375 - root_mean_squared_error:1760566.4758 7/51   [===> ..........................] - ETA:194s - 损失:1509067.6495 -   acc:0.4464 - root_mean_squared_error:1509067.6495 8/51   [===> ..........................] - ETA:178s - 损失:1320442.6319 -   acc:0.4570 - root_mean_squared_error:1320442.6319 9/51   [====> .........................] - ETA:165s - 损失:1173734.9212 -   acc:0.4792 - root_mean_squared_error:1173734.9212 10/51   [====> .........................] - ETA:155s - 损失:1056369.3193 -   acc:0.4875 - root_mean_squared_error:1056369.3193 11/51   [=====> ........................] - ETA:146s - 损失:960343.5998 -   acc:0.4943 - root_mean_squared_error:960343.5998 12/51   [======> .......................] - ETA:139s - 损失:880320.3762 -   acc:0.5052 - root_mean_squared_error:880320.3762 13/51   [======> .......................] - ETA:131s - 损失:812608.7112 -   acc:0.5216 - root_mean_squared_error:812608.7112 14/51   [=======> ......................] - ETA:125s - 损失:754570.1939 -   acc:0.5402 - root_mean_squared_error:754570.1939 15/51   [=======> ......................] - ETA:120s - 损失:704269.2443 -   acc:0.5479 - root_mean_squared_error:704269.2443 16/51   [========> .....................] - ETA:114s - 损失:660256.3035 -   acc:0.5508 - root_mean_squared_error:660256.3035 17/51   [========> .....................] - ETA:109s - 损失:621420.7248 -   acc:0.5607 - root_mean_squared_error:621420.7248 18/51   [=========> ....................] - ETA:104s - 损失:586900.8398 -   acc:0.5712 - root_mean_squared_error:586900.8398 19/51   [==========> ...................] - ETA:100s - 损失:556014.6719 -   acc:0.5806 - root_mean_squared_error:556014.6719 20/51   [==========> ...................] - ETA:95s - 损失:528216.9077 - acc:   0.5875 - root_mean_squared_error:528216.9077 21/51 [===========> ..................] - ETA:91s - 损失:503065.7743 - ACC:   0.5967 - root_mean_squared_error:503065.7743 22/51 [===========> ..................] - ETA:87s - 损失:480206.3521 - ACC:   0.6094 - root_mean_squared_error:480206.3521 23/51 [============> .................] - ETA:83s - 损失:459331.8636 - ACC:   0.6114 - root_mean_squared_error:459331.8636 24/51 [=============> ................] - ETA:80s - 损失:440196.2991 - ACC:   0.6159 - root_mean_squared_error:440196.2991 25/51 [=============> ................] - ETA:76s - 损失:422590.8381 - ACC:   0.6162 - root_mean_squared_error:422590.8381 26/51 [==============> ...............] - ETA:73s - 损失:406339.5179 - ACC:   0.6178 - root_mean_squared_error:406339.5179 27/51 [==============> ...............] - ETA:69s - 损失:391292.6992 - ACC:   0.6238 - root_mean_squared_error:391292.6992 28/51 [===============> ..............] - ETA:66s - 损失:377319.9851 - ACC:   0.6306 - root_mean_squared_error:377319.9851 29/51 [===============> ..............] - ETA:63s - 损失:364310.7557 - ACC:   0.6336 - root_mean_squared_error:364310.7557 30/51 [================> .............] - ETA:60s - 损失:352169.1059 - ACC:   0.6385 - root_mean_squared_error:352169.1059 31/51 [=================> ............] - ETA:57s - 损失:340810.8854 - ACC:   0.6401 - root_mean_squared_error:340810.8854 32/51 [=================> ............] - ETA:53s - 损失:330162.1334 - ACC:   0.6455 - root_mean_squared_error:330162.1334 33/51 [==================> ...........] - ETA:50s - 损失:320158.7622 - ACC:   0.6553 - root_mean_squared_error:320158.7622 34/51 [==================> ...........] - ETA:47s - 损失:310744.0080 - ACC:   0.6645 - root_mean_squared_error:310744.0080 35/51 [===================> ..........] - ETA:44s - 损失:301866.8259 - ACC:   0.6714 - root_mean_squared_error:301866.8259 36/51 [====================> .........] - ETA:41s - 损失:293483.0129 - ACC:   0.6762 - root_mean_squared_error:293483.0129 37/51 [====================> .........] - ETA:39s - 损失:285552.8197 - ACC:   0.6757 - root_mean_squared_error:285552.8197 38/51 [=====================> ........] - ETA:36s - 损失:278039.4488 - ACC:   0.6752 - root_mean_squared_error:278039.4488 39/51 [=====================> ........] - ETA:33s - 损失:270911.4670 - ACC:   0.6795 - root_mean_squared_error:270911.4670 40/51 [======================> .......] - ETA:30s - 损失:264140.2391 - ACC:   0.6820 - root_mean_squared_error:264140.2391 41/51 [=======================> ......] - ETA:27s - 损失:257699.1895 - ACC:   0.6852 - root_mean_squared_error:257699.1895 42/51 [=======================> ......] - ETA:25s - 损失:251564.6846 - ACC:   0.6890 - root_mean_squared_error:251564.6846 43/51 [========================> .....] - ETA:22s - 损失:245715.4124 - ACC:   0.6933 - root_mean_squared_error:245715.4124 44/51 [========================> .....] - ETA:19s - 损失:240131.9916 - ACC:   0.6960 - root_mean_squared_error:240131.9916 45/51 [=========================> ....] - ETA:16s - 损失:234796.6948 - ACC:   0.7007 - root_mean_squared_error:234796.6948 46/51 [=========================> ....] - ETA:14s - 损失:229693.3717 - ACC:   0.7045 - root_mean_squared_error:229693.3717 47/51 [==========================> ...] - ETA:11s - 损失:224807.2748 - ACC:   0.7055 - root_mean_squared_error:224807.2748 48/51 [===========================> ..] - ETA:8s - 损失:220125.0731 - ACC:   0.7077 - root_mean_squared_error:220125.0731 49/51 [===========================> ..] - ETA:5s - 损失:215634.5638 - ACC:   0.7117 - root_mean_squared_error:215634.5638 50/51 [============================&gt ;.] - ETA:3s - 损失:211323.1692 - ACC:   0.7144 - root_mean_squared_error:211323.1692 51/51 [============================&gt ;.] - ETA:0s - 损失:207180.6328 - ACC:   0.7151 - root_mean_squared_error:207180.6328 52/51 [==============================] - 143s - 损失:203253.6237 - acc:   0.7157 - root_mean_squared_error:203253.6237 - val_loss:44.4203 - val_acc:0.9878 - val_root_mean_squared_error:44.4203 Epoch 2/128   1/51 [..............................] - ETA:117s - 损失:52.6087 -   acc:0.7188 - root_mean_squared_error:52.6087

如何理解这种行为?这是我的实施。首先定义rmse函数:

from keras import backend as K
def root_mean_squared_error(y_true, y_pred):
   return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))

然后是模型:

model.compile(optimizer="rmsprop", loss=root_mean_squared_error, metrics=['accuracy', root_mean_squared_error])

然后适合模型:

estimator = alexmodel()
datagen = ImageDataGenerator()
datagen.fit(x_train)
start = time.time()
history = estimator.fit_generator(datagen.flow(x_train, x_train,batch_size=batch_size, shuffle=True),
           epochs=epochs,
           steps_per_epoch=x_train.shape[0]/batch_size,
           validation_data=(x_test, y_test))
end = time.time()

谁能告诉我为什么会这样?有什么潜在的错误吗?

1 个答案:

答案 0 :(得分:1)

因此 - 规范化数据非常重要。您似乎没有对目标进行规范化,因为网络通常会以这样的方式进行初始化,即在开始时会产生较小的值 - 这会使您在第一个时期的损失如此巨大。因此,我仍建议您规范化目标(使用StandardScalerMinMaxScaller),因为需要生成高比例值会使网络的权重具有更高的绝对值,这是你应该阻止你的网络。