完成Keras Lambda的困惑

时间:2018-06-01 01:42:27

标签: python keras

我试图定义一个Lambda图层Keras,如下所示:

首先,这是一个计算图像的小波变换然后将它一起发光的函数:

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras.layers import BatchNormalization
from keras.layers import Lambda
from keras import regularizers
from keras import backend as K
import pywt
import numpy as np
from keras.engine.topology import Layer

def mkwtarray(image):
    channels = K.image_data_format()
    if channels is 'channels_first':
        axbase = 1
    else:
        axbase = 0
    print(axbase)
    print(image.shape)
    (a,( b, c, d ))= pywt.dwt2(image, 'db1', axes=(axbase, axbase+1))
    ab = np.concatenate((a, b), axis=axbase)
    cd = np.concatenate((c, d), axis=axbase)
    abcd = np.concatenate((ab, cd), axis=axbase+1)
    return abcd

def wtoutshape(input_shape):
    return input_shape

train_data_dir = 'train'
validation_data_dir = 'validation'
nb_train_samples = 21558
nb_validation_samples = 3446
epochs = 30
batch_size = 32

if K.image_data_format() == 'channels_first':
    input_shape = (3, img_width, img_height)
else:
    input_shape = (img_width, img_height, 3)

model = Sequential()
model.add(Lambda(mkwtarray, input_shape=input_shape, output_shape = wtoutshape))
<more random  layers>

令我惊讶的是,当我定义模型(意思是,评估上面的线条)时,它出错了,声称:     ValueError:输入数组的尺寸小于指定的轴

此外,'print'语句打印了预期值0(?, 150, 150, 3),这意味着该函数实际上是在定义时评估的,而不是在模型实际运行时。我显然缺少一些关于Keras'Lambda功能的东西 - 任何启示都会受到赞赏。

UPDATE 如果以“常规”方式定义图层,则会出现完全相同的问题(通过类,其中lambda现在位于图层的调用函数中,所以这不是λ-特定

1 个答案:

答案 0 :(得分:1)

这看起来像是NumPy和Keras的灾难性组合。让我们来看看两个主要的混淆点:

  1. 进入Keras模型(示例Lambda图层)后,您正在处理张量 NumPy阵列。虽然它很方便,但你不能在模型中使用任何NumPy操作,外部库。话虽如此,张量运算符与数组非常相似是有充分理由的。因为它是你的第一层,你可以在NumPy中预处理然后将它传递给你的模型,这样就行了。

  2. 为什么要让打印件正常工作? Keras有两个主要步骤,Tensorflow:1-&gt;构建计算图,2->实际上运行它。因此,您正在构建图形,并且您的操作被称为是,但它们会创建没有价值的符号张量。因此,您可以打印在构建图形时可以确定的形状,但不能打印它所保存的值。

  3. 带走信息,不要将NumPy与Tensorflow混合在计算图形(模型)中,并且在构建图形时一定要打印形状,以便了解图形的外观,但是你不会得到更多的东西建设时的象征性张量。

相关问题