dorien
dorien

Reputation: 5407

'Subset' object is not an iterator for updating torch' legacy IMDB dataset

I'm updating a pytorch network from legacy code to the current code. Following documentation such as that here.

I used to have:

import torch
from torchtext import data
from torchtext import datasets

# setting the seed so our random output is actually deterministic
SEED = 1234

torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# defining our input fields (text) and labels. 
# We use the Spacy function because it provides strong support for tokenization in languages other than English
TEXT = data.Field(tokenize = 'spacy', include_lengths = True)
LABEL = data.LabelField(dtype = torch.float)

from torchtext import datasets
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

import random
train_data, valid_data = train_data.split(random_state = random.seed(SEED))

example = next(iter(test_data))
example.text

MAX_VOCAB_SIZE = 25_000

TEXT.build_vocab(train_data, 
                 max_size = MAX_VOCAB_SIZE, 
                 vectors = "glove.6B.100d", 
                 unk_init = torch.Tensor.normal_) # how to initialize unseen words not in glove

LABEL.build_vocab(train_data)

Now in the new code I am struggling to add the validation set. All goes well until here:

from torchtext.datasets import IMDB
train_data, test_data = IMDB(split=('train', 'test'))

I can print the outputs, while they look different (problems later on?), they have all the info. I can print test_data fine with next(train_data.

Then after I do:

test_size = int(len(train_dataset)/2)
train_data, valid_data = torch.utils.data.random_split(train_dataset, [test_size,test_size])

It tells me:

next(train_data)

TypeError: 'Subset' object is not an iterator

This makes me think I am not correct in applying random_split. How to correctly create the validation set for this dataset? Without causing issues.

Upvotes: 1

Views: 480

Answers (1)

Alexey Birukov
Alexey Birukov

Reputation: 1680

Try next(iter(train_data)). It seems one have to create iterator over dataset explicitly. And use Dataloader when effectiveness is required.

Upvotes: 1

Related Questions