在进入SGEMM时,参数号15具有非法值

时间:2017-03-08 19:10:34

标签: c++ cublas

我正在尝试使用cublasSgemmStridedBatched将块对角矩阵乘以另一个矩阵。

我想执行C = B ^ T * A

其中B是块对角线,因此它存储在存储器中,块彼此叠加(非对角线块根本不存储)

我得到的错误是: 检查失败:status == CUBLAS_STATUS_SUCCESS(7 vs. 0)CUBLAS_STATUS_INVALID_VALUE

**进入SGEMM参数号15时有非法值

这是函数调用:

CUBLAS_CHECK(cublasSgemmStridedBatched(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, strideB, A, lda, strideA, &beta, C, ldc, strideC, num_blocks));

以下是一些打印件:

cuTransB = 1 has type 17cublasOperation_t
cuTransA = 0 has type 17cublasOperation_t
N = 5 has type i
M = 100 has type i
K = 8 has type i
alpha = 1 has type f
B = 0x701164000 has type PKf, has dimensions 800 by 5 in memory
B^T block has dimensions 5 by 8
ldb = 800 has type i
strideB = 8 has type x
A = 0x7012aca00 has type PKf, has dimensions 800 by 100 in memory
A block has dimensions 8 by 100
lda = 800 has type i
strideA = 8 has type x
beta = 0 has type f
C = 0x701b60000 has type Pf, has dimensions 500 by 100 in memory
C block has dimensions 5 by 100
ldc = 500 has type i
strideC = 5 has type x
num_blocks = 100 has type i

所以函数调用真的是:

CUBLAS_CHECK(cublasSgemmStridedBatched(handle, 1, 0, 5, 100, 8, 0x7fff58eb69e4, 0x701164000, 800, 8, 0x7012aca00, 800, 8, 0x7fff58eb69e0, 0x701b60000, 500, 5, 100));

我不确定SGEMM的第15个参数是什么 - 我不认为这个功能是开源的吗?我很困惑。

可能注意到或不重要的是strideB<六味地黄丸。也就是说,B块在存储器中混合。正如我所提到的,我将B初始化为800乘5矩阵,我正在考虑100个8乘5个块。

2 个答案:

答案 0 :(得分:1)

好的,看起来像是

CUBLAS_CHECK(cublasSgemmStridedBatched(Caffe::cublas_handle(), cuTransB, cuTransA, 5, 5, 5,&alpha, B, 5, 24, A, 5, 25, &beta, C, 5, ANYTHING_LESS_THAN_25, num_blocks));

会抛出同样的错误。我怀疑他们在写C时会试图阻止碰撞,所以我想,即使我的上述情况没有碰撞,因为我描述了单独的水平5×100条C,错误是造成的原因是:

strideC = 5< 5 * 100 = C块的大小

我认为C块不能在内存中混合。

答案 1 :(得分:1)

这是因为在不同的cuda版本中cublasSgemmStridedBatched的实现逻辑不同。 在cuda 9中,C不能由列分隔,这意味着strideC必须大于或等于C块的大小。 在cuda 10+中,它可以工作。你可以试试看。

相关问题