Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Copeland Fusion for Hybrid Search #915

Open
wants to merge 11 commits into
base: mainline
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/marqo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pip install -r requirements.dev.txt
3. Pull and run the Vespa docker image
```bash
docker run --detach --name vespa --hostname vespa-tutorial \
--publish 8080:8080 --publish 19071:19071 \
--publish 8080:8080 --publish 19071:19071 --publish 2181:2181 \
OwenPendrighElliott marked this conversation as resolved.
Show resolved Hide resolved
vespaengine/vespa:latest
```

Expand Down
9 changes: 5 additions & 4 deletions src/marqo/core/models/hybrid_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class RetrievalMethod(str, Enum):


class RankingMethod(str, Enum):
Copeland = 'copeland'
RRF = 'rrf'
Tensor = 'tensor'
Lexical = 'lexical'
Expand Down Expand Up @@ -69,23 +70,23 @@ def validate_properties(cls, values):

# score_modifiers_lexical can only be defined for Lexical, RRF, NormalizeLinear
if values.get('scoreModifiersLexical') is not None:
if not (values.get('rankingMethod') in [RankingMethod.Lexical, RankingMethod.RRF] or
if not (values.get('rankingMethod') in [RankingMethod.Lexical, RankingMethod.RRF, RankingMethod.Copeland] or
values.get('retrievalMethod') == RetrievalMethod.Lexical):
raise ValueError(
"'scoreModifiersLexical' can only be defined for 'lexical', 'rrf' ranking methods or "
"'lexical' retrieval method.") # TODO: re-add normalize_linear

# score_modifiers_tensor can only be defined for Tensor, RRF, NormalizeLinear
if values.get('scoreModifiersTensor') is not None:
if values.get('rankingMethod') not in [RankingMethod.Tensor, RankingMethod.RRF]:
if values.get('rankingMethod') not in [RankingMethod.Tensor, RankingMethod.RRF, RankingMethod.Copeland]:
raise ValueError(
"'scoreModifiersTensor' can only be defined for 'tensor', 'rrf', ranking methods") # TODO: re-add normalize_linear

# if retrievalMethod == Disjunction, then ranking_method must be RRF, NormalizeLinear
if values.get('retrievalMethod') == RetrievalMethod.Disjunction:
if values.get('rankingMethod') not in [RankingMethod.RRF]:
if values.get('rankingMethod') not in [RankingMethod.RRF, RankingMethod.Copeland]:
raise ValueError(
"For retrievalMethod: disjunction, rankingMethod must be: rrf") # TODO: re-add normalize_linear
"For retrievalMethod: disjunction, rankingMethod must be: rrf or copeland") # TODO: re-add normalize_linear

