Reputation: 1143
How can I use LibTorch to expand a tensor of the shape 42, 358 into a shape of 10, 42, 358?
I know how to do this in PyTorch, (AKA Torch).
torch.ones(42, 358).expand(10, -1, -1).shape
returns
torch.Size([10, 42, 358])
In LibTorch I have a tensor of the same size I am trying to "expand" in the same way.
auto expanded_state_batch = state_batch.expand(10, -1, -1);
I get the following error...
error: no matching function for call to ‘at::Tensor::expand(int, int, int)’
335 | auto expanded_state_batch = state_batch.expand(10, -1, -1);
| ~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~
In file included from /home/iii/tor/m_gym/libtorch/include/ATen/core/Tensor.h:3,
from /home/iii/tor/m_gym/libtorch/include/ATen/Tensor.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/function_hook.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/cpp_hook.h:2,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/variable.h:6,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/autograd.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/autograd.h:3,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/all.h:7,
from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/torch.h:3,
from /home/iii/tor/m_gym/multiv_normal.cpp:2:
/home/iii/tor/m_gym/libtorch/include/ATen/core/TensorBody.h:2372:19: note: candidate: ‘at::Tensor at::Tensor::expand(at::IntArrayRef, bool) const’
2372 | inline at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit) const {
| ^~~~~~
/home/iii/tor/m_gym/libtorch/include/ATen/core/TensorBody.h:2372:19: note: candidate expects 2 arguments, 3 provided
It says that .expand
only takes two integers but three were given. I've tried a few combinations and I always get an error.
Exactly what I'm doing here is concatenating the 42, 385 tensor ten times into a new tensor. I could do this in a loop with torch::cat
, but this would be uglier.
Upvotes: 0
Views: 691
Reputation: 3553
at::expand
expects an at::IntArrayRef
the compiler tells you. Hence you want to write something like
auto expanded_state_batch = state_batch.expand({10, -1, -1});
Upvotes: 1