diff --git a/include/radix_cpp.h b/include/radix_cpp.h index b0dee5a..b7de277 100644 --- a/include/radix_cpp.h +++ b/include/radix_cpp.h @@ -13,8 +13,6 @@ #endif namespace radix_cpp { - inline constexpr uint32_t flag_is_assigned = 1; - inline uint8_t append(uint8_t key, size_t digit) noexcept { return static_cast(digit); } @@ -837,23 +835,24 @@ namespace radix_cpp { std::vector free_list_; }; - size_t get_load_factor() const noexcept { return 100 * num_entries_ / table_size_; } + size_t get_load_factor(size_t nodes_to_insert = 0) const noexcept { return 100 * (num_entries_ + nodes_to_insert) / table_size_; } std::pair create_nodes_for_key(key_type key) { if (!nodes_) { init(bucket_count); - } else if (get_load_factor() >= max_load_factor100) { // Check the load factor + } + auto n = keysize(key); + // Make sure the hash has enough space for the whole key + while (get_load_factor(n) >= max_load_factor100) { resize(table_size_ * 2); } - num_inserts_++; - auto n = keysize(key); auto depth = static_cast(n); auto nodes_start = get_nodes_start(), nodes_end = get_nodes_end(); iterator it = end(); Node * final_node = nullptr; - + // insert digits from least significant to most significant // even if keysize is zero, add at least one digit (for empty strings) for ( size_t i = 0; i < (n == 0 ? 1 : n); i++, depth-- ) { @@ -861,7 +860,7 @@ namespace radix_cpp { auto hash0 = calc_unordered_hash(depth, prefix_key); auto hash = calc_final_hash(hash0, ordinal); auto node_initial = read_node(hash); - + auto node = node_initial; while ( 1 ) { if (!node->is_assigned()) { @@ -881,10 +880,10 @@ namespace radix_cpp { } break; } - + key = std::move(prefix_key); } - + return std::pair(final_node, it); } // getFirstConst returns the key from value_type for either set or map @@ -937,22 +936,25 @@ namespace radix_cpp { auto end = nodes_ + table_size_; for (; node != end; node++) { if (node->is_assigned()) { - auto hash0 = calc_unordered_hash(node->get_depth_lsb(), node->get_prefix_key()); + // get the least significant byte of depth from node, and the other bytes from the prefix key + size_t depth = ((keysize(node->get_prefix_key()) + 1) & ~UINT64_C(0xff)) | node->get_depth_lsb(); + auto hash0 = calc_unordered_hash(depth, node->get_prefix_key()); auto hash = calc_final_hash(hash0, node->get_ordinal()); auto new_node = new_nodes + (hash & new_mask); - + while ( 1 ) { if (new_node->is_assigned()) { if (++new_node == new_nodes_end) new_node = new_nodes; num_insert_collisions_++; } else { - std::swap(*node, *new_node); + new (static_cast(new_node)) Node(std::move(*node)); node->~Node(); break; } } } } + std::free(nodes_); nodes_ = new_nodes; table_size_ = new_size; diff --git a/tests/test.cpp b/tests/test.cpp index 9c46274..d301026 100644 --- a/tests/test.cpp +++ b/tests/test.cpp @@ -382,6 +382,35 @@ TEST_CASE( "count", "[set_count]") { REQUIRE(S.count(1) == 1); } +TEST_CASE( "long keys", "[long_keys]") { + std::string k1, k2, k3; + for (size_t i = 0; i < 400; i++) k1 += 'a'; + for (size_t i = 0; i < 800; i++) k2 += 'a'; + for (size_t i = 0; i < 1000; i++) k3 += 'z'; + radix_cpp::set S; + S.insert(k2); + S.insert(k1); + S.insert(""); + S.insert(k3); + S.insert("abc"); + REQUIRE(S.size() == 5); + auto it = S.begin(); + REQUIRE(*it++ == ""); + REQUIRE(*it++ == k1); + REQUIRE(*it++ == k2); + REQUIRE(*it++ == "abc"); + REQUIRE(*it++ == k3); + REQUIRE(it == S.end()); + S.erase(k1); + S.erase(k2); + S.erase(k3); + REQUIRE(S.size() == 2); + it = S.begin(); + REQUIRE(*it++ == ""); + REQUIRE(*it++ == "abc"); + REQUIRE(it == S.end()); +} + #if 0 TEST_CASE( "erase by range", "[erase_by_range]") { radix_cpp::set S;