Skip to content

Commit

Permalink
Merge pull request #9 from OpenPecha/feat-set_threshold_for_context_s…
Browse files Browse the repository at this point in the history
…imilarity

set threshold for context similarity
  • Loading branch information
tenzin3 authored Oct 24, 2024
2 parents 0533bdd + 89bbe5c commit 06497dc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
17 changes: 13 additions & 4 deletions src/bo_rag_prep_tool/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ def cosine_similarity(vec1, vec2):
return similarity


def get_threshold(similarity_scores) -> int:
threshold = np.max(similarity_scores) - np.std(similarity_scores)
return threshold


def get_context(query: str):
query_embedding = get_openai_embedding([query])[0]

context_datas = read_json(Path("resource/ངོས་ཀྱི་ཡུལ་དང་ངོས་ཀྱི་མི་མང་།.json"))

similarities = []
# Store top three contexts data for llm generation
for context_data in context_datas:
Expand All @@ -39,8 +43,13 @@ def get_context(query: str):

# Sort the context data based on the similarity score in descending order
top_contexts = sorted(similarities, key=lambda x: x[0], reverse=True)[:10]
top_context_similarity_scores = [context[0] for context in top_contexts]
threshold = get_threshold(top_context_similarity_scores)

# Extract the top 3 context data
top_three_contexts = [context[1] for context in top_contexts]
final_contexts = []
for context in top_contexts:
context_similarity, context_data = context
if context_similarity >= threshold:
final_contexts.append(context_data)

return top_three_contexts
return final_contexts
4 changes: 1 addition & 3 deletions src/bo_rag_prep_tool/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def get_answer(query: str):
Contexts: {context_texts}
"""
# answer = get_claude_response(prompt)
answer = get_monlam_llm_response(prompt)
answer = get_claude_response(prompt)
return answer


Expand Down Expand Up @@ -79,6 +78,5 @@ def extract_text_from_monlam_response(response):

if __name__ == "__main__":
query = "Who is Songtsen Gampo?"
# answer = get_monlam_llm_response(query)
answer = get_answer(query)
print(answer)

0 comments on commit 06497dc

Please sign in to comment.