emilib
hash_set_robin_hood.hpp
1 // By Emil Ernerfeldt 2014-2016
2 // LICENSE:
3 // This software is dual-licensed to the public domain and under the following
4 // license: you are granted a perpetual, irrevocable license to copy, modify,
5 // publish, and distribute this file as you see fit.
6 
7 // This is heavily inspired by https://gist.github.com/ssylvan/5538011
8 
9 #pragma once
10 
11 #include <cstdlib> // std::malloc
12 #include <cstring> // std::memset
13 #include <iterator>
14 #include <utility> // std::move and std::pair
15 
16 namespace emilib {
17 
18 // like std::equal_to but no need to #include <functional>
19 template<typename T>
21 {
22  constexpr bool operator()(const T &lhs, const T &rhs) const
23  {
24  return lhs == rhs;
25  }
26 };
27 
28 // A cache-friendly hash set with open addressing, linear probing and power-of-two capacity
29 template <typename KeyT, typename HasherT = std::hash<KeyT>, typename EqT = HashSetRHEqualTo<KeyT>>
30 class HashSetRH
31 {
32 private:
34 
35 public:
36  using size_type = size_t;
37  using value_type = KeyT;
38  using reference = KeyT&;
39  using const_reference = const KeyT&;
40  using HashCode = size_t;
41 
42  static const size_t kMaxLoadFactorPercent = 90;
43  static const HashCode kUnusedHashCode = 0;
44  static const HashCode kTombStoneFlag = HashCode(1) << HashCode(8 * sizeof(HashCode) - 1);
45 
46  class iterator
47  {
48  public:
49  using iterator_category = std::forward_iterator_tag;
50  using difference_type = size_t;
51  using distance_type = size_t;
52  using value_type = KeyT;
53  using pointer = value_type*;
54  using reference = value_type&;
55 
56  iterator() { }
57 
58  iterator(MyType* hash_set, size_t bucket) : _set(hash_set), _bucket(bucket)
59  {
60  }
61 
62  iterator& operator++()
63  {
64  goto_next_element();
65  return *this;
66  }
67 
68  iterator operator++(int)
69  {
70  size_t old_index = _bucket;
71  goto_next_element();
72  return iterator(_set, old_index);
73  }
74 
75  reference operator*() const
76  {
77  return _set->_keys[_bucket];
78  }
79 
80  pointer operator->() const
81  {
82  return _set->_keys + _bucket;
83  }
84 
85  bool operator==(const iterator& rhs)
86  {
87  return this->_bucket == rhs._bucket;
88  }
89 
90  bool operator!=(const iterator& rhs)
91  {
92  return this->_bucket != rhs._bucket;
93  }
94 
95  private:
96  void goto_next_element()
97  {
98  do {
99  _bucket++;
100  } while (_bucket < _set->_num_buckets && !_is_filled_hash(_set->_hashes[_bucket]));
101  }
102 
103  public:
104  MyType* _set;
105  size_t _bucket;
106  };
107 
109  {
110  public:
111  using iterator_category = std::forward_iterator_tag;
112  using difference_type = size_t;
113  using distance_type = size_t;
114  using value_type = const KeyT;
115  using pointer = value_type*;
116  using reference = value_type&;
117 
118  const_iterator() { }
119 
120  const_iterator(iterator proto) : _set(proto._set), _bucket(proto._bucket)
121  {
122  }
123 
124  const_iterator(const MyType* hash_set, size_t bucket) : _set(hash_set), _bucket(bucket)
125  {
126  }
127 
128  const_iterator& operator++()
129  {
130  goto_next_element();
131  return *this;
132  }
133 
134  const_iterator operator++(int)
135  {
136  size_t old_index = _bucket;
137  goto_next_element();
138  return const_iterator(_set, old_index);
139  }
140 
141  reference operator*() const
142  {
143  return _set->_keys[_bucket];
144  }
145 
146  pointer operator->() const
147  {
148  return _set->_keys + _bucket;
149  }
150 
151  bool operator==(const const_iterator& rhs)
152  {
153  return this->_bucket == rhs._bucket;
154  }
155 
156  bool operator!=(const const_iterator& rhs)
157  {
158  return this->_bucket != rhs._bucket;
159  }
160 
161  private:
162  void goto_next_element()
163  {
164  do {
165  _bucket++;
166  } while (_bucket < _set->_num_buckets && !_is_filled_hash(_set->_hashes[_bucket]));
167  }
168 
169  public:
170  const MyType* _set;
171  size_t _bucket;
172  };
173 
174  // ------------------------------------------------------------------------
175 
176  HashSetRH() = default;
177 
178  HashSetRH(const HashSetRH& other)
179  {
180  reserve(other.size());
181  insert(cbegin(other), cend(other));
182  }
183 
184  HashSetRH(HashSetRH&& other)
185  {
186  *this = std::move(other);
187  }
188 
189  HashSetRH& operator=(const HashSetRH& other)
190  {
191  clear();
192  reserve(other.size());
193  insert(cbegin(other), cend(other));
194  return *this;
195  }
196 
197  void operator=(HashSetRH&& other)
198  {
199  this->swap(other);
200  }
201 
202  ~HashSetRH()
203  {
204  for (size_t bucket=0; bucket<_num_buckets; ++bucket) {
205  if (_is_filled_hash(_hashes[bucket])) {
206  _keys[bucket].~KeyT();
207  }
208  }
209  std::free(_hashes);
210  std::free(_keys);
211  }
212 
213  void swap(HashSetRH& other)
214  {
215  std::swap(_hasher, other._hasher);
216  std::swap(_eq, other._eq);
217  std::swap(_keys, other._keys);
218  std::swap(_hashes, other._hashes);
219  std::swap(_num_buckets, other._num_buckets);
220  std::swap(_num_filled, other._num_filled);
221  std::swap(_rehash_limit, other._rehash_limit);
222  std::swap(_mask, other._mask);
223  }
224 
225  // -------------------------------------------------------------
226 
227  iterator begin()
228  {
229  size_t bucket = 0;
230  while (bucket<_num_buckets && !_is_filled_hash(_hashes[bucket])) {
231  ++bucket;
232  }
233  return iterator(this, bucket);
234  }
235 
236  const_iterator begin() const
237  {
238  size_t bucket = 0;
239  while (bucket<_num_buckets && !_is_filled_hash(_hashes[bucket])) {
240  ++bucket;
241  }
242  return const_iterator(this, bucket);
243  }
244 
245  iterator end()
246  { return iterator(this, _num_buckets); }
247 
248  const_iterator end() const
249  { return const_iterator(this, _num_buckets); }
250 
251  size_t size() const
252  {
253  return _num_filled;
254  }
255 
256  bool empty() const
257  {
258  return _num_filled==0;
259  }
260 
261  size_t bucket_count() const
262  {
263  return _num_buckets;
264  }
265 
266  // ------------------------------------------------------------
267 
268  iterator find(const KeyT& key)
269  {
270  auto bucket = _find_filled_bucket(key);
271  if (bucket == (size_t)-1) {
272  return end();
273  }
274  return iterator(this, bucket);
275  }
276 
277  const_iterator find(const KeyT& key) const
278  {
279  auto bucket = _find_filled_bucket(key);
280  if (bucket == (size_t)-1) {
281  return end();
282  }
283  return const_iterator(this, bucket);
284  }
285 
286  bool contains(const KeyT& k) const
287  {
288  return _find_filled_bucket(k) != (size_t)-1;
289  }
290 
291  size_t count(const KeyT& k) const
292  {
293  return contains(k) ? 1 : 0;
294  }
295 
296  // -----------------------------------------------------
297 
298  // Insert an element, unless it already exists.
299  // Returns a pair consisting of an iterator to the inserted element
300  // (or to the element that prevented the insertion)
301  // and a bool denoting whether the insertion took place.
302  std::pair<iterator, bool> insert(KeyT key)
303  {
304  const HashCode hash = _hash_key(key);
305  size_t existing_pos = _find_filled_bucket(hash, key);
306  if (existing_pos == (size_t)-1) {
307  // printf("insert\n");
308  _check_expand_need();
309  size_t pos = _insert_unique_no_expand_check(hash, std::move(key));
310  return {iterator{this, pos}, true};
311  } else {
312  // printf("already_existed\n");
313  return {iterator{this, existing_pos}, false};
314  }
315  }
316 
317  template<class... Args>
318  std::pair<iterator, bool> emplace(Args&&... args)
319  {
320  return insert(KeyT(std::forward<Args>(args)...));
321  }
322 
323  void insert(const_iterator begin, const_iterator end)
324  {
325  for (; begin != end; ++begin) {
326  insert(*begin);
327  }
328  }
329 
330  // Same as above, but contains(key) MUST be false
331  void insert_unique(KeyT key)
332  {
333  _check_expand_need();
334  _insert_unique_no_expand_check(_hash_key(key), std::move(key));
335  }
336 
337  // -------------------------------------------------------
338 
339  /* Erase an element from the hash set.
340  return false if element was not found */
341  bool erase(const KeyT& key)
342  {
343  const HashCode hash = _hash_key(key);
344  const size_t pos = _find_filled_bucket(hash, key);
345  if (pos != (size_t)-1) {
346  _keys[pos].~KeyT();
347  _hashes[pos] |= kTombStoneFlag;
348  --_num_filled;
349  return true;
350  }
351  }
352 
353  /* Erase an element using an iterator.
354  Returns an iterator to the next element (or end()). */
355  iterator erase(iterator it)
356  {
357  _keys[it._bucket].~KeyT();
358  _hashes[it._bucket] |= kTombStoneFlag;
359  --_num_filled;
360  return true;
361  }
362 
363  // Remove all elements, keeping full capacity.
364  void clear()
365  {
366  for (size_t bucket=0; bucket<_num_buckets; ++bucket) {
367  if (_is_filled_hash(_hashes[bucket])) {
368  _keys[bucket].~KeyT();
369  }
370  }
371  _num_filled = 0;
372  static_assert(kUnusedHashCode == 0, "");
373  std::memset(_hashes, 0, _num_buckets * sizeof(HashCode));
374  }
375 
376  // Make room for this many elements
377  void reserve(size_t num_elems)
378  {
379  if (num_elems <= _rehash_limit) {
380  return;
381  }
382 
383  size_t min_required_buckets = 1 + num_elems * 100 / kMaxLoadFactorPercent;
384  size_t num_buckets = 4;
385  while (num_buckets < min_required_buckets) { num_buckets *= 2; }
386 
387  // printf("New size: %lu\n", num_buckets);
388 
389  auto new_hashes = (HashCode*)std::malloc(num_buckets * sizeof(HashCode));
390  auto new_keys = (KeyT*)std::malloc(num_buckets * sizeof(KeyT));
391 
392  if (new_hashes == nullptr || new_keys == nullptr) {
393  std::free(new_hashes);
394  std::free(new_keys);
395  throw std::bad_alloc();
396  }
397 
398  // auto old_num_filled = _num_filled;
399  auto old_num_buckets = _num_buckets;
400  auto old_hashes = _hashes;
401  auto old_keys = _keys;
402 
403  _num_filled = 0;
404  _num_buckets = num_buckets;
405  _rehash_limit = (_num_buckets * kMaxLoadFactorPercent) / 100;
406  _mask = _num_buckets - 1;
407  _hashes = new_hashes;
408  _keys = new_keys;
409 
410  static_assert(kUnusedHashCode == 0, "");
411  std::memset(_hashes, 0, _num_buckets * sizeof(HashCode));
412 
413  for (size_t src_bucket=0; src_bucket<old_num_buckets; ++src_bucket) {
414  auto src_hash = old_hashes[src_bucket];
415  if (_is_filled_hash(src_hash)) {
416  auto& src_key = old_keys[src_bucket];
417  _insert_unique_no_expand_check(src_hash, std::move(src_key));
418  src_key.~KeyT();
419  }
420  }
421 
422  std::free(old_hashes);
423  std::free(old_keys);
424  }
425 
426 private:
427  // Can we fit another element?
428  void _check_expand_need()
429  {
430  reserve(_num_filled + 1);
431  }
432 
433  uint32_t _hash_key(const KeyT& key) const
434  {
435  HashCode hash = static_cast<HashCode>(_hasher(key));
436 
437  // Ensure kTombStoneFlag is cleared:
438  hash &= ~kTombStoneFlag;
439 
440  // Ensure that we never return kUnusedHashCode as a hash:
441  hash ^= hash==kUnusedHashCode;
442 
443  return hash;
444  }
445 
446  static bool _is_deleted_hash(HashCode hash)
447  {
448  return (hash & kTombStoneFlag) != 0;
449  }
450 
451  static bool _is_filled_hash(HashCode hash)
452  {
453  return hash != kUnusedHashCode && !_is_deleted_hash(hash);
454  }
455 
456  size_t _desired_pos(HashCode hash) const
457  {
458  return hash & _mask;
459  }
460 
461  size_t _probe_distance(HashCode hash, HashCode pos) const
462  {
463  return (pos + _num_buckets - _desired_pos(hash)) & _mask;
464  }
465 
466  void _construct(size_t pos, HashCode hash, KeyT&& key)
467  {
468  new (&_keys[pos]) KeyT(std::move(key));
469  _hashes[pos] = hash;
470  ++_num_filled;
471  }
472 
473  size_t _insert_unique_no_expand_check(HashCode hash, KeyT&& key)
474  {
475  size_t pos = _desired_pos(hash);
476  size_t dist = 0;
477  for (;;) {
478  // printf("_insert_unique_no_expand_check: %lu\n", pos);
479  if (_hashes[pos] == kUnusedHashCode) {
480  _construct(pos, hash, std::move(key));
481  return pos;
482  }
483 
484  // If the existing elem has probed less than us, then swap places with existing
485  // elem, and keep going to find another slot for that elem.
486  size_t existing_elem_probe_dist = _probe_distance(_hashes[pos], pos);
487  if (existing_elem_probe_dist < dist) {
488  if (_is_deleted_hash(_hashes[pos])) {
489  _construct(pos, hash, std::move(key));
490  return pos;
491  }
492 
493  std::swap(hash, _hashes[pos]);
494  std::swap(key, _keys[pos]);
495  dist = existing_elem_probe_dist;
496  }
497 
498  pos = (pos + 1) & _mask;
499  ++dist;
500  }
501  }
502 
503  size_t _find_filled_bucket(HashCode hash, const KeyT& key) const
504  {
505  if (_num_buckets == 0) { return (size_t)-1; }
506  size_t pos = _desired_pos(hash);
507  size_t dist = 0;
508  for (;;) {
509  // printf("_find_filled_bucket: %lu\n", pos);
510  if (_hashes[pos] == kUnusedHashCode) {
511  return (size_t)-1;
512  } else if (dist > _probe_distance(_hashes[pos], pos)) {
513  return (size_t)-1;
514  } else if (_hashes[pos] == hash && _keys[pos] == key) {
515  return pos;
516  }
517 
518  pos = (pos + 1) & _mask;
519  ++dist;
520  }
521  }
522 
523 private:
524  // TODO: __restrict
525  HasherT _hasher;
526  EqT _eq;
527  KeyT* _keys = nullptr;
528  HashCode* _hashes = nullptr;
529  size_t _num_buckets = 0; // capacity
530  size_t _num_filled = 0; // size
531  size_t _rehash_limit = 0; // when _num_buckets exceed this, rehash.
532  size_t _mask = 0; // _num_buckets minus one
533 };
534 
535 } // namespace emilib
Definition: hash_set_robin_hood.hpp:108
Definition: hash_set_robin_hood.hpp:20
Definition: hash_set_robin_hood.hpp:46
Definition: hash_set_robin_hood.hpp:30
Definition: coroutine.hpp:18