# if retrievalMethod is Lexical or Tensor, then ranking_method must be Tensor, Lexical
if values.get('retrievalMethod') in [RetrievalMethod.Lexical, RetrievalMethod.Tensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,10 @@ def _to_vespa_hybrid_query(self, marqo_query: MarqoHybridQuery) -> Dict[str, Any

query = {k: v for k, v in query.items() if v is not None}

if marqo_query.hybrid_parameters.rankingMethod in {RankingMethod.RRF}: # TODO: Add NormalizeLinear
if marqo_query.hybrid_parameters.rankingMethod in [RankingMethod.RRF]: # TODO: Add NormalizeLinear
query["marqo__hybrid.alpha"] = marqo_query.hybrid_parameters.alpha
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If alpha and k aren't relevant for copeland, we need validation to catch this


if marqo_query.hybrid_parameters.rankingMethod in {RankingMethod.RRF}:
if marqo_query.hybrid_parameters.rankingMethod in [RankingMethod.RRF]:
query["marqo__hybrid.rrf_k"] = marqo_query.hybrid_parameters.rrfK

return query
Expand Down
102 changes: 99 additions & 3 deletions tests/tensor_search/integ_tests/test_hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,60 @@ def test_hybrid_search_score_modifiers(self):
self.assertEqual(hybrid_res["hits"][-1]["_id"], "doc10") # lowest score (score*-10*3)
self.assertAlmostEqual(hybrid_res["hits"][-1]["_lexical_score"], base_lexical_score * -10 * 3)
self.assertAlmostEqual(hybrid_res["hits"][-1]["_tensor_score"], base_tensor_score * -10 * 3)

with self.subTest("retrieval: disjunction, ranking: copeland"):
hybrid_res = tensor_search.search(
config=self.config,
index_name=index.name,
text="HELLO WORLD",
search_method="HYBRID",
hybrid_parameters=HybridParameters(
retrievalMethod=RetrievalMethod.Disjunction,
rankingMethod=RankingMethod.Copeland,
scoreModifiersLexical={
"multiply_score_by": [
{"field_name": "mult_field_1", "weight": 10},
{"field_name": "mult_field_2", "weight": -10}
],
"add_to_score": [
{"field_name": "add_field_1", "weight": 5}
]
},
scoreModifiersTensor={
"multiply_score_by": [
{"field_name": "mult_field_1", "weight": 10},
{"field_name": "mult_field_2", "weight": -10}
],
"add_to_score": [
{"field_name": "add_field_1", "weight": 5}
]
},
verbose=True
),
result_count=10
)
self.assertIn("hits", hybrid_res)

# Score without score modifiers
self.assertEqual(hybrid_res["hits"][3]["_id"], "doc6") # (score)
base_lexical_score = hybrid_res["hits"][3]["_lexical_score"]
base_tensor_score = hybrid_res["hits"][3]["_tensor_score"]

self.assertEqual(hybrid_res["hits"][0]["_id"], "doc9") # highest score (score*10*3)
self.assertAlmostEqual(hybrid_res["hits"][0]["_lexical_score"], base_lexical_score * 10 * 3)
self.assertEqual(hybrid_res["hits"][0]["_tensor_score"], base_tensor_score * 10 * 3)

self.assertEqual(hybrid_res["hits"][1]["_id"], "doc8") # (score*10*2)
self.assertAlmostEqual(hybrid_res["hits"][1]["_lexical_score"], base_lexical_score * 10 * 2)
self.assertAlmostEqual(hybrid_res["hits"][1]["_tensor_score"], base_tensor_score * 10 * 2)

self.assertEqual(hybrid_res["hits"][2]["_id"], "doc7") # (score + 5*1)
self.assertAlmostEqual(hybrid_res["hits"][2]["_lexical_score"], base_lexical_score + 5*1)
self.assertAlmostEqual(hybrid_res["hits"][2]["_tensor_score"], base_tensor_score + 5*1)

self.assertEqual(hybrid_res["hits"][-1]["_id"], "doc10") # lowest score (score*-10*3)
self.assertAlmostEqual(hybrid_res["hits"][-1]["_lexical_score"], base_lexical_score * -10 * 3)
self.assertAlmostEqual(hybrid_res["hits"][-1]["_tensor_score"], base_tensor_score * -10 * 3)


def test_hybrid_search_lexical_tensor_with_lexical_score_modifiers_succeeds(self):
Expand Down Expand Up @@ -975,6 +1029,7 @@ def test_hybrid_search_with_filter(self):

test_cases = [
(RetrievalMethod.Disjunction, RankingMethod.RRF),
(RetrievalMethod.Disjunction, RankingMethod.Copeland),
(RetrievalMethod.Lexical, RankingMethod.Lexical),
(RetrievalMethod.Tensor, RankingMethod.Tensor)
]
Expand Down Expand Up @@ -1025,7 +1080,7 @@ def test_hybrid_search_with_images(self):
)
)

