forked from SciPhi-AI/R2R
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rag.py
181 lines (157 loc) · 4.97 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Abstract base class for completion pipelines.
"""
import logging
import uuid
from abc import abstractmethod
from typing import Any, Optional, Tuple, Union
from openai.types import Completion
from openai.types.chat import ChatCompletion
from ..providers.llm import GenerationConfig, LLMProvider
from ..providers.logging import LoggingDatabaseConnection, log_execution_to_db
from .pipeline import Pipeline
logger = logging.getLogger(__name__)
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
DEFAULT_TASK_PROMPT = """
## Task:
Answer the query given immediately below given the context which follows later.
### Query:
{query}
### Context:
{context}
### Query:
{query}
## Response:
"""
class RAGPipeline(Pipeline):
def __init__(
self,
llm: "LLMProvider",
generation_config: "GenerationConfig",
system_prompt: Optional[str] = None,
task_prompt: Optional[str] = None,
logging_provider: Optional[LoggingDatabaseConnection] = None,
*args,
**kwargs,
):
self.llm = llm
self.generation_config = generation_config
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
self.task_prompt = task_prompt or DEFAULT_TASK_PROMPT
self.logging_provider = logging_provider
self.pipeline_run_info = None
super().__init__(logging_provider=logging_provider, **kwargs)
def initialize_pipeline(
self, query: str, search_only: bool, *args, **kwargs
) -> None:
self.pipeline_run_info = {
"run_id": uuid.uuid4(),
"type": "rag" if not search_only else "search",
}
self.ingress(query)
@log_execution_to_db
def ingress(self, data: Any) -> Any:
"""
Ingresses data into the pipeline.
"""
self._check_pipeline_initialized()
return data
@abstractmethod
def transform_query(self, query: str) -> Any:
"""
Transforms the input query for retrieval.
"""
pass
@abstractmethod
def search(
self,
transformed_query,
filters: dict[str, Any],
limit: int,
*args,
**kwargs,
) -> list:
"""
Retrieves results based on the transformed query.
The search_type parameter allows for specifying the type of search,
"""
pass
@abstractmethod
def rerank_results(self, results: list) -> list:
"""
Reranks the retrieved results based on relevance or other criteria.
"""
pass
@abstractmethod
def _get_extra_args(self, *args, **kwargs) -> dict[str, Any]:
"""
Returns extra arguments for the generation request.
"""
pass
@abstractmethod
def _format_results(self, results: list) -> str:
"""
Formats the results for generation.
"""
pass
@log_execution_to_db
def construct_context(
self,
results: list,
) -> str:
reranked_results = self.rerank_results(results)
return self._format_results(reranked_results)
@log_execution_to_db
def construct_prompt(self, inputs: dict[str, str]) -> str:
"""
Constructs a prompt for generation based on the reranked chunks.
"""
return self.task_prompt.format(**inputs)
@log_execution_to_db
def generate_completion(
self,
prompt: str,
generate_with_chat=True,
) -> Union[ChatCompletion, Completion]:
"""
Generates a completion based on the prompt.
"""
self._check_pipeline_initialized()
if generate_with_chat:
return self.llm.get_chat_completion(
[
{
"role": "system",
"content": self.system_prompt,
},
{
"role": "user",
"content": prompt,
},
],
self.generation_config,
**self._get_extra_args(),
)
else:
raise NotImplementedError(
"Generation without chat is not implemented yet."
)
# TODO - Clean up the return types
def run(
self, query, filters={}, limit=10, search_only=False
) -> Tuple[str, Union[ChatCompletion, Completion, list]]:
"""
Runs the completion pipeline.
"""
self.initialize_pipeline(query, search_only)
logger.debug(f"Pipeline run type: {self.pipeline_run_info}")
transformed_query = self.transform_query(query)
search_results = self.search(transformed_query, filters, limit)
if search_only:
return None, search_results
context = self.construct_context(search_results)
prompt = self.construct_prompt(
{"query": transformed_query, "context": context}
)
completion = self.generate_completion(prompt, generate_with_chat=True)
return context, completion