Skip to content

Commit

Permalink
refactor: tokio main, isolate crossbeam channels
Browse files Browse the repository at this point in the history
  • Loading branch information
juleswritescode committed Oct 11, 2024
1 parent d76481b commit 20b8567
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 70 deletions.
1 change: 1 addition & 0 deletions crates/pg_lsp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pg_base_db.workspace = true
pg_schema_cache.workspace = true
pg_workspace.workspace = true
pg_diagnostics.workspace = true
tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread", "sync"] }

[dev-dependencies]

Expand Down
7 changes: 5 additions & 2 deletions crates/pg_lsp/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use lsp_server::Connection;
use pg_lsp::server::Server;

fn main() -> anyhow::Result<()> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let (connection, threads) = Connection::stdio();
Server::init(connection)?;
let server = Server::init(connection)?;

server.run().await?;
threads.join()?;

Ok(())
Expand Down
184 changes: 116 additions & 68 deletions crates/pg_lsp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod dispatch;
pub mod options;

use async_std::task::{self};
use crossbeam_channel::{unbounded, Receiver, Sender};
use lsp_server::{Connection, ErrorCode, Message, RequestId};
use lsp_types::{
notification::{
Expand Down Expand Up @@ -33,6 +32,8 @@ use std::{collections::HashSet, sync::Arc, time::Duration};
use text_size::TextSize;
use threadpool::ThreadPool;

use tokio::sync::{mpsc, oneshot};

use crate::{
client::{client_flags::ClientFlags, LspClient},
utils::{file_path, from_proto, line_index_ext::LineIndexExt, normalize_uri, to_proto},
Expand Down Expand Up @@ -68,11 +69,39 @@ impl DbConnection {
}
}

/// `lsp-servers` `Connection` type uses a crossbeam channel, which is not compatible with tokio's async runtime.
/// For now, we move it into a separate task and use tokio's channels to communicate.
fn get_client_receiver(
connection: Connection,
) -> (mpsc::UnboundedReceiver<Message>, oneshot::Receiver<()>) {
let (message_tx, message_rx) = mpsc::unbounded_channel();
let (close_tx, close_rx) = oneshot::channel();

tokio::task::spawn(async move {
// TODO: improve Result handling
loop {
let msg = connection.receiver.recv().unwrap();

match msg {
Message::Request(r) if connection.handle_shutdown(&r).unwrap() => {
close_tx.send(()).unwrap();
return;
}

_ => message_tx.send(msg).unwrap(),
};
}
});

(message_rx, close_rx)
}

pub struct Server {
connection: Arc<Connection>,
client_rx: mpsc::UnboundedReceiver<Message>,
close_rx: oneshot::Receiver<()>,
client: LspClient,
internal_tx: Sender<InternalMessage>,
internal_rx: Receiver<InternalMessage>,
internal_tx: mpsc::UnboundedSender<InternalMessage>,
internal_rx: mpsc::UnboundedReceiver<InternalMessage>,
pool: Arc<ThreadPool>,
client_flags: Arc<ClientFlags>,
ide: Arc<Workspace>,
Expand All @@ -81,10 +110,10 @@ pub struct Server {
}

impl Server {
pub fn init(connection: Connection) -> anyhow::Result<()> {
pub fn init(connection: Connection) -> anyhow::Result<Self> {
let client = LspClient::new(connection.sender.clone());

let (internal_tx, internal_rx) = unbounded();
let (internal_tx, internal_rx) = mpsc::unbounded_channel();

let (id, params) = connection.initialize_start()?;
let params: InitializeParams = serde_json::from_value(params)?;
Expand All @@ -110,8 +139,11 @@ impl Server {
let cloned_pool = pool.clone();
let cloned_client = client.clone();

let (client_rx, close_rx) = get_client_receiver(connection);

let server = Self {
connection: Arc::new(connection),
close_rx,
client_rx,
internal_rx,
internal_tx,
client,
Expand Down Expand Up @@ -158,8 +190,7 @@ impl Server {
pool,
};

server.run()?;
Ok(())
Ok(server)
}

fn compute_now(&self) {
Expand Down Expand Up @@ -763,67 +794,84 @@ impl Server {
Ok(())
}

fn process_messages(&mut self) -> anyhow::Result<()> {
async fn process_messages(&mut self) -> anyhow::Result<()> {
loop {
crossbeam_channel::select! {
recv(&self.connection.receiver) -> msg => {
match msg? {
Message::Request(request) => {
if self.connection.handle_shutdown(&request)? {
return Ok(());
}

if let Some(response) = dispatch::RequestDispatcher::new(request)
.on::<InlayHintRequest, _>(|id, params| self.inlay_hint(id, params))?
.on::<HoverRequest, _>(|id, params| self.hover(id, params))?
.on::<ExecuteCommand,_>(|id, params| self.execute_command(id, params))?
.on::<Completion, _>(|id, params| {
self.completion(id, params)
})?
.on::<CodeActionRequest, _>(|id, params| {
self.code_actions(id, params)
})?
.default()
{
self.client.send_response(response)?;
}
}
Message::Notification(notification) => {
dispatch::NotificationDispatcher::new(notification)
.on::<DidChangeConfiguration, _>(|params| {
self.did_change_configuration(params)
})?
.on::<DidCloseTextDocument, _>(|params| self.did_close(params))?
.on::<DidOpenTextDocument, _>(|params| self.did_open(params))?
.on::<DidChangeTextDocument, _>(|params| self.did_change(params))?
.on::<DidSaveTextDocument, _>(|params| self.did_save(params))?
.on::<DidCloseTextDocument, _>(|params| self.did_close(params))?
.default();
}
Message::Response(response) => {
self.client.recv_response(response)?;
}
};
tokio::select! {
_ = &mut self.close_rx => {
return Ok(())
},
recv(&self.internal_rx) -> msg => {
match msg? {
InternalMessage::SetSchemaCache(c) => {
self.ide.set_schema_cache(c);
self.compute_now();
}
InternalMessage::RefreshSchemaCache => {
self.refresh_schema_cache();
}
InternalMessage::PublishDiagnostics(uri) => {
self.publish_diagnostics(uri)?;
}
InternalMessage::SetOptions(options) => {
self.update_options(options);
}
};

msg = self.internal_rx.recv() => {
match msg {
// TODO: handle internal sender close? Is that valid state?
None => return Ok(()),
Some(m) => self.handle_internal_message(m)
}
},

msg = self.client_rx.recv() => {
match msg {
// the client sender is closed, we can return
None => return Ok(()),
Some(m) => self.handle_message(m)
}
},
}?;
}
}

fn handle_message(&mut self, msg: Message) -> anyhow::Result<()> {
match msg {
Message::Request(request) => {
if let Some(response) = dispatch::RequestDispatcher::new(request)
.on::<InlayHintRequest, _>(|id, params| self.inlay_hint(id, params))?
.on::<HoverRequest, _>(|id, params| self.hover(id, params))?
.on::<ExecuteCommand, _>(|id, params| self.execute_command(id, params))?
.on::<Completion, _>(|id, params| self.completion(id, params))?
.on::<CodeActionRequest, _>(|id, params| self.code_actions(id, params))?
.default()
{
self.client.send_response(response)?;
}
};
}
Message::Notification(notification) => {
dispatch::NotificationDispatcher::new(notification)
.on::<DidChangeConfiguration, _>(|params| {
self.did_change_configuration(params)
})?
.on::<DidCloseTextDocument, _>(|params| self.did_close(params))?
.on::<DidOpenTextDocument, _>(|params| self.did_open(params))?
.on::<DidChangeTextDocument, _>(|params| self.did_change(params))?
.on::<DidSaveTextDocument, _>(|params| self.did_save(params))?
.on::<DidCloseTextDocument, _>(|params| self.did_close(params))?
.default();
}
Message::Response(response) => {
self.client.recv_response(response)?;
}
}

Ok(())
}

fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> {
match msg {
InternalMessage::SetSchemaCache(c) => {
self.ide.set_schema_cache(c);
self.compute_now();
}
InternalMessage::RefreshSchemaCache => {
self.refresh_schema_cache();
}
InternalMessage::PublishDiagnostics(uri) => {
self.publish_diagnostics(uri)?;
}
InternalMessage::SetOptions(options) => {
self.update_options(options);
}
}

Ok(())
}

fn pull_options(&mut self) {
Expand Down Expand Up @@ -881,10 +929,10 @@ impl Server {
}
}

pub fn run(mut self) -> anyhow::Result<()> {
pub async fn run(mut self) -> anyhow::Result<()> {
self.register_configuration();
self.pull_options();
self.process_messages()?;
self.process_messages().await?;
self.pool.join();
Ok(())
}
Expand Down

0 comments on commit 20b8567

Please sign in to comment.