Reputation: 2206
I am using scipy.optimize.fmin
to optimize the Rosenbrock function:
import scipy
import bumpy as np
def rosen(x):
"""The Rosenbrock function"""
return sum(100.0*(x[1:]-x[:-1]**2.0)**2.0 + (1-x[:-1])**2.0)
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
scipy.optimize.fmin(rosen, x0, full_output=True)
this returns a tuple for the solution (parameter that minimizes the function, the function minimum, number of iterations, number of function calls).
However I would like to be able to graph the values at each step. For example I would to plot the iteration number along the x-axis and the running minimum value along the y-axis.
Upvotes: 0
Views: 1110
Reputation: 47
Thanks Randy for the answer.
Because some optimization methods emits more arguments when call callback (for example: 'trust-constr'), I found this solution more effective:
import numpy
steps = []
def save_step(*args):
for arg in args:
if type(arg) is numpy.ndarray:
steps.append(arg)
And example of usage is:
def f(x):
x1 = x[0]
x2 = x[1]
return -(-4 * x1 * x1 - 4 * x2 * x2 + 4 * x1 * x2 + 8 * x1 + 20 * x2)
def gradient(x):
x1 = x[0]
x2 = x[1]
return np.array([
-(-8 * x1 + 4 * x2 + 8),
-(-8 * x2 + 4 * x1 + 20)
])
start_point = np.zeros(2)
result = minimize(
fun=f,
x0=start_point,
method='trust-constr',
jac=gradient,
callback=save_step
)
print(result)
print(steps)
Upvotes: 0
Reputation: 14847
fmin can take an optional callback function that gets called at each step, so you can just create a simple one that grabs the values at each step:
def save_step(k):
global steps
steps.append(k)
steps = []
scipy.optimize.fmin(rosen, x0, full_output=True, callback=save_step)
print np.array(steps)[:10]
Output:
[[ 1.339 0.721 0.824 1.71 1.236 ]
[ 1.339 0.721 0.824 1.71 1.236 ]
[ 1.339 0.721 0.824 1.71 1.236 ]
[ 1.339 0.721 0.824 1.71 1.236 ]
[ 1.2877696 0.7417984 0.8013696 1.587184 1.3580544 ]
[ 1.28043136 0.76687744 0.88219136 1.3994944 1.29688704]
[ 1.28043136 0.76687744 0.88219136 1.3994944 1.29688704]
[ 1.28043136 0.76687744 0.88219136 1.3994944 1.29688704]
[ 1.35935594 0.83266045 0.8240753 1.02414244 1.38852256]
[ 1.30094767 0.80530982 0.85898166 1.0331386 1.45104273]]
Upvotes: 4