From db80b13169a650d71291c0ea0f01cf8656bcf30e Mon Sep 17 00:00:00 2001 From: Joe Cheng Date: Wed, 28 Aug 2024 11:07:47 -0700 Subject: [PATCH] wip: streaming support * Requires r-lib/httr2#521 * Explain plot feature is currently broken --- R/chat.R | 2 ++ R/query.R | 103 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ app.R | 73 ++++++++++++-------------------------- 3 files changed, 127 insertions(+), 51 deletions(-) diff --git a/R/chat.R b/R/chat.R index abaf4fe..6842afa 100644 --- a/R/chat.R +++ b/R/chat.R @@ -43,6 +43,8 @@ chat_append_message <- function(id, msg, chunk = FALSE, session = getDefaultReac chunk_type <- "message_start" } else if (chunk == "end") { chunk_type <- "message_end" + } else if (isTRUE(chunk)) { + chunk_type <- NULL } else { stop("Invalid chunk argument") } diff --git a/R/query.R b/R/query.R index 665c54f..b85f184 100644 --- a/R/query.R +++ b/R/query.R @@ -1,6 +1,7 @@ # pak::pak("irudnyts/openai@r6") library(openai) library(here) +library(httr2) log <- function(...) {} log <- message @@ -63,6 +64,108 @@ df_to_schema <- function(df, name, categorical_threshold) { return(paste(schema, collapse = "\n")) } +chat_async <- function( + messages, + model = "gpt-4o", + on_chunk = \(chunk) {}, + polling_interval_secs = 0.2, + .ctx = NULL +) { + api_endpoint <- "https://api.openai.com/v1/chat/completions" + api_key <- Sys.getenv("OPENAI_API_KEY") + + # Build the request + response <- request(api_endpoint) %>% + req_headers( + "Content-Type" = "application/json", + "Authorization" = paste("Bearer", api_key) + ) %>% + req_body_json(list( + model = "gpt-4o", + stream = TRUE, + temperature = 0.7, + messages = messages$as_list(), + tools = tool_infos + )) %>% + req_perform_connection(mode = "rt", blocking = FALSE) + + chunks <- list() + + promises::promise(\(resolve, reject) { + do_next <- \() { + shiny:::withLogErrors({ + while (TRUE) { + sse <- read_sse(response$body) + if (!is.null(sse)) { + if (identical(sse$data, "[DONE]")) { + break + } + chunk <- jsonlite::fromJSON(sse$data, simplifyVector = FALSE) + # message(sse$data) + on_chunk(chunk) + chunks <<- c(chunks, list(chunk)) + } else { + if (isIncomplete(response$body)) { + later::later(do_next, polling_interval_secs) + return() + } else { + break + } + } + } + + # We've gathered all the chunks + resolve(Reduce(elmer:::merge_dicts, chunks)) + }) + } + do_next() + }) %>% promises::then(\(completion) { + msg <- completion$choices[[1]]$delta + messages$add(msg) + if (!is.null(msg$tool_calls)) { + log("Handling tool calls") + # TODO: optionally return the tool calls to the caller as well + tool_response_msgs <- lapply(msg$tool_calls, \(tool_call) { + id <- tool_call$id + type <- tool_call$type + fname <- tool_call$`function`$name + log("Calling ", fname) + args <- jsonlite::parse_json(tool_call$`function`$arguments) + func <- tool_funcs[[fname]] + if (is.null(func)) { + stop("Called unknown tool '", fname, "'") + } + if (".ctx" %in% names(formals(func))) { + args$.ctx <- .ctx + } + result <- tryCatch( + { + do.call(func, args) + }, + error = \(e) { + message(conditionMessage(e)) + list(success = FALSE, error = "An error occurred") + } + ) + + list( + role = "tool", + tool_call_id = id, + name = fname, + content = jsonlite::toJSON(result, auto_unbox = TRUE) + ) + }) + for (tool_response_msg in tool_response_msgs) { + messages$add(tool_response_msg) + } + + chat_async(messages, model=model, on_chunk=on_chunk, polling_interval_secs=polling_interval_secs, .ctx=.ctx) + } else { + invisible() + } + }) +} + query <- function(messages, model = "gpt-4o", ..., .ctx = NULL) { # TODO: verify it's a good response diff --git a/app.R b/app.R index 26431a9..5c654fc 100644 --- a/app.R +++ b/app.R @@ -47,7 +47,7 @@ ui <- page_sidebar( chat_ui("chat", height = "100%", fill = TRUE) ), useBusyIndicators(), - + # 🏷️ Header textOutput("show_title", container = h3), verbatimTextOutput("show_query") |> @@ -297,60 +297,31 @@ server <- function(input, output, session) { list(role = "user", content = input$chat_user_input) ) - prog <- Progress$new() - prog$set(value = NULL, message = "Thinking...") - - mirai( - msgs = messages$as_list(), - model = input$model, - { - library(duckdb) - library(DBI) - library(here) - source(here("R/query.R"), local = TRUE) - - conn <- dbConnect(duckdb(), dbdir = here("tips.duckdb"), read_only = TRUE) - on.exit(dbDisconnect(conn)) - - result_query <- NULL - result_title <- NULL - - update_dashboard <- function(query, title) { - result_query <<- query - result_title <<- title - } - - ctx <- list(conn = conn, update_dashboard = update_dashboard) - - c( - query(msgs, model = model, .ctx = ctx), - query = result_query, - title = result_title - ) + content_accum <- "" + on_chunk <- function(chunk) { + if (!is.null(chunk$choices[[1]]$delta$content)) { + content_accum <<- paste0(content_accum, chunk$choices[[1]]$delta$content) + chat_append_message("chat", list(role = "assistant", content = content_accum), chunk = TRUE, session = session) } - ) |> - then(\(result) { - for (imsg in result$intermediate_messages) { - messages$add(imsg) - } - - if (!is.null(result$query)) { - current_query(result$query) - } - if (!is.null(result$title)) { - current_title(result$title) - } + } - completion <- result$completion + update_dashboard <- function(query, title) { + if (!is.null(query)) { + current_query(query) + } + if (!is.null(title)) { + current_title(title) + } + } - response_msg <- completion$choices[[1]]$message - # print(response_msg) + chat_append_message("chat", list(role = "assistant", content = ""), chunk = "start") - # Add response to the chat history - messages$add(response_msg) + chat_async( + messages, + model = "gpt-4o", + on_chunk = on_chunk, + .ctx = list(conn = conn, update_dashboard = update_dashboard)) |> - chat_append_message("chat", response_msg) - }) |> catch(\(err) { print(err) err_msg <- list( @@ -362,7 +333,7 @@ server <- function(input, output, session) { chat_append_message("chat", err_msg) }) |> finally(\() { - prog$close() + chat_append_message("chat", list(role = "assistant", content = content_accum), chunk = "end") }) }) }