Reputation: 2139
import torch
T = torch.FloatTensor(range(0,10 ** 6)) # 1M
#case 1:
torch.save(T, 'junk.pt')
# results in a 4 MB file size
#case 2:
torch.save(T[-20:], 'junk2.pt')
# results in a 4 MB file size
#case 3:
torch.save(torch.FloatTensor(T[-20:]), 'junk3.pt')
# results in a 4 MB file size
#case 4:
torch.save(torch.FloatTensor(T[-20:].tolist()), 'junk4.pt')
# results in a 405 Bytes file size
My questions are:
In case 3 the resulting file size seems surprising as we are creating a new tensor. Why is this new tensor not just the slice?
Is case 4, the optimal method for saving just part (slice) of a tensor?
More generally, if I want to 'trim' a very large 1-dimensional tensor by removing the first half of its values in order to save memory, do I have to proceed as in case 4, or is there a more direct and less computationally costly way that does not involve creating a python list.
Upvotes: 5
Views: 6009
Reputation: 32972
(i) In case 3 the resulting file size seems surprising as we are creating a new tensor. Why is this new tensor not just the slice?
Slicing creates a view of the tensor, which shares the underlying data but contains information about the memory offsets used for the visible data. This avoids having to copy the data frequently, which makes a lot of operations much more efficient. See PyTorch - Tensor Views for a list of affected operations.
You are dealing with one of the few cases, where the underlying data matters. To save the tensor, it needs to save the underlying data, otherwise the offsets would no longer be valid.
torch.FloatTensor
does not create a copy of the tensor, if it's not necessary. You can verify that their underlying data is still the same (they have the exact same memory location):
torch.FloatTensor(T[-20:]).storage().data_ptr() == T.storage().data_ptr()
# => True
(ii) Is case 4, the optimal method for saving just part (slice) of a tensor?
(iii) More generally, if I want to 'trim' a very large 1-dimensional tensor by removing the first half of its values in order to save memory, do I have to proceed as in case 4, or is there a more direct and less computationally costly way that does not involve creating a python list.
You will most likely not get around copying the data of the slice, but at least you can avoid creating a Python list from it and creating a new tensor from the list, by using torch.Tensor.clone
instead:
torch.save(T[-20:].clone(), 'junk5.pt')
Upvotes: 8