Matlab自定义均方对数误差回归层

时间:2019-06-27 16:34:52

标签: matlab neural-network regression layer

我正在尝试在Matlab中建立均方对数误差回归层,但没有成功。

我正在关注Matlab模板。有帮助吗?

classdef msleRegressionLayer < nnet.layer.RegressionLayer
    % Custom regression layer with mean-squared-logarithmic-error loss.

    methods
        function layer = msleRegressionLayer(name)
            % layer = msleRegressionLayer(name) creates a
            % mean-squared-logarithmic-error regression layer and specifies the layer
            % name.

            % Set layer name.
            layer.Name = name;

            % Set layer description.
            layer.Description = 'Mean squared logarithmic error';
        end

        function loss = forwardLoss(layer, Y, T)
            % loss = forwardLoss(layer, Y, T) returns the MSLE loss between
            % the predictions Y and the training targets T.


            % Calculate MSLE.
            R = size(Y,3);
            %meanAbsoluteError = sum(abs(Y-T),3)/R;
            msle=sum((log10((Y+1)./(T+1))).^2,3)/R;

            % Take mean over mini-batch.
            N = size(Y,4);
            loss = sum(msle)/N;
        end

        function dLdY = backwardLoss(layer, Y, T)
            % Returns the derivatives of the MSLE loss with respect to the predictions Y

            R = size(Y,3);
            N = size(Y,4);

            dLdY = 2/(N*R)*(log10(Y+1)-log10(T+1))./(Y+1)*2.3;

        end
    end
end

这是我在检查图层时遇到的错误:

validInputSize = [1 1 64];
checkLayer(layer,validInputSize,'ObservationDimension',2);
Skipping GPU tests. No compatible GPU device found.

Running nnet.checklayer.OutputLayerTestCase
.......... ..
================================================================================
Verification failed in nnet.checklayer.OutputLayerTestCase/gradientsAreNumericallyCorrect.
    ----------------
    Test Diagnostic:
    ----------------
    The derivative 'dLdY' for 'backwardLoss' is inconsistent with the numerical gradient. Either 'dLdY' is incorrectly computed, the function is non-differentiable at some input points, or the error tolerance is too small.
    ---------------------
    Framework Diagnostic:
    ---------------------
    IsEqualTo failed.
    --> NumericComparator failed.
        --> The numeric values are not equal using "isequaln".
        --> OrTolerance failed.
            --> RelativeTolerance failed.
                --> The error was not within relative tolerance.
            --> AbsoluteTolerance failed.
                --> The error was not within absolute tolerance.
            --> Failure table (First 50 out of 128 failed indices):

0 个答案:

没有答案