Numbapro jit计算给出了错误的结果

时间:2014-07-21 20:24:04

标签: python cuda gpu jit numba-pro

我有一段代码使用Numbapro编写一个简单的内核来对两个大小为41724的数组的内容进行平方,将它们加在一起并将其存储到另一个数组中。所有数组都具有相同的大小并且是float32。代码如下:

import numpy as np
from numba import *
from numbapro import cuda

@cuda.jit('void(float32[:],float32[:],float32[:])')
def square_add(a,b,c):
    tx = cuda.threadIdx.x
    bx = cuda.blockIdx.x
    bw = cuda.blockDim.x

    i = tx + bx * bw

    #Since the length of a is 41724 and the total
    #threads is 41*1024 = 41984, this check is necessary
    if (i>len(a)):
            return
    else:
            c[i] = a[i]*a[i] + b[i]*b[i]


a = np.array(range(0,41724),dtype = np.float32)
b = np.array(range(41724,83448),dtype=np.float32)
c = np.zeros(shape=(1,41724),dtype=np.float32)

d_a = cuda.to_device(a)
d_b = cuda.to_device(b)
d_c = cuda.to_device(c,copy=False)

#Launch the kernel; Gridsize = (1,41),Blocksize=(1,1024)
square_add[(1,41),(1,1024)](d_a,d_b,d_c)

c = d_c.copy_to_host()
print c
print len(c[0])

当我打印操作结果(数组c)时,我得到的值与我在python终端中执行完全相同的操作时完全不同。 我不知道我在这里做错了什么。

1 个答案:

答案 0 :(得分:1)

这里有两个问题。

首先,您要为CUDA内核启动指定块和网格维度,这与您选择在内核中使用的索引方案不兼容。

此:

square_add[(1,41),(1,1024)](d_a,d_b,d_c)

启动二维网格,其中所有线程在x中具有相同的块和线程尺寸,并且仅在y中变化。这意味着

tx = cuda.threadIdx.x
bx = cuda.blockIdx.x
bw = cuda.blockDim.x

i = tx + bx * bw

将为每个线程产生i=0。如果您将内核启动更改为:

square_add[(41,1),(1024,1)](d_a,d_b,d_c)

你会发现在索引中可以正常工作。

第二个是c已被声明为二维数组,但内核函数签名已被声明为一维数组。在某些情况下,numbapro运行时应该检测到这一点并引发错误。

我能够让你的例子像这样正常工作:

import numpy as np
from numba import *
from numbapro import cuda

@cuda.jit('void(float32[:],float32[:],float32[:,:])')
def square_add(a,b,c):
    tx = cuda.threadIdx.x
    bx = cuda.blockIdx.x
    bw = cuda.blockDim.x

    i = tx + bx * bw

    if (i<len(a)):
        c[0,i] = a[i]*a[i] + b[i]*b[i]

a = np.array(range(0,41724),dtype=np.float32)
b = np.array(range(41724,83448),dtype=np.float32)
c = np.zeros(shape=(1,41724),dtype=np.float32)

d_a = cuda.to_device(a)
d_b = cuda.to_device(b)
d_c = cuda.to_device(c, copy=False)

square_add[(41,1),(1024,1)](d_a,d_b,d_c)

c = d_c.copy_to_host()
print(c)
print(c.shape)

[注意我使用的是Python 3,所以这使用了新式的打印语句]

$ ipython numbatest.py 
numbapro:1: ImportWarning: The numbapro package is deprecated in favour of the accelerate package. Please update your code to use equivalent functions from accelerate.
[[  1.74089216e+09   1.74097562e+09   1.74105907e+09 ...,   8.70371021e+09
    8.70396006e+09   8.70421094e+09]]
(1, 41724)
相关问题