Reputation: 174
I'm trying to write a templated function taking an Eigen::Tensor
as an argument. The same approach that works for Eigen::Matrix
etc. does not work here.
Eigen recommends writing functions using a common base class. https://eigen.tuxfamily.org/dox/TopicFunctionTakingEigenTypes.html
A minimal example for Eigen::Matrix
that compiles:
#include <Eigen/Dense>
template <typename Derived>
void func(Eigen::MatrixBase<Derived>& a)
{
a *= 2;
}
int main()
{
Eigen::Matrix<int, 2, 2> matrix;
func(matrix);
}
And the minimal example for Eigen::Tensor
that does not compile:
#include <unsupported/Eigen/CXX11/Tensor>
template <typename Derived>
void func(Eigen::TensorBase<Derived>& a)
{
a *= 2;
}
int main()
{
Eigen::Tensor<int, 1> tensor;
func(tensor);
}
$ g++ -std=c++11 -I /usr/include/eigen3 eigen_tensor_func.cpp
eigen_tensor_func.cpp: In function ‘int main()’:
eigen_tensor_func.cpp:12:16: error: no matching function for call to ‘func(Eigen::Tensor<int, 1>&)’
func(tensor);
^
eigen_tensor_func.cpp:4:6: note: candidate: ‘template<class Derived> void func(Eigen::TensorBase<Derived>&)’
void func(Eigen::TensorBase<Derived>& a)
^~~~
eigen_tensor_func.cpp:4:6: note: template argument deduction/substitution failed:
eigen_tensor_func.cpp:12:16: note: ‘Eigen::TensorBase<Derived>’ is an ambiguous base class of ‘Eigen::Tensor<int, 1>’
func(tensor);
Upvotes: 4
Views: 317
Reputation: 18807
The Tensor-Module is still far away from being fully compatible with the Eigen/Core functionality (this also implies the documentation of the core functionality does not necessarily apply to the Tensor-Module, of course).
First major difference is that TensorBase
takes two template arguments instead of one, i.e., you need to write TensorBase<Derived, Eigen::WriteAccessors>
. Also some functionality is either not implemented at all, or TensorBase
does not properly forward it. The following works with current trunk (2019-04-03):
template <typename Derived>
void func(Eigen::TensorBase<Derived, Eigen::WriteAccessors>& a)
{
// a *= 2; // operator*=(Scalar) not implemented
// a = 2*a; // operator=(...) not implemented/forwarded
a *= a; // ok
a *= 2*a; // ok
a *= 0*a+2; // ok
// a.derived() = 2*a; // derived() is not public
static_cast<Derived&>(a) = a*2; // ok
}
Upvotes: 3