无法找到Torch Lua代码中定义addmm函数的位置

时间:2018-05-28 12:51:14

标签: lua torch

我正在理解在Torch Lua中实现的神经网络。在向后传递线性图层期间,它调用一个名为Linear的函数:updateGradInput(https://github.com/torch/nn/blob/master/Linear.lua#L75

function Linear:updateGradInput(input, gradOutput)
  if self.gradInput then

     local nElement = self.gradInput:nElement()
     self.gradInput:resizeAs(input)
     if self.gradInput:nElement() ~= nElement then
        self.gradInput:zero()
     end
     if input:dim() == 1 then
        self.gradInput:addmv(0, 1, self.weight:t(), gradOutput)
     elseif input:dim() == 2 then
        self.gradInput:addmm(0, 1, gradOutput, self.weight)
     end
     return self.gradInput
  end
end

在该函数中通过调用名为addmm(https://github.com/torch/nn/blob/master/Linear.lua#L86)的函数来执行基本矩阵乘法运算。我无法找到定义此addmm函数的位置。

在TH库(https://github.com/torch/torch7/blob/master/lib/TH/generic/THTensorMath.c#L1282)中定义了一个addmm函数,但我不确定Lua代码是如何在C中连接到此代码的。

1 个答案:

答案 0 :(得分:0)

刚刚弄清楚了Lua代码和C代码之间的联系。在Lua代码中调用addmm会指向此函数(https://github.com/torch/torch7/blob/master/TensorMath.lua#L487-L510),而这又会调用此处定义的C Torch库中定义的addmm函数(https://github.com/torch/torch7/blob/master/lib/TH/generic/THTensorMath.c#L1282)。

这很棘手,因为Lua通过字符串构造对C函数的调用。

相关问题