Reputation: 2663
I'd like to create a custom PyTorch dataset of ZCA-whitened CIFAR-10 that I can subsequently load using torchvision's function torchvision.datasets.CIFAR10()
. So far, I can successfully whiten the data (see code below), but I don't know how to save the data to disk in a manner that allows it to be loaded using torchvision.datasets.CIFAR10()
. How do I do this?
Code to ZCA-whiten CIFAR 10:
trainset = torchvision.datasets.CIFAR10(
root='./datasets',
train=True,
download=False)
train_data = trainset.data.reshape(-1, 32*32*3)
zca_matrix = zca_whitening_matrix(train_data.T)
whitened_training_data = np.matmul(zca_matrix, train_data.T).T
whitened_training_data = whitened_training_data.reshape((-1, 32, 32, 3))
# whiten CIFAR-10 testing data
testset = torchvision.datasets.CIFAR10(
root='./datasets',
train=False,
download=False)
testdata = testset.data.reshape(-1, 32*32*3)
whitened_test_data = np.matmul(zca_matrix, testdata.T).T
whitened_test_data = whitened_test_data.reshape((-1, 32, 32, 3))
Is the best way to really just save the numpy arrays, as shown here?
PyTorch: How to use DataLoaders for custom Datasets
Upvotes: 1
Views: 2547