diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index ad53908383f..4987fb2504a 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -37,6 +37,7 @@ use std::{ sync::atomic::{AtomicBool, Ordering}, time::Duration, }; +use tokio::sync::OnceCell; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; use websocket::connect_via_websocket; @@ -232,27 +233,10 @@ impl PostgresNativeUrl { impl PostgreSql { /// Create a new connection to the database. - pub async fn new(url: PostgresNativeUrl) -> crate::Result { + pub async fn new(url: PostgresNativeUrl, tls_manager: &MakeTlsConnectorManager) -> crate::Result { let config = url.to_config(); - let mut tls_builder = TlsConnector::builder(); - - { - let ssl_params = url.ssl_params(); - let auth = ssl_params.to_owned().into_auth().await?; - - if let Some(certificate) = auth.certificate.0 { - tls_builder.add_root_certificate(certificate); - } - - tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); - - if let Some(identity) = auth.identity.0 { - tls_builder.identity(identity); - } - } - - let tls = MakeTlsConnector::new(tls_builder.build()?); + let tls = tls_manager.get_connector().await?; let (client, conn) = timeout::connect(url.connect_timeout(), config.connect(tls)).await?; let is_cockroachdb = conn.parameter("crdb_version").is_some(); @@ -926,6 +910,48 @@ fn is_safe_identifier(ident: &str) -> bool { true } +pub struct MakeTlsConnectorManager { + url: PostgresNativeUrl, + connector: OnceCell, +} + +impl MakeTlsConnectorManager { + pub fn new(url: PostgresNativeUrl) -> Self { + MakeTlsConnectorManager { + url, + connector: OnceCell::new(), + } + } + + pub async fn get_connector(&self) -> crate::Result { + self.connector + .get_or_try_init(|| async { + let mut tls_builder = TlsConnector::builder(); + + { + let ssl_params = self.url.ssl_params(); + let auth = ssl_params.to_owned().into_auth().await?; + + if let Some(certificate) = auth.certificate.0 { + tls_builder.add_root_certificate(certificate); + } + + tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); + + if let Some(identity) = auth.identity.0 { + tls_builder.identity(identity); + } + } + + let tls_connector = MakeTlsConnector::new(tls_builder.build()?); + + Ok(tls_connector) + }) + .await + .cloned() + } +} + #[cfg(test)] mod tests { use super::*; @@ -944,7 +970,9 @@ mod tests { let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -996,7 +1024,9 @@ mod tests { let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1047,7 +1077,9 @@ mod tests { let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1098,7 +1130,9 @@ mod tests { let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1149,7 +1183,9 @@ mod tests { let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 2026679cd48..a9bd660ab3a 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -314,7 +314,11 @@ impl Builder { url.set_flavour(flavour); } - if let QuaintManager::Postgres { ref mut url } = self.manager { + if let QuaintManager::Postgres { + ref mut url, + tls_manager: _, + } = self.manager + { url.set_flavour(flavour); } } @@ -423,7 +427,8 @@ impl Quaint { let max_connection_lifetime = url.max_connection_lifetime(); let max_idle_connection_lifetime = url.max_idle_connection_lifetime(); - let manager = QuaintManager::Postgres { url }; + let tls_manager = crate::connector::MakeTlsConnectorManager::new(url.clone()); + let manager = QuaintManager::Postgres { url, tls_manager }; let mut builder = Builder::new(s, manager)?; if let Some(limit) = connection_limit { diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 0a2fc0adbd0..4b77a3761db 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -3,7 +3,7 @@ use crate::connector::MssqlUrl; #[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; #[cfg(feature = "postgresql-native")] -use crate::connector::PostgresNativeUrl; +use crate::connector::{MakeTlsConnectorManager, PostgresNativeUrl}; use crate::{ ast, connector::{self, impl_default_TransactionCapable, IsolationLevel, Queryable, Transaction, TransactionCapable}, @@ -85,7 +85,10 @@ pub enum QuaintManager { Mysql { url: MysqlUrl }, #[cfg(feature = "postgresql")] - Postgres { url: PostgresNativeUrl }, + Postgres { + url: PostgresNativeUrl, + tls_manager: MakeTlsConnectorManager, + }, #[cfg(feature = "sqlite")] Sqlite { url: String, db_name: String }, @@ -117,9 +120,9 @@ impl Manager for QuaintManager { } #[cfg(feature = "postgresql-native")] - QuaintManager::Postgres { url } => { + QuaintManager::Postgres { url, tls_manager } => { use crate::connector::PostgreSql; - Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) + Ok(Box::new(PostgreSql::new(url.clone(), tls_manager).await?) as Self::Connection) } #[cfg(feature = "mssql-native")] diff --git a/quaint/src/single.rs b/quaint/src/single.rs index fd018925852..dec27ebf802 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -149,7 +149,9 @@ impl Quaint { #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { let url = connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; - let psql = connector::PostgreSql::new(url).await?; + let tls_manager = connector::MakeTlsConnectorManager::new(url.clone()); + let psql = connector::PostgreSql::new(url, &tls_manager + ).await?; Arc::new(psql) as Arc } #[cfg(feature = "mssql-native")] diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs index cb31b4394d7..af532661d3e 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs @@ -4,7 +4,7 @@ use enumflags2::BitFlags; use indoc::indoc; use psl::PreviewFeature; use quaint::{ - connector::{self, tokio_postgres::error::ErrorPosition, PostgresUrl}, + connector::{self, tokio_postgres::error::ErrorPosition, MakeTlsConnectorManager, PostgresUrl}, prelude::{ConnectionInfo, Queryable}, }; use schema_connector::{ConnectorError, ConnectorResult, Namespaces}; @@ -20,10 +20,15 @@ pub(super) struct Connection(connector::PostgreSql); impl Connection { pub(super) async fn new(url: url::Url) -> ConnectorResult { let url = MigratePostgresUrl::new(url)?; - + let quaint = match url.0 { - PostgresUrl::Native(ref native_url) => connector::PostgreSql::new(native_url.as_ref().clone()).await, - PostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()).await, + PostgresUrl::Native(ref native_url) => { + let tls_manager = MakeTlsConnectorManager::new(native_url.as_ref().clone()); + connector::PostgreSql::new(native_url.as_ref().clone(), &tls_manager).await + }, + PostgresUrl::WebSocket(ref ws_url) => { + connector::PostgreSql::new_with_websocket(ws_url.clone()).await + } } .map_err(quaint_err(&url))?;