with self.subTest("disjunction text search"):
with self.subTest("disjunction rrf text search"):
hybrid_res = tensor_search.search(
config=self.config,
index_name=index.name,
Expand All @@ -1043,7 +1098,7 @@ def test_hybrid_search_with_images(self):
self.assertEqual(hybrid_res["hits"][0]["_id"], "hippo text")
self.assertEqual(hybrid_res["hits"][1]["_id"], "hippo text low relevance")

with self.subTest("disjunction image search"):
with self.subTest("disjunction rrf image search"):
hybrid_res = tensor_search.search(
config=self.config,
index_name=index.name,
Expand All @@ -1061,6 +1116,44 @@ def test_hybrid_search_with_images(self):
self.assertEqual(hybrid_res["hits"][0]["_id"], "hippo image")
self.assertEqual(hybrid_res["hits"][1]["_id"], "random image")
self.assertEqual(hybrid_res["hits"][2]["_id"], "hippo text")

with self.subTest("disjunction copeland text search"):
hybrid_res = tensor_search.search(
config=self.config,
index_name=index.name,
text="hippo",
search_method="HYBRID",
hybrid_parameters=HybridParameters(
retrievalMethod="disjunction",
rankingMethod="copeland",
verbose=True
),
result_count=4
)

self.assertIn("hits", hybrid_res)
self.assertEqual(hybrid_res["hits"][0]["_id"], "hippo text")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we not checking the score to detect score regression?

self.assertEqual(hybrid_res["hits"][1]["_id"], "hippo text low relevance")
self.assertEqual(hybrid_res["hits"][2]["_id"], "random text")

with self.subTest("disjunction copeland image search"):
hybrid_res = tensor_search.search(
config=self.config,
index_name=index.name,
text="https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png",
search_method="HYBRID",
hybrid_parameters=HybridParameters(
retrievalMethod="disjunction",
rankingMethod="copeland",
verbose=True
),
result_count=4
)

self.assertIn("hits", hybrid_res)
self.assertEqual(hybrid_res["hits"][0]["_id"], "hippo image")
self.assertEqual(hybrid_res["hits"][1]["_id"], "random image")
self.assertEqual(hybrid_res["hits"][2]["_id"], "hippo text")

def test_hybrid_search_structured_opposite_retrieval_and_ranking(self):
"""
Expand Down Expand Up @@ -1370,7 +1463,7 @@ def test_hybrid_search_invalid_parameters_fails(self):
({
"retrievalMethod": "disjunction",
"rankingMethod": "lexical"
}, "rankingMethod must be: rrf"),
}, "rankingMethod must be: rrf or copeland"),
({
"retrievalMethod": "tensor",
"rankingMethod": "rrf"
Expand Down Expand Up @@ -1491,6 +1584,7 @@ def test_hybrid_search_structured_invalid_fields_fails(self):
# Non-lexical field
test_cases = [
("disjunction", "rrf"),
("disjunction", "copeland"),
("lexical", "lexical"),
("lexical", "tensor")
]
Expand All @@ -1513,6 +1607,7 @@ def test_hybrid_search_structured_invalid_fields_fails(self):
# Non-tensor field
test_cases = [
("disjunction", "rrf"),
("disjunction", "copeland"),
("tensor", "tensor"),
("tensor", "lexical")
]
Expand Down Expand Up @@ -1618,6 +1713,7 @@ def test_hybrid_search_none_query_wrong_retrieval_or_ranking_fails(self):
"""
custom_vector = [0.655 for _ in range(16)]
test_cases = [
(RetrievalMethod.Disjunction, RankingMethod.Copeland),
(RetrievalMethod.Disjunction, RankingMethod.RRF),
(RetrievalMethod.Tensor, RankingMethod.Lexical),
(RetrievalMethod.Lexical, RankingMethod.Tensor),
Expand Down
1 change: 1 addition & 0 deletions tests/tensor_search/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def test_pagination_hybrid(self):
self.assertFalse(r['errors'], "Errors in add documents call")

test_cases = [
("disjunction", "copeland"),
("disjunction", "rrf"),
("lexical", "tensor"),
("tensor", "lexical"),
Expand Down
Loading
Loading