Skip to content

Commit

Permalink
updated mindsmall
Browse files Browse the repository at this point in the history
  • Loading branch information
orionw committed Nov 9, 2024
1 parent 47e80ba commit ad0a3db
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
14 changes: 7 additions & 7 deletions mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ class AbsTaskRetrieval(AbsTask):
def __init__(self, **kwargs):
self.top_ranked = None
self.instructions = None
if isinstance(self, AbsTaskRetrieval):
super(AbsTaskRetrieval, self).__init__(**kwargs) # noqa
else:
super().__init__(**kwargs)
# there could be multiple options, so do this even if multilingual
super(AbsTaskRetrieval, self).__init__(**kwargs) # noqa

def load_data(self, **kwargs):
if self.data_loaded:
Expand Down Expand Up @@ -195,7 +193,8 @@ def _evaluate_subset(

save_predictions = kwargs.get("save_predictions", False)
export_errors = kwargs.get("export_errors", False)
if save_predictions or export_errors:
save_qrels = kwargs.get("save_qrels", False)
if save_predictions or export_errors or save_qrels:
output_folder = Path(kwargs.get("output_folder", "results"))
if not os.path.isdir(output_folder):
os.makedirs(output_folder)
Expand All @@ -219,7 +218,7 @@ def _evaluate_subset(
with open(qrels_save_path, "w") as f:
json.dump(results, f)

# save qrels also
if save_qrels:
with open(
output_folder / f"{self.metadata.name}_{hf_subset}_qrels.json", "w"
) as f:
Expand All @@ -230,8 +229,9 @@ def _evaluate_subset(
results,
retriever.k_values,
ignore_identical_ids=self.ignore_identical_ids,
task_name=self.metadata.name,
task_name=self.metadata.name
)

mrr, naucs_mrr = retriever.evaluate_custom(
relevant_docs, results, retriever.k_values, "mrr"
)
Expand Down
10 changes: 7 additions & 3 deletions mteb/evaluation/evaluators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,8 @@ def add_task_specific_scores(
task_scores.update(p_mrr_and_consolidated_scores)

if task_name in ["MindSmallReranking"]:
take_max_over_subqueries = max_over_subqueries(qrels, results, scores)
task_scores["max_over_subqueries"] = take_max_over_subqueries
take_max_over_subqueries = max_over_subqueries(qrels, results, k_values)
task_scores.update(take_max_over_subqueries)

return task_scores

Expand Down Expand Up @@ -699,6 +699,7 @@ def max_over_subqueries(qrels, results, k_values):
query_keys["_".join(key.split("_")[:-1])].append(key)

new_results = {}
new_qrels = {}
for query_id_base, query_ids in query_keys.items():
doc_scores = defaultdict(float)
for query_id_full in query_ids:
Expand All @@ -709,10 +710,12 @@ def max_over_subqueries(qrels, results, k_values):
doc_scores[doc_id] = max(score, doc_scores[doc_id])

new_results[query_id_base] = doc_scores
new_qrels[query_id_base] = qrels[query_id_full] # all the same


# now we have the new results, we can compute the scores
_, ndcg, _map, recall, precision, naucs = calculate_retrieval_scores(
new_results, qrels, k_values
new_results, new_qrels, k_values
)
score_dict = make_score_dict(ndcg, _map, recall, precision, {}, naucs, {}, {})
return {"max_over_subqueries_" + k: v for k, v in score_dict.items()}
Expand All @@ -723,6 +726,7 @@ def calculate_retrieval_scores(results, qrels, k_values):
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
recall_string = "recall." + ",".join([str(k) for k in k_values])
precision_string = "P." + ",".join([str(k) for k in k_values])

evaluator = pytrec_eval.RelevanceEvaluator(
qrels, {map_string, ndcg_string, recall_string, precision_string}
)
Expand Down
6 changes: 1 addition & 5 deletions mteb/tasks/Reranking/eng/MindSmallReranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def load_data(self, **kwargs):
all_queries = []
all_positives = []
all_negatives = []
all_ids = []
all_instance_indices = [] # Renamed for clarity
all_instance_indices = []
all_subquery_indices = []

# First pass: expand queries while maintaining relationships
Expand All @@ -157,9 +156,6 @@ def load_data(self, **kwargs):
all_queries.append(query)
all_positives.append(positives) # Same positives for each subquery
all_negatives.append(negatives) # Same negatives for each subquery
all_ids.append(
f"{instance.get('id', current_instance_idx)}_{subquery_idx}"
)
all_instance_indices.append(current_instance_idx)
all_subquery_indices.append(subquery_idx)

Expand Down

0 comments on commit ad0a3db

Please sign in to comment.