I'm trying to train a neural network to approximate a known scalar function of two variables; however, no matter the parameters of my training, the network always just ends up simply predicting the average value of the true outputs.
I am using an MLP and have tried:
My loss function is MSE and always plateaus to a value of about 5.14.
Regardless of changes I make, I get the following results:
Where the blue surface is the function to be approximated, and the green surface is the MLP approximation of the function, having a value that is roughly the average of the true function over that domain (the true average is 2.15 with a square of 4.64 - not far from the loss plateau value).
I feel like I could be missing something very obvious and have just been looking at it for too long. Any help is greatly appreciated! Thanks
I've attached my code here (I'm using JAX):
import jax.numpy as jnp
from jax import grad, jit, vmap, random, value_and_grad
import flax
import flax.linen as nn
import optax
seed = 2
key, data_key = random.split(random.PRNGKey(seed))
x1, x2, y= generate_data(data_key) # Data generation function
# Using Flax - define an MLP
class MLP(nn.Module):
features: Sequence[int]
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
# Define function that returns JITted loss function
def make_mlp_loss(input_data, true_y):
def mlp_loss(params):
pred_y = model.apply(params, input_data)
loss_vector = jnp.square(true_y.reshape(-1) - pred_y)
return jnp.average(loss_vector)
# Outer scope incapsulation saves the data and true output
return jit(mlp_loss)
# Concatenate independent variable vectors to be proper input shape
input_data = jnp.hstack((x1.reshape(-1, 1), x2.reshape(-1, 1)))
# Create loss function with data and true output
mlp_loss = make_mlp_loss(input_data, y)
# Create function that returns loss and gradient
loss_and_grad = value_and_grad(mlp_loss)
# Example architectures I've tried
architectures = [[16, 16, 1], [8, 16, 1], [16, 8, 1], [8, 16, 8, 1], [32, 32, 1]]
# Only using one seed but iterated over several
for seed in [645]:
for architecture in architectures:
# Create model
model = MLP(architecture)
# Initialize model with random parameters
key, params_key = random.split(key)
dummy = jnp.ones((1000, 2))
params = model.init(params_key, dummy)
# Create optimizer
opt = optax.adam(learning_rate=0.01) #sgd
opt_state = opt.init(params)
epochs = 50
for i in range(epochs):
# Get loss and gradient
curr_loss, curr_grad = loss_and_grad(params)
if i % 5 == 0:
# Update
updates, opt_state = opt.update(curr_grad, opt_state)
params = optax.apply_updates(params, updates)
print(f"Architecture: {architecture}\nLoss: {curr_loss}\nSeed: {seed}\n\n")
