Pavol Travnik
Pavol Travnik

Reputation: 1073

DataLoader - pytorch inconsistency between cpu and mps - Apple Silicon

I got stuck on this inconsistency with DataLoader in Pytorch on Mac Apple Silicone. If I use cpu y is interpreted correctly. However If I use mps it always returns vector with correct length, however only based on first element y[0].

import torch
from torch.utils.data import TensorDataset, random_split, DataLoader


device = torch.device("mps") 
X = torch.tensor([[[0.5,0.4], [0,0]],[[0.3,0.2], [0,0]],[[0.5,0.2], [0,0]],[[0.2,0.2], [0,0]]], dtype=torch.float32).to(device)
y = torch.tensor([1,0,0,0], dtype=torch.float32).to(device)

print(X.shape)
print(y.shape)
print(y)
dataset = TensorDataset(X, y)
train_size = int(0.5 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

for i, (batch_data, batch_labels) in enumerate(train_loader):
    print(batch_data)
    print(batch_labels)
    break

for batch_labels on mps i always get tensor full of ones or zeroes based on first value in y

torch.Size([4, 2, 2])
torch.Size([4])
tensor([1., 0., 1., 0.], device='mps:0')
tensor([[[0.5000, 0.2000],
         [0.0000, 0.0000]],

        [[0.5000, 0.4000],
         [0.0000, 0.0000]]], device='mps:0')
tensor([1., 1.], device='mps:0')

Maybe it is related to General MPS op coverage tracking issue #77764

Upvotes: 0

Views: 293

Answers (1)

Pavol Travnik
Pavol Travnik

Reputation: 1073

It is bug related to older versions of torch. 2.0.1 works just fine. Updated torch to 2.0.1

Upvotes: 1

Related Questions