Blaise
Blaise

Reputation: 43

Generating list of integers with given number of bit set and sum of bit indices

I would like to generate in an efficient way a list of integers (preferably ordered) with the following defining properties:

  1. All integers have the same number of bit set N.

  2. All integers have the same sum of bit indices K.

To be definite, for an integer I its binary representation is:

$I=\sum_{j=0}^M c_j 2^j$ where $c_j=0$ or $1$

The number of bit sets is:

$N(I)=\sum_{j=0}^M c_j$

The sum of bit indices is:

$K(I)=\sum_{j=0}^M j c_j$

I have an inefficient way to generate the list as follows: make a do/for loop over integers incrementing by use of a "snoob" function - smallest next integer with same number of bit set and at each increment checking if it has the correct value of K

this is grossly inefficient because in general starting from an integer with the correct N and K value the snoob integer from I does not have the correct K and one has to make many snoob calculations to get the next integer with both N and K equal to the chosen values. Using snoob gives an ordered list which is handy for dichotomic search but not absolutely compulsory.

Counting the number of elements in this list is easily done by recursion when viewed as a partition numner counting. here is a recursive function in fortran 90 doing that job:

=======================================================================
recursive function BoundedPartitionNumberQ(N, M, D)  result (res)
implicit none

  ! number of partitions of N into M distinct integers, bounded by D
  ! appropriate for Fermi counting rules

   integer(8) :: N, M, D, Nmin
   integer(8) :: res
    
    Nmin = M*(M+1)/2       ! the Fermi sea
    
    if(N < Nmin) then
        res = 0

    else if((N == Nmin) .and. (D >= M)) then
        res = 1

    else if(D < M) then
       res = 0

    else if(D == M)  then
       if(N == Nmin) then
              res = 1
       else 
              res = 0  
       endif

    else if(M == 0) then
       res = 0

     else

     res = BoundedPartitionNumberQ(N-M,M-1,D-1)+BoundedPartitionNumberQ(N-M,M,D-1)

     endif

    end function BoundedPartitionNumberQ
========================================================================================

My present solution is inefficient when I want to generate lists with several $10^7$ elements. Ultimately I want to stay within the realm of C/C++/Fortran and reach lists of lengths up to a few $10^9$

my present f90 code is the following:


program test
implicit none

integer(8) :: Nparticles
integer(8) :: Nmax, TmpL, CheckL, Nphi
integer(8) :: i, k, counter
integer(8) :: NextOne

Nphi = 31        ! word size is Nphi+1
Nparticles = 16  ! number of bit set

print*,Nparticles,Nphi

Nmax = ishft(1_8, Nphi + 1) - ishft(1_8, Nphi + 1 - Nparticles)

i = ishft(1, Nparticles) - 1

counter = 0

! integer CheckL is the sum of bit indices

CheckL = Nparticles*Nphi/2  ! the value of the sum giving the largest list

do while(i .le. Nmax)   ! we increment the integer

    TmpL = 0

    do k=0,Nphi
        if (btest(i,k)) TmpL = TmpL + k
    end do

    if (TmpL == CheckL) then    ! we check whether the sum of bit indices is OK

        counter = counter + 1

    end if

    i = NextOne(i)   ! a version of "snoob" described below

end do

print*,counter

end program

!==========================================================================
function NextOne (state)
implicit none

integer(8) :: bit    
integer(8) :: counter 
integer(8) :: NextOne,state,pstate

bit     =  1
counter = -1
  
!  find first one bit 

do  while (iand(bit,state) == 0)

    bit = ishft(bit,1)

end do

!  find next zero bit 

do  while (iand(bit,state) /= 0)
    
    counter = counter + 1
    bit = ishft(bit,1)

end do

if (bit == 0) then 

    print*,'overflow in NextOne'
    NextOne = not(0)
  
else 

    state = iand(state,not(bit-1))  ! clear lower bits i &= (~(bit-1));

    pstate = ishft(1_8,counter)-1 ! needed by IBM/Zahir compiler

 !  state = ior(state,ior(bit,ishft(1,counter)-1)) ! short version OK with gcc

    state = ior(state,ior(bit,pstate))

    NextOne = state

end if

end function NextOne

Upvotes: 4

Views: 525

Answers (2)

OffBy0x01
OffBy0x01

Reputation: 318

