Reputation: 15513
Given a range of integers from M to N, where M and N might not be a powers of 2. Is there an efficient way to count the number of times each bit is set?
For example the range 0 to 10
0 0000
1 0001
2 0010
3 0011
4 0100
5 0101
6 0110
7 0111
8 1000
9 1001
10 1010
I'd like the counts for the number of time each bit is set in each column which would be 3,4,5,5 in this case.
Upvotes: 7
Views: 1188
Reputation: 17131
Each bit level has a pattern consisting of 2^power
0s followed by 2^power
1s.
So there are three cases:
When M
and N
are such that M = 0 mod 2^(power+1)
and N = 2^(power+1)-1 mod 2^(power+1)
. In this case the answer is simply (N-M+1) / 2
When M
and N
are such that both M and N = the same number when integer divided by 2^(power+1)
. In this case there are a few subcases:
M
and N
are such that both M
and N
= the same number when integer divided by 2^(power)
. In this case if N < 2^(power) mod 2^(power+1)
then the answer is 0
, else the answer is N-M+1
N - (N/2^(power+1))*2^(power+1) + 2**(power)
(integer division) if N > 2^(power) mod 2^(power+1)
, else the answer is (M/2^(power+1))*2^(power+1) - 1 - M
Last case is where M and N = different numbers when integer divided by 2^(power+1)
. This this case you can combine the techniques of 1 and 2. Find the number of numbers between M
and (M/(2^(power+1)) + 1)*(2^(power+1)) - 1
. Then between (M/(2^(power+1)) + 1)*(2^(power+1))
and (N/(2^(power+1)))*2^(power+1)-1
. And finally between (N/(2^(power+1)))*2^(power+1)
and N
.
If this answer has logical bugs in it, let me know, it's complicated and I may have messed something up slightly.
UPDATE:
python implementation
def case1(M, N):
return (N - M + 1) // 2
def case2(M, N, power):
if (M > N):
return 0
if (M // 2**(power) == N // 2**(power)):
if (N % 2**(power+1) < 2**(power)):
return 0
else:
return N - M + 1
else:
if (N % 2**(power+1) >= 2**(power)):
return N - (getNextLower(N,power+1) + 2**(power)) + 1
else:
return getNextHigher(M, power+1) - M
def case3(M, N, power):
return case2(M, getNextHigher(M, power+1) - 1, power) + case1(getNextHigher(M, power+1), getNextLower(N, power+1)-1) + case2(getNextLower(N, power+1), N, power)
def getNextLower(M, power):
return (M // 2**(power))*2**(power)
def getNextHigher(M, power):
return (M // 2**(power) + 1)*2**(power)
def numSetBits(M, N, power):
if (M % 2**(power+1) == 0 and N % 2**(power+1) == 2**(power+1)-1):
return case1(M,N)
if (M // 2**(power+1) == N // 2**(power+1)):
return case2(M,N,power)
else:
return case3(M,N,power)
if (__name__ == "__main__"):
print numSetBits(0,10,0)
print numSetBits(0,10,1)
print numSetBits(0,10,2)
print numSetBits(0,10,3)
print numSetBits(0,10,4)
print numSetBits(5,18,0)
print numSetBits(5,18,1)
print numSetBits(5,18,2)
print numSetBits(5,18,3)
print numSetBits(5,18,4)
Upvotes: 8
Reputation: 7867
It can be kept as simple as -
Take x1 = 0001(for finding 1's at rightmost column), x2 = 0010, x3 = 0100 and so on..
Now, in a single loop -
n1 = n2 = n3 = 0
for i=m to n:
n1 = n1 + (i & x1)
n2 = n2 + (i & x2)
n3 = n3 + (i & x3)
where - ni = number of 1's in i'th column(from right)
Upvotes: 0