优化,分支消除

时间:2010-08-02 12:46:37

标签: c++ optimization floating-point branch

float mixValue = ... //in range -1.0f to 1.0f
for(... ; ... ; ...  ) //long loop
{
    float inputLevel = ... //in range -1.0f to 1.0f
    if(inputLevel < 0.0 && mixValue < 0.0)
    {
        mixValue = (mixValue + inputLevel) + (mixValue*inputLevel);
    }
    else
    {
        mixValue = (mixValue + inputLevel) - (mixValue*inputLevel);
    }
}

只是一个简单的问题,我们可以在没有分支的情况下计算mixValue 吗?或任何其他优化建议,例如使用SIMD?

编辑: 只是为了获得更多信息,我结束了 使用此解决方案,基于所选答案:

const float sign[] = {-1, 1};
float mixValue = ... //in range -1.0f to 1.0f
for(... ; ... ; ...  ) //long loop
{
    float inputLevel = ... //in range -1.0f to 1.0f
    unsigned a = *(unsigned*)(&mixValue);
    unsigned b = *(unsigned*)(&inputLevel);

    float mulValue = mixValue * inputLevel * sign[(a & b) >> (8*sizeof(unsigned)-1)];
    float addValue = mixValue + inputLevel;
    mixValue = addValue + mulValue;
}
谢谢。

8 个答案:

答案 0 :(得分:4)

这个怎么样:

const float sign[] = {-1, 1};

float mixValue = ... //in range -1.0f to 1.0f
for(... ; ... ; ...  ) //long loop
{
    float inputLevel = ... //in range -1.0f to 1.0f
    int bothNegative = (inputLevel < 0.0) & (mixValue < 0.0);
    mixValue = (mixValue + inputLevel) + (sign[bothNegative]*mixValue*inputLevel);
}

编辑:迈克是正确的&amp;&amp;将介绍一个分支,并感谢Pedro证明它。我换了&amp;&amp;到&amp;现在GCC(版本4.4.0)生成无分支代码。

答案 1 :(得分:1)

float mixValue = ... //in range -1.0f to 1.0f
for(... ; ... ; ...  ) //long loop
{
     float inputLevel = ... //in range -1.0f to 1.0f
     float mulValue = mixValue * inputLevel;
     float addValue = mixValue + inputLevel;
     __int32 a = *(__int32*)(&mixValue);
     __int32 b = *(__int32*)(&inputLevel);
     __int32 c = *(__int32*)(&mulValue);
     __int32 d = c & ((a ^ b) | 0x7FFFFFFF);
     mixValue = addValue + *(float*)(&d);
}

答案 2 :(得分:1)

受Roku的回答启发(在MSVC ++ 10分支上),这似乎没有分支:

#include <iostream>

using namespace std;
const float sign[] = {-1, 1};
int main() {
    const int N = 10;
    float mixValue = -0.5F;
    for(int i = 0; i < N; i++) {
        volatile float inputLevel = -0.3F;
        int bothNegative = ((((unsigned char*)&inputLevel)[3] & 0x80) & (((unsigned char*)&mixValue)[3] & 0x80)) >> 7;
        mixValue = (mixValue + inputLevel) + (sign[bothNegative]*mixValue*inputLevel);
    }

    std::cout << mixValue << std::endl;
}

这是由IDA Pro分析的反汇编(在MSVC ++ 10上编译,发布模式):

Disassembly http://img248.imageshack.us/img248/6865/floattestbranchmine.png

答案 3 :(得分:0)

您是否使用和不使用分支对循环进行了基准测试?

至少你可以删除分支的一部分,因为mixValue在循环之外。

float multiplier(float a, float b){
  unsigned char c1Neg = reinterpret_cast<unsigned char *>(&a)[3] & 0x80;
  unsigned char c2Neg = reinterpret_cast<unsigned char *>(&b)[3] & 0x80;
  unsigned char multiplierIsNeg = c1Neg & c2Neg;
  float one = 1;
  reinterpret_cast<unsigned char *>(&one)[3] |= multiplierIsNeg;
  return -one;
}
cout << multiplier(-1,-1) << endl; // +1
cout << multiplier( 1,-1) << endl; // -1
cout << multiplier( 1, 1) << endl; // -1
cout << multiplier(-1, 1) << endl; // -1

答案 4 :(得分:0)

如果您担心过度分支,请查看Duff's Device。这应该有助于解开循环。说实话,循环展开是由优化器完成的,因此尝试手动执行可能是浪费时间。检查装配输出以查找。

如果您对阵列中的每个项目执行完全相同的操作,那么SIMD肯定会有所帮助。请注意,并非所有硬件都支持SIMD,但是像gcc这样的一些编译器会为SIMD提供内在函数,这将使您免于陷入汇编程序。

