Skip to content

Commit

Permalink
feat: add semantic chunking to eval script; add wrapper for minilm (#11)
Browse files Browse the repository at this point in the history
* feat: add semantic chunking to eval script; add wrapper for minilm

* fix: gaps in semantic chunking

* feat: add option to pass custom model for chunking

* refactor: add second model to semantic chunking test
  • Loading branch information
guenthermi authored Sep 23, 2024
1 parent d5a0fa6 commit 70f81cb
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 12 deletions.
8 changes: 4 additions & 4 deletions chunked_pooling/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _setup_semantic_chunking(self, embedding_model_name):
self.embed_model = HuggingFaceEmbedding(
model_name=self.embedding_model_name,
trust_remote_code=True,
embed_batch_size=1,
)
self.splitter = SemanticSplitterNodeParser(
embed_model=self.embed_model,
Expand Down Expand Up @@ -71,13 +72,12 @@ def chunk_semantically(
start_chunk_index = bisect.bisect_left(
[offset[0] for offset in token_offsets], char_start
)
end_chunk_index = (
bisect.bisect_right([offset[1] for offset in token_offsets], char_end)
- 1
end_chunk_index = bisect.bisect_right(
[offset[1] for offset in token_offsets], char_end
)

# Add the chunk span if it's within the tokenized text
if start_chunk_index < len(token_offsets) and end_chunk_index < len(
if start_chunk_index < len(token_offsets) and end_chunk_index <= len(
token_offsets
):
chunk_spans.append((start_chunk_index, end_chunk_index))
Expand Down
2 changes: 2 additions & 0 deletions chunked_pooling/mteb_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
chunk_size: Optional[int] = None,
n_sentences: Optional[int] = None,
model_has_instructions: bool = False,
embedding_model_name: Optional[str] = None, # for semantic chunking
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -45,6 +46,7 @@ def __init__(
self.chunking_args = {
'chunk_size': chunk_size,
'n_sentences': n_sentences,
'embedding_model_name': embedding_model_name,
}

def load_data(self, **kwargs):
Expand Down
11 changes: 9 additions & 2 deletions chunked_pooling/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoModel


Expand Down Expand Up @@ -61,7 +62,10 @@ def has_instructions():
return True


MODEL_WRAPPERS = {'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper}
MODEL_WRAPPERS = {
'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper,
'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer,
}
MODELS_WITHOUT_PROMPT_NAME_ARG = [
'jinaai/jina-embeddings-v2-small-en',
'jinaai/jina-embeddings-v2-base-en',
Expand All @@ -82,7 +86,10 @@ def wrapper(self, *args, **kwargs):
def load_model(model_name, **model_kwargs):
if model_name in MODEL_WRAPPERS:
model = MODEL_WRAPPERS[model_name](model_name, **model_kwargs)
has_instructions = MODEL_WRAPPERS[model_name].has_instructions()
if hasattr(MODEL_WRAPPERS[model_name], 'has_instructions'):
has_instructions = MODEL_WRAPPERS[model_name].has_instructions()
else:
has_instructions = False
else:
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
has_instructions = False
Expand Down
9 changes: 8 additions & 1 deletion run_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
@click.option(
'--eval-split', default='test', help='The name of the evaluation split in the task.'
)
def main(model_name, strategy, task_name, eval_split):
@click.option(
'--chunking-model',
default=None,
required=False,
help='The name of the model used for semantic chunking.',
)
def main(model_name, strategy, task_name, eval_split, chunking_model):
try:
task_cls = globals()[task_name]
except:
Expand All @@ -44,6 +50,7 @@ def main(model_name, strategy, task_name, eval_split):
'n_sentences': DEFAULT_N_SENTENCES,
'chunking_strategy': strategy,
'model_has_instructions': has_instructions,
'embedding_model_name': chunking_model if chunking_model else model_name,
}

if torch.cuda.is_available():
Expand Down
36 changes: 31 additions & 5 deletions tests/test_chunking_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,42 @@ def test_chunk_by_tokens():
assert end - start <= 10


def test_chunk_semantically():
@pytest.mark.parametrize(
'model_name',
['jinaai/jina-embeddings-v2-small-en', 'sentence-transformers/all-MiniLM-L6-v2'],
)
def test_chunk_semantically(model_name):
chunker = Chunker(chunking_strategy="semantic")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
chunks = chunker.chunk(
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokens = tokenizer.encode_plus(
EXAMPLE_TEXT_1, add_special_tokens=False, return_offsets_mapping=True
)
boundary_cues = chunker.chunk(
EXAMPLE_TEXT_1,
tokenizer=tokenizer,
chunking_strategy='semantic',
embedding_model_name='jinaai/jina-embeddings-v2-small-en',
embedding_model_name=model_name,
)

# check if it returns boundary cues
assert len(boundary_cues) > 0

# test if bounaries are at the end of sentences
for start_token_idx, end_token_idx in boundary_cues:
assert (
EXAMPLE_TEXT_1[tokens.offset_mapping[end_token_idx - 1][0]] in PUNCTATIONS
)
decoded_text_chunk = tokenizer.decode(
tokens.input_ids[start_token_idx:end_token_idx]
)

# check that the boundary cues are continuous (no token is missing)
assert all(
[
boundary_cues[i][1] == boundary_cues[i + 1][0]
for i in range(len(boundary_cues) - 1)
]
)
assert len(chunks) > 0


def test_empty_input():
Expand Down

0 comments on commit 70f81cb

Please sign in to comment.