Reputation: 5666
As part of a program that I'm writing, I need to compare two values in the form a + sqrt(b)
where a
and b
are unsigned integers. As this is part of a tight loop, I'd like this comparison to run as fast as possible. (If it matters, I'm running the code on x86-64 machines, and the unsigned integers are no larger than 10^6. Also, I know for a fact that a1<a2
.)
As a stand-alone function, this is what I'm trying to optimize. My numbers are small enough integers that double
(or even float
) can exactly represent them, but rounding error in sqrt
results must not change the outcome.
// known pre-condition: a1 < a2 in case that helps
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
return a1+sqrt(b1) < a2+sqrt(b2); // computed mathematically exactly
}
Test case: is_smaller(900000, 1000000, 900001, 998002)
should return true, but as shown in comments by @wim computing it with sqrtf()
would return false. So would (int)sqrt()
to truncate back to integer.
a1+sqrt(b1) = 90100
and a2+sqrt(b2) = 901000.00050050037512481206
. The nearest float to that is exactly 90100.
As the sqrt()
function is generally quite expensive even on modern x86-64 when fully inlined as a sqrtsd
instruction, I'm trying to avoid calling sqrt()
as far as possible.
Removing sqrt by squaring potentially also avoids any danger of rounding errors by making all computation exact.
If instead the function was something like this ...
bool is_smaller(unsigned a1, unsigned b1, unsigned x) {
return a1+sqrt(b1) < x;
}
... then I could just do return x-a1>=0 && static_cast<uint64_t>(x-a1)*(x-a1)>b1;
But now since there are two sqrt(...)
terms, I cannot do the same algebraic manipulation.
I could square the values twice, by using this formula:
a1 + sqrt(b1) = a2 + sqrt(b2)
<==> a1 - a2 = sqrt(b2) - sqrt(b1)
<==> (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1) * sqrt(b2)
<==> (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1 * b2)
<==> (a1 - a2) * (a1 - a2) - (b1 + b2) = - 2 * sqrt(b1 * b2)
<==> ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 2 = sqrt(b1 * b2)
<==> ((b1 + b2) - (a1 - a2) * (a1 - a2)) * ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 4 = b1 * b2
Unsigned division by 4 is cheap because it is just a bitshift, but since I square the numbers twice I will need to use 128-bit integers and I will need to introduce a few >=0
checks (because I'm comparing inequality instead of equality).
It feels like there might be a way do this faster, by applying better algebra to this problem. Is there a way to do this faster?
Upvotes: 45
Views: 2846
Reputation: 3968
I'm not sure if algebraic manipulations, in combination with integer arithmetic, necessarily leads to the fastest solution. You'll need many scalar multiplies in that case (which isn't very fast), and/or branch prediction may fail, which may degrade performance. Obviously you'll have to benchmark to see which solution is fastest in you particular case.
One method to make
the sqrt
a bit faster is to add the -fno-math-errno
option to gcc or clang.
In that case the compiler doesn't have to check for negative inputs.
With icc this the default setting.
More performance improvement is possible by using the vectorized
sqrt
instruction sqrtpd
, instead of the scalar sqrt
instruction sqrtsd
.
Peter Cordes has shown that clang is able to auto vectorize this code,
such that it generates this sqrtpd
.
However the amount success of auto vectorization depends quite heavily on the right compiler settings
and the compiler that is used (clang, gcc, icc etc.). With -march=nehalem
, or older, clang doesn't vectorize.
More reliable vectorization results are possible with the following intrinsics code, see below. For portability we only assume SSE2 support, which is the x86-64 baseline.
/* gcc -m64 -O3 -fno-math-errno smaller.c */
/* Adding e.g. -march=nehalem or -march=skylake might further */
/* improve the generated code */
/* Note that SSE2 in guaranteed to exist with x86-64 */
#include<immintrin.h>
#include<math.h>
#include<stdio.h>
#include<stdint.h>
int is_smaller_v5(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
uint64_t a64 = (((uint64_t)a2)<<32) | ((uint64_t)a1); /* Avoid too much port 5 pressure by combining 2 32 bit integers in one 64 bit integer */
uint64_t b64 = (((uint64_t)b2)<<32) | ((uint64_t)b1);
__m128i ax = _mm_cvtsi64_si128(a64); /* Move integer from gpr to xmm register */
__m128i bx = _mm_cvtsi64_si128(b64);
__m128d a = _mm_cvtepi32_pd(ax); /* Convert 2 integers to double */
__m128d b = _mm_cvtepi32_pd(bx); /* We don't need _mm_cvtepu32_pd since a,b < 1e6 */
__m128d sqrt_b = _mm_sqrt_pd(b); /* Vectorized sqrt: compute 2 sqrt-s with 1 instruction */
__m128d sum = _mm_add_pd(a, sqrt_b);
__m128d sum_lo = sum; /* a1 + sqrt(b1) in the lower 64 bits */
__m128d sum_hi = _mm_unpackhi_pd(sum, sum); /* a2 + sqrt(b2) in the lower 64 bits */
return _mm_comilt_sd(sum_lo, sum_hi);
}
int is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
return a1+sqrt(b1) < a2+sqrt(b2);
}
int main(){
unsigned a1; unsigned b1; unsigned a2; unsigned b2;
a1 = 11; b1 = 10; a2 = 10; b2 = 10;
printf("smaller? %i %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
a1 = 10; b1 = 11; a2 = 10; b2 = 10;
printf("smaller? %i %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
a1 = 10; b1 = 10; a2 = 11; b2 = 10;
printf("smaller? %i %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
a1 = 10; b1 = 10; a2 = 10; b2 = 11;
printf("smaller? %i %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
return 0;
}
See this Godbolt link for the generated assembly.
In a simple throughput test on Intel Skylake, with compiler options gcc -m64 -O3 -fno-math-errno -march=nehalem
, I found a throughput
of is_smaller_v5()
which was 2.6 times better than the original is_smaller()
: 6.8 cpu cycles vs 18 cpu cycles, with loop overhead included. However, in a (too?)
simple latency test, where the inputs a1, a2, b1, b2
depended on the result of the previous is_smaller(_v5)
, I didn't see any improvement. (39.7 cycles vs 39 cycles).
Upvotes: 2
Reputation: 4215
Possibly not better than other answers, but uses a different idea (and a mass of pre-analysis).
// Compute approximate integer square root of input in the range [0,10^6].
// Uses a piecewise linear approximation to sqrt() with bounded error in each piece:
// 0 <= x <= 784 : x/28
// 784 < x <= 7056 : 21 + x/112
// 7056 < x <= 28224 : 56 + x/252
// 28224 < x <= 78400 : 105 + x/448
// 78400 < x <= 176400 : 168 + x/700
// 176400 < x <= 345744 : 245 + x/1008
// 345744 < x <= 614656 : 336 + x/1372
// 614656 < x <= 1000000 : (784000+x)/1784
// It is the case that sqrt(x) - 7.9992711366390365897... <= pseudosqrt(x) <= sqrt(x).
unsigned pseudosqrt(unsigned x) {
return
x <= 78400 ?
x <= 7056 ?
x <= 764 ? x/28 : 21 + x/112
: x <= 28224 ? 56 + x/252 : 105 + x/448
: x <= 345744 ?
x <= 176400 ? 168 + x/700 : 245 + x/1008
: x <= 614656 ? 336 + x/1372 : (x+784000)/1784 ;
}
// known pre-conditions: a1 < a2,
// 0 <= b1 <= 1000000
// 0 <= b2 <= 1000000
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
// Try three refinements:
// 1: a1 + sqrt(b1) <= a1 + 1000,
// so is a1 + 1000 < a2 ?
// Convert to a2 - a1 > 1000 .
// 2: a1 + sqrt(b1) <= a1 + pseudosqrt(b1) + 8 and
// a2 + pseudosqrt(b2) <= a2 + sqrt(b2),
// so is a1 + pseudosqrt(b1) + 8 < a2 + pseudosqrt(b2) ?
// Convert to a2 - a1 > pseudosqrt(b1) - pseudosqrt(b2) + 8 .
// 3: Actually do the work.
// Convert to a2 - a1 > sqrt(b1) - sqrt(b2)
// Use short circuit evaluation to stop when resolved.
unsigned ad = a2 - a1;
return (ad > 1000)
|| (ad > pseudosqrt(b1) - pseudosqrt(b2) + 8)
|| ((int) ad > (int)(sqrt(b1) - sqrt(b2)));
}
(I don't have a compiler handy, so this probably contains a typo or two.)
Upvotes: 1
Reputation: 29962
Here's a version without sqrt
, though I'm not sure whether it is faster than a version which has only one sqrt
(it may depend on the distribution of values).
Here's the math (how to remove both sqrts):
ad = a2-a1
bd = b2-b1
a1+sqrt(b1) < a2+sqrt(b2) // subtract a1
sqrt(b1) < ad+sqrt(b2) // square it
b1 < ad^2+2*ad*sqrt(b2)+b2 // arrange
ad^2+bd > -2*ad*sqrt(b2)
Here, the right side is always negative. If the left side is positive, then we have to return true.
If the left side is negative, then we can square the inequality:
ad^4+bd^2+2*bd*ad^2 < 4*ad^2*b2
The key thing to notice here is that if a2>=a1+1000
, then is_smaller
always returns true
(because the maximum value of sqrt(b1)
is 1000). If a2<=a1+1000
, then ad
is a small number, so ad^4
will always fit into 64 bit (there is no need for 128-bit arithmetic). Here's the code:
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
int ad = a2 - a1;
if (ad>1000) {
return true;
}
int bd = b2 - b1;
if (ad*ad+bd>0) {
return true;
}
int ad2 = ad*ad;
return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}
EDIT: As Peter Cordes noticed, the first if
is not necessary, as the second if handles it, so the code becomes smaller and faster:
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
int ad = a2 - a1;
int bd = b2 - b1;
if ((long long int)ad*ad+bd>0) {
return true;
}
int ad2 = ad*ad;
return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}
Upvotes: 19
Reputation: 4243
There is also newton method for calculating integer sqrts as described here Another approach would be to not calculate square root, but searching for floor(sqrt(n)) via binary search ... there are "only" 1000 full square numbers less than 10^6. This has probably bad performance, but would be an interesting approach. I haven't measure any of these, but here are examples:
#include <iostream>
#include <array>
#include <algorithm> // std::lower_bound
#include <cassert>
bool is_smaller_sqrt(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
return a1 + sqrt(b1) < a2 + sqrt(b2);
}
static std::array<int, 1001> squares;
template <typename C>
void squares_init(C& c)
{
for (int i = 0; i < c.size(); ++i)
c[i] = i*i;
}
inline bool greater(const int& l, const int& r)
{
return r < l;
}
inline bool is_smaller_bsearch(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
// return a1 + sqrt(b1) < a2 + sqrt(b2)
// find floor(sqrt(b1)) - binary search withing 1000 elems
auto it_b1 = std::lower_bound(crbegin(squares), crend(squares), b1, greater).base();
// find floor(sqrt(b2)) - binary search withing 1000 elems
auto it_b2 = std::lower_bound(crbegin(squares), crend(squares), b2, greater).base();
return (a2 - a1) > (it_b1 - it_b2);
}
unsigned int sqrt32(unsigned long n)
{
unsigned int c = 0x8000;
unsigned int g = 0x8000;
for (;;) {
if (g*g > n) {
g ^= c;
}
c >>= 1;
if (c == 0) {
return g;
}
g |= c;
}
}
bool is_smaller_sqrt32(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
return a1 + sqrt32(b1) < a2 + sqrt32(b2);
}
int main()
{
squares_init(squares);
// now can use is_smaller
assert(is_smaller_sqrt(1, 4, 3, 1) == is_smaller_sqrt32(1, 4, 3, 1));
assert(is_smaller_sqrt(1, 2, 3, 3) == is_smaller_sqrt32(1, 2, 3, 3));
assert(is_smaller_sqrt(1000, 4, 1001, 1) == is_smaller_sqrt32(1000, 4, 1001, 1));
assert(is_smaller_sqrt(1, 300, 3, 200) == is_smaller_sqrt32(1, 300, 3, 200));
}
Upvotes: 2
Reputation: 37232
I'm tired and probably made a mistake; but I'm sure if I did someone will point it out..
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
a_diff = a1-a2; // May be negative
if(a_diff < 0) {
if(b1 < b2) {
return true;
}
temp = a_diff+sqrt(b1);
if(temp < 0) {
return true;
}
return temp*temp < b2;
} else {
if(b1 >= b2) {
return false;
}
}
// return a_diff+sqrt(b1) < sqrt(b2);
temp = a_diff+sqrt(b1);
return temp*temp < b2;
}
If you know a1 < a2
then it could become:
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
a_diff = a2-a1; // Will be positive
if(b1 > b2) {
return false;
}
if(b1 >= a_diff*a_diff) {
return false;
}
temp = a_diff+sqrt(b2);
return b1 < temp*temp;
}
Upvotes: 4