Parisa asghari
Parisa asghari

Reputation: 9

solving an ODE using neural networks

I want to solve this ODE using neural nets. du/dt + 2u + t = 0 with initial condition u(0)=1 and t is between 0 to 2. I want to use pytorch and automatic differentiation method to solve this equation. but I don't know how can I calculate du/dt in pytorch. I want to define loss function as below and minimize it to find optimum weights and biases of the neural net. u_hat is an approximate solution of ODE which is substituted with neural network. R = du_hat/dt + 2*u_hat + t. loss function = sum(Ri^2). loss function is the sum of Ri which is calculated in the points t = 0, 0.5, 1, 1.5, 2.

I don't know how can I write the code in pytorch.

Upvotes: 1

Views: 493

Answers (1)

MuhammedYunus
MuhammedYunus

Reputation: 5105

In the example below, I first solve the ODE using a standard solver scipy.integrate.solve_ivp. The solution is used to train the network, as it gives us a target y for each t. The net will learn parameters such that given t and y0, it will closely match the reference y.

After training, you can supply t and y0 to the network, and it will output the estimated solution y_hat for each t.

Note that this example is somewhat minimal - you'd usually want to be evaluating the model on samples it hasn't seen (a validation set), otherwise it might just be memorising the training data without being able to generalise to unseen t (though maybe it's not an issue for your use-case).

Net comprises 501 parameters
[epoch  100/ 500] loss: 0.00044
[epoch  200/ 500] loss: 0.00015
[epoch  300/ 500] loss: 0.00011
[epoch  400/ 500] loss: 0.00008
[epoch  500/ 500] loss: 0.00004

enter image description here

import torch
from torch import nn
from torch import optim

import numpy as np
import matplotlib.pyplot as plt

#Input data
n_samples = 100
t_array = np.linspace(0, 2, num=n_samples)
y0 = 1.0

#ODE function
# dy/dt + 2y + t = 0 --> dy/dt = -(2y + t)
def dy_dt(t, y):
    return -(2 * y + t)

#Solve using scipy
# The solution will be used to train the neural network
from scipy.integrate import solve_ivp
solved = solve_ivp(dy_dt, [t_array.min(), t_array.max()], np.array([y_0]), t_eval=t_array)
solved_y = solved.y.ravel()

plt.plot(t_array, solved_y, color='cadetblue', linewidth=3, label='RK45 solver')
plt.xlabel('t')
plt.ylabel('y')
plt.gcf().set_size_inches(8, 3)

#
# Define the ODE net
#
class ODENet(nn.Module):
    def __init__(self, model_size=20, output_dim=1, activation=nn.ReLU):
        super().__init__()
        
        self.map_inputs = nn.Sequential(nn.Linear(2, model_size), activation())
        
        self.hidden_mapping = nn.Sequential(
            nn.Linear(model_size, model_size),
            activation()
        )
        
        self.output = nn.Linear(model_size, output_dim)
    
    def forward(self, x):
        # t, y0 = x[:, 0], x[:, 1]
        mapped_inputs = self.map_inputs(x)
        hidden = self.hidden_mapping(mapped_inputs)
        y_hat = self.output(hidden)
        return y_hat
    
print('Net comprises', sum([p.numel() for p in ODENet().parameters()]), 'parameters')

#Define the loss
# could alternatively pick from PyTorch's provided losses
def mse_loss(pred, target):
    return torch.mean((pred - target) ** 2)

#
# Create model
#
torch.manual_seed(0) #reproducible results

model = ODENet()
optimizer = optim.NAdam(model.parameters())

#Prepare the input data
# Convert to tensors
t_tensor = torch.Tensor(t_array).float().reshape(-1, 1)
y0_tensor = torch.Tensor([y0] * len(t_array)).float().reshape(-1, 1)
solved_y_tensor = torch.Tensor(solved_y).float()

# Combine inputs into a single matrix to make manpulation more compact
t_y0 = torch.cat([t_tensor, y0_tensor], dim=1)

#Scale the input features
# Will help the net's convergence, though not always abs necessary
t_y0 = (t_y0 - t_y0.mean(dim=0)) / (t_y0.std(dim=0) + 1e-10)

#
#Train model
#
n_epochs = 500
for epoch in range(n_epochs):
    model.train()
    
    y_hat = model(t_y0).flatten()

    #Loss, derivative, and step optimizer
    optimizer.zero_grad()
    loss = mse_loss(y_hat, solved_y_tensor)
    loss.backward()
    optimizer.step()
    
    #Print losses
    if ((epoch + 1) % 100) == 0 or (epoch + 1 == n_epochs):
        print(
            f'[epoch {epoch + 1:>4d}/{n_epochs:>4d}]',
            f'loss: {loss.item():>7.5f}'
        )

#Get the final predictions, and overlay onto the solver's solution
model.eval()
with torch.no_grad():
    predictions = model(t_y0)
    
plt.plot(t_array, predictions, color='sienna', linewidth=3, linestyle=':', label='ODENet')
plt.legend()

#Optional formatting
[plt.gca().spines[spine].set_visible(False) for spine in ['right', 'top']]
plt.gca().spines['bottom'].set_bounds(t_array.min(), t_array.max())
plt.gca().spines['left'].set_bounds(-0.75, 1)

Upvotes: 1

Related Questions