Skip to content

Commit

Permalink
feat: Teach #[pg_cast] how to pass down arguments useful to `#[pg_e…
Browse files Browse the repository at this point in the history
…xtern]`

The `#[pg_cast]` macro does two things.  It creates a regular UDF and it
then generates the appropriate `CREATE CAST` sql.

Ultimately, the cast is just a function, so we should support all the attibutes
that `#[pg_extern]` supports.
  • Loading branch information
eeeebbbbrrrr committed Oct 27, 2024
1 parent 0193fd0 commit 8541b2b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 21 deletions.
40 changes: 24 additions & 16 deletions pgrx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ By default if no attribute is specified, the cast function can only be used in a
Functions MUST accept and return exactly one value whose type MUST be a `pgrx` supported type. `pgrx` supports many PostgreSQL types by default.
New types can be defined via [`macro@PostgresType`] or [`macro@PostgresEnum`].
`#[pg_cast]` also supports all the attributes supported by the [`macro@pg_extern]` macro, which are
passed down to the underlying function.
Example usage:
```rust,ignore
use pgrx::*;
Expand All @@ -173,31 +176,36 @@ fn cast_json_to_int(input: Json) -> i32 { todo!() }
pub fn pg_cast(attr: TokenStream, item: TokenStream) -> TokenStream {
fn wrapped(attr: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
use syn::parse::Parser;
use syn::punctuated::Punctuated;

let mut cast = PgCast::Default;
match syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated.parse(attr)
{
let mut pg_extern_attrs = proc_macro2::TokenStream::new();

// look for the attributes `#[pg_cast]` directly understands
match Punctuated::<syn::Path, syn::Token![,]>::parse_terminated.parse(attr) {
Ok(paths) => {
if paths.len() > 1 {
panic!(
"pg_cast must take either 0 or 1 attribute. Found {}: {}",
paths.len(),
paths.to_token_stream()
)
} else if paths.len() == 1 {
match paths.first().unwrap().segments.last().unwrap().ident.to_string().as_str()
{
"implicit" => cast = PgCast::Implicit,
"assignment" => cast = PgCast::Assignment,
other => panic!("Unrecognized pg_cast option: {other}. "),
let mut new_paths = Punctuated::<syn::Path, syn::Token![,]>::new();
for path in paths {
if path.is_ident("implicit") {
cast = PgCast::Implicit
} else if path.is_ident("assignment") {
cast = PgCast::Assignment
} else {
// ... and anything it doesn't understand is blindly passed through to the
// underlying `#[pg_extern]` function that gets created, which will ultimately
// decide what's naughty and what's nice
new_paths.push(path);
}
}

pg_extern_attrs.extend(new_paths.into_token_stream());
}
Err(err) => {
panic!("Failed to parse attribute to pg_cast: {err}")
}
}
// `pg_cast` does not support other `pg_extern` attributes for now, pass an empty attribute token stream.
let pg_extern = PgExtern::new(TokenStream::new().into(), item.clone().into())?.0;

let pg_extern = PgExtern::new(pg_extern_attrs.into(), item.clone().into())?.0;
Ok(CodeEnrichment(pg_extern.as_cast(cast)).to_token_stream().into())
}

Expand Down
28 changes: 27 additions & 1 deletion pgrx-tests/src/tests/pg_cast_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use pgrx::prelude::*;

#[pg_schema]
mod pg_catalog {
use pgrx::pg_cast;
use pgrx::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::Value::Number;

#[pg_cast(implicit)]
Expand All @@ -27,6 +28,15 @@ mod pg_catalog {
};
panic!("Error casting json value {} to an integer", value.0)
}

#[derive(PostgresType, Serialize, Deserialize)]
struct TestCastType;

#[pg_cast(implicit, immutable)]
fn testcasttype_to_bool(_i: TestCastType) -> bool {
// look, it's just a test, okay?
true
}
}

#[cfg(any(test, feature = "pg_test"))]
Expand Down Expand Up @@ -60,4 +70,20 @@ mod tests {
fn test_pg_cast_implicit_type_cast() {
assert_eq!(Spi::get_one::<i32>("SELECT 1 + ('{\"a\": 1}'::json->'a')"), Ok(Some(2)));
}

#[pg_test]
fn assert_cast_func_is_immutable() {
let is_immutable = Spi::get_one::<bool>(
"SELECT provolatile = 'i' FROM pg_proc WHERE proname = 'testcasttype_to_bool';",
);
assert_eq!(is_immutable, Ok(Some(true)));
}

#[pg_test]
fn assert_cast_is_implicit() {
let is_immutable = Spi::get_one::<bool>(
"SELECT castcontext = 'i' FROM pg_cast WHERE castsource = 'TestCastType'::regtype AND casttarget = 'bool'::regtype;",
);
assert_eq!(is_immutable, Ok(Some(true)));
}
}
12 changes: 8 additions & 4 deletions pgrx-tests/tests/compile-fail/invalid_pgcast_options.stderr
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
error: custom attribute panicked
--> tests/compile-fail/invalid_pgcast_options.rs:3:1
error: Invalid option `invalid_opt` inside `invalid_opt `
--> tests/compile-fail/invalid_pgcast_options.rs:3:11
|
3 | #[pg_cast(invalid_opt)]
| ^^^^^^^^^^^^^^^^^^^^^^^
| ^^^^^^^^^^^

error: failed parsing pg_extern arguments
--> tests/compile-fail/invalid_pgcast_options.rs:3:11
|
= help: message: Unrecognized pg_cast option: invalid_opt.
3 | #[pg_cast(invalid_opt)]
| ^^^^^^^^^^^

0 comments on commit 8541b2b

Please sign in to comment.