user2352497
user2352497

Reputation: 85

Fenwick trees to determine which interval a point falls in

Let a0,...,an-1 be a sequence of lengths. We can construct intervals [0,a0], (a1,a2+a1],(a2+a1,a3+a2+a1],... I store the sequence a1,...,an-1 in a Fenwick tree.

I ask the question: given a number m, how can I efficiently (log n time) find into which interval m falls?

For example, given the a: 3, 5, 2, 7, 9, 4.

The Fenwick Tree stores 3, 8, 2, 17, 9, 13.

The intervals are [0,3],(3,8],(8,10],(10,17],(17,26],(26,30].

Given the number 9, the algorithm should return the 3rd index of the Fenwick Tree (2 if 0-based arrays are used, 3 if 1-based arrays are used). Given the number 26, the algorithm should return the 5th index of the Fenwick Tree (4 if 0-based arrays are used or 5 if 1-based arrays are used).

Possibly another data structure might be more suited to this operation. I am using Fenwick Trees because of their seeming simplicity and efficiency.

Upvotes: 0

Views: 174

Answers (2)

David Eisenstat
David Eisenstat

Reputation: 65498

We can get an O(log n)-time search operation. The trick is to integrate the binary search with the prefix sum operation.

def get_total(tree, i):
    total = 0
    while i > 0:
        total += tree[i - 1]
        i -= i & (-i)
    return total


def search(tree, total):
    j = 1
    while j < len(tree):
        j <<= 1
    j >>= 1
    i = -1
    while j > 0:
        if i + j < len(tree) and total > tree[i + j]:
            total -= tree[i + j]
            i += j
        j >>= 1
    return i + 1


tree = [3, 8, 2, 17, 9, 13]
print('Intervals')
for i in range(len(tree)):
    print(get_total(tree, i), get_total(tree, i + 1))
print('Searches')
for total in range(31):
    print(total, search(tree, total))

Output is

Intervals
0 3
3 8
8 10
10 17
17 26
26 30
Searches
0 0
1 0
2 0
3 0
4 1
5 1
6 1
7 1
8 1
9 2
10 2
11 3
12 3
13 3
14 3
15 3
16 3
17 3
18 4
19 4
20 4
21 4
22 4
23 4
24 4
25 4
26 4
27 5
28 5
29 5
30 5

Upvotes: 1

Juan Lopes
Juan Lopes

Reputation: 10585

If the intervals don't change frequently, you can use a simple binary search in the accumulated array to do that. In Python you can use the bisect module to do that. Each query will be O(log n):

import bisect

A = [3, 5, 2, 7, 9, 4]

for i in xrange(1, len(A)):
    A[i] += A[i-1]

print bisect.bisect_left(A, 9)
print bisect.bisect_left(A, 26)

If the intervals change, you can use the same idea, but each array lookup will be O(log n), making the query operation O(log² n) overall.

Upvotes: 0

Related Questions