Skip to content

Commit

Permalink
ccat_reranker v0.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
nickprock committed May 30, 2024
1 parent ec99ff4 commit 2ca237c
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 8 deletions.
16 changes: 11 additions & 5 deletions ccat_reranker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from cat.mad_hatter.decorators import hook
from .rankers import get_settings, recent_ranker, litm, filter_ranker

from .rankers import get_settings, recent_ranker, litm, filter_ranker, sbert_ranker
from sentence_transformers.cross_encoder import CrossEncoder

@hook(priority=1)
def after_cat_recalls_memories(cat) -> None:
Expand All @@ -16,17 +16,23 @@ def after_cat_recalls_memories(cat) -> None:
"""
settings = get_settings()
#TODO print(cat.working_memory.history[0]['message'])
if settings["RECENTNESS"]:
if cat.working_memory['episodic_memories']:
recent_docs = recent_ranker(cat.working_memory['episodic_memories'])
cat.working_memory['episodic_memories'] = recent_docs
else:
print("#HicSuntGattones")

if settings["LITM"]:
if settings["SBERT"]:
model = CrossEncoder(settings["ranker"])
if cat.working_memory['declarative_memories']:
litm_docs = litm(cat.working_memory['declarative_memories'])
cat.working_memory['declarative_memories'] = litm_docs
sbert_docs = sbert_ranker(cat.working_memory['declarative_memories'], cat.working_memory.history[0]['message'], model)
if settings["LITM"]:
litm_docs = litm(sbert_docs)
cat.working_memory['declarative_memories'] = litm_docs
else:
cat.working_memory['declarative_memories'] = sbert_docs
else:
print("#HicSuntGattones")

Expand Down
4 changes: 2 additions & 2 deletions plugin.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "Cheshire Cat ReRanker",
"version": "0.1.2",
"description": "This plugin apply a reranking to each memory. It rearranges: the episodic memory from the most recent to the oldest, the declarative using lost in the middle and the procedural using a filter.",
"version": "0.2.0",
"description": "This plugin apply a reranking to each memory. It rearranges: the episodic memory from the most recent to the oldest, the declarative using SBERT cross encoder adding a lost in the middle step and the procedural using a filter.",
"author_name": "nickprock",
"plugin_url": "https://github.com/nickprock/ccat_reranker",
"tags": "cat, memory, advanced, reranker",
Expand Down
13 changes: 12 additions & 1 deletion rankers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import json
import numpy as np

def get_settings():
if os.path.isfile("cat/plugins/ccat_reranker/settings.json"):
with open("cat/plugins/ccat_reranker/settings.json", "r") as json_file:
Expand Down Expand Up @@ -66,4 +68,13 @@ def filter_ranker(documents, tool_threshold):
filtered: The same list but filtered
"""
filtered = [d for d in documents if d[1]>tool_threshold]
return filtered
return filtered

def sbert_ranker(documents, query, model):
sentence_combinations = [[query, document[0].page_content] for document in documents]
scores = model.predict(sentence_combinations)
ranked_indices = np.argsort(scores)[::-1]
out_list = [documents[idx] for idx in ranked_indices]
# I don't change the score in the Documents using the reranker score
# because it could be very different than the classical bi-encoder and could create mistakes
return out_list
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sentence-transformers
2 changes: 2 additions & 0 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class MySettings(BaseModel):
LITM: bool = True,
RECENTNESS: bool = True,
FILTER: bool = True,
SBERT: bool = True
ranker: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
tool_threshold: float = 0.5

@plugin
Expand Down

0 comments on commit 2ca237c

Please sign in to comment.