Jack Nock
Jack Nock

Reputation: 15513

Count number of times each bit is set in a range of integers

Given a range of integers from M to N, where M and N might not be a powers of 2. Is there an efficient way to count the number of times each bit is set?

For example the range 0 to 10

0   0000
1   0001
2   0010
3   0011
4   0100
5   0101
6   0110
7   0111
8   1000
9   1001
10  1010

I'd like the counts for the number of time each bit is set in each column which would be 3,4,5,5 in this case.

Upvotes: 7

Views: 1188

Answers (2)

OmnipotentEntity
OmnipotentEntity

Reputation: 17131

Each bit level has a pattern consisting of 2^power 0s followed by 2^power 1s.

So there are three cases:

  1. When M and N are such that M = 0 mod 2^(power+1) and N = 2^(power+1)-1 mod 2^(power+1). In this case the answer is simply (N-M+1) / 2

  2. When M and N are such that both M and N = the same number when integer divided by 2^(power+1). In this case there are a few subcases:

    1. Both M and N are such that both M and N = the same number when integer divided by 2^(power). In this case if N < 2^(power) mod 2^(power+1) then the answer is 0, else the answer is N-M+1
    2. Else they are different, in this case the answer is N - (N/2^(power+1))*2^(power+1) + 2**(power) (integer division) if N > 2^(power) mod 2^(power+1), else the answer is (M/2^(power+1))*2^(power+1) - 1 - M
  3. Last case is where M and N = different numbers when integer divided by 2^(power+1). This this case you can combine the techniques of 1 and 2. Find the number of numbers between M and (M/(2^(power+1)) + 1)*(2^(power+1)) - 1. Then between (M/(2^(power+1)) + 1)*(2^(power+1)) and (N/(2^(power+1)))*2^(power+1)-1. And finally between (N/(2^(power+1)))*2^(power+1) and N.

If this answer has logical bugs in it, let me know, it's complicated and I may have messed something up slightly.

UPDATE:

python implementation

def case1(M, N):
  return (N - M + 1) // 2

def case2(M, N, power):
  if (M > N):
    return 0
  if (M // 2**(power) == N // 2**(power)):
    if (N % 2**(power+1) < 2**(power)):
      return 0
    else:
      return N - M + 1
  else:
    if (N % 2**(power+1) >= 2**(power)):
      return N - (getNextLower(N,power+1) + 2**(power)) + 1
    else:
      return getNextHigher(M, power+1) - M


def case3(M, N, power):
  return case2(M, getNextHigher(M, power+1) - 1, power) + case1(getNextHigher(M, power+1), getNextLower(N, power+1)-1) + case2(getNextLower(N, power+1), N, power)

def getNextLower(M, power):
  return (M // 2**(power))*2**(power)

def getNextHigher(M, power):
  return (M // 2**(power) + 1)*2**(power)

def numSetBits(M, N, power):
  if (M % 2**(power+1) == 0 and N % 2**(power+1) == 2**(power+1)-1):
    return case1(M,N)
  if (M // 2**(power+1) == N // 2**(power+1)):
    return case2(M,N,power)
  else:
    return case3(M,N,power)

if (__name__ == "__main__"):
  print numSetBits(0,10,0)
  print numSetBits(0,10,1)
  print numSetBits(0,10,2)
  print numSetBits(0,10,3)
  print numSetBits(0,10,4)
  print numSetBits(5,18,0)
  print numSetBits(5,18,1)
  print numSetBits(5,18,2)
  print numSetBits(5,18,3)
  print numSetBits(5,18,4)

Upvotes: 8

theharshest
theharshest

Reputation: 7867

It can be kept as simple as -

Take x1 = 0001(for finding 1's at rightmost column), x2 = 0010, x3 = 0100 and so on..

Now, in a single loop -

n1 = n2 = n3 = 0
for i=m to n:
    n1 = n1 + (i & x1)
    n2 = n2 + (i & x2)
    n3 = n3 + (i & x3)

where - ni = number of 1's in i'th column(from right)

Upvotes: 0

Related Questions