Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add shared_fn macro #55

Merged
merged 7 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
136 changes: 132 additions & 4 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
use proc_macro2::{Ident, Span};
use quote::quote;
use syn::{parse_macro_input, ItemFn, ItemForeignMod};
use syn::{parse_macro_input, FnArg, ItemFn, ItemForeignMod};

/// `plugin_fn` is used to define a function that will be exported by a plugin
/// `plugin_fn` is used to define an Extism callable function to export
///
/// It should be added to a function you would like to export, the function should
/// accept a parameter that implements `extism_pdk::FromBytes` and return a
/// `extism_pdk::FnResult` that contains a value that implements
/// `extism_pdk::ToBytes`.
/// `extism_pdk::ToBytes`. This maps input and output parameters to Extism input
/// and output instead of using function arguments directly.
///
/// ## Example
///
/// ```rust
/// use extism_pdk::{FnResult, plugin_fn};
/// #[plugin_fn]
/// pub fn greet(name: String) -> FnResult<String> {
/// let s = format!("Hello, {name}");
/// Ok(s)
/// }
/// ```
#[proc_macro_attribute]
pub fn plugin_fn(
_attr: proc_macro::TokenStream,
Expand Down Expand Up @@ -103,7 +116,122 @@ pub fn plugin_fn(
}
}

/// `host_fn` is used to define a host function that will be callable from within a plugin
/// `shared_fn` is used to define a function that will be exported by a plugin but is not directly
/// callable by an Extism runtime. These functions can be used for runtime linking and mocking host
/// functions for tests. If direct access to Wasm native parameters is needed, then a bare
/// `extern "C" fn` should be used instead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ass an example of this here

///
/// All arguments should implement `extism_pdk::ToBytes` and the return value should implement
/// `extism_pdk::FromBytes`
/// ## Example
///
/// ```rust
/// use extism_pdk::{FnResult, shared_fn};
/// #[shared_fn]
/// pub fn greet2(greeting: String, name: String) -> FnResult<String> {
/// let s = format!("{greeting}, {name}");
/// Ok(name)
/// }
/// ```
#[proc_macro_attribute]
pub fn shared_fn(
_attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let mut function = parse_macro_input!(item as ItemFn);

if !matches!(function.vis, syn::Visibility::Public(..)) {
panic!("extism_pdk::shared_fn expects a public function");
}

let name = &function.sig.ident;
let constness = &function.sig.constness;
let unsafety = &function.sig.unsafety;
let generics = &function.sig.generics;
let inputs = &mut function.sig.inputs;
let output = &mut function.sig.output;
let block = &function.block;

let (raw_inputs, raw_args): (Vec<_>, Vec<_>) = inputs
.iter()
.enumerate()
.map(|(i, x)| {
let t = match x {
FnArg::Receiver(_) => {
panic!("Receiver argument (self) cannot be used in extism_pdk::shared_fn")
}
FnArg::Typed(t) => &t.ty,
};
let arg = Ident::new(&format!("arg{i}"), Span::call_site());
(
quote! { #arg: extism_pdk::MemoryPointer<#t> },
quote! { #arg.get()? },
)
})
.unzip();

if name == "main" {
panic!(
"export_pdk::shared_fn must not be applied to a `main` function. To fix, rename this to something other than `main`."
)
}

let (no_result, raw_output) = match output {
syn::ReturnType::Default => (true, quote! {}),
syn::ReturnType::Type(_, t) => {
if let syn::Type::Path(p) = t.as_ref() {
if let Some(t) = p.path.segments.last() {
if t.ident != "SharedFnResult" {
panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult");
}
} else {
panic!("extism_pdk::shared_fn expects a function that returns extism_pdk::SharedFnResult");
}
};
(false, quote! {-> u64 })
}
};

if no_result {
quote! {
#[no_mangle]
pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) {
#constness #unsafety fn inner #generics(#inputs) -> extism_pdk::SharedFnResult<()> {
#block
}


let r = || inner(#(#raw_args,)*);
if let Err(rc) = r() {
panic!("{}", rc.to_string());
}
}
}
.into()
} else {
quote! {
#[no_mangle]
pub #constness #unsafety extern "C" fn #name(#(#raw_inputs,)*) #raw_output {
#constness #unsafety fn inner #generics(#inputs) #output {
#block
}

let r = || inner(#(#raw_args,)*);
match r().and_then(|x| extism_pdk::Memory::new(&x)) {
Ok(mem) => {
mem.offset()
},
Err(rc) => {
panic!("{}", rc.to_string());
}
}
}
}
.into()
}
}

/// `host_fn` is used to import a host function from an `extern` block
#[proc_macro_attribute]
pub fn host_fn(
attr: proc_macro::TokenStream,
Expand Down
13 changes: 13 additions & 0 deletions examples/reflect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#![no_main]

use extism_pdk::*;

#[shared_fn]
pub fn host_reflect(input: String) -> SharedFnResult<Vec<u8>> {
Ok(input.to_lowercase().into_bytes())
}

#[shared_fn]
pub fn nothing() -> SharedFnResult<()> {
Ok(())
}
2 changes: 1 addition & 1 deletion src/extism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ extern "C" {
}

/// Loads a byte array from Extism's memory. Only use this if you
/// have already considered the plugin_fn macro as well as the [extism_load_input] function.
/// have already considered the plugin_fn macro as well as the `extism_load_input` function.
///
/// # Arguments
///
Expand Down
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ pub mod http;
pub use anyhow::Error;
pub use extism_convert::*;
pub use extism_convert::{FromBytes, FromBytesOwned, ToBytes};
pub use extism_pdk_derive::{host_fn, plugin_fn};
pub use memory::Memory;
pub use extism_pdk_derive::{host_fn, plugin_fn, shared_fn};
pub use memory::{Memory, MemoryPointer};
pub use to_memory::ToMemory;

#[cfg(feature = "http")]
Expand All @@ -37,6 +37,9 @@ pub use http::HttpResponse;
/// The return type of a plugin function
pub type FnResult<T> = Result<T, WithReturnCode<Error>>;

/// The return type of a `shared_fn`
pub type SharedFnResult<T> = Result<T, Error>;

/// Logging levels
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogLevel {
Expand Down
20 changes: 20 additions & 0 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,23 @@ impl From<i64> for Memory {
Memory::find(offset as u64).unwrap_or_else(Memory::null)
}
}

#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct MemoryPointer<T>(u64, std::marker::PhantomData<T>);

impl<T> MemoryPointer<T> {
pub unsafe fn new(x: u64) -> Self {
MemoryPointer(x, Default::default())
}
}

impl<T: FromBytesOwned> MemoryPointer<T> {
pub fn get(&self) -> Result<T, Error> {
let mem = Memory::find(self.0);
match mem {
Some(mem) => T::from_bytes_owned(&mem.to_vec()),
None => anyhow::bail!("Invalid pointer offset {}", self.0),
}
}
}
Loading