Jay Mody
Jay Mody

Reputation: 4033

Generating a random integer in a range while excluding a number in the given range

I was going through the BERT repo and found the following piece of code:

for _ in range(10):
    random_document_index = rng.randint(0, len(all_documents) - 1)
    if random_document_index != document_index:
        break

The idea here being to generate a random integer on [0, len(all_documents)-1] that cannot equal document_index. Because len(all_documents) is suppose to be a very large number, the first iteration is almost guaranteed to produce a valid randint, but just to be safe, they try it for 10 iterations. I can't help but think there has to be a better way to do this.

I found this answer which is easy enough to implement in python:

random_document_index = rng.randint(0, len(all_documents) - 2)
random_document_index += 1 if random_document_index >= document_index else 0

I was just wondering if there's a better way to achieve this in python using the in-built functions (or even with numpy), or if this is the best you can do.

Upvotes: 0

Views: 147

Answers (2)

jmd_dk
jmd_dk

Reputation: 13090

Had len(all_documents) been small, a pretty solution would be to realize all valid numbers (e.g. in a list) and use random.choice(). Since your len(all_documents) is supposedly large, this solution will waste a lot of memory.

A more memory efficient solution is to stick with the original strategy. It's really very reasonable for large len(all_documents) where a single iteration is very likely to be enough, though the hard-coded 10 is ugly. A pretty one-line solution would be to make use of the new walrus operator in Python 3.8:

while (random_document_index := rng.randint(0, len(all_documents) - 1)) == document_index: pass

Upvotes: 2

norok2
norok2

Reputation: 26886

Perhaps a more elegant way of picking integers with holes is to use random.choice():

import random


seq = [0, 1, 3, 4, 6, 7]
random.choice(seq)

the drawback is that it requires a sequence, which, in the case of a simple list may not be efficient in your case, and it is generally not efficient if the size of the range is much larger than the number of invalid values. In that case, a more efficient approach would be to create custom generating sequence with the knowledge of the "holes" only.


EDIT

Such implementation would take the form of a non-contiguous range (without step support) with invalid numbers, implementing the Sequence interface:

class NonContRange(object):
    def __init__(self, start, stop, invalid=None):
        self.start = start
        self.stop = stop
        self.invalid = invalid if invalid else set()

    def __len__(self):
        return self.stop - self.start - len(self.invalid)

    def __getitem__(self, i):
        offset = 0
        for invalid in sorted(self.invalid):
            if invalid <= self.start + i + offset:
                offset += 1
        return self.start + i + offset

    def __iter__(self):
        for i in range(self.start, self.stop):
            if i not in self.invalid:
                yield i

    def __reversed__(self):
        for i in range(self.stop - 1, self.start - 1, -1):
            if i not in self.invalid:
                yield i

    def index(self, x):
        if x in self.invalid:
            raise ValueError(f'{x} not in sequence.')
        else:
            offset = sum(1 for y in self.invalid if y < x)
            return x - self.start - offset

    def count(self, x):
        return 0 if x in self.invalid or not (self.start <= x < self.stop) else 1

    def __str__(self):
        return f'NonContRange({self.start}, {self.stop}, ­­­¬{sorted(self.invalid)})'

A few tests:

seq = NonContRange(10, 20, {12, 15, 16})
print(seq)
# NonContRange(10, 20, ­­­¬[12, 15, 16])
print(list(seq))
# [10, 11, 13, 14, 17, 18, 19]
print(list(reversed(seq)))
# [19, 18, 17, 14, 13, 11, 10]
print([seq[i] for i in range(len(seq))])
# [10, 11, 13, 14, 17, 18, 19]
print(list(seq).count(19))
# 1
print(list(seq).count(12))
# 0

and this can be safely used with random.choice():

import random


invalid = {12, 17}
seq = NonContRange(10, 20, invalid)
print(all(random.choice(seq) not in invalid for _ in range(10000)))
# True

This is of course very nice in the general case, but for your specific situation it looks more like killing a fly with a cannonball.

Upvotes: 1

Related Questions