关于Keras中n << p数据集=“”与=“ =” batch_size =“”> n中的batch_size的混淆

时间:2018-07-25 06:07:37

标签: python tensorflow keras

我对运行得很好的代码有些困惑,但是我不明白为什么。

我处于n <

那么keras的默认行为是什么?只是回到纯粹的梯度下降阶段,权重在一个时期的末尾求平均值?

我正在有监督的二进制分类器设置以及基于LSTM / Autoencoder的无监督异常检测器中使用此设置

由于我一直认为在LSTM中-情况n%batch_size应该为零,所以这又增加了一个困惑。

1 个答案:

答案 0 :(得分:0)

我已经对源代码进行了深入研究,我认为我对这些问题有答案。

  1. 如果batch_size> n,Keras是否会“退化”为普通梯度下降?

答案是肯定的。如在方法batch_shuffle中从第334行开始所看到的(注意:我链接到V2.2以保留行号),如果batch_size> n,则返回整个批次。这里是相关的代码和输出:

import numpy as np
index_array = np.array([0,1,2,3,4,5])
batch_size = 72
batch_count = int(len(index_array) / batch_size)
#batch_count = 0

last_batch = index_array[batch_count * batch_size:]
# last_batch = array([0, 1, 2, 3, 4, 5])

index_array = index_array[:batch_count * batch_size]
#index_array = array([], dtype=int64)

index_array = index_array.reshape((batch_count, batch_size))
#index_array =array([], shape=(0, 72), dtype=int64)

np.random.shuffle(index_array)
index_array = index_array.flatten()
return np.append(index_array, last_batch)
# np.append(index_array, last_batch) = array([0, 1, 2, 3, 4, 5])
  1. 即使n%batch_size <> 0,为什么LSTM也能正常工作?

n%batch_size = 0仅适用于有状态LSTM的要求,再次在以下code,第817行中找到了证据,仅当stateful == True和n%batch_size <> 0时,才会引发错误