Keras点/点层在3D张量上的行为

时间:2018-08-02 15:37:25

标签: python python-3.x keras keras-layer keras-2

Keras documentation for the dot/Dot layer指出:

“计算两个张量的样本之间的点积的层。

例如如果应用于形状为(batch_size,n)的两个张量a和b的列表,则输出将为形状(batch_size,1)的张量,其中每个条目i将是a [i]和b [i之间的点积]。

参数

轴:整数,整数或元组,沿其取点积的一个或多个轴。”

我不明白这一点,这是一个快速,可重复的示例来演示:

from keras.layers import Input, dot
input_a = Input(batch_shape=(99,45000,300))
input_b = Input(batch_shape=(99,45000,300))
element_wise_dot_product = dot([input_a,input_b], axes = -1)
print(input_a.get_shape(),input_b.get_shape(),element_wise_dot_product.get_shape()) 

输出:(99,45000,300)(99,45000,300)(99,45000,45000)

为什么元素明智的点积形状不是(99,45000,1)?我在做什么错,我该如何解决?

1 个答案:

答案 0 :(得分:2)

由于这些点是3D张量而不是2D,因此点层正在沿最后一个轴执行矩阵乘法。因此,您得到的形状反映了这一点。您要做的是将产品放在每个输入的最后一列。您可以取两个输入的按元素乘积,然后沿最后一个轴求和。例如,

import keras.backend as K
import tensorflow as tf

K.sum(tf.multiply(input_a, input_b[:tf.newaxis]), axis=-1, keepdims=True)

如果仅需要keras解决方案,则可以使用keras.layers.multiply代替tf.multiply,并使用K.expand_dims代替用tf.newaxis广播。

相关问题