如何使用lambda层包装函数?

时间:2019-02-27 16:20:44

标签: python tensorflow lambda keras keras-layer

我想在自动编码器中的一层上进行一些处理,然后将其发送到下一层,但是我无法使用Keras中的预定义函数(例如add,...),我认为应该将lambda与函数配合使用。我的自动编码器的输出是一个形状为(1,28,28,1)的张量,称为编码器,我有一个输入形状为(1,4,4,1)的张量,名为wtm。现在,我想考虑编码器中的7x7块,并分别将每个7x7块的中间值与一个wtm值相加(每个编码器块具有一个wtm值),我编写了两个函数来执行此操作,但产生了此错误:

  

TypeError:“张量”对象不支持项目分配

我是python和Keras的初学者,我在寻找原因,但不幸的是,我不明白为什么会发生这种情况,我该怎么办?请指导我如何编写我的lambda层?我在这里附加了代码。我可以使用lambda

做一些简单的事情
add_const = Kr.layers.Lambda(lambda x: x[0] + x[1])
encoded_merged = add_const([encoded,wtm])

但是如果wtm具有不同的形状,并且在层上编码或拖延了复杂的事物,我不知道该怎么办?

from keras.layers import Input, Concatenate, GaussianNoise,Dropout,BatchNormalization,MaxPool2D,AveragePooling2D
from keras.layers import Conv2D, AtrousConv2D
from keras.models import Model
from keras.datasets import mnist
from keras.callbacks import TensorBoard
from keras import backend as K
from keras import layers
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as Kr
from keras.optimizers import SGD,RMSprop,Adam
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
import numpy as np
import pylab as pl
import matplotlib.cm as cm
import keract
from matplotlib import pyplot
from keras import optimizers
from keras import regularizers

from tensorflow.python.keras.layers import Lambda;
#-----------------building w train---------------------------------------------
    def grid_w(args):
    Enc, W = args
#    Ex,Ey,Ez=Enc.shape

#    Wx,Wy,Wz=W.shape
    Enc=tf.reshape(Enc,[28,28])
    W=tf.reshape(W,[4,4])
    Enc[3::7, 3::7] += W
    Enc=tf.reshape(Enc,[1,28,28,1])
    W=tf.reshape(W,[1,4,4,1])
#    Enc[:, 3::7, 3::7]=K.sum(W,axis=1)

    return Enc

    def grid_w_output_shape(shapes):
        shape1, shape2 = shapes
        return (shape1[0], 1)
    wt_random=np.random.randint(2, size=(49999,4,4))
    w_expand=wt_random.astype(np.float32)
    wv_random=np.random.randint(2, size=(9999,4,4))
    wv_expand=wv_random.astype(np.float32)
    x,y,z=w_expand.shape
    w_expand=w_expand.reshape((x,y,z,1))
    x,y,z=wv_expand.shape
    wv_expand=wv_expand.reshape((x,y,z,1))

    #-----------------building w test---------------------------------------------
    w_test = np.random.randint(2,size=(1,4,4))
    w_test=w_test.astype(np.float32)
    w_test=w_test.reshape((1,4,4,1))
    #-----------------------encoder------------------------------------------------
    #------------------------------------------------------------------------------
    wtm=Input((4,4,1))
    image = Input((28, 28, 1))
    conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', name='convl1e')(image)
    conv2 = Conv2D(64, (5, 5), activation='relu', padding='same', name='convl2e')(conv1)
    conv3 = Conv2D(64, (5, 5), activation='relu', padding='same', name='convl3e')(conv2)
    #conv3 = Conv2D(8, (3, 3), activation='relu', padding='same', name='convl3e', kernel_initializer='Orthogonal',bias_initializer='glorot_uniform')(conv2)
    BN=BatchNormalization()(conv3)
    #DrO1=Dropout(0.25,name='Dro1')(BN)
    encoded =  Conv2D(1, (5, 5), activation='relu', padding='same',name='encoded_I')(BN)

    #-----------------------adding w---------------------------------------
    encoded_merged=Kr.layers.Lambda(grid_w, output_shape=grid_w_output_shape)([encoded, wtm])

    #-----------------------decoder------------------------------------------------
    #------------------------------------------------------------------------------

    deconv1 = Conv2D(64, (5, 5), activation='elu', padding='same', name='convl1d')(encoded_merged)
    deconv2 = Conv2D(64, (5, 5), activation='elu', padding='same', name='convl2d')(deconv1)
    deconv3 = Conv2D(64, (5, 5), activation='elu',padding='same', name='convl3d')(deconv2)
    deconv4 = Conv2D(64, (5, 5), activation='elu',padding='same', name='convl4d')(deconv3)
    BNd=BatchNormalization()(deconv3)  
    decoded = Conv2D(1, (5, 5), activation='sigmoid', padding='same', name='decoder_output')(BNd) 

    model=Model(inputs=[image,wtm],outputs=decoded)

    decoded_noise = GaussianNoise(0.5)(decoded)

    #----------------------w extraction------------------------------------
    convw1 = Conv2D(64, (3,3), activation='relu', name='conl1w')(decoded_noise)
    convw2 = Conv2D(64, (3, 3), activation='relu', name='convl2w')(convw1)
    Avw1=AveragePooling2D(pool_size=(2,2))
    convw3 = Conv2D(64, (3, 3), activation='relu', padding='same', name='conl3w')(convw2)
    convw4 = Conv2D(64, (3, 3), activation='relu', padding='same', name='conl4w')(convw3)
    Avw2=AveragePooling2D(pool_size=(2,2))
    convw5 = Conv2D(64, (3, 3), activation='relu', name='conl5w')(convw4)
    convw6 = Conv2D(64, (3, 3), activation='relu', padding='same', name='conl6w')(convw5)
    BNed=BatchNormalization()(convw6)
    #DrO3=Dropout(0.25, name='DrO3')(BNed)
    pred_w = Conv2D(1, (1, 1), activation='sigmoid', padding='same', name='reconstructed_W')(BNed)  
    watermark_extraction=Model(inputs=[image,wtm],outputs=[decoded,pred_w])

    watermark_extraction.summary()

0 个答案:

没有答案
相关问题