SleekEagle
SleekEagle

Reputation: 73

Neural Network Always Predicting Average Value

I'm trying to train a neural network to approximate a known scalar function of two variables; however, no matter the parameters of my training, the network always just ends up simply predicting the average value of the true outputs.

I am using an MLP and have tried:

My loss function is MSE and always plateaus to a value of about 5.14.

Regardless of changes I make, I get the following results: enter image description here

Where the blue surface is the function to be approximated, and the green surface is the MLP approximation of the function, having a value that is roughly the average of the true function over that domain (the true average is 2.15 with a square of 4.64 - not far from the loss plateau value).

I feel like I could be missing something very obvious and have just been looking at it for too long. Any help is greatly appreciated! Thanks

I've attached my code here (I'm using JAX):

import jax.numpy as jnp
from jax import grad, jit, vmap, random, value_and_grad
import flax
import flax.linen as nn
import optax


seed = 2
key, data_key = random.split(random.PRNGKey(seed))
x1, x2, y= generate_data(data_key)  # Data generation function

# Using Flax - define an MLP
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

# Define function that returns JITted loss function
def make_mlp_loss(input_data, true_y):

  def mlp_loss(params):
    pred_y = model.apply(params, input_data)
    loss_vector = jnp.square(true_y.reshape(-1) - pred_y)
    return jnp.average(loss_vector)

  # Outer scope incapsulation saves the data and true output
  return jit(mlp_loss)


# Concatenate independent variable vectors to be proper input shape
input_data = jnp.hstack((x1.reshape(-1, 1), x2.reshape(-1, 1)))

# Create loss function with data and true output
mlp_loss = make_mlp_loss(input_data, y)

# Create function that returns loss and gradient
loss_and_grad = value_and_grad(mlp_loss)

# Example architectures I've tried
architectures = [[16, 16, 1], [8, 16, 1], [16, 8, 1], [8, 16, 8, 1], [32, 32, 1]]

# Only using one seed but iterated over several
for seed in [645]:
  for architecture in architectures:
    # Create model
    model = MLP(architecture)
    
    # Initialize model with random parameters
    key, params_key = random.split(key)
    dummy = jnp.ones((1000, 2))
    params = model.init(params_key, dummy)

    # Create optimizer
    opt = optax.adam(learning_rate=0.01) #sgd
    opt_state = opt.init(params)

    
    epochs = 50
    for i in range(epochs):
      # Get loss and gradient 
      curr_loss, curr_grad = loss_and_grad(params)
      if i % 5 == 0:
        print(curr_loss)

      # Update
      updates, opt_state = opt.update(curr_grad, opt_state)
      params = optax.apply_updates(params, updates)
      
    print(f"Architecture: {architecture}\nLoss: {curr_loss}\nSeed: {seed}\n\n")

Upvotes: 2

Views: 784

Answers (0)

Related Questions