Nafisa Mehtaj
Nafisa Mehtaj

Reputation: 1

Optimizing Physics Informed Neural Network

I am trying to solve 2D wave equation using physics informed neural network (PINN). I have total four terms in my loss function. I faced two issues while training my model. The total loss drops really first. Also, the loss PDE is very close to zero from the very first term. How should I fix this issue?

enter image description here

def guided_wave_2D(self, x, y, t):
    output = self.model(torch.cat([x, y, t], dim=1))
    phi = output[:,0].to(device)
    zi3 = output[:,1].to(device)

    # first derivative of phi wrt spatial and temporal coordinates (x,y,t)
    dphidx = torch.autograd.grad(phi, x, torch.ones_like(phi).to(device), create_graph=True)[0]
    dphidy = torch.autograd.grad(phi, y, torch.ones_like(phi).to(device), create_graph=True)[0]
    dphidt = torch.autograd.grad(phi, t, torch.ones_like(phi).to(device), create_graph=True)[0]

    # first derivative of zi wrt spatial and temporal coordinates (x,y,t)
    dzi3dx = torch.autograd.grad(zi3, x, torch.ones_like(zi3).to(device), create_graph=True)[0]
    dzi3dy = torch.autograd.grad(zi3, y, torch.ones_like(zi3).to(device), create_graph=True)[0]
    dzi3dt = torch.autograd.grad(zi3, t, torch.ones_like(zi3).to(device), create_graph=True)[0]

    # second derivative of phi wrt spatial and temporal coordinates (x,y,t)
    d2phidx2 = torch.autograd.grad(dphidx, x, torch.ones_like(dphidx).to(device), create_graph=True)[0]
    d2phidy2 = torch.autograd.grad(dphidy, y, torch.ones_like(dphidy).to(device), create_graph=True)[0]
    d2phidt2 = torch.autograd.grad(dphidt, t, torch.ones_like(dphidt).to(device), create_graph=True)[0]

    # second derivative of zi wrt spatial and temporal coordinates (x,y,t)
    d2zi3dx2 = torch.autograd.grad(dzi3dx, x, torch.ones_like(dzi3dx).to(device), create_graph=True)[0]
    d2zi3dy2 = torch.autograd.grad(dzi3dy, y, torch.ones_like(dzi3dy).to(device), create_graph=True)[0]
    d2zi3dt2 = torch.autograd.grad(dzi3dt, t, torch.ones_like(dzi3dt).to(device), create_graph=True)[0]

    # displacement calculation

    u_sim_pred = dphidx + dzi3dy
    v_sim_pred = -dzi3dx + dphidy

    # PDE function

    f_phi = (self.cp_inverse**2)*(d2phidt2) - d2phidx2 - d2phidy2
    f_zi3 = (self.cs_inverse**2)*(d2zi3dt2) - d2zi3dx2 - d2zi3dy2

    return u_sim_pred, v_sim_pred, f_phi, f_zi3

