Reputation: 95
I meant to modle Thompson sampling, but the following code gives Domain error in arguments. I search about it and find someone says it may because Beta's parameters are negative, but here the paraameters must be positive,I have no idea how to fix it.
from scipy import stats
class TS():
def __init__(self,alpha,beta,n):
self.alpha=alpha
self.beta=beta
self.n=n
self.value=[0,0,0]#estimator
self.prob=[0.4,0.6,0.8]
def generate(self):
for j in range(self.n):
tmp=[0,0,0]
for i in range(0,3):
tmp[i]=stats.beta.rvs(self.alpha[i],self.beta[i])
max_index=tmp.index(max(tmp))
r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))
self.alpha[max_index]+=r
self.beta[max_index]-=(1-r)
print(self.value)
one=TS([1,1,1],[1,1,1],100)
one.generate()
error:
ValueError Traceback (most recent call last)
<ipython-input-18-9df20b1a6a3b> in <module>()
23 print(self.value)
24 one=TS([1,1,1],[1,1,1],100)
---> 25 one.generate()
<ipython-input-18-9df20b1a6a3b> in generate(self)
16 tmp=[0,0,0]
17 for i in range(0,3):
---> 18 tmp[i]=stats.beta.rvs(self.alpha[i],self.beta[i])
19 max_index=tmp.index(max(tmp))
20 r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))
C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py in rvs(self, *args, **kwds)
938 cond = logical_and(self._argcheck(*args), (scale >= 0))
939 if not np.all(cond):
--> 940 raise ValueError("Domain error in arguments.")
941
942 if np.all(scale == 0):
ValueError: Domain error in arguments.
Upvotes: 1
Views: 1611
Reputation: 114811
self.beta[max_index]
should be increased by 1 - r
. Change the line that updates self.beta[max_index]
to
self.beta[max_index] += 1 - r
See Algorithm 2 on page 15 of https://web.stanford.edu/~bvr/pubs/TS_Tutorial.pdf.
Upvotes: 1