我正在尝试在Matlab中建立均方对数误差回归层,但没有成功。
我正在关注Matlab模板。有帮助吗?
classdef msleRegressionLayer < nnet.layer.RegressionLayer
% Custom regression layer with mean-squared-logarithmic-error loss.
methods
function layer = msleRegressionLayer(name)
% layer = msleRegressionLayer(name) creates a
% mean-squared-logarithmic-error regression layer and specifies the layer
% name.
% Set layer name.
layer.Name = name;
% Set layer description.
layer.Description = 'Mean squared logarithmic error';
end
function loss = forwardLoss(layer, Y, T)
% loss = forwardLoss(layer, Y, T) returns the MSLE loss between
% the predictions Y and the training targets T.
% Calculate MSLE.
R = size(Y,3);
%meanAbsoluteError = sum(abs(Y-T),3)/R;
msle=sum((log10((Y+1)./(T+1))).^2,3)/R;
% Take mean over mini-batch.
N = size(Y,4);
loss = sum(msle)/N;
end
function dLdY = backwardLoss(layer, Y, T)
% Returns the derivatives of the MSLE loss with respect to the predictions Y
R = size(Y,3);
N = size(Y,4);
dLdY = 2/(N*R)*(log10(Y+1)-log10(T+1))./(Y+1)*2.3;
end
end
end
这是我在检查图层时遇到的错误:
validInputSize = [1 1 64];
checkLayer(layer,validInputSize,'ObservationDimension',2);
Skipping GPU tests. No compatible GPU device found.
Running nnet.checklayer.OutputLayerTestCase
.......... ..
================================================================================
Verification failed in nnet.checklayer.OutputLayerTestCase/gradientsAreNumericallyCorrect.
----------------
Test Diagnostic:
----------------
The derivative 'dLdY' for 'backwardLoss' is inconsistent with the numerical gradient. Either 'dLdY' is incorrectly computed, the function is non-differentiable at some input points, or the error tolerance is too small.
---------------------
Framework Diagnostic:
---------------------
IsEqualTo failed.
--> NumericComparator failed.
--> The numeric values are not equal using "isequaln".
--> OrTolerance failed.
--> RelativeTolerance failed.
--> The error was not within relative tolerance.
--> AbsoluteTolerance failed.
--> The error was not within absolute tolerance.
--> Failure table (First 50 out of 128 failed indices):