Reputation: 3393
I'm trying build a Tensor class in C++. This is only for a personal project to get some practice in C++, and it kind of works, but now I hit some C++ problems that I cannot quite understand. Here's the main structure of my Tensor class with a few functions omitted which shouldn't be relevant.
#include <vector>
template<typename T> struct Tensor {
// Support for up to 5 dimensions
T& at(std::vector<std::size_t> indices); // Calls the other at() functions depending on the vector size
T& at(std::size_t d1);
T& at(std::size_t d1, std::size_t d2);
T& at(std::size_t d1, std::size_t d2, std::size_t d3);
T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4);
T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4, std::size_t d5);
Tensor transpose();
Tensor matmul(Tensor& rhs) {
rhs.at(0);
return Tensor();
}
};
When testing matmul()
for multiplying two tensors, the following code works:
void works() {
Tensor<float> tensor1;
Tensor<float> tensor2 = tensor1.transpose();
Tensor<float> tensor3 = tensor1.matmul(tensor2);
}
However, the following code where I don't explicitly create a tensor1
as the transpose of tensor1
fails:
void fails() {
Tensor<float> tensor1;
Tensor<float> tensor = tensor1.matmul(tensor1.transpose());
}
The first error that is thrown is
tensor.test.cpp:26:60: error: cannot bind non-const lvalue reference of type 'Tensor<float>&' to an rvalue of type 'Tensor<float>'
26 | Tensor<float> tensor = tensor1.matmul(tensor1.transpose());
| ~~~~~~~~~~~~~~~~~^~
From my googling and limited understanding of C++, I've tried to change the definition of matmul()
to
Tensor matmul(const Tensor& rhs) { // <-- const added
However, if I do this, I get a different error:
tensor.hpp: In instantiation of 'Tensor<T> Tensor<T>::matmul(const Tensor<T>&) [with T = float]':
tensor.test.cpp:21:43: required from here
tensor.hpp:13:15: error: passing 'const Tensor<float>' as 'this' argument discards qualifiers [-fpermissive]
13 | rhs.at(0);
| ~~~~~~^~~
EDIT 1: Following the comments, using Tensor matmul(const Tensor& rhs);
adding a const
definition of all at()
functions solved my problem. Howver now the code looks loke this:
T& at(std::vector<std::size_t> indices);
T& at(std::size_t d1);
T& at(std::size_t d1, std::size_t d2);
T& at(std::size_t d1, std::size_t d2, std::size_t d3);
T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4);
T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4, std::size_t d5);
const T& at(std::vector<std::size_t> indices) const;
const T& at(std::size_t d1) const;
const T& at(std::size_t d1, std::size_t d2) const;
const T& at(std::size_t d1, std::size_t d2, std::size_t d3) const;
const T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4) const;
const T& at(std::size_t d1, std::size_t d2, std::size_t d3, std::size_t d4, std::size_t d5) const;
which means quite some duplicated code. I wonder if this can be improved.
Upvotes: 0
Views: 659
Reputation: 16670
Your method at
is not const. Does it modify the Tensor?
If not, you can say so thus:
T& at(std::size_t d1, std::size_t d2) const;
and possibly return a const T&
.
But the problem that you're having is that you're calling a non-const member function (at
) when on a const Tensor
Upvotes: 1
Reputation: 238311
From my googling and limited understanding of C++, I've tried to change the definition of matmul() to
Tensor matmul(const Tensor& rhs); // <-- const added
Unless matmul
modifies rhs
, it indeed shouldn't take a non-const reference. And the parmeter being a non-const reference is indeed the reason why the example didn't work. So, you have come to the correct solution.
The
at(left)
which calls theat()
ofthis
works.
That's because matmul
is still a non-const member function. However, does that make sense? Just like the function not modifying rhs
, does it modify *this
? I bet not. And given that the function isn't const, you cannot do this:
const Tensor tensor;
Tensor tensor3 = tensor.matmul(some_other_tensor);
Does that restriction make sense? Probably not.
So, you probably should make it a const member function. That would prevent at
from working with *this
as well.
I've tried setting the input parameters of all
at()
functions toconst
without success.
That doesn't help. What you must do is make it a const member function. But to do that, it must return a const reference.
However, you might need at
that returns non-const reference too. So how to solve that. I recommend checking how standard library deals with this. Let's look at std::vector
for example:
reference at( size_type pos );
const_reference at( size_type pos ) const;
Would you look at that. There are two overloads. One const and the other non-const. You can follow this example with your tensor class.
Upvotes: 2