Skip to content

Commit

Permalink
Merge branch 'main' into test-chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
violenil authored Sep 20, 2024
2 parents da3dcee + 8237bf2 commit 2ddabe5
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 13 deletions.
159 changes: 159 additions & 0 deletions chunked_pooling/chunked_eval_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,162 @@ def load_data(self, **kwargs):
self.relevant_docs = {self._EVAL_SPLIT: qrels}

self.data_loaded = True


class LEMBNeedleRetrievalChunked(AbsTaskChunkedRetrieval):
"""
modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBNeedleRetrieval.py
"""

_EVAL_SPLIT = [
"test_256",
"test_512",
"test_1024",
"test_2048",
"test_4096",
"test_8192",
"test_16384",
"test_32768",
]

metadata = TaskMetadata(
name="LEMBNeedleRetrievalChunked",
dataset={
"path": "dwzhu/LongEmbed",
"revision": "6e346642246bfb4928c560ee08640dc84d074e8c",
"name": "needle",
},
reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
description=("needle subset of dwzhu/LongEmbed dataset."),
type="Retrieval",
category="s2p",
modalities=["text"],
eval_splits=_EVAL_SPLIT,
eval_langs=["eng-Latn"],
main_score="ndcg_at_1",
date=("2000-01-01", "2023-12-31"),
domains=["Academic", "Blog", "Written"],
task_subtypes=["Article retrieval"],
license="not specified",
annotations_creators="derived",
dialect=[],
sample_creation="found",
bibtex_citation="""
@article{zhu2024longembed,
title={LongEmbed: Extending Embedding Models for Long Context Retrieval},
author={Zhu, Dawei and Wang, Liang and Yang, Nan and Song, Yifan and Wu, Wenhao and Wei, Furu and Li, Sujian},
journal={arXiv preprint arXiv:2404.12096},
year={2024}
}
""",
descriptive_stats={
"n_samples": {
"test_256": 150,
"test_512": 150,
"test_1024": 150,
"test_2048": 150,
"test_4096": 150,
"test_8192": 150,
"test_16384": 150,
"test_32768": 150,
},
"avg_character_length": {
"test_256": {
"average_document_length": 1013.22,
"average_query_length": 60.48,
"num_documents": 100,
"num_queries": 50,
"average_relevant_docs_per_query": 1.0,
},
"test_512": {
"average_document_length": 2009.96,
"average_query_length": 57.3,
"num_documents": 100,
"num_queries": 50,
"average_relevant_docs_per_query": 1.0,
},
"test_1024": {
"average_document_length": 4069.9,
"average_query_length": 58.28,
"num_documents": 100,
"num_queries": 50,
"average_relevant_docs_per_query": 1.0,
},
"test_2048": {
"average_document_length": 8453.82,
"average_query_length": 59.92,
"num_documents": 100,
"num_queries": 50,
"average_relevant_docs_per_query": 1.0,
},
"test_4096": {
"average_document_length": 17395.8,
"average_query_length": 55.86,
"num_documents": 100,
"num_queries": 50,
"average_relevant_docs_per_query": 1.0,
},
"test_8192": {
"average_document_length": 35203.82,
"average_query_length": 59.6,
"num_documents": 100,
"num_queries": 50,
"average_relevant_docs_per_query": 1.0,
},
"test_16384": {
"average_document_length": 72054.8,
"average_query_length": 59.12,
"num_documents": 100,
"num_queries": 50,
"average_relevant_docs_per_query": 1.0,
},
"test_32768": {
"average_document_length": 141769.8,
"average_query_length": 58.34,
"num_documents": 100,
"num_queries": 50,
"average_relevant_docs_per_query": 1.0,
},
},
},
)

def load_data(self, **kwargs):
if self.data_loaded:
return

self.corpus = {}
self.queries = {}
self.relevant_docs = {}

