因此,我将头撞在墙上几个小时后,我在网上寻找解决问题的方法,效果很好。我只想知道是什么原因导致了我本来的问题。
这里有更多详细信息。输入的是来自MNIST数据集的20x20px图像,并且有5000个样本,因此X或A1为5000x400。单个隐藏层中有25个节点。输出是一个0-9位数的热向量。 y
(不是Y,它是y的一种热编码)是一个5000x1的矢量,其值是1-10。
这是我原先的成本函数代码:
Y = zeros(m, num_labels);
for i = 1:m
Y(i, y(i)) = 1;
endfor
H = sigmoid(Theta2*[ones(1,m);sigmoid(Theta1*[ones(m, 1) X]'))
J = (1/m) * sum(sum((-Y*log(H]))' - (1-Y)*log(1-H]))')))
但是后来我发现了:
A1 = [ones(m, 1) X];
Z2 = A1 * Theta1';
A2 = [ones(size(Z2, 1), 1) sigmoid(Z2)];
Z3 = A2*Theta2';
H = A3 = sigmoid(Z3);
J = (1/m)*sum(sum((-Y).*log(H) - (1-Y).*log(1-H), 2));
我看到这可能会稍微干净一些,但是什么在功能上导致我的原始代码获得304.88
而其他代码获得〜0.25
?是元素明智的乘法吗?
仅供参考,如果您需要写出形式方程式,这与this question一样。
感谢您能提供的任何帮助!我真的很想了解我要去哪里哪里
答案 0 :(得分:1)
从评论中转移:
快速浏览一下,J = (1/m) * sum(sum((-Y*log(H]))' - (1-Y)*log(1-H]))')))
中肯定有括号,但是可能是您在此处粘贴的方式,而不是原始代码,因为在运行时会引发错误。如果我正确理解并且Y,H是矩阵,那么在您的第一个版本Y*log(H)
中是矩阵乘法,而在第二个版本Y.*log(H)
中是逐项乘法(不是矩阵乘法,只是c(i,j)=a(i,j)*b(i,j)
)。
更新1:
关于您在评论中的问题。
从第一个屏幕截图中,您将Y矩阵的条目Y(i,k)
中的每个值yk(i)表示为H(i,k)
,并将每个值h(x ^(i))k表示为Y(i,k) log(H(i,k)) + (1-Y(i,k)) log(1-H(i,k))
。因此,基本上,您要为每个i,k计算C = Y.*log(H) + (1-Y).*log(1-H)
。您可以对所有值进行处理,并将结果存储在矩阵C中。然后.*
和每个C(i,k)都具有上述值。这是操作sum
,因为您要对每个矩阵的每个元素(i,k)进行操作(与multiplying the matrices完全不同)。然后,要获取2D二维矩阵C内所有值的总和,请使用两次八度音阶函数sum(sum(C))
:sum(C(:))
对列和行进行求和(或如@ Irreducible建议的那样, $ gem install sqlite3
warning: mingw-w64-x86_64-sqlite3-3.27.2-2 is up to date -- skipping
ERROR: Error installing sqlite3:
ERROR: Failed to build gem native extension.
current directory: C:/Ruby25-x64/lib/ruby/gems/2.5.0/gems/sqlite3-1.4.0/ext/sqlite3
C:/Ruby25-x64/bin/ruby.exe -r ./siteconf20190325-16708-12lhjav.rb extconf.rb
checking for sqlite3.h... yes
checking for pthread_create() in -lpthread... yes
checking for -ldl... no
checking for dlopen()... no
missing function dlopen
*** extconf.rb failed ***
Could not create Makefile due to some reason, probably lack of necessary
libraries and/or headers. Check the mkmf.log file for more details. You may
need configuration options.
Provided configuration options:
--with-opt-dir
--without-opt-dir
--with-opt-include
--without-opt-include=${opt-dir}/include
--with-opt-lib
--without-opt-lib=${opt-dir}/lib
--with-make-prog
--without-make-prog
--srcdir=.
--curdir
--ruby=C:/Ruby25-x64/bin/$(RUBY_BASE_NAME)
--with-sqlcipher
--without-sqlcipher
--with-sqlite3-config
--without-sqlite3-config
--with-pkg-config
--without-pkg-config
--with-sqlcipher
--without-sqlcipher
--with-sqlite3-dir
--without-sqlite3-dir
--with-sqlite3-include
--without-sqlite3-include=${sqlite3-dir}/include
--with-sqlite3-lib
--without-sqlite3-lib=${sqlite3-dir}/lib
--with-pthreadlib
--without-pthreadlib
--with-dllib
--without-dllib
To see why this extension failed to compile, please check the mkmf.log which can be found here:
C:/Ruby25-x64/lib/ruby/gems/2.5.0/extensions/x64-mingw32/2.5.0/sqlite3-1.4.0/mkmf.log
extconf failed, exit code 1
Gem files will remain installed in C:/Ruby25-x64/lib/ruby/gems/2.5.0/gems/sqlite3-1.4.0 for inspection.
Results logged to C:/Ruby25-x64/lib/ruby/gems/2.5.0/extensions/x64-mingw32/2.5.0/sqlite3-1.4.0/gem_make.out
Temporarily enhancing PATH for MSYS/MINGW...
Installing required msys2 packages: mingw-w64-x86_64-sqlite3
Building native extensions. This could take a while...
C:\Ruby25-x64\bin\ruby.exe C:\Users\jpper\Projects\blog\bin\bundle install
Fetching gem metadata from https://rubygems.org/.............
Fetching gem metadata from https://rubygems.org/.
Could not find gem 'sqlite3 (~> 3.0) x64-mingw32' in any of the gem sources
listed in your Gemfile.
。
请注意,可能还会出现其他错误。