Skip to content

Commit

Permalink
add param bucket_size for bucketing based on results
Browse files Browse the repository at this point in the history
  • Loading branch information
krunal1313 committed Dec 24, 2024
1 parent 0399aba commit 8ccd2ce
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 14 deletions.
18 changes: 13 additions & 5 deletions include/field.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,11 @@ struct sort_by {
linear,
};

enum text_match_params_t {
buckets,
bucket_size
};

struct eval_t {
filter_node_t** filter_trees = nullptr; // Array of filter_node_t pointers.
std::vector<uint32_t*> eval_ids_vec;
Expand All @@ -565,6 +570,7 @@ struct sort_by {

// for text_match score bucketing
uint32_t text_match_buckets;
uint32_t text_match_bucket_size;

// geo related fields
int64_t geopoint;
Expand All @@ -589,21 +595,21 @@ struct sort_by {
sort_by_type_t type{};

sort_by(const std::string & name, const std::string & order):
name(name), order(order), text_match_buckets(0), geopoint(0), exclude_radius(0), geo_precision(0),
missing_values(normal) {
name(name), order(order), text_match_buckets(0), text_match_bucket_size(0), geopoint(0), exclude_radius(0),
geo_precision(0), missing_values(normal) {
}

sort_by(std::vector<std::string> eval_expressions, std::vector<int64_t> scores, std::string order):
eval_expressions(std::move(eval_expressions)), order(std::move(order)), text_match_buckets(0), geopoint(0), exclude_radius(0),
geo_precision(0), missing_values(normal) {
eval_expressions(std::move(eval_expressions)), order(std::move(order)), text_match_buckets(0), text_match_bucket_size(0),
geopoint(0), exclude_radius(0), geo_precision(0), missing_values(normal) {
name = sort_field_const::eval;
eval.scores = std::move(scores);
type = eval_expression;
}

sort_by(const std::string &name, const std::string &order, uint32_t text_match_buckets, int64_t geopoint,
uint32_t exclude_radius, uint32_t geo_precision) :
name(name), order(order), text_match_buckets(text_match_buckets),
name(name), order(order), text_match_buckets(text_match_buckets), text_match_bucket_size(0),
geopoint(geopoint), exclude_radius(exclude_radius), geo_precision(geo_precision),
missing_values(normal) {
type = geopoint_field;
Expand All @@ -616,6 +622,7 @@ struct sort_by {
eval_expressions = other.eval_expressions;
order = other.order;
text_match_buckets = other.text_match_buckets;
text_match_bucket_size = other.text_match_bucket_size;
geopoint = other.geopoint;
exclude_radius = other.exclude_radius;
geo_precision = other.geo_precision;
Expand All @@ -642,6 +649,7 @@ struct sort_by {
eval_expressions = other.eval_expressions;
order = other.order;
text_match_buckets = other.text_match_buckets;
text_match_bucket_size = other.text_match_bucket_size;
geopoint = other.geopoint;
exclude_radius = other.exclude_radius;
geo_precision = other.geo_precision;
Expand Down
23 changes: 16 additions & 7 deletions src/collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1238,18 +1238,22 @@ Option<bool> Collection::validate_and_standardize_sort_fields(const std::vector<
std::vector<std::string> match_parts;
const std::string& match_config = sort_field_std.name.substr(paran_start+1, sort_field_std.name.size() - paran_start - 2);
StringUtils::split(match_config, match_parts, ":");
if(match_parts.size() != 2 || match_parts[0] != "buckets") {
if(match_parts.size() != 2 || (match_parts[0] != "buckets" && match_parts[0] != "bucket_size")) {
return Option<bool>(400, "Invalid sorting parameter passed for _text_match.");
}

if(!StringUtils::is_uint32_t(match_parts[1])) {
return Option<bool>(400, "Invalid value passed for _text_match `buckets` configuration.");
return Option<bool>(400, "Invalid value passed for _text_match `buckets` or `bucket_size` configuration.");
}

sort_field_std.name = actual_field_name;
sort_field_std.text_match_buckets = std::stoll(match_parts[1]);
sort_field_std.type = sort_by::text_match;

if(match_parts[0] == magic_enum::enum_name(sort_by::text_match_params_t::buckets)) {
sort_field_std.text_match_buckets = std::stoll(match_parts[1]);
} else if(match_parts[0] == magic_enum::enum_name(sort_by::text_match_params_t::bucket_size)) {
sort_field_std.text_match_bucket_size = std::stoll(match_parts[1]);
}
} else if(actual_field_name == sort_field_const::vector_query) {
const std::string& vector_query_str = sort_field_std.name.substr(paran_start + 1,
sort_field_std.name.size() - paran_start -
Expand Down Expand Up @@ -2710,21 +2714,26 @@ Option<nlohmann::json> Collection::search(collection_search_args_t& coll_args) c
// apply bucketing on text match score
int match_score_index = -1;
for(size_t i = 0; i < sort_fields_std.size(); i++) {
if(sort_fields_std[i].name == sort_field_const::text_match && sort_fields_std[i].text_match_buckets != 0) {
if(sort_fields_std[i].name == sort_field_const::text_match &&
(sort_fields_std[i].text_match_buckets != 0 || sort_fields_std[i].text_match_bucket_size != 0)) {
match_score_index = i;
break;
}
}

if(match_score_index >= 0 && sort_fields_std[match_score_index].text_match_buckets > 0) {
if(match_score_index >= 0 && (sort_fields_std[match_score_index].text_match_buckets > 0
|| sort_fields_std[match_score_index].text_match_bucket_size > 0)) {

size_t num_buckets = sort_fields_std[match_score_index].text_match_buckets;
size_t bucket_size = sort_fields_std[match_score_index].text_match_bucket_size;

const size_t max_kvs_bucketed = std::min<size_t>(Index::DEFAULT_TOPSTER_SIZE, raw_result_kvs.size());

if(max_kvs_bucketed >= num_buckets) {
if((num_buckets > 0 && max_kvs_bucketed >= num_buckets) || (bucket_size > 0 && max_kvs_bucketed >= bucket_size)) {
spp::sparse_hash_map<uint64_t, int64_t> result_scores;

// only first `max_kvs_bucketed` elements are bucketed to prevent pagination issues past 250 records
size_t block_len = (max_kvs_bucketed / num_buckets);
size_t block_len = num_buckets > 0 ? (max_kvs_bucketed / num_buckets) : bucket_size;
size_t i = 0;
while(i < max_kvs_bucketed) {
int64_t anchor_score = raw_result_kvs[i][0]->scores[raw_result_kvs[i][0]->match_score_index];
Expand Down
75 changes: 73 additions & 2 deletions test/collection_sorting_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1739,7 +1739,7 @@ TEST_F(CollectionSortingTest, TextMatchBucketRanking) {
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {3}, 1000, true);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Invalid value passed for _text_match `buckets` configuration.", res_op.error());
ASSERT_EQ("Invalid value passed for _text_match `buckets` or `bucket_size` configuration.", res_op.error());

// handle negative value
sort_fields[0] = sort_by("_text_match(buckets: -1)", "DESC");
Expand All @@ -1750,7 +1750,7 @@ TEST_F(CollectionSortingTest, TextMatchBucketRanking) {
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {3}, 1000, true);
ASSERT_FALSE(res_op.ok());
ASSERT_EQ("Invalid value passed for _text_match `buckets` configuration.", res_op.error());
ASSERT_EQ("Invalid value passed for _text_match `buckets` or `bucket_size` configuration.", res_op.error());

collectionManager.drop_collection("coll1");
}
Expand Down Expand Up @@ -3218,4 +3218,75 @@ TEST_F(CollectionSortingTest, DecayFunctionsTest) {
ASSERT_EQ(1728387250, results["hits"][3]["document"]["timestamp"].get<size_t>());
ASSERT_EQ("0", results["hits"][4]["document"]["id"]);
ASSERT_EQ(1728383250, results["hits"][4]["document"]["timestamp"].get<size_t>());
}

TEST_F(CollectionSortingTest, TextMatchBucketSizeRanking) {
std::vector<field> fields = {field("title", field_types::STRING, false),
field("description", field_types::STRING, false),
field("points", field_types::INT32, false),};

Collection *coll1 = collectionManager.create_collection("coll1", 1, fields, "points").get();

nlohmann::json doc1;
doc1["id"] = "0";
doc1["title"] = "Mark Antony";
doc1["description"] = "Counsellor";
doc1["points"] = 100;

nlohmann::json doc2;
doc2["id"] = "1";
doc2["title"] = "Marks Spencer";
doc2["description"] = "Sales Expert";
doc2["points"] = 200;

nlohmann::json doc3;
doc3["id"] = "2";
doc3["title"] = "Mark Twain";
doc3["description"] = "Writer";
doc3["points"] = 100;

nlohmann::json doc4;
doc4["id"] = "3";
doc4["title"] = "Mark Zuckerberg";
doc4["description"] = "Entrepreneur";
doc4["points"] = 300;

nlohmann::json doc5;
doc5["id"] = "4";
doc5["title"] = "Marks Henry";
doc5["description"] = "Wrestler";
doc5["points"] = 200;

nlohmann::json doc6;
doc6["id"] = "5";
doc6["title"] = "Mark Hughes";
doc6["description"] = "Football Coach";
doc6["points"] = 200;

ASSERT_TRUE(coll1->add(doc1.dump()).ok());
ASSERT_TRUE(coll1->add(doc2.dump()).ok());
ASSERT_TRUE(coll1->add(doc3.dump()).ok());
ASSERT_TRUE(coll1->add(doc4.dump()).ok());
ASSERT_TRUE(coll1->add(doc5.dump()).ok());
ASSERT_TRUE(coll1->add(doc6.dump()).ok());

sort_fields = {
sort_by("_text_match(bucket_size: 3)", "DESC"),
sort_by("points", "DESC"),
};

auto results = coll1->search("mark", {"title"},
"", {}, sort_fields, {2}, 10,
1, FREQUENCY, {true},
10, spp::sparse_hash_set<std::string>(),
spp::sparse_hash_set<std::string>(), 10, "", 30, 4, "title", 20, {}, {}, {}, 0,
"<mark>", "</mark>", {3}, 1000, true).get();

ASSERT_EQ(6, results["hits"].size());
ASSERT_EQ("3", results["hits"][0]["document"]["id"].get<std::string>());
ASSERT_EQ("5", results["hits"][1]["document"]["id"].get<std::string>());
ASSERT_EQ("4", results["hits"][2]["document"]["id"].get<std::string>());
ASSERT_EQ("1", results["hits"][3]["document"]["id"].get<std::string>());
ASSERT_EQ("2", results["hits"][4]["document"]["id"].get<std::string>());
ASSERT_EQ("0", results["hits"][5]["document"]["id"].get<std::string>());
}

0 comments on commit 8ccd2ce

Please sign in to comment.