Reputation:
I am training a sample model with dummy data then i got this error. I have gave everything properly but still i am getting this error: No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.
when i start training. Is the problem because the way i feed the dummy data into network or is their any other reason.
import torch
from torch import nn, optim
import pytorch_lightning as pl
from torch.utils.data import DataLoader
class ImageClassifier(pl.LightningModule):
def __init__(self, learning_rate=0.001):
super().__init__()
self.learning_rate = learning_rate
self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
def forward(self,x):
output = self.conv_layer1(x)
print(output.shape)
return output
def training_step(self,batch, batch_idx):
inputs, targets = batch
output = self(inputs)
accuracy = self.binary_accuracy(output, targets)
loss = self.loss(output, targets)
self.log('train_accuracy', accuracy, prog_bar=True)
self.log('train_loss', loss)
return {'loss':loss,"training_accuracy": accuracy}
def test_step(self, batch, batch_idx):
inputs, targets = batch
outputs = self.inputs(inputs)
accuracy = self.binary_accuracy(outputs, targets)
loss = self.loss(outputs, targets)
self.log('test_accuracy', accuracy)
return {"test_loss":loss, "test_accuracy":accuracy}
def configure_optimizer(self):
params = self.parameters()
optimizer = optim.Adam(params=params, lr=self.learning_rate)
return optimizer
def binary_accuracy(self, outputs, inputs):
_, outputs = torch.max(outputs,1)
correct_results_sum = (outputs == targets).sum().float()
acc = correct_results_sum/targets.shape[0]
return acc
model = ImageClassifier()
Input = DataLoader(torch.randn(1,3,28,28))
trainer = pl.Trainer(max_epochs=10, progress_bar_refresh_rate=1)
trainer.fit(model, train_dataloader = Input)
Upvotes: 1
Views: 9078
Reputation: 21
I have a same problem, then I realize that a wrong method name could lead to the error. just make sure you type medthods name or import package and use it appropriately.
Upvotes: 1
Reputation: 115
In your code, the method name is configure_optimizer()
. Therefore, no configure_optimizers()
method defined. Seems like an error in name of the function.
Upvotes: 2