diff --git a/crates/pg_lsp/Cargo.toml b/crates/pg_lsp/Cargo.toml index 122e3ccd..84e97785 100644 --- a/crates/pg_lsp/Cargo.toml +++ b/crates/pg_lsp/Cargo.toml @@ -32,6 +32,7 @@ pg_base_db.workspace = true pg_schema_cache.workspace = true pg_workspace.workspace = true pg_diagnostics.workspace = true +tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread", "sync"] } [dev-dependencies] diff --git a/crates/pg_lsp/src/main.rs b/crates/pg_lsp/src/main.rs index eb5eddb6..803e0f39 100644 --- a/crates/pg_lsp/src/main.rs +++ b/crates/pg_lsp/src/main.rs @@ -1,9 +1,12 @@ use lsp_server::Connection; use pg_lsp::server::Server; -fn main() -> anyhow::Result<()> { +#[tokio::main] +async fn main() -> anyhow::Result<()> { let (connection, threads) = Connection::stdio(); - Server::init(connection)?; + let server = Server::init(connection)?; + + server.run().await?; threads.join()?; Ok(()) diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 927d7f16..351112a0 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -3,7 +3,6 @@ mod dispatch; pub mod options; use async_std::task::{self}; -use crossbeam_channel::{unbounded, Receiver, Sender}; use lsp_server::{Connection, ErrorCode, Message, RequestId}; use lsp_types::{ notification::{ @@ -33,6 +32,8 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; use text_size::TextSize; use threadpool::ThreadPool; +use tokio::sync::{mpsc, oneshot}; + use crate::{ client::{client_flags::ClientFlags, LspClient}, utils::{file_path, from_proto, line_index_ext::LineIndexExt, normalize_uri, to_proto}, @@ -68,11 +69,39 @@ impl DbConnection { } } +/// `lsp-servers` `Connection` type uses a crossbeam channel, which is not compatible with tokio's async runtime. +/// For now, we move it into a separate task and use tokio's channels to communicate. +fn get_client_receiver( + connection: Connection, +) -> (mpsc::UnboundedReceiver, oneshot::Receiver<()>) { + let (message_tx, message_rx) = mpsc::unbounded_channel(); + let (close_tx, close_rx) = oneshot::channel(); + + tokio::task::spawn(async move { + // TODO: improve Result handling + loop { + let msg = connection.receiver.recv().unwrap(); + + match msg { + Message::Request(r) if connection.handle_shutdown(&r).unwrap() => { + close_tx.send(()).unwrap(); + return; + } + + _ => message_tx.send(msg).unwrap(), + }; + } + }); + + (message_rx, close_rx) +} + pub struct Server { - connection: Arc, + client_rx: mpsc::UnboundedReceiver, + close_rx: oneshot::Receiver<()>, client: LspClient, - internal_tx: Sender, - internal_rx: Receiver, + internal_tx: mpsc::UnboundedSender, + internal_rx: mpsc::UnboundedReceiver, pool: Arc, client_flags: Arc, ide: Arc, @@ -81,10 +110,10 @@ pub struct Server { } impl Server { - pub fn init(connection: Connection) -> anyhow::Result<()> { + pub fn init(connection: Connection) -> anyhow::Result { let client = LspClient::new(connection.sender.clone()); - let (internal_tx, internal_rx) = unbounded(); + let (internal_tx, internal_rx) = mpsc::unbounded_channel(); let (id, params) = connection.initialize_start()?; let params: InitializeParams = serde_json::from_value(params)?; @@ -110,8 +139,11 @@ impl Server { let cloned_pool = pool.clone(); let cloned_client = client.clone(); + let (client_rx, close_rx) = get_client_receiver(connection); + let server = Self { - connection: Arc::new(connection), + close_rx, + client_rx, internal_rx, internal_tx, client, @@ -158,8 +190,7 @@ impl Server { pool, }; - server.run()?; - Ok(()) + Ok(server) } fn compute_now(&self) { @@ -763,67 +794,84 @@ impl Server { Ok(()) } - fn process_messages(&mut self) -> anyhow::Result<()> { + async fn process_messages(&mut self) -> anyhow::Result<()> { loop { - crossbeam_channel::select! { - recv(&self.connection.receiver) -> msg => { - match msg? { - Message::Request(request) => { - if self.connection.handle_shutdown(&request)? { - return Ok(()); - } - - if let Some(response) = dispatch::RequestDispatcher::new(request) - .on::(|id, params| self.inlay_hint(id, params))? - .on::(|id, params| self.hover(id, params))? - .on::(|id, params| self.execute_command(id, params))? - .on::(|id, params| { - self.completion(id, params) - })? - .on::(|id, params| { - self.code_actions(id, params) - })? - .default() - { - self.client.send_response(response)?; - } - } - Message::Notification(notification) => { - dispatch::NotificationDispatcher::new(notification) - .on::(|params| { - self.did_change_configuration(params) - })? - .on::(|params| self.did_close(params))? - .on::(|params| self.did_open(params))? - .on::(|params| self.did_change(params))? - .on::(|params| self.did_save(params))? - .on::(|params| self.did_close(params))? - .default(); - } - Message::Response(response) => { - self.client.recv_response(response)?; - } - }; + tokio::select! { + _ = &mut self.close_rx => { + return Ok(()) }, - recv(&self.internal_rx) -> msg => { - match msg? { - InternalMessage::SetSchemaCache(c) => { - self.ide.set_schema_cache(c); - self.compute_now(); - } - InternalMessage::RefreshSchemaCache => { - self.refresh_schema_cache(); - } - InternalMessage::PublishDiagnostics(uri) => { - self.publish_diagnostics(uri)?; - } - InternalMessage::SetOptions(options) => { - self.update_options(options); - } - }; + + msg = self.internal_rx.recv() => { + match msg { + // TODO: handle internal sender close? Is that valid state? + None => return Ok(()), + Some(m) => self.handle_internal_message(m) + } + }, + + msg = self.client_rx.recv() => { + match msg { + // the client sender is closed, we can return + None => return Ok(()), + Some(m) => self.handle_message(m) + } + }, + }?; + } + } + + fn handle_message(&mut self, msg: Message) -> anyhow::Result<()> { + match msg { + Message::Request(request) => { + if let Some(response) = dispatch::RequestDispatcher::new(request) + .on::(|id, params| self.inlay_hint(id, params))? + .on::(|id, params| self.hover(id, params))? + .on::(|id, params| self.execute_command(id, params))? + .on::(|id, params| self.completion(id, params))? + .on::(|id, params| self.code_actions(id, params))? + .default() + { + self.client.send_response(response)?; } - }; + } + Message::Notification(notification) => { + dispatch::NotificationDispatcher::new(notification) + .on::(|params| { + self.did_change_configuration(params) + })? + .on::(|params| self.did_close(params))? + .on::(|params| self.did_open(params))? + .on::(|params| self.did_change(params))? + .on::(|params| self.did_save(params))? + .on::(|params| self.did_close(params))? + .default(); + } + Message::Response(response) => { + self.client.recv_response(response)?; + } } + + Ok(()) + } + + fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> { + match msg { + InternalMessage::SetSchemaCache(c) => { + self.ide.set_schema_cache(c); + self.compute_now(); + } + InternalMessage::RefreshSchemaCache => { + self.refresh_schema_cache(); + } + InternalMessage::PublishDiagnostics(uri) => { + self.publish_diagnostics(uri)?; + } + InternalMessage::SetOptions(options) => { + self.update_options(options); + } + } + + Ok(()) } fn pull_options(&mut self) { @@ -881,10 +929,10 @@ impl Server { } } - pub fn run(mut self) -> anyhow::Result<()> { + pub async fn run(mut self) -> anyhow::Result<()> { self.register_configuration(); self.pull_options(); - self.process_messages()?; + self.process_messages().await?; self.pool.join(); Ok(()) }