anmo
anmo

Reputation: 95

"Domain error in arguments." with "stats.beta.rvs"

I meant to modle Thompson sampling, but the following code gives Domain error in arguments. I search about it and find someone says it may because Beta's parameters are negative, but here the paraameters must be positive,I have no idea how to fix it.

from scipy import stats   
class TS():
    def __init__(self,alpha,beta,n):
        self.alpha=alpha
        self.beta=beta
        self.n=n
        self.value=[0,0,0]#estimator
        self.prob=[0.4,0.6,0.8]
    def generate(self):
        for j in range(self.n):
            tmp=[0,0,0]
            for i in range(0,3):
                tmp[i]=stats.beta.rvs(self.alpha[i],self.beta[i])
            max_index=tmp.index(max(tmp))
            r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))
            self.alpha[max_index]+=r
            self.beta[max_index]-=(1-r)
        print(self.value)
one=TS([1,1,1],[1,1,1],100)
one.generate()

error:

ValueError                                Traceback (most recent call last)
<ipython-input-18-9df20b1a6a3b> in <module>()
     23         print(self.value)
     24 one=TS([1,1,1],[1,1,1],100)
---> 25 one.generate()

<ipython-input-18-9df20b1a6a3b> in generate(self)
     16             tmp=[0,0,0]
     17             for i in range(0,3):
---> 18                 tmp[i]=stats.beta.rvs(self.alpha[i],self.beta[i])
     19             max_index=tmp.index(max(tmp))
     20             r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))

C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py in rvs(self, *args, **kwds)
    938         cond = logical_and(self._argcheck(*args), (scale >= 0))
    939         if not np.all(cond):
--> 940             raise ValueError("Domain error in arguments.")
    941 
    942         if np.all(scale == 0):

ValueError: Domain error in arguments.

Upvotes: 1

Views: 1611

Answers (1)

Warren Weckesser
Warren Weckesser

Reputation: 114811

self.beta[max_index] should be increased by 1 - r. Change the line that updates self.beta[max_index] to

        self.beta[max_index] +=  1 - r

See Algorithm 2 on page 15 of https://web.stanford.edu/~bvr/pubs/TS_Tutorial.pdf.

Upvotes: 1

Related Questions