sds
sds

Reputation: 60044

Fast access to sums of pairwise ops

Given a vector of numbers v, I can access sums of sections of this vector by using cumulative sums, i.e., instead of O(n)

v = [1,2,3,4,5]
def sum_v(i,j):
    return sum(v[i:j])

I can do O(1)

import itertools
v = [1,2,3,4,5]
cache = [0]+list(itertools.accumulate(v))
def sum_v(i,j):
    return cache[j] - cache[i]

Now, I need something similar but for pairwise instead of sum_v:

def pairwise(i,j):
    ret = 0
    for p in range(i,j):
        for q in range(p+1,j):
            ret += f(v(p),v(q))
    return ret

where f is, preferably, something relatively arbitrary (e.g., * or ^ or ...). However, something working for just product or just XOR would be good too.

PS1. I am looking for a speed-up in terms of O, not generic memoization such as functools.cache.

PS2. The question is about algorithms, not implementations, and is thus language-agnostic. I tagged it python only because my examples are in python.

PS3. Obviously, one can precompute all values of pairwise, so the solution should be o(n^2) both in time and space (preferably linear).

Upvotes: 4

Views: 134

Answers (2)

kaya3
kaya3

Reputation: 51093

In principle, you can always precompute every possible output in Θ(n²) space and then answer queries in Θ(1) by just looking it up in the precomputed table. Everything else is a trade-off depending on the cost of precomputation time, space, and actual computation time; the interesting question is what can be done with o(n²) space, i.e. sub-quadratic. This will generally depend on the application, and on properties of the binary operation f.

In the particular case where f is *, we can get Θ(1) lookups with only Θ(n) space: we'll take advantage that the sum for pairs where p < q equals the sum of all pairs, minus the sum of pairs where p = q, divided by 2 to account for the pairs where p > q.

# input data
v = [1, 2, 3, 4, 5]
n = len(v)

# precomputation
partial_sums = [0] * (n + 1)
partial_sums_squares = [0] * (n + 1)
for i, x in enumerate(v):
    partial_sums[i + 1] = partial_sums[i] + x
    partial_sums_squares[i + 1] = partial_sums_squares[i] + x * x

# query response
def pairwise(i, j):
    s = partial_sums[j] - partial_sums[i]
    s2 = partial_sums_squares[j] - partial_sums_squares[i]
    return (s * s - s2) / 2

More generally, this works whenever f is commutative and distributes over the accumulator operation (+ in this case). I wrote the example here without itertools, so that it is more easily translatable to other languages, since the question is meant to be language-agnostic.

Upvotes: 1

Abhinav Mathur
Abhinav Mathur

Reputation: 8111

For binary operations such as or, and, xor, an O(N) algorithm is possible.
Let's consider XOR for this example, but this can be easily modified for OR/AND as well.
The most important thing to note here is, the result of a binary operator on bit x of two numbers will not affect the result for bit y. (You can easily see that by trying something like 010 ^ 011 = 001. So we first count the contribution made by the leftmost bits of all numbers to the final sum, then the next least significant bit, and so on. Here's a simple algo/pseudocode for that:

  1. Construct a simple table dp, where dp[i][j] = count of numbers in range [i,n) with jth bit set
l = [5,3,1,7,8]
n = len(l)
ans = 0

max_binary_length = max(log2(i) for i in l)+1  #maximum number of bits we need to check

for j in range(max_binary_length):
    # we check the jth bits of all numbers here
    for i in range(0,n):
        # we need sum((l[i]^l[j]) for j in range (i+1,n))
        current = l[i]
        if jth bit of current == 0:
            # since 0^1 = 1, we need count of numbers with jth bit 1
            count = dp[i+1][j]
        else:
            # we need count of numbers with jth bit 0
            count = (n-i)-dp[i+1][j] 
            # the indexing could be slightly off, you can check that once
        ans += count * (2^j)
        # since we're checking the jth bit, it will have a value of 2^j when set
print(ans)

In most cases, for integers, number of bits <= 32. So this should have a complexity of O(N*log2(max(A[i]))) == O(N*32) == O(N).

Upvotes: 1

Related Questions