查找SSE寄存器中最常出现的元素

时间:2017-08-03 05:56:35

标签: algorithm assembly x86 sse

有没有人想过如何计算SSE4.x中8位整数向量的模式(统计量)?为了澄清,这将是128位寄存器中的16x8位值。

我希望结果作为矢量掩码选择模值元素。即_mm_cmpeq_epi8(v, set1(mode(v)))的结果,以及标量值。

提供一些额外的背景;虽然上述问题本身就是一个有趣的问题,但我已经通过线性复杂度可以想到的大多数算法。这个课程将消除我从计算这个数字中获得的任何收益。

我希望在这里与大家一起寻找一些深刻的魔法。有可能需要近似来打破这个界限,例如“选择 a 频繁出现的元素”(NB与最多的差异),这将是值得。概率性答案也是可用的。

SSE和x86有一些非常有趣的语义。可能值得探索超优化传递。

3 个答案:

答案 0 :(得分:5)

这里可能是一个相对简单的暴力SSEx方法,请参阅下面的代码。 我们的想法是将输入向量v字节旋转1到15个位置,并比较旋转的向量 原始v表示平等。缩短依赖链并增加依赖链 指令级并行,两个计数器用于计算(垂直和)这些相等的元素: sum1sum2,因为可能会有受益于此的架构。 等元素计为-1。变量sum = sum1 + sum2包含值的总计数 介于-1和-16之间。 min_brc包含向所有元素广播的水平最小值summask = _mm_cmpeq_epi8(sum,min_brc)是请求的模式值元素的掩码 OP的中间结果。在代码的下几行中,将提取实际模式。

此解决方案肯定比标量解决方案更快。 请注意,对于AVX2,可以使用高128位通道进一步加速计算。

仅计算模式值元素的掩码需要20个周期(吞吐量)。随着实际模式的播出 在SSE寄存器中,它需要大约21.4个周期。

请注意下一个示例中的行为:  [1, 1, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]返回mask=[-1,-1,-1,-1,0,0,...,0]  模式值为1,尽管1经常发生3次。

以下代码已经过测试,但尚未经过全面测试

#include <stdio.h>
#include <x86intrin.h>
/*  gcc -O3 -Wall -m64 -march=nehalem mode_uint8.c   */
int print_vec_char(__m128i x);

__m128i mode_statistic(__m128i v){
    __m128i  sum2         = _mm_set1_epi8(-1);                    /* Each integer occurs at least one time */
    __m128i  v_rot1       = _mm_alignr_epi8(v,v,1);
    __m128i  v_rot2       = _mm_alignr_epi8(v,v,2);
    __m128i  sum1         =                   _mm_cmpeq_epi8(v,v_rot1);
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot2));

    __m128i  v_rot3       = _mm_alignr_epi8(v,v,3);
    __m128i  v_rot4       = _mm_alignr_epi8(v,v,4);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot3));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot4));

    __m128i  v_rot5       = _mm_alignr_epi8(v,v,5);
    __m128i  v_rot6       = _mm_alignr_epi8(v,v,6);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot5));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot6));

    __m128i  v_rot7       = _mm_alignr_epi8(v,v,7);
    __m128i  v_rot8       = _mm_alignr_epi8(v,v,8);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot7));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot8));

    __m128i  v_rot9       = _mm_alignr_epi8(v,v,9);
    __m128i  v_rot10      = _mm_alignr_epi8(v,v,10);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot9));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot10));

    __m128i  v_rot11      = _mm_alignr_epi8(v,v,11);
    __m128i  v_rot12      = _mm_alignr_epi8(v,v,12);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot11));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot12));

    __m128i  v_rot13      = _mm_alignr_epi8(v,v,13);
    __m128i  v_rot14      = _mm_alignr_epi8(v,v,14);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot13));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot14));

    __m128i  v_rot15      = _mm_alignr_epi8(v,v,15);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot15));
    __m128i  sum          = _mm_add_epi8(sum1,sum2);                      /* Sum contains values such as -1, -2 ,...,-16                                    */
                                                                          /* The next three instructions compute the horizontal minimum of sum */
    __m128i  sum_shft     = _mm_srli_epi16(sum,8);                        /* Shift right 8 bits, while shifting in zeros                                    */
    __m128i  min1         = _mm_min_epu8(sum,sum_shft);                   /* sum and sum_shuft are considered as unsigned integers. sum_shft is zero at the odd positions and so is min1 */ 
    __m128i  min2         = _mm_minpos_epu16(min1);                       /* Byte 0 within min2 contains the horizontal minimum of sum                      */
    __m128i  min_brc      = _mm_shuffle_epi8(min2,_mm_setzero_si128());   /* Broadcast horizontal minimum                                                   */

    __m128i  mask         = _mm_cmpeq_epi8(sum,min_brc);                  /* Mask = -1 at the byte positions where the value of v is equal to the mode of v */

    /* comment next 4 lines out if there is no need to broadcast the mode value */
    int      bitmask      = _mm_movemask_epi8(mask);
    int      indx         = __builtin_ctz(bitmask);                            /* Index of mode                            */
    __m128i  v_indx       = _mm_set1_epi8(indx);                               /* Broadcast indx                           */
    __m128i  answer       = _mm_shuffle_epi8(v,v_indx);                        /* Broadcast mode to each element of answer */ 

