在python numpy中实现Relu派生

时间:2017-09-25 17:35:59

标签: machine-learning python derivative numpy

我试图实现一个函数来计算矩阵中每个元素的Relu导数,然后将结果返回到矩阵中。我使用的是Python和Numpy。

根据其他交叉验证帖子,x的Relu衍生物是 当x> 1时当x <0时,0,0。 0,未定义或当x == 0

时为0

目前,到目前为止,我有以下代码:

def reluDerivative(self, x):
    return np.array([self.reluDerivativeSingleElement(xi) for xi in x])

def reluDerivativeSingleElement(self, xi):
    if xi > 0:
        return 1
    elif xi <= 0:
        return 0

不幸的是,xi是一个数组,因为x是一个矩阵。 reluDerivativeSingleElement函数不适用于数组。所以我想知道有没有办法使用numpy将矩阵中的值映射到另一个矩阵,就像numpy中的exp函数一样?

提前多多感谢。

10 个答案:

答案 0 :(得分:12)

这是矢量化的练习。

此代码

if x > 0:
  y = 1
elif xi <= 0:
  y = 0

可以重新制定成

y = (x > 0) * 1

这适用于numpy数组,因为涉及它们的布尔表达式将转换为所述数组中元素的这些表达式的值数组。

答案 1 :(得分:8)

我想这就是你要找的东西:

>>> def reluDerivative(x):
...     x[x<=0] = 0
...     x[x>0] = 1
...     return x

>>> z = np.random.uniform(-1, 1, (3,3))
>>> z
array([[ 0.41287266, -0.73082379,  0.78215209],
       [ 0.76983443,  0.46052273,  0.4283139 ],
       [-0.18905708,  0.57197116,  0.53226954]])
>>> reluDerivative(z)
array([[ 1.,  0.,  1.],
       [ 1.,  1.,  1.],
       [ 0.,  1.,  1.]])

答案 2 :(得分:5)

返回relu导数的基本函数可归纳如下:

f'(x) = x > 0

所以,numpy将是:

def relu_derivative(z):
    return np.greater(z, 0).astype(int)

答案 3 :(得分:2)

def dRelu(z):
    return np.where(z <= 0, 0, 1)

在我的情况下,z是一个ndarray。

答案 4 :(得分:0)

你正在走上良好的轨道:思考矢量化操作。在我们定义函数的地方,我们将此函数应用于矩阵,而不是编写for循环。

这个主题回答了你的问题,它替换满足条件的所有元素。您可以将其修改为ReLU衍生物。

https://stackoverflow.com/questions/19766757/replacing-numpy-elements-if-condition-is-met

另外,python非常支持函数式编程,尝试使用lambda函数。

https://www.python-course.eu/lambda.php

答案 5 :(得分:0)

这有效:

def dReLU(x):
    return 1. * (x > 0)

答案 6 :(得分:0)

正如Neil在评论中所提到的,你可以使用numpy的heaviside函数。

def reluDerivative(self, x):
    return np.heaviside(x, 0)

答案 7 :(得分:0)

如果要使用纯Python:

def relu_derivative(x):
    return max(sign(x), 0)

答案 8 :(得分:0)

def reluDerivative(self, x): 
    return 1 * (x > 0)

答案 9 :(得分:-1)

当x大于0时,斜率为1。 当x小于或等于0时,斜率为0.

if (x > 0):
    return 1
if (x <= 0):
    return 0

这可以写得更紧凑:

return 1 * (x > 0)