jds
jds

Reputation: 8269

In PyTorch, what makes a tensor have non-contiguous memory?

According to this SO and this PyTorch discussion, PyTorch's view function works only on contiguous memory, while reshape does not. In the second link, the author even claims:

[view] will raise an error on a non-contiguous tensor.

But when does a tensor have non-contiguous memory?

Upvotes: 9

Views: 7404

Answers (2)

jdhao
jdhao

Reputation: 28459

I think your title contiguous memory is a bit misleading. As I understand, contiguous in PyTorch means if the neighboring elements in the tensor are actually next to each other in memory. Let's take a simple example:

x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # x is contiguous
y = torch.transpose(x, 0, 1) # y is non-contiguous

According to documentation of tranpose():

Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.

The resulting out tensor shares it’s underlying storage with the input tensor, so changing the content of one would change the content of the other.

So that x and y in the above example share the same memory space. But if you check their contiguity with is_contiguous(), you will find that x is contiguous and y is not. Now you will find that contiguity does not refer to contiguous memory.

Since x is contiguous, x[0][0] and x[0][1] are next to each other in memory. But y[0][0] and y[0][1] is not. That is what contiguous means.

Upvotes: 4

Jatentaki
Jatentaki

Reputation: 13113

This is a very good answer, which explains the topic in the context of NumPy. PyTorch works essentially the same. Its docs don't generally mention whether function outputs are (non)contiguous, but that's something that can be guessed based on the kind of the operation (with some experience and understanding of the implementation). As a rule of thumb, most operations preserve contiguity as they construct new tensors. You may see non-contiguous outputs if the operation works on the array inplace and change its striding. A couple of examples below

import torch

t = torch.randn(10, 10)

def check(ten):
    print(ten.is_contiguous())

check(t) # True

# flip sets the stride to negative, but element j is still adjacent to
# element i, so it is contiguous
check(torch.flip(t, (0,))) # True

# if we take every 2nd element, adjacent elements in the resulting array
# are not adjacent in the input array
check(t[::2]) # False

# if we transpose, we lose contiguity, as in case of NumPy
check(t.transpose(0, 1)) # False

# if we transpose twice, we first lose and then regain contiguity
check(t.transpose(0, 1).transpose(0, 1)) # True

In general, if you have non-contiguous tensor t, you can make it contiguous by calling t = t.contiguous(). If t is contiguous, call to t.contiguous() is essentially a no-op, so you can do that without risking a big performance hit.

Upvotes: 10

Related Questions