From 2fa2b002e98e292eac14005fb2d5d1ca513ef912 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:28:29 -0500 Subject: [PATCH] add set threshold --- pyproject.toml | 2 +- semantic_router/__init__.py | 2 +- semantic_router/splitters/rolling_window.py | 7 ++++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1db92298..24c1fa09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "semantic-router" -version = "0.0.23" +version = "0.0.24" description = "Super fast semantic router for AI decision making" authors = [ "James Briggs ", diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py index 7ac0c93e..d810106a 100644 --- a/semantic_router/__init__.py +++ b/semantic_router/__init__.py @@ -4,4 +4,4 @@ __all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"] -__version__ = "0.0.23" +__version__ = "0.0.24" diff --git a/semantic_router/splitters/rolling_window.py b/semantic_router/splitters/rolling_window.py index ca9eed95..0e7c651d 100644 --- a/semantic_router/splitters/rolling_window.py +++ b/semantic_router/splitters/rolling_window.py @@ -14,6 +14,7 @@ def __init__( self, encoder: BaseEncoder, threshold_adjustment=0.01, + dynamic_threshold: bool = True, window_size=5, min_split_tokens=100, max_split_tokens=300, @@ -25,6 +26,7 @@ def __init__( self.calculated_threshold: float self.encoder = encoder self.threshold_adjustment = threshold_adjustment + self.dynamic_threshold = dynamic_threshold self.window_size = window_size self.plot_splits = plot_splits self.min_split_tokens = min_split_tokens @@ -321,7 +323,10 @@ def __call__(self, docs: List[str]) -> List[DocumentSplit]: ) docs = split_to_sentences(docs[0]) encoded_docs = self.encode_documents(docs) - self.find_optimal_threshold(docs, encoded_docs) + if self.dynamic_threshold: + self.find_optimal_threshold(docs, encoded_docs) + else: + self.calculated_threshold = self.encoder.score_threshold similarities = self.calculate_similarity_scores(encoded_docs) split_indices = self.find_split_indices(similarities=similarities) splits = self.split_documents(docs, split_indices, similarities)