为什么这是对神经网络执行成本函数的正确方法?

时间:2019-03-25 04:11:25

标签: matlab neural-network vectorization octave backpropagation

因此,我将头撞在墙上几个小时后,我在网上寻找解决问题的方法,效果很好。我只想知道是什么原因导致了我本来的问题。

enter image description here

enter image description here

这里有更多详细信息。输入的是来自MNIST数据集的20x20px图像,并且有5000个样本,因此X或A1为5000x400。单个隐藏层中有25个节点。输出是一个0-9位数的热向量。 y(不是Y,它是y的一种热编码)是一个5000x1的矢量,其值是1-10。

enter image description here

这是我原先的成本函数代码:

Y = zeros(m, num_labels);
   for i = 1:m
   Y(i, y(i)) = 1; 
endfor
H = sigmoid(Theta2*[ones(1,m);sigmoid(Theta1*[ones(m, 1) X]'))
J = (1/m) * sum(sum((-Y*log(H]))' - (1-Y)*log(1-H]))')))

但是后来我发现了:

A1 = [ones(m, 1) X];
Z2 = A1 * Theta1';
A2 = [ones(size(Z2, 1), 1) sigmoid(Z2)];
Z3 = A2*Theta2';
H = A3 = sigmoid(Z3);

J = (1/m)*sum(sum((-Y).*log(H) - (1-Y).*log(1-H), 2));

我看到这可能会稍微干净一些,但是什么在功能上导致我的原始代码获得304.88而其他代码获得〜0.25?是元素明智的乘法吗?

仅供参考,如果您需要写出形式方程式,这与this question一样。

感谢您能提供的任何帮助!我真的很想了解我要去哪里哪里

1 个答案:

答案 0 :(得分:1)

从评论中转移:
快速浏览一下,J = (1/m) * sum(sum((-Y*log(H]))' - (1-Y)*log(1-H]))')))中肯定有括号,但是可能是您在此处粘贴的方式,而不是原始代码,因为在运行时会引发错误。如果我正确理解并且Y,H是矩阵,那么在您的第一个版本Y*log(H)中是矩阵乘法,而在第二个版本Y.*log(H)中是逐项乘法(不是矩阵乘法,只是c(i,j)=a(i,j)*b(i,j) )。

更新1:
关于您在评论中的问题。 从第一个屏幕截图中,您将Y矩阵的条目Y(i,k)中的每个值yk(i)表示为H(i,k),并将每个值h(x ^(i))k表示为Y(i,k) log(H(i,k)) + (1-Y(i,k)) log(1-H(i,k))。因此,基本上,您要为每个i,k计算C = Y.*log(H) + (1-Y).*log(1-H)。您可以对所有值进行处理,并将结果存储在矩阵C中。然后.*和每个C(i,k)都具有上述值。这是操作sum,因为您要对每个矩阵的每个元素(i,k)进行操作(与multiplying the matrices完全不同)。然后,要获取2D二维矩阵C内所有值的总和,请使用两次八度音阶函数sum(sum(C))sum(C(:))对列和行进行求和(或如@ Irreducible建议的那样, $ gem install sqlite3 warning: mingw-w64-x86_64-sqlite3-3.27.2-2 is up to date -- skipping ERROR: Error installing sqlite3: ERROR: Failed to build gem native extension. current directory: C:/Ruby25-x64/lib/ruby/gems/2.5.0/gems/sqlite3-1.4.0/ext/sqlite3 C:/Ruby25-x64/bin/ruby.exe -r ./siteconf20190325-16708-12lhjav.rb extconf.rb checking for sqlite3.h... yes checking for pthread_create() in -lpthread... yes checking for -ldl... no checking for dlopen()... no missing function dlopen *** extconf.rb failed *** Could not create Makefile due to some reason, probably lack of necessary libraries and/or headers. Check the mkmf.log file for more details. You may need configuration options. Provided configuration options: --with-opt-dir --without-opt-dir --with-opt-include --without-opt-include=${opt-dir}/include --with-opt-lib --without-opt-lib=${opt-dir}/lib --with-make-prog --without-make-prog --srcdir=. --curdir --ruby=C:/Ruby25-x64/bin/$(RUBY_BASE_NAME) --with-sqlcipher --without-sqlcipher --with-sqlite3-config --without-sqlite3-config --with-pkg-config --without-pkg-config --with-sqlcipher --without-sqlcipher --with-sqlite3-dir --without-sqlite3-dir --with-sqlite3-include --without-sqlite3-include=${sqlite3-dir}/include --with-sqlite3-lib --without-sqlite3-lib=${sqlite3-dir}/lib --with-pthreadlib --without-pthreadlib --with-dllib --without-dllib To see why this extension failed to compile, please check the mkmf.log which can be found here: C:/Ruby25-x64/lib/ruby/gems/2.5.0/extensions/x64-mingw32/2.5.0/sqlite3-1.4.0/mkmf.log extconf failed, exit code 1 Gem files will remain installed in C:/Ruby25-x64/lib/ruby/gems/2.5.0/gems/sqlite3-1.4.0 for inspection. Results logged to C:/Ruby25-x64/lib/ruby/gems/2.5.0/extensions/x64-mingw32/2.5.0/sqlite3-1.4.0/gem_make.out Temporarily enhancing PATH for MSYS/MINGW... Installing required msys2 packages: mingw-w64-x86_64-sqlite3 Building native extensions. This could take a while... C:\Ruby25-x64\bin\ruby.exe C:\Users\jpper\Projects\blog\bin\bundle install Fetching gem metadata from https://rubygems.org/............. Fetching gem metadata from https://rubygems.org/. Could not find gem 'sqlite3 (~> 3.0) x64-mingw32' in any of the gem sources listed in your Gemfile.

请注意,可能还会出现其他错误。

相关问题