NEON SIMD中设置位的总和

时间:2015-05-30 17:33:35

标签: simd neon

我有一个对大量字节进行操作的算法。作为一个预处理步骤,我需要为给定的索引创建一个计数,该位数是到目前为止在数组中设置的频率。

我可以使用以下(伪)代码在C中执行此操作:

input: uint8_t values[COUNT];
output: uint32_t bitsum[COUNT+1][8];
       (bitsum[i][x] is the counter for the x-th bit being set in
        the PRECEEDING i bytes -- this makes bitsum[0][x] == 0)

// we skip first row
for (int i=1; i < COUNT+1; i++) {
   for (int bit=0; bit < 8; bit++) {
      bitsum[i][bit] = bitsum[i-1][bit];
      if (values[i-1] & (1 << bit) != 0) {
         bitsum[i][bit]++;
      }
   }
}

然而,我很想知道我可以使用NEON SIMD更快地实现这一目标。不幸的是,我对此很陌生,所以我无法解决这个问题(但是?)并寻求一些帮助。甚至可以在NEON中这样做吗?

更新

尝试在C中加速此代码,我相信以下方法是最快的(当然,没有展开内部for循环):

// pre-calculate lookup-table
uint16_t lookup[256][8];
for (int value=0; value < 256; value++) {
   for (int bit=0; bit < 8; bit++) {
      if (value & (1 << bit) != 0) {
         lookup[value][bit]++;
      }
   }
}

// create sum
for (int i=1; i < COUNT+1; i++) {
   for (int bit=0; bit < 8; bit++) {
      bitsum[i][bit] = bitsum[i-1][bit] + lookup[values[i-1]][bit];
   }
}

除了查找表访问外,这看起来对SIMD来说是理想的 - 至少我找不到在NEON中这样做的方法。

1 个答案:

答案 0 :(得分:1)

您可以使用VTBLVTBX指令在NEON中执行表查找,但它们仅对具有少量条目的查找表有用。在针对NEON进行优化时,通常最好寻找一种在运行时计算值的方法,而不是使用表格。

在此示例中,可以直接在运行时计算查找。该功能基本上是

int lookup(int val, int bit) { return (val & (1<<bit) >> bit); }

可轻松转换为NEON SIMD。

因此,您的函数可以使用NEON内在函数实现,如下所示:

#include <arm_neon.h>

void f(uint32_t *output, const uint8_t *input, int length)
{   

    static const uint8_t mask_vals[] = {  0x1,  0x2,  0x4,  0x8,
                                         0x10, 0x20, 0x40, 0x80 };
    /* NEON shifts are left shifts, and we want a right shift,
       so use negative numbers here */
    static const int8_t shift_vals[] = { 0, -1, -2, -3, -4, -5, -6, -7 };

    /* constants we need in the main loop */
    uint8x8_t mask    = vld1_u8(mask_vals);
    int8x8_t shift    = vld1_s8(shift_vals);

    /* accumulators for results, bits 0-3 in cumul1, bits 4-7 in cumul2 */
    uint32x4_t cumul1 = vdupq_n_u32(0);
    uint32x4_t cumul2 = vdupq_n_u32(0);

    for (int i = 0; i < length; i++)
    {   
        uint8x8_t v = vld1_dup_u8(input+i);
        /* this gives 0 or 1 in each lane, depending on whether the
           appropriate bit is set */
        uint8x8_t incr = vshl_u8(vand_u8(v, mask), shift);

        /* widen to 16 bits */
        uint16x8_t incr_w = vmovl_u8(incr);

        /* increment the accumulators */
        cumul1 = vaddw_u16(cumul1, vget_low_u16(incr_w));
        cumul2 = vaddw_u16(cumul2, vget_high_u16(incr_w));
        /* store the accumulator values */
        vst1q_u32(output + i*8, cumul1);
        vst1q_u32(output + i*8 + 4, cumul2);
    }
}

免责声明:此代码已编译,但我尚未对其进行测试。