如何确保参数类似于 NumPy 数组?

时间:2021-03-03 08:34:32

标签: python arrays numpy

我有一个函数,它被设计为接受一个参数 a,它是一个 NumPy 数组,但我希望它也能优雅地与其他 array_like 类型一起使用。这包括 float,被视为 0 维数组,数字列表,被视为一维数组等。

一种简单的方法来确保将其转换为 NumPy 数组,例如通过

a = np.asanyarray(a)

缺点是 asanyarray 还转换已经像 NumPy 数组一样开箱即用的对象,如 Dask 数组和 CuPy 数组,其中转换是不必要的,因为我的代码可以使用原始对象,并且有害,因为它破坏了这些对象的好处(并行/GPU 计算)。

是否有一个函数或方法可以返回任何不变的东西,像 NumPy 数组一样行走或嘎嘎声,但将其他一切都转换为 np.ndarray

1 个答案:

答案 0 :(得分:0)

我要找的函数似乎不存在,我最终创建了自己的实现,该实现依赖于 __array_function__ 的存在。

它似乎有效,但我不能 100% 确定它适用于所有情况。 NumPy 开发人员的意见会很有用。

def duckarray(x):
    """Ensure an argument walks and quacks like an array.

    Parameters
    ----------
    x : array_like
        Something array-like.

    Returns
    -------
    array_like
        Something that walks and quacks like an array.

    Notes
    -----
    Checks whether `x` supports the NumPy `__array_function__` protocol,
    returns it unchanged if yes, and converts it via `np.asanyarray` if not.
    """
    if hasattr(x, '__array_function__'):
        return x
    else:
        return np.asanyarray(x)