/* Uncomment lines below to print intermediate results, to see how it works. */
//    printf("sum         = ");print_vec_char (sum           );
//    printf("sum_shft    = ");print_vec_char (sum_shft      );
//    printf("min1        = ");print_vec_char (min1          );
//    printf("min2        = ");print_vec_char (min2          );
//    printf("min_brc     = ");print_vec_char (min_brc       );
//    printf("mask        = ");print_vec_char (mask          );
//    printf("v_indx      = ");print_vec_char (v_indx        );
//    printf("answer      = ");print_vec_char (answer        );

             return answer;   /* or return mask, or return both ....    :) */
}


int main() {
    /* To test throughput set throughput_test to 1, otherwise 0    */
    /* Use e.g. perf stat -d ./a.out to test throughput           */
    #define throughput_test 0

    /* Different test vectors  */
    int i;
    char   x1[16] = {5, 2, 2, 7, 21, 4, 7, 7, 3, 9, 2, 5, 4, 3, 5, 5};
    char   x2[16] = {5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5};
    char   x3[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    char   x4[16] = {1, 2, 3, 2, 1, 6, 7, 8, 2, 2, 2, 3, 3, 2, 15, 16};
    char   x5[16] = {1, 1, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

    printf("\n15...0      =   15  14  13  12    11  10  9   8     7   6   5   4     3   2   1   0\n\n");

    __m128i  x_vec  = _mm_loadu_si128((__m128i*)x1);

    printf("x_vec       = ");print_vec_char(x_vec        );

    __m128i  y      = mode_statistic (x_vec);

    printf("answer      = ");print_vec_char(y         );


    #if throughput_test == 1
    __m128i  x_vec1  = _mm_loadu_si128((__m128i*)x1);
    __m128i  x_vec2  = _mm_loadu_si128((__m128i*)x2);
    __m128i  x_vec3  = _mm_loadu_si128((__m128i*)x3);
    __m128i  x_vec4  = _mm_loadu_si128((__m128i*)x4);
    __m128i  x_vec5  = _mm_loadu_si128((__m128i*)x5);
    __m128i  y1, y2, y3, y4, y5;
    __asm__ __volatile__ ( "vzeroupper" : : : );   /* Remove this line on non-AVX processors */
    for (i=0;i<100000000;i++){
        y1       = mode_statistic (x_vec1);
        y2       = mode_statistic (x_vec2);
        y3       = mode_statistic (x_vec3);
        y4       = mode_statistic (x_vec4);
        y5       = mode_statistic (x_vec5);
        x_vec1   = mode_statistic (y1    );
        x_vec2   = mode_statistic (y2    );
        x_vec3   = mode_statistic (y3    );
        x_vec4   = mode_statistic (y4    );
        x_vec5   = mode_statistic (y5    );
     }
    printf("mask mode   = ");print_vec_char(y1           );
    printf("mask mode   = ");print_vec_char(y2           );
    printf("mask mode   = ");print_vec_char(y3           );
    printf("mask mode   = ");print_vec_char(y4           );
    printf("mask mode   = ");print_vec_char(y5           );
    #endif

    return 0;
}



int print_vec_char(__m128i x){
    char v[16];
    _mm_storeu_si128((__m128i *)v,x);
    printf("%3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi\n",
           v[15],v[14],v[13],v[12],v[11],v[10],v[9],v[8],v[7],v[6],v[5],v[4],v[3],v[2],v[1],v[0]);
    return 0;
}

输出:

15...0      =   15  14  13  12    11  10  9   8     7   6   5   4     3   2   1   0

x_vec       =   5   5   3   4 |   5   2   9   3 |   7   7   4  21 |   7   2   2   5
sum         =  -4  -4  -2  -2 |  -4  -3  -1  -2 |  -3  -3  -2  -1 |  -3  -3  -3  -4
min_brc     =  -4  -4  -4  -4 |  -4  -4  -4  -4 |  -4  -4  -4  -4 |  -4  -4  -4  -4
mask        =  -1  -1   0   0 |  -1   0   0   0 |   0   0   0   0 |   0   0   0  -1
answer      =   5   5   5   5 |   5   5   5   5 |   5   5   5   5 |   5   5   5   5

水平最小值是使用Evgeny Kluev的method计算的。

答案 1 :(得分:4)

对寄存器中的数据进行排序。 插入排序可以在16(15)步骤中完成,方法是将寄存器初始化为&#34; Infinity&#34;,它试图说明单调递减的数组并将新元素并行插入所有可能的位置:

// e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF 78
__m128i sorted = _mm_or_si128(my_array, const_FFFFF00);

for (int i = 1; i < 16; ++i)
{
    // Trying to insert e.g. A0, we must shift all the FF's to left
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF 78 00
    __m128i shifted = _mm_bslli_si128(sorted, 1);

    // Taking the MAX of shifted and 'A0 on all places'
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF A0 A0
    shifted = _mm_max_epu8(shifted, _mm_set1_epi8(my_array[i]));

    // and minimum of the shifted + original --
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF A0 78
    sorted = _mm_min_epu8(sorted, shifted);
}

然后计算vec[n+1] == vec[n]的掩码,将掩码移动到GPR并使用它来索引32768条目LUT以获得最佳索引位置。

在实际情况中,人们可能希望排序不仅仅是一个向量;即一次排序16个16入口向量;

__m128i input[16];      // not 1, but 16 vectors
transpose16x16(input);  // inplace vector transpose
sort(transpose);        // 60-stage network exists for 16 inputs
// linear search -- result in 'mode'
__m128i mode = input[0];
__m128i previous = mode;
__m128i count = _mm_set_epi8(0);
__m128i max_count = _mm_setzero_si128(0);
for (int i = 1; i < 16; i++)
{
   __m128i &current = input[i];
   // histogram count is off by one
   // if (current == previous) count++;
   //    else count = 0;
   // if (count > max_count)
   //    mode = current, max_count = count
   prev = _mm_cmpeq_epi8(prev, current);
   count = _mm_and_si128(_mm_sub_epi8(count, prev), prev);
   __m128i max_so_far = _mm_cmplt_epi8(max_count, count);
   mode = _mm_blendv_epi8(mode, current, max_so_far);
   max_count = _mm_max_epi8(max_count, count);
   previous = current;
}

内部循环总计每个结果的7-8个指令的摊销成本; 排序通常每级有2个指令 - 即每个结果8个指令,当16个结果需要60个阶段或120个指令时。 (这仍然将转置留作练习 - 但我认为它应该比排序快得多?)

所以,这应该是每8位结果24个指令的球场。

答案 2 :(得分:0)

用于与标量代码进行性能比较。在主要部分上非矢量化但在表清除和tmp初始化时矢量化。 (fx8150每f()呼叫168个周期(在3.7 GHz时1.0002秒内完成22M呼叫))

#include <x86intrin.h>

unsigned char tmp[16]; // extracted values are here (single instruction, store_ps)
unsigned char table[256]; // counter table containing zeroes
char f(__m128i values)
{
    _mm_store_si128((__m128i *)tmp,values);
    int maxOccurence=0;
    int currentValue=0;
    for(int i=0;i<16;i++)
    {
        unsigned char ind=tmp[i];
        unsigned char t=table[ind];
        t++;
        if(t>maxOccurence)
        {
             maxOccurence=t;
             currentValue=ind;
        }
        table[ind]=t;
    }
    for(int i=0;i<256;i++)
        table[i]=0;
    return currentValue;
}

g ++ 6.3输出:

f:                                      # @f
        movaps  %xmm0, tmp(%rip)
        movaps  %xmm0, -24(%rsp)
        xorl    %r8d, %r8d
        movq    $-15, %rdx
        movb    -24(%rsp), %sil
        xorl    %eax, %eax
        jmp     .LBB0_1
.LBB0_2:                                # %._crit_edge
        cmpl    %r8d, %esi
        cmovgel %esi, %r8d
        movb    tmp+16(%rdx), %sil
        incq    %rdx
.LBB0_1:                                # =>This Inner Loop Header: Depth=1
        movzbl  %sil, %edi
        movb    table(%rdi), %cl
        incb    %cl
        movzbl  %cl, %esi
        cmpl    %r8d, %esi
        cmovgl  %edi, %eax
        movb    %sil, table(%rdi)
        testq   %rdx, %rdx
        jne     .LBB0_2
        xorps   %xmm0, %xmm0
        movaps  %xmm0, table+240(%rip)
        movaps  %xmm0, table+224(%rip)
        movaps  %xmm0, table+208(%rip)
        movaps  %xmm0, table+192(%rip)
        movaps  %xmm0, table+176(%rip)
        movaps  %xmm0, table+160(%rip)
        movaps  %xmm0, table+144(%rip)
        movaps  %xmm0, table+128(%rip)
        movaps  %xmm0, table+112(%rip)
        movaps  %xmm0, table+96(%rip)
        movaps  %xmm0, table+80(%rip)
        movaps  %xmm0, table+64(%rip)
        movaps  %xmm0, table+48(%rip)
        movaps  %xmm0, table+32(%rip)
        movaps  %xmm0, table+16(%rip)
        movaps  %xmm0, table(%rip)
        movsbl  %al, %eax
        ret