sv98bc
sv98bc

Reputation: 21

PyTorch: 'CrossEntropyLoss" object has no attribute 'item'

Currently deploying a CNN model.

model = CNN(height=96, width=96, channels=3)

and looking to observe its cross entropy loss.

criterion = nn.CrossEntropyLoss()

The Trainer class is given as follows,

class Trainer:
def __init__(
    self,
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.Module,
    optimizer: Optimizer,
    summary_writer: SummaryWriter,
    device: torch.device,
):
    self.model = model.to(device)
    self.device = device
    self.train_loader = train_loader
    self.val_loader = val_loader
    self.criterion = criterion
    self.optimizer = optimizer
    self.summary_writer = summary_writer
    self.step = 0

def train(
        self,
        epochs: int,
        val_frequency: int,
        print_frequency: int = 20,
        log_frequency: int = 5,
        start_epoch: int = 0
):
    self.model.train()
    for epoch in range(start_epoch, epochs):
        self.model.train()
        data_load_start_time = time.time()
        for batch, labels in self.train_loader:
            batch = batch.to(self.device)
            labels = labels.to(self.device)
            data_load_end_time = time.time()
            loss=self.criterion
            logits=self.model.forward(batch)

            with torch.no_grad():
                preds = logits
                accuracy = compute_accuracy(labels, preds)

            data_load_time = data_load_end_time - data_load_start_time
            step_time = time.time() - data_load_end_time
            if ((self.step + 1) % log_frequency) == 0:
                self.log_metrics(epoch, accuracy, loss, data_load_time, step_time)
            if ((self.step + 1) % print_frequency) == 0:
                self.print_metrics(epoch, accuracy, loss, data_load_time, step_time)

            self.step += 1
            data_load_start_time = time.time()

        self.summary_writer.add_scalar("epoch", epoch, self.step)
        if ((epoch + 1) % val_frequency) == 0:
            self.validate()
            self.model.train()

The function to log the loss is,

    def log_metrics(self, epoch, accuracy, loss, data_load_time, step_time):
    self.summary_writer.add_scalar("epoch", epoch, self.step)
    self.summary_writer.add_scalars(
            "accuracy",
            {"train": accuracy},
            self.step
    )
    self.summary_writer.add_scalars(
            "loss",
            {"train": float(loss.item())},
            self.step
    )
    self.summary_writer.add_scalar(
            "time/data", data_load_time, self.step
    )
    self.summary_writer.add_scalar(
            "time/data", step_time, self.step
    )

I have been receiving an attribute error "' CrossEntropyLoss' object has no attribute 'item'". I have tried removing several ways such as removing "item()" from different parts of code and trying different types of loss functions like MSELoss etc. Any solution or direction would be highly appreciated. Thank you.

Edit-1:

Here is the error traceback

Traceback (most recent call last):


 File "/Users/xyz/main.py", line 316, in <module>
main(parser.parse_args())
 File "/Users/xyz/main.py", line 128, in main
    log_frequency=args.log_frequency,
  File "/Users/xyz/main.py", line 198, in train
    self.log_metrics(epoch, accuracy, loss, data_load_time, step_time)
  File "/Users/xyz/main.py", line 232, in log_metrics
    {"train": float(loss.item)},
  File "/Users/xyz/main.py", line 585, in __getattr__
    type(self).__name__, name))
AttributeError: 'CrossEntropyLoss' object has no attribute 'item'

Upvotes: 0

Views: 3472

Answers (1)

Sergii Dymchenko
Sergii Dymchenko

Reputation: 7229

It looks like the loss in the call self.log_metrics(epoch, accuracy, loss, data_load_time, step_time) is the criterion itself (CrossEntropyLoss object), not the result of calling it.

Your training loop needs to call the criterion to compute the loss, I don't see it in the code your provided.

Upvotes: 1

Related Questions