mcb
mcb

Reputation: 418

Repeating a pytorch tensor without copying memory

Does pytorch support repeating a tensor without allocating significantly more memory?

Assume we have a tensor

t = torch.ones((1,1000,1000))
t10 = t.repeat(10,1,1)

Repeating t 10 times will require take 10x the memory. Is there a way how I can create a tensor t10 without allocating significantly more memory?

Here is a related question, but without answers.

Upvotes: 10

Views: 5555

Answers (1)

jodag
jodag

Reputation: 22214

You can use torch.expand

t = torch.ones((1, 1000, 1000))
t10 = t.expand(10, 1000, 1000)

Keep in mind that the t10 is just a reference to t. So for example, a change to t10[0,0,0] will result in the same change in t[0,0,0] and every member of t10[:,0,0].

Other than direct access, most operations performed on t10 will cause memory to be copied which will break the reference and cause more memory to be used. For example: changing the device (.cpu(), .to(device=...), .cuda()), changing the datatype (.float(), .long(), .to(dtype=...)), or using .contiguous().

Upvotes: 16

Related Questions