def loss_cal(self):
    # for calculating loss data and loss PDE
    u_sim_pred, v_sim_pred, f_phi, f_zi3 = self.guided_wave_2D(self.x, self.y, self.t)

    # loss data
    loss_u = self.loss(u_sim_pred, self.u)
    loss_v = self.loss(v_sim_pred, self.v)
    loss_data = loss_u + loss_v

    # loss pde
    loss_pde_phi = self.loss(f_phi, torch.zeros_like(f_phi))
    loss_pde_zi3 = self.loss(f_zi3, torch.zeros_like(f_zi3))
    loss_pde = loss_pde_phi + loss_pde_zi3

    #loss boundary
    # defining displacement values in boundaries
    u_bc = torch.zeros_like(self.x)
    v_bc = torch.zeros_like(self.y)

    #left boundary
    left_x_values = torch.full((self.train_sample_num, 1), self.lb).requires_grad_(True).to(device)
    u_sim_pred_lb, v_sim_pred_lb, f_phi_lb, f_zi3_lb = self.guided_wave_2D(left_x_values, self.y, self.t)
    loss_lb = self.loss(u_sim_pred_lb, torch.zeros_like(u_sim_pred_lb))

    #right boundary
    right_x_values = torch.full((self.train_sample_num, 1), self.rb).requires_grad_(True).to(device)
    u_sim_pred_rb, v_sim_pred_rb, f_phi_rb, f_zi3_rb = self.guided_wave_2D(right_x_values, self.y, self.t)
    loss_rb = self.loss(u_sim_pred_lb, torch.zeros_like(u_sim_pred_lb))

    # upper boundary
    upper_y_value = torch.full((self.train_sample_num, 1), self.upb).requires_grad_(True).to(device)
    u_sim_pred_upb, v_sim_pred_upb, f_phi_upb, f_zi3_upb = self.guided_wave_2D(self.x, upper_y_value, self.t)
    loss_upb = self.loss(u_sim_pred_lb, torch.zeros_like(u_sim_pred_lb))

    # lower boundary
    lower_y_value = torch.full((self.train_sample_num, 1), self.lob).requires_grad_(True).to(device)
    u_sim_pred_lob, v_sim_pred_lob, f_phi_lob, f_zi3_lob = self.guided_wave_2D(self.x, lower_y_value, self.t)
    loss_lob = self.loss(u_sim_pred_lob, torch.zeros_like(u_sim_pred_lob))

    # summing up all loss values of all boundaries
    loss_boundary = loss_lb + loss_rb + loss_upb + loss_lob

    # initial condition
    # checking up the index among the training samples where t == 0
    indices = torch.where(self.t == self.ic)[0]

    # accessing values of input and output parameters where t == 0
    x_ic = self.x[indices]
    y_ic = self.y[indices]
    t_ic = self.t[indices]
    u_ic = self.u[indices]
    v_ic = self.v[indices]

    u_sim_pred_ic, v_sim_pred_ic, f_phi_ic, f_zi3_ic = self.guided_wave_2D(x_ic, y_ic, t_ic)
    loss_ic = self.loss(u_sim_pred_ic, u_ic) + self.loss(v_sim_pred_ic, v_ic)

    # total loss
    loss_f = loss_data + loss_pde + loss_boundary + loss_ic

    return loss_f, loss_data, loss_pde, loss_boundary, loss_ic

  def trainning(self, X_train, y_train, batch_size):
    train_loss_values = []
    loss_data_values = []
    loss_pde_values = []
    loss_boundary_values = []
    loss_ic_values = []
    epoch_count = []

    # Split data into batches
    X_batches = torch.split(X_train, batch_size)
    y_batches = torch.split(y_train, batch_size)

    for batch_num, (batch_of_X_train, batch_of_y_train) in enumerate(zip(X_batches, y_batches), 1):

      print(f"Batch {batch_num}")
        # Prepare batch tensors
      self.x = torch.tensor(batch_of_X_train[:, 0:1], requires_grad=True).float().to(device)
      self.y = torch.tensor(batch_of_X_train[:, 1:2], requires_grad=True).float().to(device)
      self.t = torch.tensor(batch_of_X_train[:, 2:3], requires_grad=True).float().to(device)
      self.u = torch.tensor(batch_of_y_train[:, 0:1], requires_grad=True).float().to(device)
      self.v = torch.tensor(batch_of_y_train[:, 1:2], requires_grad=True).float().to(device)

      for epoch in range(self.epoch):
        print(f"Epoch {epoch + 1}")
        loss_total, loss_data, loss_pde, loss_boundary, loss_ic = self.loss_cal()
        self.model.train()
        self.optimizer.zero_grad()
        loss_total.backward()
        self.optimizer.step()

        if epoch % 5 == 0:
            print(f"Batch {batch_num}, Epoch {epoch + 1}/{self.epoch}, Loss: {loss_total.item()}")

        # Store loss values for plotting
        train_loss_values.append(loss_total.item())
        loss_data_values.append(loss_data.item())
        loss_pde_values.append(loss_pde.item())
        loss_boundary_values.append(loss_boundary.item())
        loss_ic_values.append(loss_ic.item())
        epoch_count.append(len(epoch_count) + 1)

      torch.cuda.empty_cache()

Upvotes: -1

Views: 59

Answers (1)

saluisto
saluisto

Reputation: 1

Have you considered multiplying the PDE loss by a constant factor? For example, I tried this approach in one of my projects:

model.compile("adam", lr=0.003,loss_weights=[1, 100])

The second loss term here is multiplied by a factor of 100 (in my case, this term is associated with observations). This adjustment significantly improved performance.

Upvotes: 0

Related Questions