Skip to content

Commit

Permalink
wip: streaming support
Browse files Browse the repository at this point in the history
* Requires r-lib/httr2#521
* Explain plot feature is currently broken
  • Loading branch information
jcheng5 committed Aug 28, 2024
1 parent 971cc46 commit db80b13
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 51 deletions.
2 changes: 2 additions & 0 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ chat_append_message <- function(id, msg, chunk = FALSE, session = getDefaultReac
chunk_type <- "message_start"
} else if (chunk == "end") {
chunk_type <- "message_end"
} else if (isTRUE(chunk)) {
chunk_type <- NULL
} else {
stop("Invalid chunk argument")
}
Expand Down
103 changes: 103 additions & 0 deletions R/query.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pak::pak("irudnyts/openai@r6")
library(openai)
library(here)
library(httr2)

log <- function(...) {}
log <- message
Expand Down Expand Up @@ -63,6 +64,108 @@ df_to_schema <- function(df, name, categorical_threshold) {
return(paste(schema, collapse = "\n"))
}

chat_async <- function(
messages,
model = "gpt-4o",
on_chunk = \(chunk) {},
polling_interval_secs = 0.2,
.ctx = NULL
) {
api_endpoint <- "https://api.openai.com/v1/chat/completions"
api_key <- Sys.getenv("OPENAI_API_KEY")

# Build the request
response <- request(api_endpoint) %>%
req_headers(
"Content-Type" = "application/json",
"Authorization" = paste("Bearer", api_key)
) %>%
req_body_json(list(
model = "gpt-4o",
stream = TRUE,
temperature = 0.7,
messages = messages$as_list(),
tools = tool_infos
)) %>%
req_perform_connection(mode = "rt", blocking = FALSE)

chunks <- list()

promises::promise(\(resolve, reject) {
do_next <- \() {
shiny:::withLogErrors({
while (TRUE) {
sse <- read_sse(response$body)
if (!is.null(sse)) {
if (identical(sse$data, "[DONE]")) {
break
}
chunk <- jsonlite::fromJSON(sse$data, simplifyVector = FALSE)
# message(sse$data)
on_chunk(chunk)
chunks <<- c(chunks, list(chunk))
} else {
if (isIncomplete(response$body)) {
later::later(do_next, polling_interval_secs)
return()
} else {
break
}
}
}

# We've gathered all the chunks
resolve(Reduce(elmer:::merge_dicts, chunks))
})
}
do_next()
}) %>% promises::then(\(completion) {
msg <- completion$choices[[1]]$delta
messages$add(msg)
if (!is.null(msg$tool_calls)) {
log("Handling tool calls")
# TODO: optionally return the tool calls to the caller as well
tool_response_msgs <- lapply(msg$tool_calls, \(tool_call) {
id <- tool_call$id
type <- tool_call$type
fname <- tool_call$`function`$name
log("Calling ", fname)
args <- jsonlite::parse_json(tool_call$`function`$arguments)
func <- tool_funcs[[fname]]
if (is.null(func)) {
stop("Called unknown tool '", fname, "'")
}
if (".ctx" %in% names(formals(func))) {
args$.ctx <- .ctx
}
result <- tryCatch(
{
do.call(func, args)
},
error = \(e) {
message(conditionMessage(e))
list(success = FALSE, error = "An error occurred")
}
)

list(
role = "tool",
tool_call_id = id,
name = fname,
content = jsonlite::toJSON(result, auto_unbox = TRUE)
)
})
for (tool_response_msg in tool_response_msgs) {
messages$add(tool_response_msg)
}

chat_async(messages, model=model, on_chunk=on_chunk, polling_interval_secs=polling_interval_secs, .ctx=.ctx)
} else {
invisible()
}
})
}

query <- function(messages, model = "gpt-4o", ..., .ctx = NULL) {
# TODO: verify it's a good response

Expand Down
73 changes: 22 additions & 51 deletions app.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ ui <- page_sidebar(
chat_ui("chat", height = "100%", fill = TRUE)
),
useBusyIndicators(),

# 🏷️ Header
textOutput("show_title", container = h3),
verbatimTextOutput("show_query") |>
Expand Down Expand Up @@ -297,60 +297,31 @@ server <- function(input, output, session) {
list(role = "user", content = input$chat_user_input)
)

prog <- Progress$new()
prog$set(value = NULL, message = "Thinking...")

mirai(
msgs = messages$as_list(),
model = input$model,
{
library(duckdb)
library(DBI)
library(here)
source(here("R/query.R"), local = TRUE)

conn <- dbConnect(duckdb(), dbdir = here("tips.duckdb"), read_only = TRUE)
on.exit(dbDisconnect(conn))

result_query <- NULL
result_title <- NULL

update_dashboard <- function(query, title) {
result_query <<- query
result_title <<- title
}

ctx <- list(conn = conn, update_dashboard = update_dashboard)

c(
query(msgs, model = model, .ctx = ctx),
query = result_query,
title = result_title
)
content_accum <- ""
on_chunk <- function(chunk) {
if (!is.null(chunk$choices[[1]]$delta$content)) {
content_accum <<- paste0(content_accum, chunk$choices[[1]]$delta$content)
chat_append_message("chat", list(role = "assistant", content = content_accum), chunk = TRUE, session = session)
}
) |>
then(\(result) {
for (imsg in result$intermediate_messages) {
messages$add(imsg)
}

if (!is.null(result$query)) {
current_query(result$query)
}
if (!is.null(result$title)) {
current_title(result$title)
}
}

completion <- result$completion
update_dashboard <- function(query, title) {
if (!is.null(query)) {
current_query(query)
}
if (!is.null(title)) {
current_title(title)
}
}

response_msg <- completion$choices[[1]]$message
# print(response_msg)
chat_append_message("chat", list(role = "assistant", content = ""), chunk = "start")

# Add response to the chat history
messages$add(response_msg)
chat_async(
messages,
model = "gpt-4o",
on_chunk = on_chunk,
.ctx = list(conn = conn, update_dashboard = update_dashboard)) |>

chat_append_message("chat", response_msg)
}) |>
catch(\(err) {
print(err)
err_msg <- list(
Expand All @@ -362,7 +333,7 @@ server <- function(input, output, session) {
chat_append_message("chat", err_msg)
}) |>
finally(\() {
prog$close()
chat_append_message("chat", list(role = "assistant", content = content_accum), chunk = "end")
})
})
}
Expand Down

0 comments on commit db80b13

Please sign in to comment.