iggy12345
iggy12345

Reputation: 1383

How to define a custom key equivalence predicate for an std::unordered_set?

I'm trying to define a key equivalence predicate for a hash set I'm using, but the compiler keeps telling me that function template "UniqueTable::table_key_equiv_pred" is not a type name and I don't understand why.

I've defined my equivalence function as this:

template<class t>
bool UniqueTable::table_key_equiv_pred(t a, node b) const
{
    return a == b;
}

template<>
bool UniqueTable::table_key_equiv_pred<int>(int a, node b) const
{
    return b && a == b->level;
}

template<>
bool UniqueTable::table_key_equiv_pred<node>(node a, node b) const
{
    return b && a && a->level == b->level && a->one == b->one && a->zero == b->zero;
}

And then instantiating a set using:

auto hashtable = new std::unordered_set<node, std::hash<node>, table_key_equiv_pred>();

How do I define a predicate to be used as a type parameter to an std::unordered_set?

Upvotes: 1

Views: 1070

Answers (1)

Indiana Kernick
Indiana Kernick

Reputation: 5331

The third template parameter of std::unordered_set must be the type of a comparison object. table_key_equiv_pred is not a type. The default is std::equal_to and is defined similar to this:

template <typename T>
struct equal_to {
  bool operator()(const T &lhs, const T &rhs) const {
    return lhs == rhs;
  }
};

You want to be able to search for nodes in a set using ints or any other type. Doing this which std::unordered_set is not impossible, just kind of a pain. The problem is that std::unordered_set compares hashes first, then compares the actual objects second. If you want to search for an node using an int, then both of those must have the same hashes and compare equal.

From the code in the question, I've gathered that a node is more than just a level so the hash of an 4 is not going to be equal to the hash of a node with level == 4 (unless you've defined std::hash<node> to just return level). This means that the custom comparison object is not even going to be used before it is concluded that 4 is not in the set even if a level 4 node is.

Being able to search for a node that compares to anything is tricky which a std::unordered_set because you need a "transparent" hash function. I suggest you switch over to a binary tree set (std::set). This set compares elements for order (using a less-than comparison). This means that you only need to define a less-than comparison function.

It's rather uncommon to actually create a custom comparison object and do something like this: std::set<node, MyCustomLessThan>. What is usually done is defining a custom less-than operator for the type and then letting the default comparison object (in this case it's std::less) call the operator. Now you want to use the "transparent" less.

std::less<int> compares two integers. std::less<node> compares two nodes. std::less<> compares anything to anything and is called transparent less. std::less<> is defined similar to this:

template <>
struct less<void> {
  template <typename Left, typename Right>
  bool operator()(const Left &lhs, const Right &rhs) const {
    return lhs < rhs;
  }
};

So the data structure you want is std::set<node, std::less<>> but you still need to define the comparison between integers and nodes. You can do this by overloading the less-than operator.

// I'm not confident that these implementations are correct
// Nullable objects make this pretty tricky!
bool operator<(const node &lhs, const int rhs) {
  return !lhs || lhs->level < rhs;
}
bool operator<(const int lhs, const node &rhs) {
  return rhs && lhs < rhs->level;
}
bool operator<(const node &lhs, const node &rhs) {
  // I'll probably get this wrong so I'll leave this up to you!
}

Here's a full example (note that this uses C++14 features):

#include <set>
#include <iostream>    

struct Node {
  int level;
  int thing;
};

bool operator<(const Node lhs, const int rhs) {
  return lhs.level < rhs;
}
bool operator<(const int lhs, const Node rhs) {
  return lhs < rhs.level;
}
bool operator<(const Node lhs, const Node rhs) {
  // A better way to do this is with std::tie
  // see https://stackoverflow.com/a/16090720/4093378
  if (lhs.level < rhs.level) return true;
  if (lhs.level > rhs.level) return false;
  if (lhs.thing < rhs.thing) return true;
  if (lhs.thing > rhs.thing) return false;
  return false;
}

int main() {
  std::set<Node, std::less<>> set;
  set.insert({1, 91});
  set.insert({1, 87});
  set.insert({2, 43});
  set.insert({2, 10});

  // find a level 2 node
  // there's more than 1 level 2 node so you'll get the first one
  // see https://en.cppreference.com/w/cpp/container/set/find
  if (auto iter = set.find(2); iter != set.end()) {
    // prints "2 - 10"
    std::cout << iter->level << " - " << iter->thing << '\n';
  }

  // get a range of all the level 1 nodes
  // see https://en.cppreference.com/w/cpp/container/set/equal_range
  auto range = set.equal_range(1);
  for (auto iter = range.first; iter != range.second; ++iter) {
    // prints "1 - 87" then "1 - 91"
    std::cout << iter->level << " - " << iter->thing << '\n';
  }

  // a neat thing about std::set is that all the elements are kept in order
  // that's how std::set differs from std::unordered_set!
  // this prints "1 - 87", "1 - 91", "2 - 10" and then "2 - 43"
  for (const Node node : set) {
    std::cout << node.level << " - " << node.thing << '\n';
  }
}

If you'd prefer to use std::unordered_set instead then you'll have to use a transparent hash function. This is a C++20 feature. Read the documentation I've linked for more information.

In order to search for a node with a given level, hashes for the level and the node must be compatible. This means that the hash of level 5 must be the same as the hash of a level 5 node.

Here's a full example (note that this uses C++20 features):

#include <iostream>
#include <unordered_set>

struct Node {
  int level;
  int thing;
};

struct NodeEqual {
  struct is_transparent {};

  bool operator()(const Node lhs, const int rhs) const noexcept {
    return lhs.level == rhs;
  }
  bool operator()(const int lhs, const Node rhs) const noexcept {
    return lhs == rhs.level;
  }
  bool operator()(const Node lhs, const Node rhs) const noexcept {
    return lhs.level == rhs.level && lhs.thing == rhs.thing;
  }
};

struct NodeHash {
  using transparent_key_equal = NodeEqual;

  size_t operator()(const Node node) const noexcept {
    // note that the hash depends only on level
    // if lots of nodes have the same level,
    // then lots of nodes will have the same hash
    // this could lead to lots of equality comparisons
    return node.level;
  }
  size_t operator()(const int level) const noexcept {
    return level;
  }
};

int main() {
  std::unordered_set<Node, NodeHash> set;
  set.insert({1, 91});
  set.insert({1, 87});
  set.insert({2, 43});
  set.insert({2, 10});

  // find a level 2 node
  // see https://en.cppreference.com/w/cpp/container/unordered_set/find
  if (auto iter = set.find(2); iter != set.end()) {
    // prints either "2 - 43" or "2 - 10"
    std::cout << iter->level << " - " << iter->thing << '\n';
  }

  // getting a range of all the level 1 nodes isn't actually possible
  // there are a couple of tweaks required to make it possible
  // however, these tweaks will further reduce the performance of std::unordered_set
}

I can't actually test this example (because it's C++20) so I'm not sure if it's correct. I can only hope that I've read the docs properly!

std::set might actually be faster than std::unordered_set unless I know more about the uniqueness of level.

Upvotes: 2

Related Questions