如何快速检查numpy数组的所有元素是否为浮点数?

时间:2013-10-21 04:30:32

标签: python numpy

我需要编写一个函数F,它接受一个带有dtype = object的numpy数组,并返回数组的所有元素是浮点数,整数还是字符串。例如:

F(np.array([1., 2.], dtype=object))  --> float
F(np.array(['1.', '2.'], dtype=object))  --> string
F(np.array([1, 2], dtype=object))  --> int
F(np.array([1, 2.], dtype=object))  --> float
F(np.array(['hello'], dtype=object))  --> string

F(np.array([1, 'hello'], dtype=object))  --> ERROR

任何想法如何有效地做到这一点? (==使用numpy内置函数)

非常感谢

2 个答案:

答案 0 :(得分:3)

最简单的方法是通过np.array运行内容并检查结果类型:

a = np.array([1., 2.], dtype=object)
b = np.array(['1.', '2.'], dtype=object)
c = np.array([1, 2], dtype=object)
d = np.array([1, 2.], dtype=object)
e = np.array(['hello'], dtype=object)
f = np.array([1, 'hello'], dtype=object)

>>> np.array(list(a)).dtype
dtype('float64')
>>> np.array(list(b)).dtype
dtype('S2')
>>> np.array(list(c)).dtype
dtype('int32')
>>> np.array(list(d)).dtype
dtype('float64')
>>> np.array(list(e)).dtype
dtype('S5')

如果类型不兼容,则无法引发错误,因为这不是numpy的行为:

>>> np.array(list(f)).dtype
dtype('S5')

答案 1 :(得分:1)

不确定这是对象管理最有效的方法,但是如何:

def F(a):
    unique_types = set([type(i) for i in list(a)])
    if len(unique_types) > 1:
        raise ValueError('data types not consistent')
    else:
        return unique_types.pop()