-
Notifications
You must be signed in to change notification settings - Fork 26
/
workflow.py
131 lines (106 loc) · 3.7 KB
/
workflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from typing import List, Any
from llama_index.core.schema import Document
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.llms.llm import LLM
from llama_index.core.workflow import (
step,
Context,
Workflow,
Event,
StartEvent,
StopEvent,
)
from markdown_pdf import MarkdownPdf, Section
from subquery import get_sub_queries
from tavily import get_docs_from_tavily_search
from compress import get_compressed_context
from report import generate_report_from_context
class SubQueriesCreatedEvent(Event):
sub_queries: List[str]
class ToProcessSubQueryEvent(Event):
sub_query: str
class DocsScrapedEvent(Event):
sub_query: str
docs: List[Document]
class ToCombineContextEvent(Event):
sub_query: str
context: str
class ReportPromptCreatedEvent(Event):
context: str
class LLMResponseEvent(Event):
response: str
class ResearchAssistantWorkflow(Workflow):
def __init__(
self,
*args: Any,
llm: LLM,
embed_model: BaseEmbedding,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.llm = llm
self.embed_model = embed_model
self.visited_urls: set[str] = set()
@step
async def create_sub_queries(
self, ctx: Context, ev: StartEvent
) -> SubQueriesCreatedEvent:
query = ev.query
await ctx.set("query", query)
sub_queries = await get_sub_queries(query, self.llm)
await ctx.set("num_sub_queries", len(sub_queries))
return SubQueriesCreatedEvent(sub_queries=sub_queries)
@step
async def deligate_sub_queries(
self, ctx: Context, ev: SubQueriesCreatedEvent
) -> ToProcessSubQueryEvent:
for sub_query in ev.sub_queries:
ctx.send_event(ToProcessSubQueryEvent(sub_query=sub_query))
return None
@step
async def get_docs_for_subquery(
self, ev: ToProcessSubQueryEvent
) -> DocsScrapedEvent:
sub_query = ev.sub_query
docs, visited_urls = await get_docs_from_tavily_search(
sub_query, self.visited_urls
)
self.visited_urls = visited_urls
return DocsScrapedEvent(sub_query=sub_query, docs=docs)
@step(num_workers=3)
async def compress_docs(self, ev: DocsScrapedEvent) -> ToCombineContextEvent:
sub_query = ev.sub_query
docs = ev.docs
print(f"\n> Compressing docs for sub query: {sub_query}\n")
compressed_context = await get_compressed_context(
sub_query, docs, self.embed_model
)
return ToCombineContextEvent(sub_query=sub_query, context=compressed_context)
@step
async def combine_contexts(
self, ctx: Context, ev: ToCombineContextEvent
) -> ReportPromptCreatedEvent:
events = ctx.collect_events(
ev, [ToCombineContextEvent] * await ctx.get("num_sub_queries")
)
if events is None:
return None
context = ""
for event in events:
context += (
f'Research findings for topic "{event.sub_query}":\n{event.context}\n\n'
)
return ReportPromptCreatedEvent(context=context)
@step
async def write_report(
self, ctx: Context, ev: ReportPromptCreatedEvent
) -> StopEvent:
context = ev.context
query = await ctx.get("query")
print(f"\n> Writing report. This will take a few minutes...\n")
report = await generate_report_from_context(query, context, self.llm)
pdf = MarkdownPdf()
pdf.add_section(Section(report, toc=False))
pdf.save("report.pdf")
print("\n> Done writing report to report.pdf! Trying to open the file...\n")
return StopEvent(result="report.pdf")