baptiste
baptiste

Reputation: 1169

C++ Matrix multiplication type detection

In my C++ code I have a Matrix class, and some operators written to multiply them. My class is templated which mean I can have int, float, double ... matrices.

My operator overload is classic I guess

    template <typename T, typename U>
    Matrix<T>& operator*(const Matrix<T>& a, const Matrix<U>& b)
    {
    assert(a.rows() == b.cols() && "You have to multiply a MxN matrix with a NxP one to get a MxP matrix\n");
    Matrix<T> *c = new Matrix<T>(a.rows(), b.cols());
    for (int ci=0 ; ci<c->rows() ; ++ci)
    {
      for (int cj=0 ; cj<c->cols() ; ++cj)
      {
        c->at(ci,cj)=0;
        for (int k=0 ; k<a.cols() ; ++k)
        {
          c->at(ci,cj) += (T)(a.at(ci,k)*b.at(k,cj));
        }
      }
    }
    return *c;
  }

In this code I return a matrix of the same type than the first parameter i.e. Matrix<int> * Matrix<float> = Matrix<int>. My question is how can I detect the most precised type among the two I give to not lose too much precision i.e. to have Matrix<int> * Matrix<float> = Matrix<float> ? Is there a clever to do it ?

Upvotes: 0

Views: 208

Answers (1)

Barry
Barry

Reputation: 303107

What you want is just the type that happens when you multiply a T by a U. That can be given by:

template <class T, class U>
using product_type = decltype(std::declval<T>() * std::declval<U>());

You can just use that as an extra defaulted template parameter:

template <typename T, typename U, typename R = product_type<T, U>>
Matrix<R> operator*(const Matrix<T>& a, const Matrix<U>& b) {
    ...
}

In C++03 you can accomplish the same thing by doing a giant series of overloads with lots of small helper types like so (this is how Boost does it):

template <int I> struct arith;
template <int I, typename T> struct arith_helper {
    typedef T type;
    typedef char (&result_type)[I];
};

template <> struct arith<1> : arith_helper<1, bool> { };
template <> struct arith<2> : arith_helper<2, bool> { };
template <> struct arith<3> : arith_helper<3, signed char> { };
template <> struct arith<4> : arith_helper<4, short> { };
// ... lots more

We then can write:

template <class T, class U>
class common_type {
private:
    static arith<1>::result_type select(arith<1>::type );
    static arith<2>::result_type select(arith<2>::type );
    static arith<3>::result_type select(arith<3>::type );
    // ...

    static bool cond();
public:
    typedef typename arith<sizeof(select(cond() ? T() : U() ))>::type type;
};

Assuming you write out all the integer types, then you can use typename common_type<T, U>::type where before I used product_type.

If this isn't a demonstration of how cool C++11 is, I don't know what is.


Note, operator* should not return a reference. What you're doing will leak memory.

Upvotes: 9

Related Questions