Reputation: 5407
I'm updating a pytorch network from legacy code to the current code. Following documentation such as that here.
I used to have:
import torch
from torchtext import data
from torchtext import datasets
# setting the seed so our random output is actually deterministic
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
# defining our input fields (text) and labels.
# We use the Spacy function because it provides strong support for tokenization in languages other than English
TEXT = data.Field(tokenize = 'spacy', include_lengths = True)
LABEL = data.LabelField(dtype = torch.float)
from torchtext import datasets
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
import random
train_data, valid_data = train_data.split(random_state = random.seed(SEED))
example = next(iter(test_data))
example.text
MAX_VOCAB_SIZE = 25_000
TEXT.build_vocab(train_data,
max_size = MAX_VOCAB_SIZE,
vectors = "glove.6B.100d",
unk_init = torch.Tensor.normal_) # how to initialize unseen words not in glove
LABEL.build_vocab(train_data)
Now in the new code I am struggling to add the validation set. All goes well until here:
from torchtext.datasets import IMDB
train_data, test_data = IMDB(split=('train', 'test'))
I can print the outputs, while they look different (problems later on?), they have all the info. I can print test_data fine with next(train_data.
Then after I do:
test_size = int(len(train_dataset)/2)
train_data, valid_data = torch.utils.data.random_split(train_dataset, [test_size,test_size])
It tells me:
next(train_data)
TypeError: 'Subset' object is not an iterator
This makes me think I am not correct in applying random_split. How to correctly create the validation set for this dataset? Without causing issues.
Upvotes: 1
Views: 480
Reputation: 1680
Try next(iter(train_data))
. It seems one have to create iterator over dataset
explicitly. And use Dataloader
when effectiveness is required.
Upvotes: 1