Skip to content

Commit

Permalink
fix score method in SentenceEmbeddingFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed Apr 3, 2024
1 parent 3345dda commit 49efc92
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
3 changes: 2 additions & 1 deletion opusfilter/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def _score_chunk(self, chunk):

def score(self, pairs):
for chunk in grouper(pairs, self.chunksize):
return self._score_chunk(chunk)
for score in self._score_chunk(chunk):
yield score

def accept(self, score):
return all(similarity >= self.threshold for similarity in score)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from opusfilter import ConfigurationError
from opusfilter.embeddings import *
from opusfilter.pipeline import FilterPipeline


try:
Expand Down Expand Up @@ -80,3 +81,17 @@ def test_bilingual_margin_ratios(self):
results = [testfilter.accept(x) for x in testfilter.score(self.bi_inputs)]
for result, correct in zip(results, expected):
self.assertEqual(result, correct)

def test_chunking(self):
testfilter = SentenceEmbeddingFilter(languages=self.bi_langs, threshold=0.4, chunksize=19)
inputs = 50 * self.bi_inputs
expected = 50 * [True, True, False, False]
results = [testfilter.accept(x) for x in testfilter.score(inputs)]
for result, correct in zip(results, expected):
self.assertEqual(result, correct)
pipeline = FilterPipeline(filters=[testfilter])
pipeline.chunksize = 30
filtered = list(pipeline.filter(inputs))
self.assertEqual(len(filtered), len([x for x in expected if x]))
scores = list(pipeline.score(inputs))
self.assertEqual(len(scores), len(expected))

0 comments on commit 49efc92

Please sign in to comment.