From cd4748b6343dade7bef61d98d16b81f5d855519d Mon Sep 17 00:00:00 2001 From: ikolomi Date: Mon, 21 Oct 2024 09:30:26 +0000 Subject: [PATCH] Glide-core UDS Socket Handling Rework: 1.Introduced a user-land mechanism for ensuring singleton behavior of the socket, rather than relying on OS-specific semantics. This addresses the issue where macOS and Linux report different errors when the socket path already exists. 2.Simplified the implementation by removing unnecessary abstractions, including redundant connection retry logic. Signed-off-by: ikolomi --- glide-core/src/retry_strategies.rs | 1 + glide-core/src/socket_listener.rs | 230 ++++++++++------------- glide-core/tests/test_socket_listener.rs | 15 +- 3 files changed, 110 insertions(+), 136 deletions(-) diff --git a/glide-core/src/retry_strategies.rs b/glide-core/src/retry_strategies.rs index dbe5683347..d851cb63dd 100644 --- a/glide-core/src/retry_strategies.rs +++ b/glide-core/src/retry_strategies.rs @@ -56,6 +56,7 @@ pub(crate) fn get_exponential_backoff( } #[cfg(feature = "socket-layer")] +#[allow(dead_code)] pub(crate) fn get_fixed_interval_backoff( fixed_interval: u32, number_of_retries: u32, diff --git a/glide-core/src/socket_listener.rs b/glide-core/src/socket_listener.rs index 50445c881d..22fa6ca753 100644 --- a/glide-core/src/socket_listener.rs +++ b/glide-core/src/socket_listener.rs @@ -11,11 +11,10 @@ use crate::connection_request::ConnectionRequest; use crate::errors::{error_message, error_type, RequestErrorType}; use crate::response; use crate::response::Response; -use crate::retry_strategies::get_fixed_interval_backoff; use bytes::Bytes; use directories::BaseDirs; -use dispose::{Disposable, Dispose}; use logger_core::{log_debug, log_error, log_info, log_trace, log_warn}; +use once_cell::sync::Lazy; use protobuf::{Chars, Message}; use redis::cluster_routing::{ MultipleNodeRoutingInfo, Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, @@ -23,18 +22,18 @@ use redis::cluster_routing::{ use redis::cluster_routing::{ResponsePolicy, Routable}; use redis::{Cmd, PushInfo, RedisError, ScanStateRC, Value}; use std::cell::Cell; +use std::collections::HashSet; use std::rc::Rc; +use std::sync::{Arc, RwLock}; use std::{env, str}; use std::{io, thread}; use thiserror::Error; -use tokio::io::ErrorKind::AddrInUse; use tokio::net::{UnixListener, UnixStream}; use tokio::runtime::Builder; use tokio::sync::mpsc; use tokio::sync::mpsc::{channel, Sender}; use tokio::sync::Mutex; use tokio::task; -use tokio_retry::Retry; use tokio_util::task::LocalPoolHandle; use ClosingReason::*; use PipeListeningResult::*; @@ -53,20 +52,6 @@ pub const ZSET: &str = "zset"; pub const HASH: &str = "hash"; pub const STREAM: &str = "stream"; -/// struct containing all objects needed to bind to a socket and clean it. -struct SocketListener { - socket_path: String, - cleanup_socket: bool, -} - -impl Dispose for SocketListener { - fn dispose(self) { - if self.cleanup_socket { - close_socket(&self.socket_path); - } - } -} - /// struct containing all objects needed to read from a unix stream. struct UnixStreamListener { read_socket: Rc, @@ -734,109 +719,6 @@ async fn listen_on_client_stream(socket: UnixStream) { log_trace("client closing", "closing connection"); } -enum SocketCreationResult { - // Socket creation was successful, returned a socket listener. - Created(UnixListener), - // There's an existing a socket listener. - PreExisting, - // Socket creation failed with an error. - Err(io::Error), -} - -impl SocketListener { - fn new(socket_path: String) -> Self { - SocketListener { - socket_path, - // Don't cleanup the socket resources unless we know that the socket is in use, and owned by this listener. - cleanup_socket: false, - } - } - - /// Return true if it's possible to connect to socket. - async fn socket_is_available(&self) -> bool { - if UnixStream::connect(&self.socket_path).await.is_ok() { - return true; - } - - let retry_strategy = get_fixed_interval_backoff(10, 3); - - let action = || async { - UnixStream::connect(&self.socket_path) - .await - .map(|_| ()) - .map_err(|_| ()) - }; - let result = Retry::spawn(retry_strategy.get_iterator(), action).await; - result.is_ok() - } - - async fn get_socket_listener(&self) -> SocketCreationResult { - const RETRY_COUNT: u8 = 3; - let mut retries = RETRY_COUNT; - while retries > 0 { - match UnixListener::bind(self.socket_path.clone()) { - Ok(listener) => { - return SocketCreationResult::Created(listener); - } - Err(err) if err.kind() == AddrInUse => { - if self.socket_is_available().await { - return SocketCreationResult::PreExisting; - } else { - // socket file might still exist, even if nothing is listening on it. - close_socket(&self.socket_path); - retries -= 1; - continue; - } - } - Err(err) => { - return SocketCreationResult::Err(err); - } - } - } - SocketCreationResult::Err(io::Error::new( - io::ErrorKind::Other, - "Failed to connect to socket", - )) - } - - pub(crate) async fn listen_on_socket(&mut self, init_callback: InitCallback) - where - InitCallback: FnOnce(Result) + Send + 'static, - { - // Bind to socket - let listener = match self.get_socket_listener().await { - SocketCreationResult::Created(listener) => listener, - SocketCreationResult::Err(err) => { - log_info("listen_on_socket", format!("failed with error: {err}")); - init_callback(Err(err.to_string())); - return; - } - SocketCreationResult::PreExisting => { - init_callback(Ok(self.socket_path.clone())); - return; - } - }; - - self.cleanup_socket = true; - init_callback(Ok(self.socket_path.clone())); - let local_set_pool = LocalPoolHandle::new(num_cpus::get()); - loop { - match listener.accept().await { - Ok((stream, _addr)) => { - local_set_pool.spawn_pinned(move || listen_on_client_stream(stream)); - } - Err(err) => { - log_debug( - "listen_on_socket", - format!("Socket closed with error: `{err}`"), - ); - return; - } - } - } - } -} - #[derive(Debug)] /// Enum describing the reason that a socket listener stopped listening on a socket. pub enum ClosingReason { @@ -924,23 +806,109 @@ pub fn start_socket_listener_internal( init_callback: InitCallback, socket_path: Option, ) where - InitCallback: FnOnce(Result) + Send + 'static, + InitCallback: FnOnce(Result) + Send + Clone + 'static, { + static INITIALIZED_SOCKETS: Lazy>>> = + Lazy::new(|| Arc::new(RwLock::new(HashSet::new()))); + + let socket_path = socket_path.unwrap_or_else(get_socket_path); + + { + // Optimize for already initialized + let initialized_sockets = INITIALIZED_SOCKETS + .read() + .expect("Failed to acquire sockets db read guard"); + if initialized_sockets.contains(&socket_path) { + init_callback(Ok(socket_path.clone())); + return; + } + } + + // Retry with write lock, will be dropped upon the function completion + let mut sockets_write_guard = INITIALIZED_SOCKETS + .write() + .expect("Failed to acquire sockets db write guard"); + if sockets_write_guard.contains(&socket_path) { + init_callback(Ok(socket_path.clone())); + return; + } + + let (tx, rx) = std::sync::mpsc::channel(); + let socket_path_cloned = socket_path.clone(); + let init_callback_cloned = init_callback.clone(); + let tx_cloned = tx.clone(); thread::Builder::new() .name("socket_listener_thread".to_string()) .spawn(move || { - let runtime = Builder::new_current_thread().enable_all().build(); - match runtime { - Ok(runtime) => { - let mut listener = Disposable::new(SocketListener::new( - socket_path.unwrap_or_else(get_socket_path), - )); - runtime.block_on(listener.listen_on_socket(init_callback)); + let init_result = { + let runtime = Builder::new_current_thread().enable_all().build(); + if let Err(err) = runtime { + log_error( + "listen_on_socket", + format!("Error failed to create a new tokio thread: {err}"), + ); + return Err(err); } - Err(err) => init_callback(Err(err.to_string())), + + runtime.unwrap().block_on(async move { + let listener_socket = UnixListener::bind(socket_path_cloned.clone()); + if let Err(err) = listener_socket { + log_error( + "listen_on_socket", + format!("Error failed to bind listening socket: {err}"), + ); + return Err(err); + } + let listener_socket = listener_socket.unwrap(); + + // signal initialization success + init_callback(Ok(socket_path_cloned.clone())); + let _ = tx.send(true); + + let local_set_pool = LocalPoolHandle::new(num_cpus::get()); + loop { + match listener_socket.accept().await { + Ok((stream, _addr)) => { + local_set_pool + .spawn_pinned(move || listen_on_client_stream(stream)); + } + Err(err) => { + log_error( + "listen_on_socket", + format!("Error accepting connection: {err}"), + ); + break; + } + } + } + + // ensure socket file removal + drop(listener_socket); + let _ = std::fs::remove_file(socket_path_cloned.clone()); + + // no more listening on socket - update the sockets db + let mut sockets_write_guard = INITIALIZED_SOCKETS + .write() + .expect("Failed to acquire sockets db write guard"); + sockets_write_guard.remove(&socket_path_cloned); + Ok(()) + }) }; + + if let Err(err) = init_result { + init_callback_cloned(Err(err.to_string())); + let _ = tx_cloned.send(false); + } + Ok(()) }) .expect("Thread spawn failed. Cannot report error because callback was moved."); + + // wait for thread initialization signaling, callback invocation is done in the thread + let _ = rx.recv().map(|res| { + if res { + sockets_write_guard.insert(socket_path); + } + }); } /// Creates a new thread with a main loop task listening on the socket for new connections. @@ -950,7 +918,7 @@ pub fn start_socket_listener_internal( /// * `init_callback` - called when the socket listener fails to initialize, with the reason for the failure. pub fn start_socket_listener(init_callback: InitCallback) where - InitCallback: FnOnce(Result) + Send + 'static, + InitCallback: FnOnce(Result) + Send + Clone + 'static, { start_socket_listener_internal(init_callback, None); } diff --git a/glide-core/tests/test_socket_listener.rs b/glide-core/tests/test_socket_listener.rs index a242eb80d1..e51d62344a 100644 --- a/glide-core/tests/test_socket_listener.rs +++ b/glide-core/tests/test_socket_listener.rs @@ -518,8 +518,10 @@ mod socket_listener { #[rstest] #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] fn test_working_after_socket_listener_was_dropped() { - let socket_path = - get_socket_path_from_name("test_working_after_socket_listener_was_dropped".to_string()); + let socket_path = get_socket_path_from_name(format!( + "{}_test_working_after_socket_listener_was_dropped", + std::process::id() + )); close_socket(&socket_path); // create a socket listener and drop it, to simulate a panic in a previous iteration. Builder::new_current_thread() @@ -528,6 +530,8 @@ mod socket_listener { .unwrap() .block_on(async { let _ = UnixListener::bind(socket_path.clone()).unwrap(); + // UDS sockets require explicit removal of the socket file + close_socket(&socket_path); }); const CALLBACK_INDEX: u32 = 99; @@ -554,9 +558,10 @@ mod socket_listener { #[rstest] #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] fn test_multiple_listeners_competing_for_the_socket() { - let socket_path = get_socket_path_from_name( - "test_multiple_listeners_competing_for_the_socket".to_string(), - ); + let socket_path = get_socket_path_from_name(format!( + "{}_test_multiple_listeners_competing_for_the_socket", + std::process::id() + )); close_socket(&socket_path); let server = Arc::new(RedisServer::new(ServerType::Tcp { tls: false }));