diff --git a/pgrx-examples/shmem/Cargo.toml b/pgrx-examples/shmem/Cargo.toml index 86ed554c6..8c85221fd 100644 --- a/pgrx-examples/shmem/Cargo.toml +++ b/pgrx-examples/shmem/Cargo.toml @@ -17,7 +17,7 @@ edition = "2021" crate-type = ["cdylib"] [features] -default = ["pg13"] +default = ["pg13", "pgrx/cshim"] pg12 = ["pgrx/pg12", "pgrx-tests/pg12" ] pg13 = ["pgrx/pg13", "pgrx-tests/pg13" ] pg14 = ["pgrx/pg14", "pgrx-tests/pg14" ] diff --git a/pgrx-examples/shmem/src/lib.rs b/pgrx-examples/shmem/src/lib.rs index e34d015a5..0bd396b25 100644 --- a/pgrx-examples/shmem/src/lib.rs +++ b/pgrx-examples/shmem/src/lib.rs @@ -11,6 +11,7 @@ use pgrx::atomics::*; use pgrx::lwlock::PgLwLock; use pgrx::prelude::*; use pgrx::shmem::*; +use pgrx::shmem_hash::*; use pgrx::{pg_shmem_init, warning}; use serde::*; use std::iter::Iterator; @@ -39,6 +40,7 @@ static HASH: PgLwLock> = PgLwLock::new(); static STRUCT: PgLwLock = PgLwLock::new(); static PRIMITIVE: PgLwLock = PgLwLock::new(); static ATOMIC: PgAtomic = PgAtomic::new(); +static HASH_TABLE: ShmemHashMap = ShmemHashMap::new(250); #[pg_guard] pub extern "C" fn _PG_init() { @@ -48,6 +50,7 @@ pub extern "C" fn _PG_init() { pg_shmem_init!(STRUCT); pg_shmem_init!(PRIMITIVE); pg_shmem_init!(ATOMIC); + pg_shmem_init!(HASH_TABLE); } #[pg_extern] @@ -60,6 +63,26 @@ fn vec_count() -> i32 { VEC.share().len() as i32 } +#[pg_extern] +fn hash_table_insert(key: i64, value: i64) -> Option { + HASH_TABLE.insert(key, value).unwrap() +} + +#[pg_extern] +fn hash_table_get(key: i64) -> Option { + HASH_TABLE.get(key) +} + +#[pg_extern] +fn hash_table_remove(key: i64) -> Option { + HASH_TABLE.remove(key) +} + +#[pg_extern] +fn hash_table_len() -> i64 { + HASH_TABLE.len() as i64 +} + #[pg_extern] fn vec_drain() -> SetOfIterator<'static, Pgtest> { let mut vec = VEC.exclusive(); diff --git a/pgrx-tests/src/tests/shmem_tests.rs b/pgrx-tests/src/tests/shmem_tests.rs index 353d8e363..0dea386ae 100644 --- a/pgrx-tests/src/tests/shmem_tests.rs +++ b/pgrx-tests/src/tests/shmem_tests.rs @@ -11,22 +11,35 @@ use pgrx::prelude::*; use pgrx::{pg_shmem_init, PgAtomic, PgLwLock, PgSharedMemoryInitialization}; use std::sync::atomic::AtomicBool; +#[cfg(feature = "cshim")] +use pgrx::PgHashMap; + static ATOMIC: PgAtomic = PgAtomic::new(); static LWLOCK: PgLwLock = PgLwLock::new(); +#[cfg(feature = "cshim")] +static HASH_MAP: PgHashMap = PgHashMap::new(500); + #[pg_guard] pub extern "C" fn _PG_init() { // This ensures that this functionality works across PostgreSQL versions pg_shmem_init!(ATOMIC); pg_shmem_init!(LWLOCK); + + #[cfg(feature = "cshim")] + pg_shmem_init!(HASH_MAP); } + #[cfg(any(test, feature = "pg_test"))] #[pgrx::pg_schema] mod tests { #[allow(unused_imports)] use crate as pgrx_tests; + #[cfg(feature = "cshim")] + use crate::tests::shmem_tests::HASH_MAP; use crate::tests::shmem_tests::LWLOCK; + use pgrx::prelude::*; #[pg_test] @@ -53,4 +66,55 @@ mod tests { }); let _lock = LWLOCK.exclusive(); } + + #[cfg(feature = "cshim")] + #[pg_test] + pub fn test_pg_hash_map() { + use rand::prelude::IteratorRandom; + + for i in 1..250 { + assert_eq!(HASH_MAP.insert(i, i), Ok(None)); + } + + assert_eq!(HASH_MAP.len(), 249); + + for i in 1..250 { + assert_eq!(HASH_MAP.get(i), Some(i)); + } + + assert_eq!(HASH_MAP.len(), 249); + + for i in 251..500 { + assert_eq!(HASH_MAP.get(i), None); + } + + assert_eq!(HASH_MAP.len(), 249); + + for i in 1..250 { + assert_eq!(HASH_MAP.insert(i, i), Ok(Some(i))); + } + + assert_eq!(HASH_MAP.len(), 249); + + for i in 1..250 { + assert_eq!(HASH_MAP.remove(i), Some(i)); + } + + assert_eq!(HASH_MAP.len(), 0); + + for i in 1..250 { + assert_eq!(HASH_MAP.get(i), None); + } + + assert_eq!(HASH_MAP.len(), 0); + + for _ in 0..25_000 { + for key in 0..250 { + let value = (0..1000).choose(&mut rand::thread_rng()).unwrap(); + assert!(HASH_MAP.insert(key, value).is_ok()); + } + } + + assert_eq!(HASH_MAP.len(), 250); + } } diff --git a/pgrx/src/lib.rs b/pgrx/src/lib.rs index c57d90c22..ea49e4357 100644 --- a/pgrx/src/lib.rs +++ b/pgrx/src/lib.rs @@ -65,6 +65,8 @@ pub mod pg_catalog; pub mod pgbox; pub mod rel; pub mod shmem; +#[cfg(feature = "cshim")] +pub mod shmem_hash; pub mod spi; #[cfg(feature = "cshim")] pub mod spinlock; @@ -106,6 +108,8 @@ pub use nodes::*; pub use pgbox::*; pub use rel::*; pub use shmem::*; +#[cfg(feature = "cshim")] +pub use shmem_hash::*; pub use spi::Spi; // only Spi. We don't want the top-level namespace polluted with spi::Result and spi::Error pub use stringinfo::*; pub use trigger_support::*; diff --git a/pgrx/src/shmem_hash.rs b/pgrx/src/shmem_hash.rs new file mode 100644 index 000000000..c0f89d3c0 --- /dev/null +++ b/pgrx/src/shmem_hash.rs @@ -0,0 +1,250 @@ +//! Shared memory hash map implemented with Postgres' internal `HTAB`, +//! which is used by other extensions like `pg_stat_statements`. +use crate::{pg_sys, shmem::PgSharedMemoryInitialization, spinlock::*, PGRXSharedMemory}; +use once_cell::sync::OnceCell; +use std::ffi::c_void; +use uuid::Uuid; + +#[derive(Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum ShmemHashMapError { + /// Hash table can't have more entries due to fixed allocation size. + HashTableFull, +} + +#[derive(Copy, Clone, Debug)] +struct ShmemHashMapInner { + htab: *mut pg_sys::HTAB, + elements: i64, +} + +unsafe impl PGRXSharedMemory for ShmemHashMapInner {} +unsafe impl Send for ShmemHashMapInner {} +unsafe impl Sync for ShmemHashMapInner {} + +#[repr(C)] +#[derive(Copy, Clone, Debug)] +struct Key { + // We copy it with std::ptr::copy, but we don't actually use the field + // in Rust, hence the warning. + #[allow(dead_code)] + key: K, +} + +#[repr(C)] +#[derive(Copy, Clone, Debug)] +struct Value { + #[allow(dead_code)] + key: Key, + value: V, +} + +impl Default for ShmemHashMapInner { + fn default() -> Self { + Self { htab: std::ptr::null_mut(), elements: 0 } + } +} + +/// A shared memory HashMap using Postgres' `HTAB`. +/// This HashMap is used for `pg_stat_statements` and Postgres +/// internals to store key/value pairs in shared memory. +pub struct ShmemHashMap { + /// HTAB protected by a SpinLock. + htab: OnceCell>, + + /// Max size, allocated at server start. + size: i64, + + // Markers for key/value types. + _phantom_key: std::marker::PhantomData, + _phantom_value: std::marker::PhantomData, +} + +/// Compute the hash for the key and its pointer +/// to pass to `pg_sys::hash_search_with_hash_value`. +/// Lock on HTAB should be taken, although not strictly required I think. +macro_rules! key { + ($key:expr, $htab:expr) => {{ + let key = Key { key: $key }; + let key_ptr: *const c_void = std::ptr::addr_of!(key) as *const Key as *const c_void; + let hash_value = unsafe { pg_sys::get_hash_value($htab.htab, key_ptr) }; + + (key_ptr, hash_value) + }}; +} + +/// Get the value pointer. It's stored next to the key. +macro_rules! value_ptr { + ($entry:expr) => {{ + let value_ptr: *mut Value = $entry as *mut Value; + + value_ptr + }}; +} + +impl ShmemHashMap { + /// Create new `ShmemHashMap`. This still needs to be allocated with + /// `pg_shmem_init!` just like any other shared memory structure. + /// + /// # Arguments + /// + /// * `size` - Maximum number of elements in the HashMap. This is allocated + /// at server start and cannot be changed. `i64` is the expected type + /// for `pg_sys::ShmemInitHash`, so we don't attempt runtime conversions + /// unnecessarily. + /// + pub const fn new(size: i64) -> ShmemHashMap { + ShmemHashMap { + htab: OnceCell::new(), + size, + _phantom_key: std::marker::PhantomData, + _phantom_value: std::marker::PhantomData, + } + } + + /// Insert a key and value into the `ShmemHashMap`. If the key is already + /// present, it will be replaced and returned. If the `ShmemHashMap` is full, + /// an error is returned. + pub fn insert(&self, key: K, value: V) -> Result, ShmemHashMapError> { + let mut found = false; + let mut htab = self.htab.get().unwrap().lock(); + let (key_ptr, hash_value) = key!(key, htab); + + let entry = unsafe { + pg_sys::hash_search_with_hash_value( + htab.htab, + key_ptr, + hash_value, + pg_sys::HASHACTION_HASH_FIND, + &mut found, + ) + }; + + let return_value = if entry.is_null() { + None + } else { + let value_ptr = value_ptr!(entry); + let value = unsafe { std::ptr::read(value_ptr) }; + Some(value.value) + }; + + // If we don't do this check, pg will overwrite + // some random entry with our key/value pair... + if entry.is_null() && htab.elements == self.size { + return Err(ShmemHashMapError::HashTableFull); + } + + let entry = unsafe { + pg_sys::hash_search_with_hash_value( + htab.htab, + key_ptr, + hash_value, + pg_sys::HASHACTION_HASH_ENTER_NULL, + &mut found, + ) + }; + + if !entry.is_null() { + let value_ptr = value_ptr!(entry); + let value = Value { key: Key { key }, value }; + unsafe { + std::ptr::copy(std::ptr::addr_of!(value), value_ptr, 1); + } + // We inserted a new element, increasing the size of the table. + if return_value.is_none() { + htab.elements += 1; + } + Ok(return_value) + } else { + // OOM. We pre-allocate at server start, so this should never be an issue. + return Err(ShmemHashMapError::HashTableFull); + } + } + + /// Get a value from the HashMap using the key. + /// If the key doesn't exist, return `None`. + pub fn get(&self, key: K) -> Option { + let htab = self.htab.get().unwrap().lock(); + let (key_ptr, hash_value) = key!(key, htab); + + let entry = unsafe { + pg_sys::hash_search_with_hash_value( + htab.htab, + key_ptr, + hash_value, + pg_sys::HASHACTION_HASH_FIND, + std::ptr::null_mut(), + ) + }; + + if entry.is_null() { + return None; + } else { + let value_ptr = value_ptr!(entry); + let value = unsafe { std::ptr::read(value_ptr) }; + return Some(value.value); + } + } + + /// Remove the value from the `ShmemHashMap` and return it. + /// If the key doesn't exist, return None. + pub fn remove(&self, key: K) -> Option { + if let Some(value) = self.get(key) { + let mut htab = self.htab.get().unwrap().lock(); + let (key_ptr, hash_value) = key!(key, htab); + + // Dangling pointer, don't touch it. + let _ = unsafe { + pg_sys::hash_search_with_hash_value( + htab.htab, + key_ptr, + hash_value, + pg_sys::HASHACTION_HASH_REMOVE, + std::ptr::null_mut(), + ); + }; + + htab.elements -= 1; + return Some(value); + } else { + return None; + } + } + + /// Get the number of elements in the HashMap. + pub fn len(&self) -> i64 { + let htab = self.htab.get().unwrap().lock(); + htab.elements + } +} + +impl PgSharedMemoryInitialization for ShmemHashMap { + fn pg_init(&'static self) { + self.htab + .set(PgSpinLock::new(ShmemHashMapInner::default())) + .expect("htab cell is not empty"); + } + + fn shmem_init(&'static self) { + let mut htab = self.htab.get().unwrap().lock(); + + let mut hash_ctl = pg_sys::HASHCTL::default(); + hash_ctl.keysize = std::mem::size_of::>(); + hash_ctl.entrysize = std::mem::size_of::>(); + + let shm_name = + alloc::ffi::CString::new(Uuid::new_v4().to_string()).expect("CString::new() failed"); + + let htab_ptr = unsafe { + pg_sys::ShmemInitHash( + shm_name.into_raw(), + self.size, + self.size, + &mut hash_ctl, + (pg_sys::HASH_ELEM | pg_sys::HASH_BLOBS).try_into().unwrap(), + ) + }; + + htab.htab = htab_ptr; + } +} diff --git a/pgrx/src/spinlock.rs b/pgrx/src/spinlock.rs index c6a8f5e06..84c76be35 100644 --- a/pgrx/src/spinlock.rs +++ b/pgrx/src/spinlock.rs @@ -9,6 +9,7 @@ //LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file. use crate::pg_sys; use core::mem::MaybeUninit; +use std::fmt; use std::{cell::UnsafeCell, marker::PhantomData}; /// A Rust locking mechanism which uses a PostgreSQL `slock_t` to lock the data. @@ -38,6 +39,12 @@ pub struct PgSpinLock { unsafe impl Send for PgSpinLock {} unsafe impl Sync for PgSpinLock {} +impl fmt::Debug for PgSpinLock { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PgSpinLock").finish() + } +} + impl PgSpinLock { /// Create a new [`PgSpinLock`]. See the type documentation for more info. #[inline]