From 9e23fd66bc13ba7b088816a49a2ea032304d0f28 Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 23 Sep 2024 11:48:27 +0200 Subject: [PATCH 1/7] feat: add semantic chunking to eval script; add wrapper for minilm --- chunked_pooling/mteb_chunked_eval.py | 2 ++ chunked_pooling/wrappers.py | 11 +++++++++-- run_chunked_eval.py | 1 + 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/chunked_pooling/mteb_chunked_eval.py b/chunked_pooling/mteb_chunked_eval.py index 8da9a54..2433e7f 100644 --- a/chunked_pooling/mteb_chunked_eval.py +++ b/chunked_pooling/mteb_chunked_eval.py @@ -25,6 +25,7 @@ def __init__( chunk_size: Optional[int] = None, n_sentences: Optional[int] = None, model_has_instructions: bool = False, + embedding_model_name: Optional[str] = None, # for semantic chunking **kwargs, ): super().__init__(**kwargs) @@ -45,6 +46,7 @@ def __init__( self.chunking_args = { 'chunk_size': chunk_size, 'n_sentences': n_sentences, + 'embedding_model_name': embedding_model_name, } def load_data(self, **kwargs): diff --git a/chunked_pooling/wrappers.py b/chunked_pooling/wrappers.py index 44984c5..a4bb0ef 100644 --- a/chunked_pooling/wrappers.py +++ b/chunked_pooling/wrappers.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +from sentence_transformers import SentenceTransformer from transformers import AutoModel @@ -61,7 +62,10 @@ def has_instructions(): return True -MODEL_WRAPPERS = {'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper} +MODEL_WRAPPERS = { + 'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper, + 'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer, +} MODELS_WITHOUT_PROMPT_NAME_ARG = [ 'jinaai/jina-embeddings-v2-small-en', 'jinaai/jina-embeddings-v2-base-en', @@ -82,7 +86,10 @@ def wrapper(self, *args, **kwargs): def load_model(model_name, **model_kwargs): if model_name in MODEL_WRAPPERS: model = MODEL_WRAPPERS[model_name](model_name, **model_kwargs) - has_instructions = MODEL_WRAPPERS[model_name].has_instructions() + if hasattr(MODEL_WRAPPERS[model_name], 'has_instructions'): + has_instructions = MODEL_WRAPPERS[model_name].has_instructions() + else: + has_instructions = False else: model = AutoModel.from_pretrained(model_name, trust_remote_code=True) has_instructions = False diff --git a/run_chunked_eval.py b/run_chunked_eval.py index 4f0057c..a3c855d 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -44,6 +44,7 @@ def main(model_name, strategy, task_name, eval_split): 'n_sentences': DEFAULT_N_SENTENCES, 'chunking_strategy': strategy, 'model_has_instructions': has_instructions, + 'embedding_model_name': model_name, } if torch.cuda.is_available(): From 23511102c234b316813e9d876b7d699a7fe26bbc Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 23 Sep 2024 15:02:09 +0200 Subject: [PATCH 2/7] fix: gaps in semantic chunking --- chunked_pooling/chunking.py | 8 ++++---- tests/test_chunking_methods.py | 28 +++++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/chunked_pooling/chunking.py b/chunked_pooling/chunking.py index beabb8c..facf1b0 100644 --- a/chunked_pooling/chunking.py +++ b/chunked_pooling/chunking.py @@ -31,6 +31,7 @@ def _setup_semantic_chunking(self, embedding_model_name): self.embed_model = HuggingFaceEmbedding( model_name=self.embedding_model_name, trust_remote_code=True, + embed_batch_size=1, ) self.splitter = SemanticSplitterNodeParser( embed_model=self.embed_model, @@ -71,13 +72,12 @@ def chunk_semantically( start_chunk_index = bisect.bisect_left( [offset[0] for offset in token_offsets], char_start ) - end_chunk_index = ( - bisect.bisect_right([offset[1] for offset in token_offsets], char_end) - - 1 + end_chunk_index = bisect.bisect_right( + [offset[1] for offset in token_offsets], char_end ) # Add the chunk span if it's within the tokenized text - if start_chunk_index < len(token_offsets) and end_chunk_index < len( + if start_chunk_index < len(token_offsets) and end_chunk_index <= len( token_offsets ): chunk_spans.append((start_chunk_index, end_chunk_index)) diff --git a/tests/test_chunking_methods.py b/tests/test_chunking_methods.py index d99cc17..ff21fc5 100644 --- a/tests/test_chunking_methods.py +++ b/tests/test_chunking_methods.py @@ -100,14 +100,36 @@ def test_chunk_by_tokens(): def test_chunk_semantically(): chunker = Chunker(chunking_strategy="semantic") - tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - chunks = chunker.chunk( + tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-small-en') + tokens = tokenizer.encode_plus( + EXAMPLE_TEXT_1, add_special_tokens=False, return_offsets_mapping=True + ) + boundary_cues = chunker.chunk( EXAMPLE_TEXT_1, tokenizer=tokenizer, chunking_strategy='semantic', embedding_model_name='jinaai/jina-embeddings-v2-small-en', ) - assert len(chunks) > 0 + + # check if it returns boundary cues + assert len(boundary_cues) > 0 + + # test if bounaries are at the end of sentences + for start_token_idx, end_token_idx in boundary_cues: + assert ( + EXAMPLE_TEXT_1[tokens.offset_mapping[end_token_idx - 1][0]] in PUNCTATIONS + ) + decoded_text_chunk = tokenizer.decode( + tokens.input_ids[start_token_idx:end_token_idx] + ) + + # check that the boundary cues are continuous (no token is missing) + assert all( + [ + boundary_cues[i][1] == boundary_cues[i + 1][0] + for i in range(len(boundary_cues) - 1) + ] + ) def test_empty_input(): From f707e584530c45c86ca619fffde248be79163bc8 Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 23 Sep 2024 15:49:59 +0200 Subject: [PATCH 3/7] feat: add option to pass custom model for chunking --- run_chunked_eval.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/run_chunked_eval.py b/run_chunked_eval.py index a3c855d..ff49da0 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -29,7 +29,13 @@ @click.option( '--eval-split', default='test', help='The name of the evaluation split in the task.' ) -def main(model_name, strategy, task_name, eval_split): +@click.option( + '--chunking-model', + default=None, + required=False, + help='The name of the model used for semantic chunking.', +) +def main(model_name, strategy, task_name, eval_split, chunking_model): try: task_cls = globals()[task_name] except: @@ -44,7 +50,7 @@ def main(model_name, strategy, task_name, eval_split): 'n_sentences': DEFAULT_N_SENTENCES, 'chunking_strategy': strategy, 'model_has_instructions': has_instructions, - 'embedding_model_name': model_name, + 'embedding_model_name': chunking_model if chunking_model else model_name, } if torch.cuda.is_available(): From 4ca4204a2c93f4bec132099be9bf75517d85f952 Mon Sep 17 00:00:00 2001 From: admin Date: Tue, 24 Sep 2024 09:28:19 +0000 Subject: [PATCH 4/7] feat: support nomic ai model --- chunked_pooling/wrappers.py | 72 ++++++++++++++++++++++++++++++++----- 1 file changed, 64 insertions(+), 8 deletions(-) diff --git a/chunked_pooling/wrappers.py b/chunked_pooling/wrappers.py index a4bb0ef..074461b 100644 --- a/chunked_pooling/wrappers.py +++ b/chunked_pooling/wrappers.py @@ -4,6 +4,16 @@ import torch.nn as nn from sentence_transformers import SentenceTransformer from transformers import AutoModel +from transformers.modeling_outputs import BaseModelOutputWithPooling + + +def construct_document(doc): + if isinstance(doc, str): + return doc + elif 'title' in doc: + return f'{doc["title"]} {doc["text"].strip()}' + else: + return doc['text'].strip() class JinaEmbeddingsV3Wrapper(nn.Module): @@ -31,7 +41,7 @@ def encode_corpus( *args, **kwargs, ): - _sentences = [self._construct_document(sentence) for sentence in sentences] + _sentences = [construct_document(sentence) for sentence in sentences] return self._model.encode(_sentences, *args, task=self.tasks[1], **kwargs) def get_instructions(self): @@ -45,13 +55,57 @@ def forward(self, *args, **kwargs): ) return self._model.forward(*args, adapter_mask=adapter_mask, **kwargs) - def _construct_document(self, doc): - if isinstance(doc, str): - return doc - elif 'title' in doc: - return f'{doc["title"]} {doc["text"].strip()}' - else: - return doc['text'].strip() + @property + def device(self): + return self._model.device + + @staticmethod + def has_instructions(): + return True + + +class NomicAIWrapper(nn.Module): + def __init__(self, model_name, **model_kwargs): + super().__init__() + self._model = SentenceTransformer( + model_name, trust_remote_code=True, **model_kwargs + ) + self.instructions = ['search_query: ', 'search_document: '] + + def get_instructions(self): + return self.instructions + + def forward(self, *args, **kwargs): + # TODO combine kwargs into input + model_output = self._model.forward(kwargs) + base_model_output = BaseModelOutputWithPooling( + last_hidden_state=model_output['token_embeddings'], + pooler_output=model_output['sentence_embedding'], + attentions=model_output['attention_mask'], + ) + return base_model_output + + def encode_queries( + self, + sentences: Union[str, List[str]], + *args, + **kwargs, + ): + return self._model.encode( + [self.instructions[0] + s for s in sentences], *args, **kwargs + ) + + def encode_corpus( + self, + sentences: Union[str, List[str]], + *args, + **kwargs, + ): + return self._model.encode( + [self.instructions[1] + construct_document(s) for s in sentences], + *args, + **kwargs, + ) @property def device(self): @@ -65,7 +119,9 @@ def has_instructions(): MODEL_WRAPPERS = { 'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper, 'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer, + 'nomic-ai/nomic-embed-text-v1': NomicAIWrapper, } + MODELS_WITHOUT_PROMPT_NAME_ARG = [ 'jinaai/jina-embeddings-v2-small-en', 'jinaai/jina-embeddings-v2-base-en', From b66a13c1230264ea689010185374475fb627a7a7 Mon Sep 17 00:00:00 2001 From: admin Date: Wed, 25 Sep 2024 08:58:44 +0000 Subject: [PATCH 5/7] feat: add additional cmd args --- chunked_pooling/mteb_chunked_eval.py | 19 +++++++++++++++++++ run_chunked_eval.py | 26 ++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/chunked_pooling/mteb_chunked_eval.py b/chunked_pooling/mteb_chunked_eval.py index 2433e7f..b119deb 100644 --- a/chunked_pooling/mteb_chunked_eval.py +++ b/chunked_pooling/mteb_chunked_eval.py @@ -26,6 +26,7 @@ def __init__( n_sentences: Optional[int] = None, model_has_instructions: bool = False, embedding_model_name: Optional[str] = None, # for semantic chunking + truncate_max_length: Optional[int] = 8192, **kwargs, ): super().__init__(**kwargs) @@ -48,6 +49,7 @@ def __init__( 'n_sentences': n_sentences, 'embedding_model_name': embedding_model_name, } + self.truncate_max_length = truncate_max_length def load_data(self, **kwargs): self.retrieval_task.load_data(**kwargs) @@ -97,6 +99,21 @@ def evaluate( return scores + def _truncate_documents(self, corpus): + for k, v in corpus.items(): + if 'title' in v: + raise NotImplementedError( + 'Currently truncation is only implemented for documents without titles' + ) + tokens = self.tokenizer( + v['text'], + return_offsets_mapping=True, + max_length=self.truncate_max_length, + ) + last_token_span = tokens.offset_mapping[-2] + v['text'] = v['text'][: last_token_span[1]] + return corpus + def _evaluate_monolingual( self, model, @@ -108,6 +125,8 @@ def _evaluate_monolingual( encode_kwargs=None, **kwargs, ): + if self.truncate_max_length: + corpus = self._truncate_documents(corpus) # split corpus into chunks if not self.chunked_pooling_enabled: corpus = self._apply_chunking(corpus, self.tokenizer) diff --git a/run_chunked_eval.py b/run_chunked_eval.py index ff49da0..3dbdbc0 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -35,7 +35,27 @@ required=False, help='The name of the model used for semantic chunking.', ) -def main(model_name, strategy, task_name, eval_split, chunking_model): +@click.option( + '--truncate-max-length', + default=None, + type=int, + help='Maximum number of tokens; By default, no truncation is done.', +) +@click.option( + '--chunk-size', + default=DEFAULT_CHUNK_SIZE, + type=int, + help='Number of tokens per chunk for fixed strategy.', +) +def main( + model_name, + strategy, + task_name, + eval_split, + chunking_model, + truncate_max_length, + chunk_size, +): try: task_cls = globals()[task_name] except: @@ -46,7 +66,7 @@ def main(model_name, strategy, task_name, eval_split, chunking_model): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) chunking_args = { - 'chunk_size': DEFAULT_CHUNK_SIZE, + 'chunk_size': chunk_size, 'n_sentences': DEFAULT_N_SENTENCES, 'chunking_strategy': strategy, 'model_has_instructions': has_instructions, @@ -64,6 +84,7 @@ def main(model_name, strategy, task_name, eval_split, chunking_model): chunked_pooling_enabled=True, tokenizer=tokenizer, prune_size=None, + truncate_max_length=truncate_max_length, **chunking_args, ) ] @@ -90,6 +111,7 @@ def main(model_name, strategy, task_name, eval_split, chunking_model): chunked_pooling_enabled=False, tokenizer=tokenizer, prune_size=None, + truncate_max_length=truncate_max_length, **chunking_args, ) ] From 13f546c62fd269b44caf025878bb03af186522a4 Mon Sep 17 00:00:00 2001 From: admin Date: Wed, 25 Sep 2024 09:30:06 +0000 Subject: [PATCH 6/7] feat: add arg for n_sentences --- run_chunked_eval.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/run_chunked_eval.py b/run_chunked_eval.py index 3dbdbc0..88494bd 100644 --- a/run_chunked_eval.py +++ b/run_chunked_eval.py @@ -47,6 +47,12 @@ type=int, help='Number of tokens per chunk for fixed strategy.', ) +@click.option( + '--n-sentences', + default=DEFAULT_N_SENTENCES, + type=int, + help='Number of sentences per chunk for sentence strategy.', +) def main( model_name, strategy, @@ -55,6 +61,7 @@ def main( chunking_model, truncate_max_length, chunk_size, + n_sentences, ): try: task_cls = globals()[task_name] @@ -67,7 +74,7 @@ def main( chunking_args = { 'chunk_size': chunk_size, - 'n_sentences': DEFAULT_N_SENTENCES, + 'n_sentences': n_sentences, 'chunking_strategy': strategy, 'model_has_instructions': has_instructions, 'embedding_model_name': chunking_model if chunking_model else model_name, From 9568a2249ea432937ee38a6eee020ff81516c342 Mon Sep 17 00:00:00 2001 From: admin Date: Wed, 25 Sep 2024 12:46:24 +0000 Subject: [PATCH 7/7] feat: add trivia-qa-eval --- chunked_pooling/chunked_eval_tasks.py | 66 +++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/chunked_pooling/chunked_eval_tasks.py b/chunked_pooling/chunked_eval_tasks.py index 23dbcbf..7417fa7 100644 --- a/chunked_pooling/chunked_eval_tasks.py +++ b/chunked_pooling/chunked_eval_tasks.py @@ -454,3 +454,69 @@ def load_data(self, **kwargs): self.relevant_docs[split] = qrels self.data_loaded = True + + +class TriviaQAChunked(AbsTaskChunkedRetrieval): + + _EVAL_SPLIT = ["test"] + + metadata = TaskMetadata( + name='TriviaQAChunked', + description=('Retrieval dataset derived from TriviaQA for chunked evaluation.'), + reference='https://nlp.cs.washington.edu/triviaqa', + dataset={ + 'path': 'mandarjoshi/trivia_qa', + 'revision': '0f7faf33a3908546c6fd5b73a660e0f8ff173c2f', + 'name': 'rc', + }, + type='Retrieval', + category='s2p', + eval_splits=['test'], + eval_langs=['eng-Latn'], + main_score='ndcg_at_10', + date=None, + form=None, + domains=None, + task_subtypes=None, + license=None, + socioeconomic_status=None, + annotations_creators=None, + dialect=None, + text_creation=None, + bibtex_citation=None, + n_samples=None, + avg_character_length=None, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_data(self, **kwargs): + if self.data_loaded: + return + self.corpus = {} + self.queries = {} + self.relevant_docs = {} + + data = datasets.load_dataset(**self.metadata_dict["dataset"]) + for split in self._EVAL_SPLIT: + corpus = {} + queries = {} + relevant_docs = {} + for i, row in enumerate(data[split]): + if len(row['entity_pages']['wiki_context']) < 1: + continue + queries[f'q{i}'] = row['question'] + relevant_docs[f'q{i}'] = dict() + for j, (title, content) in enumerate( + zip( + row['entity_pages']['title'], + row['entity_pages']['wiki_context'], + ) + ): + corpus[f'c{i}-{j}'] = {'title': title, 'text': content} + relevant_docs[f'q{i}'][f'c{i}-{j}'] = 1 + self.corpus[split] = corpus + self.queries[split] = queries + self.relevant_docs[split] = relevant_docs + self.data_loaded = True