Reputation: 113
Suppose I have the following simple example of a function of several variables
@tf.function
def f(A, Y, X):
AX = tf.matmul(A, X)
norm = tf.norm(Y - AX)
return norm
N = 2
A = tf.Variable(np.array([[1, 2], [3, 4]]))
Y = tf.Variable(np.identity(N))
X = tf.Variable(np.zeros((N, N)))
How do I find X
that minimizes f
with Tensorflow ?
I would be interested in a generic solution that works with a function declared as above and when there are more than one variable to optimize.
Upvotes: 1
Views: 2822
Reputation: 59701
avanwyk is essentially right, although note that: 1) you can directly use the minimize
method of the optimizer for simplicity 2) if you only want to optimize X
you should make sure that is the only variable you are updating.
import tensorflow as tf
@tf.function
def f(A, Y, X):
AX = tf.matmul(A, X)
norm = tf.norm(Y - AX)
return norm
# Input data
N = 2
A = tf.Variable([[1., 2.], [3., 4.]], tf.float32)
Y = tf.Variable(tf.eye(N, dtype=tf.float32))
X = tf.Variable(tf.zeros((N, N), tf.float32))
# Initial function value
print(f(A, Y, X).numpy())
# 1.4142135
# Setup a stochastic gradient descent optimizer
opt = tf.keras.optimizers.SGD(learning_rate=0.01)
# Define loss function and variables to optimize
loss_fn = lambda: f(A, Y, X)
var_list = [X]
# Optimize for a fixed number of steps
for _ in range(1000):
opt.minimize(loss_fn, var_list)
# Optimized function value
print(f(A, Y, X).numpy())
# 0.14933111
# Optimized variable
print(X.numpy())
# [[-2.0012102 0.98504114]
# [ 1.4754106 -0.5111093 ]]
Upvotes: 5
Reputation: 700
Assuming Tensorflow 2, you can use a Keras optimizer:
@tf.function
def f(A, Y, X):
AX = tf.matmul(A, X)
norm = tf.norm(Y - AX)
return norm
N = 2
A = tf.Variable(np.array([[1., 2.], [3., 4.]]))
Y = tf.Variable(np.identity(N))
X = tf.Variable(np.zeros((N, N)))
optimizer = tf.keras.optimizers.SGD()
for iteration in range(0, 100):
with tf.GradientTape() as tape:
loss = f(X, Y, X)
print(loss)
grads = tape.gradient(loss, [A, Y, X])
optimizer.apply_gradients(zip(grads, [A, Y, X]))
print(A, Y, X)
That will work for any differentiable function. For non-differentiable functions you could look at other optimization techniques (such as Genetic Algorithms, or Swarm Optimization. NEAT has implementations of these https://neat-python.readthedocs.io/en/latest/).
Upvotes: 4