如果您使用gcc编译ARM代码,可以找到SIMD内在函数here

答案 5 :(得分:0)

就在我的头顶(我确信它可以减少):

mixValue = (mixValue + inputLevel) + (((mixValue / fabs(mixValue)) + (inputLevel / fabs(inputLevel))+1) / fabs(((mixValue / fabs(mixValue)) + (inputLevel / fabs(inputLevel))+1)))*-1*(mixValue*inputLevel);

为了澄清一下,我会分别计算一下标志:

float sign = (((mixValue / fabs(mixValue)) + (inputLevel / fabs(inputLevel))+1) / fabs(((mixValue / fabs(mixValue)) + (inputLevel / fabs(inputLevel))+1)))*-1;
mixValue = (mixValue + inputLevel) + sign*(mixValue*inputLevel);

这是浮点数学,所以你可能需要纠正一些舍入问题,但这应该让你走上正确的道路。

答案 6 :(得分:0)

查看您的代码,您会看到始终会添加mixValueinputLevel的绝对值,除非两者都是正数。

有了一些比特小知识和IEEE浮点知识,你可以摆脱条件:

// sets the first bit of f to zero => makes it positive.
void absf( float& f ) {
   assert( sizeof( float ) == sizeof( int ) );
   reinterpret_cast<int&>( f ) &= ~0x80000000;
}

// returns a first-bit = 1 if f is positive
int pos( float& f ) {
  return ~(reinterpret_cast<int&>(f) & 0x80000000) & 0x80000000;
}

// returns -fabs( f*g ) if f>0 and g>0, fabs(f*g) otherwise.    
float prod( float& f, float& g ) {
  float p = f*g;
  float& rp=p;
  int& ri = reinterpret_cast<int&>(rp);
  absf(p);
  ri |= ( pos(f) & pos(g) & 0x80000000); // first bit = + & +
  return p;
}

int main(){
 struct T { float f, g, r; 
    void test() {
       float p = prod(f,g);
       float d = (p-r)/r;
       assert( -1e-15 < d && d < 1e-15 );
    }
 };
 T vals[] = { {1,1,-1},{1,-1,1},{-1,1,1},{-1,-1,1} };
 for( T* val=vals; val != vals+4; ++val ) {
    val->test();
 }
}

最后:你的循环

for( ... ) {
    mixedResult += inputLevel + prod(mixedResult,inputLevel);
}

注意:积累的尺寸不匹配。 inputLevel是无量纲的数量,而mixedResult是您的...结果(例如在Pascal中,在Volts中,......)。您无法添加具有不同尺寸的两个数量。您可能希望mixedResult += prod( mixedResult, inputLevel )作为累加器。

答案 7 :(得分:0)

某些编译器(即MSC)也需要手动签名。

来源:

volatile float mixValue;
volatile float inputLevel;

float u   = mixValue*inputLevel;
float v   = -u;
float a[] = { v, u };

mixValue = (mixValue + inputLevel) + a[ (inputLevel<0.0) & (mixValue<0.0) ];

IntelC 11.1:

movss     xmm1, DWORD PTR [12+esp]    
mulss     xmm1, DWORD PTR [16+esp]    
movss     xmm6, DWORD PTR [12+esp]    
movss     xmm2, DWORD PTR [16+esp]    
movss     xmm3, DWORD PTR [16+esp]    
movss     xmm5, DWORD PTR [12+esp]    
xorps     xmm4, xmm4                  
movaps    xmm0, xmm4                  
subss     xmm0, xmm1                  
movss     DWORD PTR [esp], xmm0       
movss     DWORD PTR [4+esp], xmm1     
addss     xmm6, xmm2                  
xor       eax, eax                    
cmpltss   xmm3, xmm4                  
movd      ecx, xmm3                   
neg       ecx                         
cmpltss   xmm5, xmm4                  
movd      edx, xmm5                   
neg       edx                         
and       ecx, edx                    
addss     xmm6, DWORD PTR [esp+ecx*4] 
movss     DWORD PTR [12+esp], xmm6    

gcc 4.5:

flds    32(%esp)
flds    16(%esp)
fmulp   %st, %st(1)
fld     %st(0)
fchs
fstps   (%esp)
fstps   4(%esp)
flds    32(%esp)
flds    16(%esp)
flds    16(%esp)
flds    32(%esp)
fxch    %st(2)
faddp   %st, %st(3)
fldz
fcomi   %st(2), %st
fstp    %st(2)
fxch    %st(1)
seta    %dl
xorl    %eax, %eax
fcomip  %st(1), %st
fstp    %st(0)
seta    %al
andl    %edx, %eax
fadds   (%esp,%eax,4)
xorl    %eax, %eax
fstps   32(%esp)