Reputation: 355
I faced this issue when I tried to train the model. This model is meant to predict Celsius to Fahrenheit.
Do help me to review my code below:
# this file is a simple example of how to use pytorch to train a model
# to predict fahrenheit from celsius
import numpy as np
import torch
import torch.nn as nn
celsius_arr = [-40, -10, 0, 8, 15, 22, 38]
fahrenheit_arr = [-40, 14, 32, 46, 59, 72, 100]
# convert arrays to numpy arrays
celsius_arr = np.array(celsius_arr, dtype=np.float32)
fahrenheit_arr = np.array(fahrenheit_arr, dtype=np.float32)
# build a model to predict fahrenheit from celsius
# 1. define the model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
x = self.linear(x)
return x
# 2. construct the loss and optimizer
model = Net()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 3. training loop
for epoch in range(1000):
inputs = torch.from_numpy(celsius_arr)
labels = torch.from_numpy(fahrenheit_arr)
# forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f'epoch {epoch}, loss = {loss.item():.4f}')
# 4. inference
print(model(torch.tensor([100.0])))
# 5. save the model
torch.save(model.state_dict(), 'celsius_fahrenheit_model.pth')
Upvotes: 0
Views: 232
Reputation: 6135
It seems like the following two arrays are expected to represent the data samples, and not features per sample.
# convert arrays to numpy arrays
celsius_arr = np.array(celsius_arr, dtype=np.float32)
fahrenheit_arr = np.array(fahrenheit_arr, dtype=np.float32)
You need to unsqueeze them (a.k.a, expand_dims
) so that they would have the expected sample size, which is 7
in your case.
celsius_arr = np.expand_dims(celsius_arr, axis=1)
fahrenheit_arr = np.expand_dims(fahrenheit_arr, axis=1)
You can also use torch.unsqueeze
celsius_arr = torch.from_numpy(celsius_arr).unsqueeze(1)
fahrenheit_arr = torch.from_numpy(fahrenheit_arr).unsqueeze(1)
Upvotes: 1