我有一个函数,它被设计为接受一个参数 a
,它是一个 NumPy 数组,但我希望它也能优雅地与其他 array_like
类型一起使用。这包括 float
,被视为 0 维数组,数字列表,被视为一维数组等。
一种简单的方法来确保将其转换为 NumPy 数组,例如通过
a = np.asanyarray(a)
缺点是 asanyarray
还转换已经像 NumPy 数组一样开箱即用的对象,如 Dask 数组和 CuPy 数组,其中转换是不必要的,因为我的代码可以使用原始对象,并且有害,因为它破坏了这些对象的好处(并行/GPU 计算)。
是否有一个函数或方法可以返回任何不变的东西,像 NumPy 数组一样行走或嘎嘎声,但将其他一切都转换为 np.ndarray
?
答案 0 :(得分:0)
我要找的函数似乎不存在,我最终创建了自己的实现,该实现依赖于 __array_function__
的存在。
它似乎有效,但我不能 100% 确定它适用于所有情况。 NumPy 开发人员的意见会很有用。
def duckarray(x):
"""Ensure an argument walks and quacks like an array.
Parameters
----------
x : array_like
Something array-like.
Returns
-------
array_like
Something that walks and quacks like an array.
Notes
-----
Checks whether `x` supports the NumPy `__array_function__` protocol,
returns it unchanged if yes, and converts it via `np.asanyarray` if not.
"""
if hasattr(x, '__array_function__'):
return x
else:
return np.asanyarray(x)