自定义Keras图层的未定义输出形状

时间:2019-05-27 13:57:18

标签: tensorflow keras keras-layer

我正在编写一个自定义的Keras层,该层会平整除输入的最后一个维度以外的所有维度。但是,当将层的输出馈送到下一层时,会发生错误,因为该层的输出形状在所有维度上均为None

class FlattenLayers( Layer ):
    """
    Takes a nD tensor flattening the middle dimensions, ignoring the edge dimensions.
    [ n, x, y, z ] -> [ n, xy, z ]
    """

    def __init__( self, **kwargs ):
        super( FlattenLayers, self ).__init__( **kwargs )


    def build( self, input_shape ):
        super( FlattenLayers, self ).build( input_shape )


    def call( self, inputs ):
        input_shape = tf.shape( inputs )

        flat = tf.reshape(
            inputs,
            tf.stack( [ 
                -1, 
                K.prod( input_shape[ 1 : -1 ] ),
                input_shape[ -1 ]
            ] )
        )

        return flat


    def compute_output_shape( self, input_shape ):
        if not all( input_shape[ 1: ] ):
            raise ValueError( 'The shape of the input to "Flatten" '
                             'is not fully defined '
                             '(got ' + str( input_shape[ 1: ] ) + '). '
                             'Make sure to pass a complete "input_shape" '
                             'or "batch_input_shape" argument to the first '
                             'layer in your model.' )

        output_shape = ( 
            input_shape[ 0 ], 
            np.prod( input_shape[ 1 : -1 ] ), 
            input_shape[ -1 ] 
        )

        return output_shape

例如,当紧跟一层时,我会收到错误ValueError: The last dimension of the inputs to Dense should be defined. Found None.

1 个答案:

答案 0 :(得分:1)

为什么您的tf.stack()具有新的形状?您想要展平除最后一个尺寸外的所有尺寸;这是你怎么做的:

import tensorflow as tf
from tensorflow.keras.layers import Layer
import numpy as np

class FlattenLayer(Layer):

    def __init__( self, **kwargs):
        super(FlattenLayer, self).__init__(**kwargs)

    def build( self, input_shape ):
        super(FlattenLayer, self).build(input_shape)

    def call( self, inputs):
        new_shape = self.compute_output_shape(tf.shape(inputs))
        return tf.reshape(inputs, new_shape)

    def compute_output_shape(self, input_shape):
        new_shape = (input_shape[0]*input_shape[1]*input_shape[2],
                     input_shape[3])
        return new_shape

使用单个数据点(tf.__version__=='1.13.1')进行测试:

inputs = tf.keras.layers.Input(shape=(10, 10, 1))    
res = tf.keras.layers.Conv2D(filters=3, kernel_size=2)(inputs)
res = FlattenLayer()(res)
model = tf.keras.models.Model(inputs=inputs, outputs=res)

x_data = np.random.normal(size=(1, 10, 10, 1))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    evaled = model.outputs[0].eval({model.inputs[0]:x_data})
    print(evaled.shape) # (81, 3)