Jerry Kakol
Jerry Kakol

Reputation: 11

How to partially specialize a template based on the relation between its two integer parameters

I want to design a m x n matrix class (as a template parameterized by m rows and n columns) and need to specialize it in order to equip it with operations that are mathematically possible based on three conditions:

  1. m > n
  2. m == n
  3. no specialization for m < n, that is, basic or default implementation

The template signature is simply:

template <size_t m, size_t n, typename T = double> class MatrixBase
{
....
};

How do I do that? Can it be done with type traits? Or should I use std::conditional<> or std::enable_if<> ?. Conceptually, what I want to accomplish is to add methods to a class but without subclassing it and creating any inheritance hierarchy. The derivation tree I want to use for other things, namely the data storage within the matrix.

So I would like to have a matrix that if declared as for instance MatrixBase<4, 4, float> has (by virtue of specialization) a method called inverse (), while matrices declared with m <> n don't. Similarly, extra methods for matrices with m > n.

Upvotes: 1

Views: 69

Answers (1)

Max
Max

Reputation: 20004

It can be done with std::enable_if:

template <size_t m, size_t n, typename T = double>
class MatrixBase
{
public:

    template <typename T1 = T>
    std::enable_if_t<m == n, MatrixBase<m, m, T1>> inverse() const
    {
        // Calculate inverse
        return{};
    }
};

int main(int argc, const char *argv[])
{
    auto msquare = MatrixBase<4, 4>();
    auto mrect = MatrixBase<4, 3>();
    msquare.inverse(); // No error
    mrect.inverse();   // Compilation error

    return 0;
}

For partial specialization you can also use enable_if:

template <size_t m, size_t n, typename T = double, typename E = void>
class MatrixBase
{
public:

    template <typename T1 = T>
    std::enable_if_t<m == n, MatrixBase<m, m, T1>> inverse() const
    {
        // Calculate inverse
        return{};
    }
};

template <size_t m, size_t n, typename T>
class MatrixBase<m, n, T, std::enable_if_t<m == n, void>>
{
public:

    static bool m_equals_n()
    {
        return true;
    }

};

template <size_t m, size_t n, typename T>
class MatrixBase<m, n, T, std::enable_if_t<n < m, void>>
{
public:

    static bool m_greater_than_n()
    {
        return true;
    }

};


template <size_t m, size_t n, typename T>
class MatrixBase < m, n, T, std::enable_if_t<m < n, void>>
{
public:

    static bool m_less_than_n()
    {
        return true;
    }

};


int main(int argc, const char *argv[])
{
    auto msquare = MatrixBase<4, 4>();
    auto m_4_3 = MatrixBase<4, 3>();
    auto m_3_4 = MatrixBase<3, 4>();

    msquare.m_equals_n();
    //msquare.m_greater_than_n();  // Compilation error
    m_4_3.m_greater_than_n();
    m_3_4.m_less_than_n();


    return 0;
}

Upvotes: 3

Related Questions