Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add semantic chunking to eval script; add wrapper for minilm #11

Merged
merged 4 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions chunked_pooling/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _setup_semantic_chunking(self, embedding_model_name):
self.embed_model = HuggingFaceEmbedding(
model_name=self.embedding_model_name,
trust_remote_code=True,
embed_batch_size=1,
)
self.splitter = SemanticSplitterNodeParser(
embed_model=self.embed_model,
Expand Down Expand Up @@ -71,13 +72,12 @@ def chunk_semantically(
start_chunk_index = bisect.bisect_left(
[offset[0] for offset in token_offsets], char_start
)
end_chunk_index = (
bisect.bisect_right([offset[1] for offset in token_offsets], char_end)
- 1
end_chunk_index = bisect.bisect_right(
[offset[1] for offset in token_offsets], char_end
)

# Add the chunk span if it's within the tokenized text
if start_chunk_index < len(token_offsets) and end_chunk_index < len(
if start_chunk_index < len(token_offsets) and end_chunk_index <= len(
token_offsets
):
chunk_spans.append((start_chunk_index, end_chunk_index))
Expand Down
2 changes: 2 additions & 0 deletions chunked_pooling/mteb_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
chunk_size: Optional[int] = None,
n_sentences: Optional[int] = None,
model_has_instructions: bool = False,
embedding_model_name: Optional[str] = None, # for semantic chunking
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -45,6 +46,7 @@ def __init__(
self.chunking_args = {
'chunk_size': chunk_size,
'n_sentences': n_sentences,
'embedding_model_name': embedding_model_name,
}

def load_data(self, **kwargs):
Expand Down
11 changes: 9 additions & 2 deletions chunked_pooling/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoModel


Expand Down Expand Up @@ -61,7 +62,10 @@ def has_instructions():
return True


MODEL_WRAPPERS = {'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper}
MODEL_WRAPPERS = {
'jinaai/jina-embeddings-v3': JinaEmbeddingsV3Wrapper,
'sentence-transformers/all-MiniLM-L6-v2': SentenceTransformer,
}
MODELS_WITHOUT_PROMPT_NAME_ARG = [
'jinaai/jina-embeddings-v2-small-en',
'jinaai/jina-embeddings-v2-base-en',
Expand All @@ -82,7 +86,10 @@ def wrapper(self, *args, **kwargs):
def load_model(model_name, **model_kwargs):
if model_name in MODEL_WRAPPERS:
model = MODEL_WRAPPERS[model_name](model_name, **model_kwargs)
has_instructions = MODEL_WRAPPERS[model_name].has_instructions()
if hasattr(MODEL_WRAPPERS[model_name], 'has_instructions'):
has_instructions = MODEL_WRAPPERS[model_name].has_instructions()
else:
has_instructions = False
else:
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
has_instructions = False
Expand Down
9 changes: 8 additions & 1 deletion run_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
@click.option(
'--eval-split', default='test', help='The name of the evaluation split in the task.'
)
def main(model_name, strategy, task_name, eval_split):
@click.option(
'--chunking-model',
default=None,
required=False,
help='The name of the model used for semantic chunking.',
)
def main(model_name, strategy, task_name, eval_split, chunking_model):
try:
task_cls = globals()[task_name]
except:
Expand All @@ -44,6 +50,7 @@ def main(model_name, strategy, task_name, eval_split):
'n_sentences': DEFAULT_N_SENTENCES,
'chunking_strategy': strategy,
'model_has_instructions': has_instructions,
'embedding_model_name': chunking_model if chunking_model else model_name,
}

if torch.cuda.is_available():
Expand Down
36 changes: 31 additions & 5 deletions tests/test_chunking_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,42 @@ def test_chunk_by_tokens():
assert end - start <= 10


def test_chunk_semantically():
@pytest.mark.parametrize(
'model_name',
['jinaai/jina-embeddings-v2-small-en', 'sentence-transformers/all-MiniLM-L6-v2'],
)
def test_chunk_semantically(model_name):
chunker = Chunker(chunking_strategy="semantic")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
chunks = chunker.chunk(
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokens = tokenizer.encode_plus(
EXAMPLE_TEXT_1, add_special_tokens=False, return_offsets_mapping=True
)
boundary_cues = chunker.chunk(
EXAMPLE_TEXT_1,
tokenizer=tokenizer,
chunking_strategy='semantic',
embedding_model_name='jinaai/jina-embeddings-v2-small-en',
embedding_model_name=model_name,
)

# check if it returns boundary cues
assert len(boundary_cues) > 0

# test if bounaries are at the end of sentences
for start_token_idx, end_token_idx in boundary_cues:
assert (
EXAMPLE_TEXT_1[tokens.offset_mapping[end_token_idx - 1][0]] in PUNCTATIONS
)
decoded_text_chunk = tokenizer.decode(
tokens.input_ids[start_token_idx:end_token_idx]
)

# check that the boundary cues are continuous (no token is missing)
assert all(
[
boundary_cues[i][1] == boundary_cues[i + 1][0]
for i in range(len(boundary_cues) - 1)
]
)
assert len(chunks) > 0


def test_empty_input():
Expand Down
Loading