Reputation: 1055
I'm developing an ABtest framework using django. I want to assign variant number based on bucket_id from cookies' request.
bucket_id is set by the front end with a range integer from 0-99.
So far, I have created the function name get_bucket_name:
def get_bucket_range(data):
range_bucket = []
first_val = 0
next_val = 0
for i, v in enumerate(data.split(",")):
v = int(v)
if i == 0:
first_val = v
range_bucket.append([0, first_val])
elif i == 1:
range_bucket.append([first_val, first_val + v])
next_val = first_val + v
else:
range_bucket.append([next_val, next_val + v])
next_val = next_val + v
return range_bucket
Data input for get_bucket_range is a comma delineated string which means we have 3 variants where each variant has its own weight e.g. data = "25,25,50" with first variant's weight being 25 etc.
I then created a function to assign the variant named,
def assign_variant(range_bucket, num):
for i in range(len(range_bucket)):
if num in range(range_bucket[i][0], range_bucket[i][1]):
return i
This function should have 2 parameters, range_bucket -> from get_bucket_range function, and num -> bucket_id from cookies.
With this function I can return which bucket_id belongs to the variant id.
For example, we have 25 as bucket_id, with data = "25,25,50". This means our bucket_id should belong to variant id 1. Or in the case that we have 25 as bucket_id, with data = "10,10,10,70". This should mean that our bucket_id will belong to variant id 2.
However, it feels like neither of my functions are pythonic or optimised. Does anyone here have any suggestions as to how I could improve my code?
Upvotes: 0
Views: 1677
Reputation: 427
You can greatly reduce the lengths of your functions with the itertools.accumulate and bisect.bisect functions. The first function accumulates all the weights into sums (10,10,10,70
becomes 10,20,30,100
), and the second function gives you the index of where that element would belong, which in your case is equivalent to the index of the group it belongs to.
from itertools import accumulate
from bisect import bisect
def get_bucket_range(data):
return list(accumulate(map(int, data.split(',')))
def assign_variant(range_bucket, num):
return bisect(range_bucket, num)
Upvotes: 1
Reputation: 4209
Your functions could look like this for example:
def get_bucket_range(data):
last = 0
range_bucket = []
for v in map(int, data.split(',')):
range_bucket.append([last, last+v])
last += v
return range_bucket
def assign_variant(range_bucket, num):
for i, (low, high) in enumerate(range_bucket):
if low <= num < high:
return i
Upvotes: 1