From 8237bf2c36d46e10d93220d6ee904d7d53210f83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20G=C3=BCnther?= Date: Fri, 20 Sep 2024 11:12:30 +0200 Subject: [PATCH] feat: add needle retrieval (#9) * feat: add support for v3 model * feat: support passing model args * feat: add needle dataset; fix encoder args; add einops dep --- chunked_pooling/chunked_eval_tasks.py | 159 ++++++++++++++++++++++++++ chunked_pooling/mteb_chunked_eval.py | 23 +++- pyproject.toml | 3 +- run_chunked_eval.py | 22 ++-- 4 files changed, 191 insertions(+), 16 deletions(-) diff --git a/chunked_pooling/chunked_eval_tasks.py b/chunked_pooling/chunked_eval_tasks.py index 34f223f..23dbcbf 100644 --- a/chunked_pooling/chunked_eval_tasks.py +++ b/chunked_pooling/chunked_eval_tasks.py @@ -295,3 +295,162 @@ def load_data(self, **kwargs): self.relevant_docs = {self._EVAL_SPLIT: qrels} self.data_loaded = True + + +class LEMBNeedleRetrievalChunked(AbsTaskChunkedRetrieval): + """ + modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBNeedleRetrieval.py + """ + + _EVAL_SPLIT = [ + "test_256", + "test_512", + "test_1024", + "test_2048", + "test_4096", + "test_8192", + "test_16384", + "test_32768", + ] + + metadata = TaskMetadata( + name="LEMBNeedleRetrievalChunked", + dataset={ + "path": "dwzhu/LongEmbed", + "revision": "6e346642246bfb4928c560ee08640dc84d074e8c", + "name": "needle", + }, + reference="https://huggingface.co/datasets/dwzhu/LongEmbed", + description=("needle subset of dwzhu/LongEmbed dataset."), + type="Retrieval", + category="s2p", + modalities=["text"], + eval_splits=_EVAL_SPLIT, + eval_langs=["eng-Latn"], + main_score="ndcg_at_1", + date=("2000-01-01", "2023-12-31"), + domains=["Academic", "Blog", "Written"], + task_subtypes=["Article retrieval"], + license="not specified", + annotations_creators="derived", + dialect=[], + sample_creation="found", + bibtex_citation=""" + @article{zhu2024longembed, + title={LongEmbed: Extending Embedding Models for Long Context Retrieval}, + author={Zhu, Dawei and Wang, Liang and Yang, Nan and Song, Yifan and Wu, Wenhao and Wei, Furu and Li, Sujian}, + journal={arXiv preprint arXiv:2404.12096}, + year={2024} + } + """, + descriptive_stats={ + "n_samples": { + "test_256": 150, + "test_512": 150, + "test_1024": 150, + "test_2048": 150, + "test_4096": 150, + "test_8192": 150, + "test_16384": 150, + "test_32768": 150, + }, + "avg_character_length": { + "test_256": { + "average_document_length": 1013.22, + "average_query_length": 60.48, + "num_documents": 100, + "num_queries": 50, + "average_relevant_docs_per_query": 1.0, + }, + "test_512": { + "average_document_length": 2009.96, + "average_query_length": 57.3, + "num_documents": 100, + "num_queries": 50, + "average_relevant_docs_per_query": 1.0, + }, + "test_1024": { + "average_document_length": 4069.9, + "average_query_length": 58.28, + "num_documents": 100, + "num_queries": 50, + "average_relevant_docs_per_query": 1.0, + }, + "test_2048": { + "average_document_length": 8453.82, + "average_query_length": 59.92, + "num_documents": 100, + "num_queries": 50, + "average_relevant_docs_per_query": 1.0, + }, + "test_4096": { + "average_document_length": 17395.8, + "average_query_length": 55.86, + "num_documents": 100, + "num_queries": 50, + "average_relevant_docs_per_query": 1.0, + }, + "test_8192": { + "average_document_length": 35203.82, + "average_query_length": 59.6, + "num_documents": 100, + "num_queries": 50, + "average_relevant_docs_per_query": 1.0, + }, + "test_16384": { + "average_document_length": 72054.8, + "average_query_length": 59.12, + "num_documents": 100, + "num_queries": 50, + "average_relevant_docs_per_query": 1.0, + }, + "test_32768": { + "average_document_length": 141769.8, + "average_query_length": 58.34, + "num_documents": 100, + "num_queries": 50, + "average_relevant_docs_per_query": 1.0, + }, + }, + }, + ) + + def load_data(self, **kwargs): + if self.data_loaded: + return + + self.corpus = {} + self.queries = {} + self.relevant_docs = {} + + for split in self._EVAL_SPLIT: + context_length = int(split.split("_")[1]) + query_list = datasets.load_dataset(**self.metadata_dict["dataset"])[ + "queries" + ] # dict_keys(['qid', 'text']) + query_list = query_list.filter( + lambda x: x["context_length"] == context_length + ) + queries = {row["qid"]: row["text"] for row in query_list} + + corpus_list = datasets.load_dataset(**self.metadata_dict["dataset"])[ + "corpus" + ] # dict_keys(['doc_id', 'text']) + corpus_list = corpus_list.filter( + lambda x: x["context_length"] == context_length + ) + corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list} + + qrels_list = datasets.load_dataset(**self.metadata_dict["dataset"])[ + "qrels" + ] # dict_keys(['qid', 'doc_id']) + qrels_list = qrels_list.filter( + lambda x: x["context_length"] == context_length + ) + qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list} + + self.corpus[split] = corpus + self.queries[split] = queries + self.relevant_docs[split] = qrels + + self.data_loaded = True diff --git a/chunked_pooling/mteb_chunked_eval.py b/chunked_pooling/mteb_chunked_eval.py index abf46b0..fd75210 100644 --- a/chunked_pooling/mteb_chunked_eval.py +++ b/chunked_pooling/mteb_chunked_eval.py @@ -84,13 +84,27 @@ def evaluate( ) scores[hf_subset] = self._evaluate_monolingual( - model, corpus, queries, relevant_docs, hf_subset, **kwargs + model, + corpus, + queries, + relevant_docs, + hf_subset, + encode_kwargs=encode_kwargs, + **kwargs, ) return scores def _evaluate_monolingual( - self, model, corpus, queries, relevant_docs, lang=None, batch_size=1, **kwargs + self, + model, + corpus, + queries, + relevant_docs, + lang=None, + batch_size=1, + encode_kwargs=None, + **kwargs, ): # split corpus into chunks if not self.chunked_pooling_enabled: @@ -101,7 +115,10 @@ def _evaluate_monolingual( # determine the maximum number of documents to consider in a ranking max_k = int(max(k_values) / max_chunks) retriever = RetrievalEvaluator( - model, k_values=k_values, batch_size=batch_size, **kwargs + model, + k_values=k_values, + encode_kwargs=(encode_kwargs or dict()), + **kwargs, ) results = retriever(corpus, queries) else: diff --git a/pyproject.toml b/pyproject.toml index 66e4b08..e717de2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,8 @@ dependencies = [ "datasets==2.19.1", "llama-index-embeddings-huggingface==0.3.1", "llama-index==0.11.10", - "click==8.1.7" + "click==8.1.7", + "einops==0.6.1" ] version = "0.0.0" diff --git a/run_chunked_eval.py b/run_chunked_eval.py index 6f70b79..8656a13 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -4,14 +4,9 @@ from transformers import AutoModel, AutoTokenizer from mteb import MTEB -from chunked_pooling.chunked_eval_tasks import ( - SciFactChunked, - TRECCOVIDChunked, - FiQA2018Chunked, - NFCorpusChunked, - QuoraChunked, - LEMBWikimQARetrievalChunked, -) +from chunked_pooling.chunked_eval_tasks import * + +from chunked_pooling.wrappers import load_model from chunked_pooling.wrappers import load_model @@ -34,9 +29,12 @@ help='The chunking strategy to be applied.', ) @click.option( - '--task-name', default='SciFactChunked', help='The evaluationtask to perform.' + '--task-name', default='SciFactChunked', help='The evaluation task to perform.' +) +@click.option( + '--eval-split', default='test', help='The name of the evaluation split in the task.' ) -def main(model_name, strategy, task_name): +def main(model_name, strategy, task_name, eval_split): try: task_cls = globals()[task_name] except: @@ -78,7 +76,7 @@ def main(model_name, strategy, task_name): evaluation.run( model, output_folder='results-chunked-pooling', - eval_splits=['test'], + eval_splits=[eval_split], overwrite_results=True, batch_size=BATCH_SIZE, encode_kwargs={'batch_size': BATCH_SIZE}, @@ -104,7 +102,7 @@ def main(model_name, strategy, task_name): evaluation.run( model, output_folder='results-normal-pooling', - eval_splits=['test'], + eval_splits=[eval_split], overwrite_results=True, batch_size=BATCH_SIZE, encode_kwargs={'batch_size': BATCH_SIZE},