Reputation: 4033
I was going through the BERT repo and found the following piece of code:
for _ in range(10):
random_document_index = rng.randint(0, len(all_documents) - 1)
if random_document_index != document_index:
break
The idea here being to generate a random integer on [0, len(all_documents)-1]
that cannot equal document_index
. Because len(all_documents)
is suppose to be a very large number, the first iteration is almost guaranteed to produce a valid randint, but just to be safe, they try it for 10 iterations. I can't help but think there has to be a better way to do this.
I found this answer which is easy enough to implement in python:
random_document_index = rng.randint(0, len(all_documents) - 2)
random_document_index += 1 if random_document_index >= document_index else 0
I was just wondering if there's a better way to achieve this in python using the in-built functions (or even with numpy
), or if this is the best you can do.
Upvotes: 0
Views: 147
Reputation: 13090
Had len(all_documents)
been small, a pretty solution would be to realize all valid numbers (e.g. in a list
) and use random.choice()
. Since your len(all_documents)
is supposedly large, this solution will waste a lot of memory.
A more memory efficient solution is to stick with the original strategy. It's really very reasonable for large len(all_documents)
where a single iteration is very likely to be enough, though the hard-coded 10
is ugly. A pretty one-line solution would be to make use of the new walrus operator in Python 3.8:
while (random_document_index := rng.randint(0, len(all_documents) - 1)) == document_index: pass
Upvotes: 2
Reputation: 26886
Perhaps a more elegant way of picking integers with holes is to use random.choice()
:
import random
seq = [0, 1, 3, 4, 6, 7]
random.choice(seq)
the drawback is that it requires a sequence, which, in the case of a simple list
may not be efficient in your case, and it is generally not efficient if the size of the range is much larger than the number of invalid values. In that case, a more efficient approach would be to create custom generating sequence with the knowledge of the "holes" only.
Such implementation would take the form of a non-contiguous range (without step support) with invalid numbers, implementing the Sequence
interface:
class NonContRange(object):
def __init__(self, start, stop, invalid=None):
self.start = start
self.stop = stop
self.invalid = invalid if invalid else set()
def __len__(self):
return self.stop - self.start - len(self.invalid)
def __getitem__(self, i):
offset = 0
for invalid in sorted(self.invalid):
if invalid <= self.start + i + offset:
offset += 1
return self.start + i + offset
def __iter__(self):
for i in range(self.start, self.stop):
if i not in self.invalid:
yield i
def __reversed__(self):
for i in range(self.stop - 1, self.start - 1, -1):
if i not in self.invalid:
yield i
def index(self, x):
if x in self.invalid:
raise ValueError(f'{x} not in sequence.')
else:
offset = sum(1 for y in self.invalid if y < x)
return x - self.start - offset
def count(self, x):
return 0 if x in self.invalid or not (self.start <= x < self.stop) else 1
def __str__(self):
return f'NonContRange({self.start}, {self.stop}, ¬{sorted(self.invalid)})'
A few tests:
seq = NonContRange(10, 20, {12, 15, 16})
print(seq)
# NonContRange(10, 20, ¬[12, 15, 16])
print(list(seq))
# [10, 11, 13, 14, 17, 18, 19]
print(list(reversed(seq)))
# [19, 18, 17, 14, 13, 11, 10]
print([seq[i] for i in range(len(seq))])
# [10, 11, 13, 14, 17, 18, 19]
print(list(seq).count(19))
# 1
print(list(seq).count(12))
# 0
and this can be safely used with random.choice()
:
import random
invalid = {12, 17}
seq = NonContRange(10, 20, invalid)
print(all(random.choice(seq) not in invalid for _ in range(10000)))
# True
This is of course very nice in the general case, but for your specific situation it looks more like killing a fly with a cannonball.
Upvotes: 1