Ant
Ant

Reputation: 1143

How can I expand a tensor in Libtorch? (The C++ version of PyTorch)

How can I use LibTorch to expand a tensor of the shape 42, 358 into a shape of 10, 42, 358?

I know how to do this in PyTorch, (AKA Torch).
torch.ones(42, 358).expand(10, -1, -1).shape
returns
torch.Size([10, 42, 358])

In LibTorch I have a tensor of the same size I am trying to "expand" in the same way.

    auto expanded_state_batch = state_batch.expand(10, -1, -1);

I get the following error...

error: no matching function for call to ‘at::Tensor::expand(int, int, int)’
  335 |         auto expanded_state_batch = state_batch.expand(10, -1, -1);
      |                                     ~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~
In file included from /home/iii/tor/m_gym/libtorch/include/ATen/core/Tensor.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/ATen/Tensor.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/function_hook.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/cpp_hook.h:2,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/variable.h:6,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/autograd/autograd.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/autograd.h:3,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/all.h:7,
                 from /home/iii/tor/m_gym/libtorch/include/torch/csrc/api/include/torch/torch.h:3,
                 from /home/iii/tor/m_gym/multiv_normal.cpp:2:
/home/iii/tor/m_gym/libtorch/include/ATen/core/TensorBody.h:2372:19: note: candidate: ‘at::Tensor at::Tensor::expand(at::IntArrayRef, bool) const’
 2372 | inline at::Tensor Tensor::expand(at::IntArrayRef size, bool implicit) const {
      |                   ^~~~~~
/home/iii/tor/m_gym/libtorch/include/ATen/core/TensorBody.h:2372:19: note:   candidate expects 2 arguments, 3 provided  

It says that .expand only takes two integers but three were given. I've tried a few combinations and I always get an error.
Exactly what I'm doing here is concatenating the 42, 385 tensor ten times into a new tensor. I could do this in a loop with torch::cat, but this would be uglier.

Upvotes: 0

Views: 691

Answers (1)

trialNerror
trialNerror

Reputation: 3553

at::expand expects an at::IntArrayRef the compiler tells you. Hence you want to write something like

auto expanded_state_batch = state_batch.expand({10, -1, -1});

Upvotes: 1

Related Questions