diff --git a/.cargo/config b/.cargo/config.toml similarity index 100% rename from .cargo/config rename to .cargo/config.toml diff --git a/derive/src/lib.rs b/derive/src/lib.rs index f9c233c..77f47b4 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -1,12 +1,25 @@ +use proc_macro2::{Ident, Span}; use quote::quote; -use syn::{parse_macro_input, ItemFn, ItemForeignMod}; +use syn::{parse_macro_input, FnArg, ItemFn, ItemForeignMod}; -/// `plugin_fn` is used to define a function that will be exported by a plugin +/// `plugin_fn` is used to define an Extism callable function to export /// /// It should be added to a function you would like to export, the function should /// accept a parameter that implements `extism_pdk::FromBytes` and return a /// `extism_pdk::FnResult` that contains a value that implements -/// `extism_pdk::ToBytes`. +/// `extism_pdk::ToBytes`. This maps input and output parameters to Extism input +/// and output instead of using function arguments directly. +/// +/// ## Example +/// +/// ```rust +/// use extism_pdk::{FnResult, plugin_fn}; +/// #[plugin_fn] +/// pub fn greet(name: String) -> FnResult { +/// let s = format!("Hello, {name}"); +/// Ok(s) +/// } +/// ``` #[proc_macro_attribute] pub fn plugin_fn( _attr: proc_macro::TokenStream, @@ -103,7 +116,122 @@ pub fn plugin_fn( } } -/// `host_fn` is used to define a host function that will be callable from within a plugin +/// `shared_fn` is used to define a function that will be exported by a plugin but is not directly +/// callable by an Extism runtime. These functions can be used for runtime linking and mocking host +/// functions for tests. If direct access to Wasm native parameters is needed, then a bare +/// `extern "C" fn` should be used instead. +/// +/// All arguments should implement `extism_pdk::ToBytes` and the return value should implement +/// `extism_pdk::FromBytes` +/// ## Example +/// +/// ```rust +/// use extism_pdk::{FnResult, shared_fn}; +/// #[shared_fn] +/// pub fn greet2(greeting: String, name: String) -> FnResult { +/// let s = format!("{greeting}, {name}"); +/// Ok(name) +/// } +/// ``` +#[proc_macro_attribute] +pub fn shared_fn( + _attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let mut function = parse_macro_input!(item as ItemFn); + + if !matches!(function.vis, syn::Visibility::Public(..)) { + panic!("extism_pdk::shared_fn expects a public function"); + } + + let name = &function.sig.ident; + let constness = &function.sig.constness; + let unsafety = &function.sig.unsafety; + let generics = &function.sig.generics; + let inputs = &mut function.sig.inputs; + let output = &mut function.sig.output; + let block = &function.block; + + let (raw_inputs, raw_args): (Vec<_>, Vec<_>) = inputs + .iter() + .enumerate() + .map(|(i, x)| { + let t = match x { + FnArg::Receiver(_) => { + panic!("Receiver argument (self) cannot be used in extism_pdk::shared_fn") + } + FnArg::Typed(t) => &t.ty, + }; + let arg = Ident::new(&format!("arg{i}"), Span::call_site()); + ( + quote! { #arg: extism_pdk::MemoryPointer<#t> }, + quote! { #arg.get()? }, + ) + }) + .unzip(); + + if name == "main" { + panic!( + "export_pdk::shared_fn must not be applied to a `main` function. To fix, rename this to something other than `main`." + ) + } + + let (no_result, raw_output) = match output { + syn::ReturnType::Default => (true, quote! {}), + syn::ReturnType::Type(_, t) => { + if let syn::Type::Path(p) = t.as_ref() { + if let Some(t) = p.path.segments.last() { + if t.ident != "SharedFnResult" { + panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult"); + } + } else { + panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult"); + } + }; + (false, quote! {-> u64 }) + } + }; + + if no_result { + quote! { + #[no_mangle] + pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) { + #constness #unsafety fn inner #generics(#inputs) -> extism_pdk::SharedFnResult<()> { + #block + } + + + let r = || inner(#(#raw_args,)*); + if let Err(rc) = r() { + panic!("{}", rc.to_string()); + } + } + } + .into() + } else { + quote! { + #[no_mangle] + pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) #raw_output { + #constness #unsafety fn inner #generics(#inputs) #output { + #block + } + + let r = || inner(#(#raw_args,)*); + match r().and_then(|x| extism_pdk::Memory::new(&x)) { + Ok(mem) => { + mem.offset() + }, + Err(rc) => { + panic!("{}", rc.to_string()); + } + } + } + } + .into() + } +} + +/// `host_fn` is used to import a host function from an `extern` block #[proc_macro_attribute] pub fn host_fn( attr: proc_macro::TokenStream, diff --git a/examples/reflect.rs b/examples/reflect.rs new file mode 100644 index 0000000..a47df22 --- /dev/null +++ b/examples/reflect.rs @@ -0,0 +1,13 @@ +#![no_main] + +use extism_pdk::*; + +#[shared_fn] +pub fn host_reflect(input: String) -> SharedFnResult> { + Ok(input.to_lowercase().into_bytes()) +} + +#[shared_fn] +pub fn nothing() -> SharedFnResult<()> { + Ok(()) +} diff --git a/src/extism.rs b/src/extism.rs index a473791..35b1116 100644 --- a/src/extism.rs +++ b/src/extism.rs @@ -25,7 +25,7 @@ extern "C" { } /// Loads a byte array from Extism's memory. Only use this if you -/// have already considered the plugin_fn macro as well as the [extism_load_input] function. +/// have already considered the plugin_fn macro as well as the `extism_load_input` function. /// /// # Arguments /// diff --git a/src/lib.rs b/src/lib.rs index 2fcb1e5..52937bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,8 +22,8 @@ pub mod http; pub use anyhow::Error; pub use extism_convert::*; pub use extism_convert::{FromBytes, FromBytesOwned, ToBytes}; -pub use extism_pdk_derive::{host_fn, plugin_fn}; -pub use memory::Memory; +pub use extism_pdk_derive::{host_fn, plugin_fn, shared_fn}; +pub use memory::{Memory, MemoryPointer}; pub use to_memory::ToMemory; #[cfg(feature = "http")] @@ -37,6 +37,9 @@ pub use http::HttpResponse; /// The return type of a plugin function pub type FnResult = Result>; +/// The return type of a `shared_fn` +pub type SharedFnResult = Result; + /// Logging levels #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum LogLevel { diff --git a/src/memory.rs b/src/memory.rs index c58ae10..eee3ffc 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -174,3 +174,23 @@ impl From for Memory { Memory::find(offset as u64).unwrap_or_else(Memory::null) } } + +#[repr(transparent)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +pub struct MemoryPointer(u64, std::marker::PhantomData); + +impl MemoryPointer { + pub unsafe fn new(x: u64) -> Self { + MemoryPointer(x, Default::default()) + } +} + +impl MemoryPointer { + pub fn get(&self) -> Result { + let mem = Memory::find(self.0); + match mem { + Some(mem) => T::from_bytes_owned(&mem.to_vec()), + None => anyhow::bail!("Invalid pointer offset {}", self.0), + } + } +}