如何有效地检查numpy数组是否可以转换为另一个整数类型?

时间:2017-01-12 15:34:37

标签: python numpy casting

让我说我有一个整数类型的numpy数组(比如说np.int64)并想把它转换成另一种类型(比如np.int8)。如何最有效地检查操作是否安全(保留所有值)?

我提出了两种方法:

方法1:使用类型信息

def is_safe(data, new_type):
    if np.can_cast(data, new_type):
        return True    # Handle the trivial allowed cases
    type_info = np.iinfo(new_type)
    return np.all((data >= type_info.min) & (data <= type_info.max))

方法2:对所有项目使用np.can_cast

def is_safe(data, new_type):
    if np.can_cast(data, new_type):
        return True    # Handle the trivial allowed cases
    return all(np.can_cast(item, new_type) for item in np.nditer(item)) 

这两种方法似乎都是有效的(并且适用于琐碎的案例),但它们是正确有效的吗?还有另一种更好的方法吗?

P.S。为了使事情进一步复杂化,np.can_cast(np.int8, np.uint64)返回False(自然地),因此必须在某种程度上单独检查有符号和无符号整数之间的变化。

1 个答案:

答案 0 :(得分:1)

如果您已经知道该数组是NumPy整数类型,那么唯一需要检查的是该值是在目标整数范围的最小值/最大值指定的范围内。这是比通用android.os.Looper更简单的检查,它通常不知道它所喂食的东西。因此,can_cast需要更长时间。我在从np.int64到np.int8的整数0-99上测试了这个。

因此,虽然两种方法都是正确的,但如果您知道can_cast是NumPy整数数组,则第一种方法更可取。

data

将最小值和最大值分配给新变量稍快(20%左右):

>>> timeit.timeit("np.all((data >= type_info.min) & (data <= type_info.max))", setup="import numpy as np\ndata = np.array(range(100), dtype=np.int64)\ntype_info = np.iinfo(np.int8)")
6.745509549000417
>>> timeit.timeit("all(np.can_cast(item, np.uint8) for item in np.nditer(data))", setup="import numpy as np\ndata = np.array(range(100), dtype=np.int64)")
51.0065170609887