为 Keras 中的自动编码器导出低级矩阵乘法

时间:2021-06-02 19:06:08

标签: python tensorflow keras autoencoder

我正在尝试推导低级矩阵乘法并了解自动编码器第一个隐藏层的矩阵架构,但我无法验证结果。
我的自动编码器是一个 16 位二进制输入,编码为 7,然后解码回 16。
我只是想了解第一个隐藏层,在这种情况下,“dense_8”并推导出乘法根据结果进一步向下。

自编码器定义如下:

M = 16 
N = 10000
Input(shape=(M,))
<块引用>

input_signal = Input(shape=(M,))
encoded = Dense(M, activation='relu')(input_signal)
print(encoded)
<块引用>

张量("dense_8/Relu:0", shape=(None, 16), dtype=float32)

encoded1 = Dense(n_channel, activation='linear')(encoded)
encoded2 = BatchNormalization()(encoded1)
print (encoded1)
print (encoded2)
<块引用>

Tensor("dense_9/BiasAdd:0", shape=(None, 7), dtype=float32)
张量(“batch_normalization_2/batchnorm/add_1:0”,形状=(无,7), dtype=float32)

等等...

autoencoder.summary()
encoder.summary()

Autoencoder model

当我检查乘法时,我无法得出数字。例如下面的前 16 位输入我无法得到输出

layer_name = 'dense_8'

intermediate_layer_model = keras.Model(inputs=autoencoder.input, outputs=autoencoder.get_layer(layer_name).output)
intermediate_layer_model_o = keras.Model(inputs=autoencoder.input, outputs=autoencoder.get_layer(layer_name).input)


intermediate_output1 = intermediate_layer_model(data)
intermediate_input1 = intermediate_layer_model_o(data)

print ('--------THIS IS THE INPUT-------------')
print(intermediate_input1)
print ('--------THIS IS THE OUTPUT-------------')
print(intermediate_output1)

print ('--------THESE ARE THE WEIGHTS-------------')

w = np.array(encoder.weights)
print(w.shape)
print ('--------THIS IS FIRST ROW FIRST COL-------------')
print(w[0][0])
print ('--------THIS IS SEC ROW FIRST COL-------------')
print(w[1][0])
print ('--------THIS IS FIRST ROW SEC COL-------------')
print(w[0][1])
print ('--------THIS IS SEC ROW SEC COL-------------')
print(w[1][1])
print ('--------THIS IS FIRST ROW THIRD COL-------------')
print(w[0][2])


#print(w[1])
print ('--------INPUT/OUTPUT-------------')
inp = np.array(intermediate_input1)
out = np.array(intermediate_output1)
print ('---------------------')

print ('--------THIS IS THE FIRST INPUT-------------')
print(inp[0,])
print ('--------THIS IS THE FIRST OUTPUT-------------')
print(out[0,])

Displaying the results after multiplication

0 个答案:

没有答案
相关问题