通过keras使用batch_size> 1获得不兼容的形状错误

时间:2018-09-19 17:25:29

标签: python python-3.x tensorflow keras keras-layer

我最近开始与keras合作,并研究了每种可用的解决方案,但是它没有用。我有一个简单的神经网络,可以从5个输入常数和10个输入张量计算张量值。这是我的神经网络,主输入层输入5个常数,张量输入层输入10个张量:

This is my neural network, the main input layer inputs 5 constants and tensor input layer inputs 10 tensors.

import numpy as np
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input,Lambda
from keras.layers import Dense
from keras import backend as K


def function(x): #This function is used for the last layer to compute Anisotropic R-S
    tensor = x[0]
    constants = x[1]
    a = K.zeros(shape=(3,3,))

    for i in range(10):
        a = a + constants[:,i]*tensor[:,:,:,i]
    return a


main_input = Input(shape = (5,),name = 'main_input') #The invariant inputs
hidden1 = Dense(10,activation = 'relu')(main_input)
hidden2 = Dense(10,activation= 'relu')(hidden1) #10 constants

tensor_input = Input(shape= (3,3,10,),name = 'tensor_input')

output_layer = Lambda(function)([tensor_input,hidden2])

model = Model(inputs = [main_input,tensor_input], outputs = output_layer)
print(model.summary())

model.compile(optimizer = 'adam', loss='mean_squared_error', metrics=['accuracy'])
plot_model(model, to_file='multilayer_perceptron_graph.png')

#Just test inputs and outputs to correct shape
I_test = np.ones((120,5))
T_test = np.ones((120,3,3,10))
a_test = np.ones((120,3,3))

model.fit({'main_input': I_test, 'tensor_input': T_test},a_test,epochs=50,batch_size=2)

我已经使用了lambda层作为输出层。将张量a计算为: a = g1 * T1 + g2 * T2 + .... g10 * T10 其中g是从密集24层计算得出的常数。 这里没有激活,它是简单的线性组合。 输出a是(3,3)矩阵。输入张量是10 3 * 3张量,因此I的形状为(129,3,3,10)。

每当batchsize> 1时,我都会收到以下错误消息:

InvalidArgumentError: Incompatible shapes: [2] vs. [2,3,3]
     [[Node: training_7/Adam/gradients/lambda_12/mul_9_grad/BroadcastGradientArgs = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@training_7/Adam/gradients/lambda_12/mul_9_grad/Reshape"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training_7/Adam/gradients/lambda_12/mul_9_grad/Shape, training_7/Adam/gradients/lambda_12/mul_9_grad/Shape_1)]]

请帮助我解决这个问题。

0 个答案:

没有答案
相关问题