自定义Keras损失函数中的conv2d

时间:2018-06-01 23:45:31

标签: python tensorflow keras

我正在尝试使用基于两个图像的拉普拉斯算子的TF后端在Keras中实现自定义丢失函数。

def blur_loss(y_true, y_pred):
    #weighting of blur loss
    alpha = 1
    mae = losses.mean_absolute_error(y_true, y_pred)
    lapKernel = K.constant([0, 1, 0, 1, -4, 1, 0, 1, 0],shape = [3, 3])

    trueLap = K.conv2d(y_true, lapKernel)
    predLap = K.conv2d(y_pred, lapKernel)
    trueBlur = K.var(trueLap)
    predBlur = K.var(predLap)
    blurLoss = alpha * K.abs(trueBlur - predBlur)
    loss = (1-alpha) * mae + alpha * blurLoss
    return loss

当我尝试编译模型时,我收到此错误

Traceback (most recent call last):
  File "kitti_train.py", line 65, in <module>
    model.compile(loss='mean_absolute_error', optimizer='adam', metrics=[blur_loss])
  File "/home/ubuntu/.virtualenvs/dl4cv/lib/python3.5/site-packages/keras/engine/training.py", line 924, in compile
    handle_metrics(output_metrics)
  File "/home/ubuntu/.virtualenvs/dl4cv/lib/python3.5/site-packages/keras/engine/training.py", line 921, in handle_metrics
    mask=masks[i])
  File "/home/ubuntu/.virtualenvs/dl4cv/lib/python3.5/site-packages/keras/engine/training.py", line 450, in weighted
    score_array = fn(y_true, y_pred)
  File "/home/ubuntu/prednet/blur_loss.py", line 14, in blur_loss
    trueLap = K.conv2d(y_true, lapKernel)
  File "/home/ubuntu/.virtualenvs/dl4cv/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py", line 3164, in conv2d
    data_format='NHWC')
  File "/home/ubuntu/.virtualenvs/dl4cv/lib/python3.5/site-packages/tensorflow/python/ops/nn_ops.py", line 655, in convolution
    num_spatial_dims, strides, dilation_rate)
  File "/home/ubuntu/.virtualenvs/dl4cv/lib/python3.5/site-packages/tensorflow/python/ops/nn_ops.py", line 483, in _get_strides_and_dilation_rate
    (len(dilation_rate), num_spatial_dims))
ValueError: len(dilation_rate)=2 but should be 0

在阅读其他问题之后,我的理解是这个问题源于使用y_true和y_pred的占位符张量进行编译。我已经尝试检查输入是否是占位符并用零张量替换它们,但这给了我其他错误。

如何在丢失函数中使用卷积(图像处理函数,而不是图层)而不会出现这些错误?

1 个答案:

答案 0 :(得分:1)

这里的问题是对conv2d函数的误解,它不仅仅是一个二维卷积。它是多个通道的批量2-d卷积。因此,虽然您可能期望* 2d函数接受2维张量,但输入实际上应该是4维(batch_size,height,width,channels),并且过滤器也应该是4维(filter_height,filter_width,input_channels,output_channels)。详细信息可在TF docs

中找到