具有多个输入和多个输出的Keras模型

时间:2019-01-22 17:35:26

标签: tensorflow keras

我想用两个输入和两个输出使用相同的架构/权重来构建Keras模型。然后将两个输出都用于计算单个损耗。

这是我想要的架构的照片。

enter image description here

这是我的伪代码:

model = LeNet(inputs=[input1, input2, input3],outputs=[output1, output2, output3])

model.compile(optimizer='adam',
          loss=my_custom_loss_function([output1,outpu2,output3],target)
          metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

这种方法行得通吗?
我需要使用其他Keras API吗?

1 个答案:

答案 0 :(得分:1)

架构很好。这是一个玩具示例,其中包含有关如何使用keras的功能性API进行定义的训练数据:

<Window x:Class="PortFolio_application.MainWindow"
    xmlns:System="clr-namespace:System;assembly=mscorlib"
    xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
    xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
    xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
    xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
    xmlns:local="clr-namespace:PortFolio_application"
    mc:Ignorable="d"
    Title="MainWindow" Height="1080" Width="1900">
<Page Name="portfolio">
    <Page.Resources>
        <System:Double x:Key="theMargin">0.35</System:Double>
    </Page.Resources>

    <Grid>
        <Grid.RowDefinitions>
            <RowDefinition Height="10*"/>
            <RowDefinition Height="1*"/>
            <RowDefinition Height="*"/>

        </Grid.RowDefinitions>
        <Grid Grid.Row="1">

            <Grid.ColumnDefinitions>
                <ColumnDefinition Width="1*"/>
                <ColumnDefinition Width="2*"/>
                <ColumnDefinition Width="2*"/>
                <ColumnDefinition Width="2*"/>
                <ColumnDefinition Width="2*"/>
                <ColumnDefinition Width="2*"/>
                <ColumnDefinition Width="1*"/>
                <!--Margin="424,944,1259,55"-->
            </Grid.ColumnDefinitions>

            <Rectangle Grid.Row="1" Grid.Column="1" 
        Height="50" 
        Fill="#00aaff" Stroke="Black" 
        StrokeThickness="0" 
        RenderTransformOrigin="0.517,2.253"
        TextBlock.FontSize="24" 
        TextBlock.TextAlignment="center">
            </Rectangle>

            <TextBlock Grid.Row="1" Grid.Column="1"
        FontSize="24" HorizontalAlignment="Center"
        VerticalAlignment="Center">
        <Run Foreground="Black">1</Run>
            </TextBlock>

            <Button Grid.Row="1" Grid.Column="1">
                <Button.Background>
                    <SolidColorBrush Color="Gray" Opacity="0" />
                </Button.Background>
                <TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">Hi there</Run></TextBlock>
            </Button>
            <Rectangle Grid.Row="1"  Grid.Column="1"

        Fill="#00aaff" Stroke="Black" 
        StrokeThickness="2" 
        RenderTransformOrigin="0.517,2.253"
        TextBlock.FontSize="24" 
        TextBlock.TextAlignment="center">
            </Rectangle>

            <TextBlock Grid.Row="1" Grid.Column="1"
        FontSize="24" HorizontalAlignment="Center"
        VerticalAlignment="Center">
        <Run Foreground="Black"></Run>
            </TextBlock>

            <Button Grid.Row="1" Grid.Column="1">
                <Button.Background>
                    <SolidColorBrush Color="Gray" Opacity="0" />
                </Button.Background>
                <TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black"></Run>1</TextBlock>
            </Button>
            <Rectangle Grid.Row="1" Grid.Column="2"

        Fill="#00aaff" Stroke="Black" 
        StrokeThickness="2" 
        RenderTransformOrigin="0.517,2.253"

        TextBlock.FontSize="24" 
        TextBlock.TextAlignment="center">
            </Rectangle>

            <TextBlock Grid.Row="1" Grid.Column="2"
        FontSize="24" HorizontalAlignment="Center"
        VerticalAlignment="Center">
        <Run Foreground="Black"></Run>
            </TextBlock>

            <Button Grid.Row="1" Grid.Column="2">
                <Button.Background>
                    <SolidColorBrush Color="Gray" Opacity="0" />
                </Button.Background>
                <TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">2</Run></TextBlock>
            </Button>
            <Rectangle Grid.Row="1" Grid.Column="3"

        Fill="#00aaff" Stroke="Black" 
        StrokeThickness="2" 
        RenderTransformOrigin="0.517,2.253"
        TextBlock.FontSize="24" 
        TextBlock.TextAlignment="center">
            </Rectangle>

            <TextBlock Grid.Row="1" Grid.Column="3"
        FontSize="24" HorizontalAlignment="Center"
        VerticalAlignment="Center">
        <Run Foreground="Black"></Run>
            </TextBlock>

            <Button Grid.Row="1" Grid.Column="3">
                <Button.Background>
                    <SolidColorBrush Color="Gray" Opacity="0" />
                </Button.Background>
                <TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">3</Run></TextBlock>
            </Button>
            <Rectangle Grid.Row="1" Grid.Column="4"

        Fill="#00aaff" Stroke="Black" 
        StrokeThickness="2" 
        RenderTransformOrigin="0.517,2.253"
        TextBlock.FontSize="24" 
        TextBlock.TextAlignment="center">
            </Rectangle>

            <TextBlock Grid.Row="1" Grid.Column="4"
        FontSize="24" HorizontalAlignment="Center"
        VerticalAlignment="Center">
        <Run Foreground="Black"></Run>
            </TextBlock>

            <Button Grid.Row="1" Grid.Column="4">
                <Button.Background>
                    <SolidColorBrush Color="Gray" Opacity="0" />
                </Button.Background>
                <TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">4</Run></TextBlock>
            </Button>

            <Rectangle Grid.Row="1" Grid.Column="5" 

        Fill="#00aaff" Stroke="Black" 
        StrokeThickness="2" 
        RenderTransformOrigin="0.517,2.253"
        TextBlock.FontSize="24" 
        TextBlock.TextAlignment="center">
            </Rectangle>

            <TextBlock Grid.Row="1" Grid.Column="5"
        FontSize="24" HorizontalAlignment="Center"
        VerticalAlignment="Center">
        <Run Foreground="Black"></Run>
            </TextBlock>

            <Button Grid.Row="0" Grid.Column="5">
                <Button.Background>
                    <SolidColorBrush Color="Gray" Opacity="0" />
                </Button.Background>
                <TextBlock FontSize="24" TextAlignment="center"><Run Foreground="Black">5</Run></TextBlock>
            </Button>
        </Grid>

        <Grid Grid.Row="0">

            <Grid.RowDefinitions>
                <RowDefinition Height="*"/>
                <RowDefinition Height="*"/>
            </Grid.RowDefinitions>
            <Grid.ColumnDefinitions>
                <ColumnDefinition Width="1*"/>
                <ColumnDefinition Width="10*"/>
                <ColumnDefinition Width="1*"/>
                <!--Margin="424,944,1259,55"-->
            </Grid.ColumnDefinitions>

            <Grid Grid.Row="1" Grid.Column="1">
                <Path Data="M40,0 L66,0 106.4,30 0,30 z" Fill="#98FB98  " Stretch="Fill" Stroke="Black" Width="Auto" />
                <Grid>

                    <Grid.RowDefinitions>
                        <RowDefinition Height="1*"/>
                    </Grid.RowDefinitions>

                    <Grid.ColumnDefinitions>
                        <ColumnDefinition Width="1.00*"/>
                        <ColumnDefinition Width="1.15*"/>
                        <ColumnDefinition Width="0.7*"/>
                        <ColumnDefinition Width="1.15*"/>
                        <ColumnDefinition Width="1.00*"/>
                    </Grid.ColumnDefinitions>

                    <Grid Grid.Row="1"></Grid>
                    <Grid Grid.Row="1" Grid.Column="1">
                        <Line X1="299" Y1="483" X2="700" Y2="0" Stroke="black" StrokeThickness="1" Stretch="UniformToFill"></Line>
                    </Grid>
                    <Grid Grid.Row="1" Grid.Column="2"></Grid>
                    <Grid Grid.Row="1" Grid.Column="3">
                        <Line X1="115.5" Y1="115" X2="186" Y2="200" Stroke="black" StrokeThickness="1" Stretch="UniformToFill"></Line>
                    </Grid>

                </Grid>

                <Grid>

                    <Grid.RowDefinitions>
                        <RowDefinition Height="1*"/>
                    </Grid.RowDefinitions>

                    <Grid.ColumnDefinitions>
                        <ColumnDefinition Width="4.07*"/>
                        <ColumnDefinition Width="0.77*"/>
                        <ColumnDefinition Width="0.5*"/>
                        <ColumnDefinition Width="0.77*"/>
                        <ColumnDefinition Width="4.07*"/>
                    </Grid.ColumnDefinitions>

                    <Grid Grid.Row="1"></Grid>
                    <Grid Grid.Row="1" Grid.Column="1">
                        <Line X1="330" Y1="1360" X2="700" Y2="0" Stroke="black" StrokeThickness="1" Stretch="UniformToFill"></Line>
                    </Grid>
                    <Grid Grid.Row="1" Grid.Column="2"></Grid>
                    <Grid Grid.Row="1" Grid.Column="3">
                        <Line X1="115" Y1="132" X2="202" Y2="455" Stroke="black" StrokeThickness="1" Stretch="UniformToFill"></Line>
                    </Grid>

                </Grid>



            </Grid>

        </Grid>
    </Grid>
</Page>

编辑,如果您想一起计算损失,则可以使用from keras.models import Model from keras.layers import Dense, Input # two separate inputs in_1 = Input((10,10)) in_2 = Input((10,10)) # both inputs share these layers dense_1 = Dense(10) dense_2 = Dense(10) # both inputs are passed through the layers out_1 = dense_1(dense_2(in_1)) out_2 = dense_1(dense_2(in_2)) # create and compile the model model = Model(inputs=[in_1, in_2], outputs=[out_1, out_2]) model.compile(optimizer='adam', loss='mse') model.summary() # train the model on some dummy data import numpy as np i_1 = np.random.rand(10, 10, 10) i_2 = np.random.rand(10, 10, 10) model.fit(x=[i_1, i_2], y=[i_1, i_2])

Concatenate()

您传递给output = Concatenate()([out_1, out_2]) 的任何损失函数都将以其组合状态应用于model.compile。从预测中获得输出后,您可以将其拆分回原始状态:

output