Reputation: 1
import numpy as np
Q = np.loadtxt(open("D:\data_homework_4\Q.csv","rb"), delimiter = ",", skiprows = 0)
b = np.loadtxt(open("D:\data_homework_4\B.csv","rb"), delimiter = ",", skiprows = 0)
def f(x):
return 1/2 * x.T @ Q @ x + b.T @ x
def gradient(x):
return Q @ x - b
n = 2000
x_t = np.zeros((100, 1))
alpha = 0.1
beta = 0.3
eta_t = 3.887e-6
for t in range(n + 2):
g_t = gradient(x_t)
k = 0
while True:
if f(x_t - beta**k * g_t) <= f(x_t) - alpha * beta**k * np.linalg.norm(g_t)**2:
eta_t = beta**k
break
k += 1
x_t -= eta_t * g_t
print(x_t)
line 21 is
if f(x_t - beta**k * g_t) <= f(x_t) - alpha * beta**k * np.linalg.norm(g_t)**2:
Q is 100x100, b is 100x1, x is 100x1. I looked up similiar errors, but none of them are like mine. Can somebody help me with this error. Thank you.
Upvotes: 0
Views: 159
Reputation: 1441
As it says, in the if
condition you are comparing between 2 arrays - both of which has multiple values and the if condition evaluates all of them but doesn't know how to collapse them into a single value of Truth - that's why it's asking you to use any
or all
:
Try this for example :
import numpy as np
arr = np.array([2,3])
arr1 = np.array([1,4])
arr, arr1
if (arr<arr1):
pass
It would give you the same error that you have.
And to solve that I've added an
all
condition so that all the elements in those arrays have to satisfy the<
condition
So:
import numpy as np
arr = np.array([2,3])
arr1 = np.array([1,4])
arr, arr1
if (arr<arr1).all():
pass
Think about what makes sense in your case and use that (be it any
or all
)
Upvotes: 1