Jedediah Heal
Jedediah Heal

Reputation: 123

Overload C++ template class method by it's template classes

I am writing a template class that manages a union of 2 classes. All the functions are pretty simple except the Get() functions. It looks like this:

UnionPair.hpp

template <class First, class Second>
class UnionPair {
 public:
  UnionPair() : state_(State::kEmpty){};
  ~UnionPair(){};

  void Reset();

  void Set(std::unique_ptr<First>&& first);
  void Set(std::unique_ptr<Second>&& second);

  template <First>
  First *Get();

  template <Second>
  Second *Get();

 private:
  enum class State { kEmpty, kFirst, kSecond } state_;
  union {
    std::unique_ptr<First> first_;
    std::unique_ptr<Second> second_;
  };
};

UnionPair.cpp

template <class First, class Second>
template <First>
First *UnionPair<First, Second>::Get() {
  return first_.get();
}

template <class First, class Second>
template <Second>
Second *UnionPair<First, Second>::Get() {
  return second_.get();
}

I'm trying to make the Get() functions template methods that can only be called with the template classes of it's instantiated UnionPair object, so that they can be overloaded by the template class they're called with. The above code does not compile.

My idea of how it would be called is this, but I get a compiler error when I try to call them:

// Definitions of structs A, B and C here

UnionPair<A, B> pair;
pair.Set(std::unique_ptr(new A()));
pair.Get<A>(); // should return a pointer to A (causes a compiler error right now)
pair.Get<C>(); // should cause a compiler error

I'm not really sure how to implement this method in the .cpp file, or if it's even possible. I've been reading about specialization within templates, but haven't seen any examples similar to mine, or examples where the specialization class MUST be the same as the object's template classes.

What is wrong with my usage/definition? Is what I'm trying to do possible? If not, are there any alternatives? What I DON'T want to do is have a seperate getter for First and Second like First *GetFirst() because that relies on the order of the classes in the object instantiation and isn't clear.

Side note: If there is a c++ library that manages the state, setting and getting of union members I'd consider using it, but I'd still like to figure my question out.

Upvotes: 1

Views: 702

Answers (1)

Remy Lebeau
Remy Lebeau

Reputation: 596256

First off, you can't split template code into separate .h and .cpp files:

Why can templates only be implemented in the header file?

Second, a union's data is not set until runtime, so there is no way for the compiler to validate the template parameter of Get() at compile-time, at least not the way you want. It is possible to validate at compile-time only whether the union could never convert to the specified type at all, but if it could convert then you can't validate that until runtime, after the union has been assigned.

Try something like this:

#include <memory>
#include <type_traits>

template <class First, class Second>
class UnionPair {
 static_assert(!std::is_same<First, Second>::value, "First and Second can't be the same type");

 public:
  UnionPair();
  UnionPair(std::unique_ptr<First>&& value);
  UnionPair(std::unique_ptr<Second>&& value);
  ~UnionPair(){}

  void Reset();

  void Set(std::unique_ptr<First>&& value);
  void Set(std::unique_ptr<Second>&& value);

  template <class T>
  T* Get();

 private:
  enum class State { kEmpty, kFirst, kSecond } state_;

  union {
    std::unique_ptr<First> first_;
    std::unique_ptr<Second> second_;
  };

  template<class T>
  typename std::enable_if<std::is_same<T, First>::value, T*>::type
    InternalGet();

  template<class T>
  typename std::enable_if<std::is_same<T, Second>::value, T*>::type
    InternalGet();
};

template <class First, class Second>
UnionPair<First, Second>::UnionPair()
    : state_(State::kEmpty)
{
}

template <class First, class Second>
UnionPair<First, Second>::UnionPair(std::unique_ptr<First>&& value)
    : UnionPair()
{
    Set(std::move(value));
}

template <class First, class Second>
UnionPair<First, Second>::UnionPair(std::unique_ptr<Second>&& value)
    : UnionPair()
{
    Set(std::move(value));
}

template <class First, class Second>
void UnionPair<First, Second>::Reset()
{
  if (state_ == State::kFirst)
    first_.reset();
  else if (state_ == State::kSecond)
    second_.reset();
  state_ = State::kEmpty;
}

template <class First, class Second>
void UnionPair<First, Second>::Set(std::unique_ptr<First>&& value)
{
    Reset();
    first_ = std::move(value);
    state_ = State::kFirst;
}

template <class First, class Second>
void UnionPair<First, Second>::Set(std::unique_ptr<Second>&& value)
{
    Reset();
    second_ = std::move(value);
    state_ = State::kSecond;
}

template <class First, class Second>
template <class T>
typename std::enable_if<std::is_same<T, First>::value, T*>::type
  UnionPair<First, Second>::InternalGet() {
    if (state_ == State::kFirst)
        return first_.get();
    throw std::domain_error("wrong state");
}

template <class First, class Second>
template <class T>
typename std::enable_if<std::is_same<T, Second>::value, T*>::type
  UnionPair<First, Second>::InternalGet() {
    if (state_ == State::kSecond)
        return second_.get();
    throw std::domain_error("wrong state");
}

template <class First, class Second>
template <class T>
T* UnionPair<First, Second>::Get() {
    return InternalGet<T>();
}
UnionPair<A, B> pair;
pair.Set(std::unique_ptr<A>(new A));
pair.Get<A>(); // OK
pair.Get<B>(); // runtime error
pair.Get<C>(); // compiler error

Upvotes: 2

Related Questions