Reputation: 60044
Given a vector of numbers v
, I can access sums of sections of this vector by using cumulative sums, i.e., instead of O(n)
v = [1,2,3,4,5]
def sum_v(i,j):
return sum(v[i:j])
I can do O(1)
import itertools
v = [1,2,3,4,5]
cache = [0]+list(itertools.accumulate(v))
def sum_v(i,j):
return cache[j] - cache[i]
Now, I need something similar but for pairwise
instead of sum_v
:
def pairwise(i,j):
ret = 0
for p in range(i,j):
for q in range(p+1,j):
ret += f(v(p),v(q))
return ret
where f
is, preferably, something relatively arbitrary (e.g., *
or ^
or ...). However, something working for just product or just XOR would be good too.
PS1. I am looking for a speed-up in terms of O
, not generic memoization such as functools.cache
.
PS2. The question is about algorithms, not implementations, and is thus language-agnostic. I tagged it python
only because my examples are in python.
PS3. Obviously, one can precompute all values of pairwise
, so the solution should be o(n^2)
both in time and space (preferably linear).
Upvotes: 4
Views: 134
Reputation: 51093
In principle, you can always precompute every possible output in Θ(n²) space and then answer queries in Θ(1) by just looking it up in the precomputed table. Everything else is a trade-off depending on the cost of precomputation time, space, and actual computation time; the interesting question is what can be done with o(n²) space, i.e. sub-quadratic. This will generally depend on the application, and on properties of the binary operation f
.
In the particular case where f
is *
, we can get Θ(1) lookups with only Θ(n) space: we'll take advantage that the sum for pairs where p < q
equals the sum of all pairs, minus the sum of pairs where p = q
, divided by 2 to account for the pairs where p > q
.
# input data
v = [1, 2, 3, 4, 5]
n = len(v)
# precomputation
partial_sums = [0] * (n + 1)
partial_sums_squares = [0] * (n + 1)
for i, x in enumerate(v):
partial_sums[i + 1] = partial_sums[i] + x
partial_sums_squares[i + 1] = partial_sums_squares[i] + x * x
# query response
def pairwise(i, j):
s = partial_sums[j] - partial_sums[i]
s2 = partial_sums_squares[j] - partial_sums_squares[i]
return (s * s - s2) / 2
More generally, this works whenever f
is commutative and distributes over the accumulator operation (+
in this case). I wrote the example here without itertools, so that it is more easily translatable to other languages, since the question is meant to be language-agnostic.
Upvotes: 1
Reputation: 8111
For binary operations such as or, and, xor
, an O(N)
algorithm is possible.
Let's consider XOR for this example, but this can be easily modified for OR/AND as well.
The most important thing to note here is, the result of a binary operator on bit x
of two numbers will not affect the result for bit y
. (You can easily see that by trying something like 010 ^ 011 = 001
. So we first count the contribution made by the leftmost bits of all numbers to the final sum, then the next least significant bit, and so on. Here's a simple algo/pseudocode for that:
dp
, where dp[i][j] = count of numbers in range [i,n) with jth bit set
l = [5,3,1,7,8]
n = len(l)
ans = 0
max_binary_length = max(log2(i) for i in l)+1 #maximum number of bits we need to check
for j in range(max_binary_length):
# we check the jth bits of all numbers here
for i in range(0,n):
# we need sum((l[i]^l[j]) for j in range (i+1,n))
current = l[i]
if jth bit of current == 0:
# since 0^1 = 1, we need count of numbers with jth bit 1
count = dp[i+1][j]
else:
# we need count of numbers with jth bit 0
count = (n-i)-dp[i+1][j]
# the indexing could be slightly off, you can check that once
ans += count * (2^j)
# since we're checking the jth bit, it will have a value of 2^j when set
print(ans)
In most cases, for integers, number of bits <= 32. So this should have a complexity of O(N*log2(max(A[i])))
== O(N*32)
== O(N)
.
Upvotes: 1