Since you mentioned C/C++/Fortran, I've tried to keep this relatively language agnostic/easily transferable but have also included faster builtins alternatives where applicable.

All integers have the same number of bit set N

Then we can also say, all valid integers will be permutations of N set bits.

First, we must generate the initial/min permutation:

uint32_t firstPermutation(uint32_t n){
    // Fill the first n bits (on the right)
    return (1 << n) -1;
}

Next, we must set the final/max permutation - indicating the 'stop point':

uint32_t lastPermutation(uint32_t n){
    // Fill the last n bits (on the left)
    return (0xFFFFFFFF >> n) ^ 0xFFFFFFFF;
}

Finally, we need a way to get the next permutation.

uint32_t nextPermutation(uint32_t n){
    uint32_t t = (n | (n - 1)) + 1;
    return t | ((((t & -t) / (n & -n)) >> 1) - 1);
}

// or with builtins:
uint32_t nextPermutation(uint32_t &p){
    uint32_t t = (p | (p - 1));
    return (t + 1) | (((~t & -~t) - 1) >> (__builtin_ctz(p) + 1));
}

All integers have the same sum of bit indices K

Assuming these are integers (32bit), you can use this DeBruijn sequence to quickly identify the index of the first set bit - fsb. Similar sequences exist for other types/bitcounts, for example this one could be adapted for use.

By stripping the current fsb, we can apply the aforementioned technique to identify index of the next fsb, and so on.

int sumIndices(uint32_t n){
    const int MultiplyDeBruijnBitPosition[32] = {
      0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
      31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9
    };

    int sum = 0;
    // Get fsb idx
    do sum += MultiplyDeBruijnBitPosition[((uint32_t)((n & -n) * 0x077CB531U)) >> 27];        
    // strip fsb
    while (n &= n-1);   

    return sum;
}

// or with builtin
int sumIndices(uint32_t n){
    int sum = 0;
    do sum += __builtin_ctz(n);
    while (n &= n-1);
    return sum;
}

Finally, we can iterate over each permutation, checking if the sum of all indices matches the specified K value.

p = firstPermutation(n);
lp = lastPermutation(n);

do {
    p = nextPermutation(p);
    if (sumIndices(p) == k){
        std::cout << "p:" << p << std::endl;
    }
} while(p != lp);

You could easily change the 'handler' code to do something similar starting at a given integer - using it's N & K values.


benchmark of intrinsic vs self-implemented

Upvotes: 2

user555045
user555045

Reputation: 64904

A basic recursive implementation could be:

void listIntegersWithWeight(int currentBitCount, int currentWeight, uint32_t pattern, int index, int n, int k, std::vector<uint32_t> &res)
{
    if (currentBitCount > n ||
        currentWeight > k)
        return;

    if (index < 0)
    {
        if (currentBitCount == n && currentWeight == k)
            res.push_back(pattern);
    }
    else
    {
        listIntegersWithWeight(currentBitCount, currentWeight, pattern, index - 1, n, k, res);
        listIntegersWithWeight(currentBitCount + 1, currentWeight + index, pattern | (1u << index), index - 1, n, k, res);
    }
}

That is not my suggestion, just the starting point. On my PC, for n = 16, k = 248, both this version and the iterative version take almost (but not quite) 9 seconds. Almost exactly the same amount of time, but that's just a coincidence. More pruning can be done:

  • currentBitCount + index + 1 < n if the number of set bits cannot reach n with the number of unfilled positions that are left, continuing is pointless.
  • currentWeight + (index * (index + 1) / 2) < k if the sum of positions cannot reach k, continuing is pointless.

Together:

void listIntegersWithWeight(int currentBitCount, int currentWeight, uint32_t pattern, int index, int n, int k, std::vector<uint32_t> &res)
{
    if (currentBitCount > n || 
        currentWeight > k ||
        currentBitCount + index + 1 < n ||
        currentWeight + (index * (index + 1) / 2) < k)
        return;

    if (index < 0)
    {
        if (currentBitCount == n && currentWeight == k)
            res.push_back(pattern);
    }
    else
    {
        listIntegersWithWeight(currentBitCount, currentWeight, pattern, index - 1, n, k, res);
        listIntegersWithWeight(currentBitCount + 1, currentWeight + index, pattern | (1u << index), index - 1, n, k, res);
    }
}

On my PC with the same parameters, this only takes half a second. It can probably be improved further.

Upvotes: 1

Related Questions