Skip to content

Commit

Permalink
Make #[derive(PostgresType)] impl its own FromDatum (#1381)
Browse files Browse the repository at this point in the history
Mainly, this removes a source of persistent and confounding type errors
because of the generic blanket impl of FromDatum for all T that fulfill
so-and-so bounds. These may mislead one, if one is writing code generic
over FromDatum, to imagine that one needs a Serialize and Deserialize
impl or bound for a given case, even when those are _not_ required. By
moving these requirements onto the type that derives, this moves any
confusion to the specific cases it actually applies to.

This has a regrettable effect that now PostgresType _requires_ a
Serialize and Deserialize impl in order to work, _unless_ one uses the
hacky `#[bikeshed_postgres_type_manually_impl_from_into_datum]`
attribute, which I intend to rename or otherwise fix up before pgrx
reaches its 0.12.0 release.
  • Loading branch information
workingjubilee authored Nov 8, 2023
1 parent 08f7817 commit 2df66d8
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 103 deletions.
1 change: 1 addition & 0 deletions pgrx-examples/custom_types/src/fixed_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use pgrx::{opname, pg_operator, PgVarlena, PgVarlenaInOutFuncs, StringInfo};
use std::str::FromStr;

#[derive(Copy, Clone, PostgresType)]
#[bikeshed_postgres_type_manually_impl_from_into_datum]
#[pgvarlena_inoutfuncs]
pub struct FixedF32Array {
array: [f32; 91],
Expand Down
76 changes: 69 additions & 7 deletions pgrx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ use syn::spanned::Spanned;
use syn::{parse_macro_input, Attribute, Data, DeriveInput, Item, ItemImpl};

use operators::{deriving_postgres_eq, deriving_postgres_hash, deriving_postgres_ord};
use pgrx_sql_entity_graph::{
use pgrx_sql_entity_graph as sql_gen;
use sql_gen::{
parse_extern_attributes, CodeEnrichment, ExtensionSql, ExtensionSqlFile, ExternArgs,
PgAggregate, PgExtern, PostgresEnum, PostgresType, Schema,
PgAggregate, PgExtern, PostgresEnum, Schema,
};

use crate::rewriter::PgGuardRewriter;
Expand Down Expand Up @@ -709,7 +710,16 @@ Optionally accepts the following attributes:
* `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type.
* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).
*/
#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgrx))]
#[proc_macro_derive(
PostgresType,
attributes(
inoutfuncs,
pgvarlena_inoutfuncs,
bikeshed_postgres_type_manually_impl_from_into_datum,
requires,
pgrx
)
)]
pub fn postgres_type(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);

Expand Down Expand Up @@ -752,10 +762,60 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
};

// all #[derive(PostgresType)] need to implement that trait
// and also the FromDatum and IntoDatum
stream.extend(quote! {
impl #generics ::pgrx::PostgresType for #name #generics { }
impl #generics ::pgrx::datum::PostgresType for #name #generics { }
});

if !args.contains(&PostgresTypeAttribute::ManualFromIntoDatum) {
stream.extend(
quote! {
impl #generics ::pgrx::datum::IntoDatum for #name #generics {
fn into_datum(self) -> Option<::pgrx::pg_sys::Datum> {
#[allow(deprecated)]
Some(unsafe { ::pgrx::cbor_encode(&self) }.into())
}

fn type_oid() -> ::pgrx::pg_sys::Oid {
::pgrx::wrappers::rust_regtypein::<Self>()
}
}

impl #generics ::pgrx::datum::FromDatum for #name #generics {
unsafe fn from_polymorphic_datum(
datum: ::pgrx::pg_sys::Datum,
is_null: bool,
_typoid: ::pgrx::pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
#[allow(deprecated)]
::pgrx::cbor_decode(datum.cast_mut_ptr())
}
}

unsafe fn from_datum_in_memory_context(
mut memory_context: ::pgrx::memcxt::PgMemoryContexts,
datum: ::pgrx::pg_sys::Datum,
is_null: bool,
_typoid: ::pgrx::pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
memory_context.switch_to(|_| {
// this gets the varlena Datum copied into this memory context
let varlena = ::pgrx::pg_sys::pg_detoast_datum_copy(datum.cast_mut_ptr());
Self::from_datum(varlena.into(), is_null)
})
}
}
}
}
)
}

// and if we don't have custom inout/funcs, we use the JsonInOutFuncs trait
// which implements _in and _out #[pg_extern] functions that just return the type itself
if args.contains(&PostgresTypeAttribute::Default) {
Expand Down Expand Up @@ -834,7 +894,7 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
});
}

let sql_graph_entity_item = PostgresType::from_derive_input(ast)?;
let sql_graph_entity_item = sql_gen::PostgresTypeDerive::from_derive_input(ast)?;
sql_graph_entity_item.to_tokens(&mut stream);

Ok(stream)
Expand Down Expand Up @@ -933,6 +993,7 @@ enum PostgresTypeAttribute {
InOutFuncs,
PgVarlenaInOutFuncs,
Default,
ManualFromIntoDatum,
}

fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet<PostgresTypeAttribute> {
Expand All @@ -945,11 +1006,12 @@ fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet<PostgresTypeAtt
"inoutfuncs" => {
categorized_attributes.insert(PostgresTypeAttribute::InOutFuncs);
}

"pgvarlena_inoutfuncs" => {
categorized_attributes.insert(PostgresTypeAttribute::PgVarlenaInOutFuncs);
}

"bikeshed_postgres_type_manually_impl_from_into_datum" => {
categorized_attributes.insert(PostgresTypeAttribute::ManualFromIntoDatum);
}
_ => {
// we can just ignore attributes we don't understand
}
Expand Down
2 changes: 1 addition & 1 deletion pgrx-sql-entity-graph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub use postgres_hash::PostgresHash;
pub use postgres_ord::entity::PostgresOrdEntity;
pub use postgres_ord::PostgresOrd;
pub use postgres_type::entity::PostgresTypeEntity;
pub use postgres_type::PostgresType;
pub use postgres_type::PostgresTypeDerive;
pub use schema::entity::SchemaEntity;
pub use schema::Schema;
pub use to_sql::entity::ToSqlConfigEntity;
Expand Down
27 changes: 12 additions & 15 deletions pgrx-sql-entity-graph/src/postgres_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ use crate::{CodeEnrichment, ToSqlConfig};
/// ```rust
/// use syn::{Macro, parse::Parse, parse_quote, parse};
/// use quote::{quote, ToTokens};
/// use pgrx_sql_entity_graph::PostgresType;
/// use pgrx_sql_entity_graph::PostgresTypeDerive;
///
/// # fn main() -> eyre::Result<()> {
/// use pgrx_sql_entity_graph::CodeEnrichment;
/// let parsed: CodeEnrichment<PostgresType> = parse_quote! {
/// let parsed: CodeEnrichment<PostgresTypeDerive> = parse_quote! {
/// #[derive(PostgresType)]
/// struct Example<'a> {
/// demo: &'a str,
Expand All @@ -49,15 +49,15 @@ use crate::{CodeEnrichment, ToSqlConfig};
/// # }
/// ```
#[derive(Debug, Clone)]
pub struct PostgresType {
pub struct PostgresTypeDerive {
name: Ident,
generics: Generics,
in_fn: Ident,
out_fn: Ident,
to_sql_config: ToSqlConfig,
}

impl PostgresType {
impl PostgresTypeDerive {
pub fn new(
name: Ident,
generics: Generics,
Expand Down Expand Up @@ -100,7 +100,7 @@ impl PostgresType {
}
}

impl ToEntityGraphTokens for PostgresType {
impl ToEntityGraphTokens for PostgresTypeDerive {
fn to_entity_graph_tokens(&self) -> TokenStream2 {
let name = &self.name;
let mut static_generics = self.generics.clone();
Expand Down Expand Up @@ -211,17 +211,14 @@ impl ToEntityGraphTokens for PostgresType {
}
}

impl ToRustCodeTokens for PostgresType {}
impl ToRustCodeTokens for PostgresTypeDerive {}

impl Parse for CodeEnrichment<PostgresType> {
impl Parse for CodeEnrichment<PostgresTypeDerive> {
fn parse(input: ParseStream) -> Result<Self, syn::Error> {
let parsed: ItemStruct = input.parse()?;
let to_sql_config =
ToSqlConfig::from_attributes(parsed.attrs.as_slice())?.unwrap_or_default();
let funcname_in =
Ident::new(&format!("{}_in", parsed.ident).to_lowercase(), parsed.ident.span());
let funcname_out =
Ident::new(&format!("{}_out", parsed.ident).to_lowercase(), parsed.ident.span());
PostgresType::new(parsed.ident, parsed.generics, funcname_in, funcname_out, to_sql_config)
let ItemStruct { attrs, ident, generics, .. } = input.parse()?;
let to_sql_config = ToSqlConfig::from_attributes(attrs.as_slice())?.unwrap_or_default();
let in_fn = Ident::new(&format!("{}_in", ident).to_lowercase(), ident.span());
let out_fn = Ident::new(&format!("{}_out", ident).to_lowercase(), ident.span());
PostgresTypeDerive::new(ident, generics, in_fn, out_fn, to_sql_config)
}
}
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/postgres_type_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use pgrx::{InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, StringInfo};
use serde::{Deserialize, Serialize};
use std::str::FromStr;

#[derive(Copy, Clone, PostgresType)]
#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
#[pgvarlena_inoutfuncs]
pub struct VarlenaType {
a: f32,
Expand All @@ -38,7 +38,7 @@ impl PgVarlenaInOutFuncs for VarlenaType {
}
}

#[derive(Copy, Clone, PostgresType)]
#[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
#[pgvarlena_inoutfuncs]
pub enum VarlenaEnumType {
A,
Expand Down
86 changes: 8 additions & 78 deletions pgrx/src/datum/varlena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
use crate::pg_sys::{VARATT_SHORT_MAX, VARHDRSZ_SHORT};
use crate::{
pg_sys, rust_regtypein, set_varsize, set_varsize_short, vardata_any, varsize_any,
varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, PostgresType,
StringInfo,
varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, StringInfo,
};
use pgrx_sql_entity_graph::metadata::{
ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
Expand Down Expand Up @@ -60,8 +59,9 @@ impl Clone for PallocdVarlena {
/// use std::str::FromStr;
///
/// use pgrx::prelude::*;
/// use serde::{Serialize, Deserialize};
///
/// #[derive(Copy, Clone, PostgresType)]
/// #[derive(Copy, Clone, PostgresType, Serialize, Deserialize)]
/// #[pgvarlena_inoutfuncs]
/// struct MyType {
/// a: f32,
Expand Down Expand Up @@ -378,50 +378,8 @@ where
}
}

impl<T> IntoDatum for T
where
T: PostgresType + Serialize,
{
fn into_datum(self) -> Option<pg_sys::Datum> {
Some(cbor_encode(&self).into())
}

fn type_oid() -> pg_sys::Oid {
crate::rust_regtypein::<T>()
}
}

impl<'de, T> FromDatum for T
where
T: PostgresType + Deserialize<'de>,
{
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
cbor_decode(datum.cast_mut_ptr())
}
}

unsafe fn from_datum_in_memory_context(
memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
cbor_decode_into_context(memory_context, datum.cast_mut_ptr())
}
}
}

fn cbor_encode<T>(input: T) -> *const pg_sys::varlena
#[doc(hidden)]
pub unsafe fn cbor_encode<T>(input: T) -> *const pg_sys::varlena
where
T: Serialize,
{
Expand All @@ -439,6 +397,7 @@ where
varlena as *const pg_sys::varlena
}

#[doc(hidden)]
pub unsafe fn cbor_decode<'de, T>(input: *mut pg_sys::varlena) -> T
where
T: Deserialize<'de>,
Expand All @@ -450,6 +409,8 @@ where
serde_cbor::from_slice(slice).expect("failed to decode CBOR")
}

#[doc(hidden)]
#[deprecated(since = "0.12.0", note = "just use the FromDatum impl")]
pub unsafe fn cbor_decode_into_context<'de, T>(
mut memory_context: PgMemoryContexts,
input: *mut pg_sys::varlena,
Expand All @@ -464,37 +425,6 @@ where
})
}

#[allow(dead_code)]
fn json_encode<T>(input: T) -> *const pg_sys::varlena
where
T: Serialize,
{
let mut serialized = StringInfo::new();

serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); // reserve space for the header
serde_json::to_writer(&mut serialized, &input).expect("failed to encode as JSON");

let size = serialized.len();
let varlena = serialized.into_char_ptr();
unsafe {
set_varsize(varlena as *mut pg_sys::varlena, size as i32);
}

varlena as *const pg_sys::varlena
}

#[allow(dead_code)]
unsafe fn json_decode<'de, T>(input: *mut pg_sys::varlena) -> T
where
T: Deserialize<'de>,
{
let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena);
let len = varsize_any_exhdr(varlena);
let data = vardata_any(varlena);
let slice = std::slice::from_raw_parts(data as *const u8, len);
serde_json::from_slice(slice).expect("failed to decode JSON")
}

unsafe impl<T> SqlTranslatable for PgVarlena<T>
where
T: SqlTranslatable + Copy,
Expand Down

0 comments on commit 2df66d8

Please sign in to comment.