Reputation: 1439
Is there a good method to compute the correctly rounded result of
sqrt(a+b)
for floating-point numbers a
and b
(same precision), where 0<=a<+inf
and 0<=b<+inf
?
In particular, for input values where the computation of a+b
would overflow?
("Correctly rounded" here meaning the same as for the computation of sqrt
itself, that is, returning the representable value closest to the "true" result calculated in infinite precision.)
(Note: one obvious approach is to do the computation in a larger floating-point size and avoid the overflow that way. Unfortunately, this does not work in general (e.g. if there is no larger floating-point format supported).)
I've tried Herbie on this, but it completely gives up. It doesn't seem to sample enough points where a+b overflows to detect the problem, and doesn't seem to handle dependent sampling well either. Unfortunate, as it's usually a great tool.
What I've been doing thus far has been (pseudocode)
if a + b would overflow:
2*sqrt(a/4 + b/4) # Cannot overflow for finite inputs, as f::MAX/4 + f::MAX/4 <= f::MAX
else:
... # handle non-overflow case. Also interesting; not quite the topic of this question.
...which appears to mostly work in practice, but is a) completely unprincipled and b) in practice occasionally returns a result that's off by epsilon in the overflow-avoidance portion (as in e.g. the true result is x + 0.2(x.next_larger()-x)
but this returns x.next_larger()
instead of x
)
For a quick example of the "off-by-epsilon" issue in f32:
>>> import decimal
>>> decimal.getcontext().prec = 256
>>> from decimal import Decimal as D
>>> from numpy import float32 as f32
>>> a = D(f32("6.0847234e31").astype(float))
>>> b = D(f32("3.4028235e38").astype(float))
>>> res_act = (a+b).sqrt()
>>> res_calc = D(f32("1.8446744e19").astype(float)) # 2*sqrt(a/4 + b/4) in f32 precision
>>> res_best = D(f32("1.8446746e19").astype(float)) # obtained by brute-force
>>> abs(res_calc - res_act) > abs(res_best - res_act)
True # oops
(You'll have to take my word on the result calculated in f32, as Python usually operates in f64 precision. That's also why the f32 dance.)
Upvotes: 3
Views: 546
Reputation: 9512
Here is an extreme example. Let's have u = 2^-p
where p is float precision.
We have (1+u)^2 = (1+2u) + u^2
.
If we take a = 1+2u
, we have float(a)=a
, a is a representable in float (it is the next float after 1), and b= u^2
, float(b)=b
, b is representable as float too (as a power of 2^(-2p)).
The exact sqrt(a+b)
is (1+u)
, which should be rounded to float(1+u)=1
, due to exact tie, it is rounded down to nearest even significand...
float(a+b)=a
and float(sqrt(a))=1
, so that's OK.
But let's change change the last bit of b: b=(1+2*u)*u^2
; float(b)=b
, b is just a scaled down by twice the precision.
We now have the exact sqrt(a+b) > 1+u
, hence it should round up to float(sqrt(a+b)) = 1+2u
.
We see that a bit to 2^(-3p+1) places (thrice the float precision) can change the correct rounding!
That means that you shall NOT rely on double precision to perform a correctly rounded operation.
Upvotes: 0
Reputation: 1439
An alternative method, now that @EricPostpischil and @njuffa highlighted the actual problem (namely, double-rounding).
(Note: the below is talking about "well-behaved" numbers. It does not take precision boundaries or subnormals into account, although it can be extended to do so.)
First, note that both sqrt(x)
and a+b
are guaranteed to return the closest representable value to the result. The problem is the double rounding. That is, we're calculating, essentially, round(sqrt(round(a+b)))
, when we want to be calculating round(sqrt(a+b))
. Note the lack of inner round.
So, how much can that inner round affect the result? Well, the inner round adds up to ±0.5 ULP to the result of the addition. So we have, roughly, sqrt((a+b)*(1 ±2**-p))
, assuming a p
-bit mantissa.
This reduces to sqrt(a+b)*sqrt(1 ±2**-p)
... but sqrt(1 ±2**-p)
is closer to 1 than (1 ±2**-p)
is! (It's close, but not quite, (1 ±2**-(p+1))
, as this is a finite difference. You can see this from the Taylor series around 1 (d/dx = 1/2).) The second rounding then affects the result by another ±0.5ULP.
What this means is that we are guaranteed to be no further than 1 ULP from the "true" result. And hence a fixup that just chooses between {sqrt(a+b)-1ULP, sqrt(a+b), sqrt(a+b)+1ULP}
is a viable strategy, if we can "just" figure out how to choose...
So let's see if we can come up with a comparison-based method that works in finite precision. (Note: the below is in infinite precision unless otherwise specified)
resy = float(sqrt(a+b))
resx = resy.prev_nearest()
resz = resy.next_nearest()
Note that resx < resy < resz
.
Assuming we have p
bits of precision in our floats, that becomes
res = sqrt(a+b) // in infinite precision
resy = float(res)
resx = resy * (1 - 2**(1-p))
resz = resy * (1 + 2**(1-p))
So let's compare resx
and resy
for a moment:
distx = abs(resx - res)
disty = abs(resy - res)
checkxy: distx < disty
checkxy: abs(resx - res) < abs(resy - res)
checkxy: (resx - res)**2 < (resy - res)**2
checkxy: resx**2 - 2*resx*res - res**2 < resy**2 - 2*resy*res - res**2
checkxy: resx**2 - resy**2 < 2*resx*res - 2*resy*res
checkxy: resx**2 - resy**2 < 2*res*(resx - resy)
// Assuming resx < resy
checkxy: resx+resy > 2*res
checkxy: resx+resy > 2*sqrt(a+b)
// Assuming resx+resy >= 0
checkxy: (resx+resy)**2 > 4*(a+b)
checkxy: (resy*(2 - 2**(1-p)))**2 > 4*(a+b)
checkxy: (resy**2)*((2 - 2**(1-p)))**2 > 4*(a+b)
checkxy: (resy**2)*(4 - 2*2**(1-p) + 2**(2-2p)) > 4*(a+b)
checkxy: (resy**2)*(4 - 4*2**(0-p) + 4*2**(0-2p)) > 4*(a+b)
checkxy: (resy**2)*(1 - 2**-p + 2**-2p) > a+b
...which is a check we can actually do in finite precision (although it still requires higher precision, which is annoying).
Ditto, for checkyz
we get
checkxy: disty < distz
checkyz: (resy**2)*(1 + 2**-p + 2**-2p) < a+b
From these two checks you can select the correct result. ...and then it's "just" a matter of checking / handling the edge cases I glossed over above.
Now, in practice I don't think this is worth it compared to just doing the sqrt in higher precision in the first place, at least unless someone can come up with a better method of choosing. But it's still an interesting alternative.
Upvotes: 1
Reputation: 26175
Overflow is easily avoided via appropriate scaling by powers of two, such that the argument large in magnitude is scaled towards unity. The hard part is producing correctly rounded results. I am not even completely convinced that performing the intermediate computation in the next larger IEEE-754 binary floating-point type guarantees that, due to potential issues with double rounding.
In the absence of a wider floating-point type, one would have to fall back to chaining multiple native-precision numbers together to perform operations with higher intermediate precision. A common scheme due to Dekker is called pair-precision. It uses pairs of floating-point numbers where the more significant part is commonly called the "head" and the less significant part is called the "tail". These two parts are normalized such that the magnitude of the tail is at most half an ulp of the magnitude of the head.
The number of effective significand bits in this scheme is 2*p+1, where p is the number of significand bits in the underlying floating-point types. The "extra" bit is represented by the sign bit of the tail. It is important to note that the exponent range is unchanged compared to the underlying base type, so we need to scale fairly aggressively towards unity to avoid encountering subnormal operands in intermediate computations. Pair-precision computation cannot guarantee correctly-rounded results. Using triplets would probably work, but requires more effort than I can afford to invest in an answer.
However, pair-precision can deliver results that are faithfully rounded and almost always correctly rounded. When FMA (fused multiply-add) is available, a Newton-Raphson based pair-precision square root producing about 2*p-1 good bits can be constructed fairly efficiently. This is what I am using in the exemplary IS0-C99 code below that uses float
mapped to IEEE-754 binary32
as the native floating-point type. Pair-precision code should be compiled with the highest adherence to the IEEE-754 standard to prevent unexpected deviations from the written order of floating-point operations. In my case I used the /fp:strict
command line switch of MSVC 2019.
With tens of billions of random test vectors, my test program reports a maximum error of 0.500000179 ulp.
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>
/* compute square root of sum of two positive floating-point numbers */
float sqrt_sum_pos (float a, float b)
{
float mn, mx, res, scale_in, scale_out;
float r, s, t, u, v, w, x;
/* sort arguments according to magnitude */
mx = a < b ? b : a;
mn = a < b ? a : b;
/* select scale factor: scale argument larger in magnitude towards unity */
scale_in = (mx > 1.0f) ? 0x1.0p-64f : 0x1.0p+64f;
scale_out = (mx > 1.0f) ? 0x1.0p+32f : 0x1.0p-32f;
/* scale input arguments */
mn = mn * scale_in;
mx = mx * scale_in;
/* represent sum as a normalized pair s:t of 'float' */
s = mx + mn; // most significant bits
t = (mx - s) + mn; // least significant bits
/* compute square root of s:t. Based on Alan Karp and Peter Markstein,
"High Precision Division and Square Root", ACM TOMS, vol. 23, no. 4,
December 1997, pp. 561-589
*/
r = sqrtf (1.0f / s);
if (s == 0.0f) r = 0.0f;
x = r * s;
s = fmaf (x, -x, s);
r = 0.5f * r;
u = s + t;
v = (s - u) + t;
s = r * u;
t = fmaf (r, u, -s);
t = fmaf (r, v, t);
r = x + s;
s = (x - r) + s;
s = s + t;
t = r + s;
s = (r - t) + s;
/* Component sum of t:s represents square root with maximum error very close to 0.5 ulp */
w = s + t;
/* compensate scaling of source operands */
res = w * scale_out;
/* handle special cases: NaN, Inf */
t = a + b;
if (isinf (mx)) res = mx;
if (isnan (t)) res = t;
return res;
}
// George Marsaglia's KISS PRNG, period 2**123. Newsgroup sci.math, 21 Jan 1999
// Bug fix: Greg Rose, "KISS: A Bit Too Simple" http://eprint.iacr.org/2011/007
static uint32_t kiss_z=362436069, kiss_w=521288629;
static uint32_t kiss_jsr=123456789, kiss_jcong=380116160;
#define znew (kiss_z=36969*(kiss_z&65535)+(kiss_z>>16))
#define wnew (kiss_w=18000*(kiss_w&65535)+(kiss_w>>16))
#define MWC ((znew<<16)+wnew )
#define SHR3 (kiss_jsr^=(kiss_jsr<<13),kiss_jsr^=(kiss_jsr>>17), \
kiss_jsr^=(kiss_jsr<<5))
#define CONG (kiss_jcong=69069*kiss_jcong+1234567)
#define KISS ((MWC^CONG)+SHR3)
uint32_t float_as_uint32 (float a)
{
uint32_t r;
memcpy (&r, &a, sizeof r);
return r;
}
uint64_t double_as_uint64 (double a)
{
uint64_t r;
memcpy (&r, &a, sizeof r);
return r;
}
float uint32_as_float (uint32_t a)
{
float r;
memcpy (&r, &a, sizeof r);
return r;
}
double floatUlpErr (float res, double ref)
{
uint64_t i, j, err, refi;
int expoRef;
/* ulp error cannot be computed if either operand is NaN, infinity, zero */
if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
(res == 0.0f) || (ref == 0.0f)) {
return 0.0;
}
/* Convert the float result to an "extended float". This is like a float
with 56 instead of 24 effective mantissa bits
*/
i = ((uint64_t) float_as_uint32 (res)) << 32;
/* Convert the double reference to an "extended float". If the reference is
>= 2^129, we need to clamp to the maximum "extended float". If reference
is < 2^-126, we need to denormalize because of float's limited exponent
range.
*/
refi = double_as_uint64 (ref);
expoRef = (int)(((refi >> 52) & 0x7ff) - 1023);
if (expoRef >= 129) {
j = 0x7fffffffffffffffULL;
} else if (expoRef < -126) {
j = ((refi << 11) | 0x8000000000000000ULL) >> 8;
j = j >> (-(expoRef + 126));
} else {
j = ((refi << 11) & 0x7fffffffffffffffULL) >> 8;
j = j | ((uint64_t)(expoRef + 127) << 55);
}
j = j | (refi & 0x8000000000000000ULL);
err = (i < j) ? (j - i) : (i - j);
return err / 4294967296.0;
}
int main (void)
{
float arga, argb, res, reff;
uint32_t argai, argbi, resi, refi, diff;
double ref, ulp, maxulp = 0;
unsigned long long int count = 0;
do {
/* random positive inputs */
argai = KISS & 0x7fffffff;
argbi = KISS & 0x7fffffff;
/* increase occurence of zero, infinity */
if ((argai & 0xffff) == 0x5555) argai = 0x00000000;
if ((argbi & 0xffff) == 0x3333) argbi = 0x00000000;
if ((argai & 0xffff) == 0xaaaa) argai = 0x7f800000;
if ((argbi & 0xffff) == 0xcccc) argbi = 0x7f800000;
arga = uint32_as_float (argai);
argb = uint32_as_float (argbi);
res = sqrt_sum_pos (arga, argb);
ref = sqrt ((double)arga + (double)argb);
reff = (float)ref;
ulp = floatUlpErr (res, ref);
resi = float_as_uint32 (res);
refi = float_as_uint32 (reff);
diff = (refi > resi) ? (refi - resi) : (resi - refi);
if (diff > 1) {
/* if both source operands were NaNs, result could be either NaN,
quietened if necessary
*/
if (!(isnan (arga) && isnan (argb) &&
((resi == (argai | 0x00400000)) ||
(resi == (argbi | 0x00400000))))) {
printf ("\rerror: refi=%08x resi=%08x a=% 15.8e %08x b=% 15.8e %08x\n",
refi, resi, arga, argai, argb, argbi);
return EXIT_FAILURE;
}
}
if (ulp > maxulp) {
printf ("\rulp = %.9f @ a=%14.8e (%15.6a) b=%14.8e (%15.6a) a+b=%22.13a res=%15.6a ref=%22.13a\n",
ulp, arga, arga, argb, argb, (double)arga + argb, res, ref);
maxulp = ulp;
}
count++;
if (!(count & 0xffffff)) printf ("\r%llu", count);
} while (1);
printf ("\ntest passed\n");
return EXIT_SUCCESS;
}
Upvotes: 4