我们说c = a + b
,但a
和b
是ndarray
,其形状不一定相同。也就是说,它们可以是general broadcasting rules之后的任意两个数组。
我有一些输出dl/dc
,我想计算dl/da
。如果a
和b
具有相同的形状,则为dl/da = dl/db = dl/dc
。但是,我可能会在a.shape == (3,)
和b.shape == (2,3)
这样添加c[i][j] = a[j] + b[i][j]
。这意味着dl/da[j] = sum_i c[i][j]
。通常,dl/da
是dl/dc
中广播的所有轴上a
的总和。
为了计算a
和b
的链规则衍生物,我编写了以下函数,但我觉得它不是非常pythonic,并且可能更有效地完成:
def addition_derivatives(x, y, d):
flip = False
if x.ndim < y.ndim: # x should have higher ndim
flip = True
x, y = y, x
S = x.shape # shape of array with higher ndim
s = y.shape # shape of array with lower ndim
# figure out which axes will be broadcast in which arrays
n = len(S)
# impute missing ones in the shape of the smaller array as per:
# https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
s = tuple(1 if i < len(S) - len(s) else s[i - (len(S) - len(s))] for i in range(n))
axis_x = []
axis_y = []
for i in range(n):
assert s[i] == S[i] or s[i] == 1 or S[i] == 1
if S[i] == 1 and s[i] != 1:
axis_x.append(i)
if s[i] == 1 and S[i] != 1:
axis_y.append(i)
axis_x, axis_y = map(tuple, (axis_x, axis_y))
# compute the derivatives
dx = np.sum(d, axis=axis_x).reshape(x.shape)
dy = np.sum(d, axis=axis_y).reshape(y.shape)
if flip:
dx, dy = dy, dx
return dx, dy
答案 0 :(得分:0)
我实际上最终使用np.broadcast_arrays
和np.strides
找到了一种破解方法。我不确定这会在所有情况下都有效,但它到目前为止一直有效,因为np.strides
对于维度为1的所有轴都返回0。
def addition_derivatives(x, y, d):
bx, by = np.broadcast_arrays(x, y)
ax = tuple(i for i, (dx, dy) in enumerate(zip(bx.strides, by.strides)) if dx == 0 and dy != 0)
ay = tuple(i for i, (dx, dy) in enumerate(zip(bx.strides, by.strides)) if dx != 0 and dy == 0)
dx = np.sum(d, ax).reshape(x.shape)
dy = np.sum(d, ay).reshape(y.shape)
return dx, dy