Skip to content

Commit

Permalink
fix truction; style and docs changes (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
guenthermi authored Oct 7, 2024
1 parent 6a4fc5d commit a79df8b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
33 changes: 23 additions & 10 deletions chunked_pooling/mteb_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,35 @@ def evaluate(

def _truncate_documents(self, corpus):
for k, v in corpus.items():
title_tokens = 0
if 'title' in v:
raise NotImplementedError(
'Currently truncation is only implemented for documents without titles'
tokens = self.tokenizer(
v['title'] + ' ',
return_offsets_mapping=True,
max_length=self.truncate_max_length,
)
title_tokens = len(tokens.input_ids)
tokens = self.tokenizer(
v['text'],
return_offsets_mapping=True,
max_length=self.truncate_max_length,
max_length=self.truncate_max_length - title_tokens,
)
last_token_span = tokens.offset_mapping[-2]
v['text'] = v['text'][: last_token_span[1]]
return corpus

def _embed_with_overlap(self, model, model_inputs):

len_tokens = len(model_inputs["input_ids"][0])

if len_tokens > self.long_late_chunking_embed_size:
indices = []
for i in range(0, len_tokens, self.long_late_chunking_embed_size - self.long_late_chunking_overlap_size):
for i in range(
0,
len_tokens,
self.long_late_chunking_embed_size
- self.long_late_chunking_overlap_size,
):
start = i
end = min(i + self.long_late_chunking_embed_size, len_tokens)
indices.append((start, end))
Expand All @@ -138,10 +147,12 @@ def _embed_with_overlap(self, model, model_inputs):
batch_inputs = {k: v[:, start:end] for k, v in model_inputs.items()}

with torch.no_grad():
model_output = model(**batch_inputs)
model_output = model(**batch_inputs)

if start > 0:
outputs.append(model_output[0][:, self.long_late_chunking_overlap_size:])
outputs.append(
model_output[0][:, self.long_late_chunking_overlap_size :]
)
else:
outputs.append(model_output[0])

Expand Down Expand Up @@ -227,10 +238,12 @@ def _evaluate_monolingual(
output_embs = chunked_pooling(
[model_outputs], annotations, max_length=None
)
else: # truncation
else: # truncation
model_outputs = model(**model_inputs)
output_embs = chunked_pooling(
model_outputs, annotations, max_length=self.truncate_max_length
model_outputs,
annotations,
max_length=self.truncate_max_length,
)
corpus_embs.extend(output_embs)

Expand Down
18 changes: 10 additions & 8 deletions run_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
DEFAULT_N_SENTENCES = 5
BATCH_SIZE = 1
DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE = 256
DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE = 0 # set to 0 to disable long late chunking
DEFAULT_TRUNCATE_MAX_LENGTH = 8192
DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE = 0 # set to 0 to disable long late chunking
DEFAULT_TRUNCATE_MAX_LENGTH = None


@click.command()
Expand Down Expand Up @@ -60,13 +60,13 @@
'--long-late-chunking-embed-size',
default=DEFAULT_LONG_LATE_CHUNKING_EMBED_SIZE,
type=int,
help='Token length of the embeddings that come before/after soft boundaries (i.e. overlapping embeddings). Above zero, overlap is used between neighbouring embeddings.',
help='Number of tokens per chunk for fixed strategy.',
)
@click.option(
'--long-late-chunking-overlap-size',
default=DEFAULT_LONG_LATE_CHUNKING_OVERLAP_SIZE,
type=int,
help='Number of tokens per chunk for fixed strategy.',
help='Token length of the embeddings that come before/after soft boundaries (i.e. overlapping embeddings). Above zero, overlap is used between neighbouring embeddings.',
)
def main(
model_name,
Expand All @@ -78,17 +78,19 @@ def main(
chunk_size,
n_sentences,
long_late_chunking_embed_size,
long_late_chunking_overlap_size
long_late_chunking_overlap_size,
):
try:
task_cls = globals()[task_name]
except:
raise ValueError(f'Unknown task name: {task_name}')

if truncate_max_length is not None and (long_late_chunking_embed_size > 0):
truncate_max_length = None
print(f'Truncation is disabled because Long Late Chunking algorithm is enabled.')

print(
f'Truncation is disabled because Long Late Chunking algorithm is enabled.'
)

model, has_instructions = load_model(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
Expand Down

0 comments on commit a79df8b

Please sign in to comment.