for split in self._EVAL_SPLIT:
context_length = int(split.split("_")[1])
query_list = datasets.load_dataset(**self.metadata_dict["dataset"])[
"queries"
] # dict_keys(['qid', 'text'])
query_list = query_list.filter(
lambda x: x["context_length"] == context_length
)
queries = {row["qid"]: row["text"] for row in query_list}

corpus_list = datasets.load_dataset(**self.metadata_dict["dataset"])[
"corpus"
] # dict_keys(['doc_id', 'text'])
corpus_list = corpus_list.filter(
lambda x: x["context_length"] == context_length
)
corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}

qrels_list = datasets.load_dataset(**self.metadata_dict["dataset"])[
"qrels"
] # dict_keys(['qid', 'doc_id'])
qrels_list = qrels_list.filter(
lambda x: x["context_length"] == context_length
)
qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}

self.corpus[split] = corpus
self.queries[split] = queries
self.relevant_docs[split] = qrels

self.data_loaded = True
23 changes: 20 additions & 3 deletions chunked_pooling/mteb_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,27 @@ def evaluate(
)

scores[hf_subset] = self._evaluate_monolingual(
model, corpus, queries, relevant_docs, hf_subset, **kwargs
model,
corpus,
queries,
relevant_docs,
hf_subset,
encode_kwargs=encode_kwargs,
**kwargs,
)

return scores

def _evaluate_monolingual(
self, model, corpus, queries, relevant_docs, lang=None, batch_size=1, **kwargs
self,
model,
corpus,
queries,
relevant_docs,
lang=None,
batch_size=1,
encode_kwargs=None,
**kwargs,
):
# split corpus into chunks
if not self.chunked_pooling_enabled:
Expand All @@ -101,7 +115,10 @@ def _evaluate_monolingual(
# determine the maximum number of documents to consider in a ranking
max_k = int(max(k_values) / max_chunks)
retriever = RetrievalEvaluator(
model, k_values=k_values, batch_size=batch_size, **kwargs
model,
k_values=k_values,
encode_kwargs=(encode_kwargs or dict()),
**kwargs,
)
results = retriever(corpus, queries)
else:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [
"llama-index-embeddings-huggingface==0.3.1",
"llama-index==0.11.10",
"click==8.1.7",
"einops==0.8.0",
"einops==0.6.1",
]
version = "0.0.0"

Expand Down
17 changes: 8 additions & 9 deletions run_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
from mteb import MTEB
from transformers import AutoModel, AutoTokenizer

from chunked_pooling.chunked_eval_tasks import (FiQA2018Chunked,
LEMBWikimQARetrievalChunked,
NFCorpusChunked, QuoraChunked,
SciFactChunked,
TRECCOVIDChunked)
from chunked_pooling.chunked_eval_tasks import *
from chunked_pooling.wrappers import load_model

DEFAULT_CHUNKING_STRATEGY = 'fixed'
Expand All @@ -28,9 +24,12 @@
help='The chunking strategy to be applied.',
)
@click.option(
'--task-name', default='SciFactChunked', help='The evaluationtask to perform.'
'--task-name', default='SciFactChunked', help='The evaluation task to perform.'
)
def main(model_name, strategy, task_name):
@click.option(
'--eval-split', default='test', help='The name of the evaluation split in the task.'
)
def main(model_name, strategy, task_name, eval_split):
try:
task_cls = globals()[task_name]
except:
Expand Down Expand Up @@ -72,7 +71,7 @@ def main(model_name, strategy, task_name):
evaluation.run(
model,
output_folder='results-chunked-pooling',
eval_splits=['test'],
eval_splits=[eval_split],
overwrite_results=True,
batch_size=BATCH_SIZE,
encode_kwargs={'batch_size': BATCH_SIZE},
Expand All @@ -98,7 +97,7 @@ def main(model_name, strategy, task_name):
evaluation.run(
model,
output_folder='results-normal-pooling',
eval_splits=['test'],
eval_splits=[eval_split],
overwrite_results=True,
encode_kwargs={'batch_size': BATCH_SIZE},
)
Expand Down

0 comments on commit 2ddabe5

Please sign in to comment.