diff --git a/x/programs/cmd/simulator/cmd/program.go b/x/programs/cmd/simulator/cmd/program.go index d0c9a98307..7a2b33012b 100644 --- a/x/programs/cmd/simulator/cmd/program.go +++ b/x/programs/cmd/simulator/cmd/program.go @@ -144,7 +144,6 @@ func programExecuteFunc( ) (ids.ID, []int64, uint64, error) { // simulate create program transaction programTxID, err := generateRandomID() - if err != nil { return ids.Empty, nil, 0, err } diff --git a/x/programs/examples/imports/program/program.go b/x/programs/examples/imports/program/program.go index 80f0eff35e..81433cb920 100644 --- a/x/programs/examples/imports/program/program.go +++ b/x/programs/examples/imports/program/program.go @@ -59,12 +59,15 @@ func (i *Import) Register(link *host.Link, callContext program.Context) error { } // callProgramFn makes a call to an entry function of a program in the context of another program's ID. -func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Caller, int64, int64, int64, int64) int64 { +func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Caller, int32, int32, int32, int32, int32, int32, int64) int64 { return func( wasmCaller *wasmtime.Caller, - programID int64, - function int64, - args int64, + programPtr int32, + programLen int32, + functionPtr int32, + functionLen int32, + argsPtr int32, + argsLen int32, maxUnits int64, ) int64 { ctx, cancel := context.WithCancel(context.Background()) @@ -80,7 +83,7 @@ func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Calle } // get the entry function for invoke to call. - functionBytes, err := program.SmartPtr(function).Bytes(memory) + functionBytes, err := memory.Range(uint32(functionPtr), uint32(functionLen)) if err != nil { i.log.Error("failed to read function name from memory", zap.Error(err), @@ -88,7 +91,7 @@ func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Calle return -1 } - programIDBytes, err := program.SmartPtr(programID).Bytes(memory) + programIDBytes, err := memory.Range(uint32(programPtr), uint32(programLen)) if err != nil { i.log.Error("failed to read id from memory", zap.Error(err), @@ -143,7 +146,7 @@ func (i *Import) callProgramFn(callContext program.Context) func(*wasmtime.Calle } }() - argsBytes, err := program.SmartPtr(args).Bytes(memory) + argsBytes, err := memory.Range(uint32(argsPtr), uint32(argsLen)) if err != nil { i.log.Error("failed to read program args from memory", zap.Error(err), diff --git a/x/programs/examples/imports/pstate/pstate.go b/x/programs/examples/imports/pstate/pstate.go index 2bea833234..bf0baf333c 100644 --- a/x/programs/examples/imports/pstate/pstate.go +++ b/x/programs/examples/imports/pstate/pstate.go @@ -43,38 +43,38 @@ func (*Import) Name() string { func (i *Import) Register(link *host.Link, _ program.Context) error { i.meter = link.Meter() wrap := wrap.New(link) - if err := wrap.RegisterAnyParamFn(Name, "put", 3, i.putFnVariadic); err != nil { + if err := wrap.RegisterAnyParamFn(Name, "put", 6, i.putFnVariadic); err != nil { return err } - if err := wrap.RegisterAnyParamFn(Name, "get", 2, i.getFnVariadic); err != nil { + if err := wrap.RegisterAnyParamFn(Name, "get", 4, i.getFnVariadic); err != nil { return err } - return wrap.RegisterAnyParamFn(Name, "delete", 2, i.deleteFnVariadic) + return wrap.RegisterAnyParamFn(Name, "delete", 4, i.deleteFnVariadic) } -func (i *Import) putFnVariadic(caller *program.Caller, args ...int64) (*types.Val, error) { - if len(args) != 3 { - return nil, errors.New("expected 3 arguments") +func (i *Import) putFnVariadic(caller *program.Caller, args ...int32) (*types.Val, error) { + if len(args) != 6 { + return nil, errors.New("expected 6 arguments") } - return i.putFn(caller, args[0], args[1], args[2]) + return i.putFn(caller, args[0], args[1], args[2], args[3], args[4], args[5]) } -func (i *Import) getFnVariadic(caller *program.Caller, args ...int64) (*types.Val, error) { - if len(args) != 2 { - return nil, errors.New("expected 2 arguments") +func (i *Import) getFnVariadic(caller *program.Caller, args ...int32) (*types.Val, error) { + if len(args) != 4 { + return nil, errors.New("expected 4 arguments") } - return i.getFn(caller, args[0], args[1]) + return i.getFn(caller, args[0], args[1], args[2], args[3]) } -func (i *Import) deleteFnVariadic(caller *program.Caller, args ...int64) (*types.Val, error) { - if len(args) != 2 { - return nil, errors.New("expected 2 arguments") +func (i *Import) deleteFnVariadic(caller *program.Caller, args ...int32) (*types.Val, error) { + if len(args) != 4 { + return nil, errors.New("expected 4 arguments") } - return i.deleteFn(caller, args[0], args[1]) + return i.deleteFn(caller, args[0], args[1], args[2], args[3]) } -func (i *Import) putFn(caller *program.Caller, id int64, key int64, value int64) (*types.Val, error) { +func (i *Import) putFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr int32, keyLen int32, valuePtr int32, valueLen int32) (*types.Val, error) { memory, err := caller.Memory() if err != nil { i.log.Error("failed to get memory from caller", @@ -83,7 +83,7 @@ func (i *Import) putFn(caller *program.Caller, id int64, key int64, value int64) return nil, err } - programIDBytes, err := program.SmartPtr(id).Bytes(memory) + programIDBytes, err := memory.Range(uint32(idPtr), uint32(idLen)) if err != nil { i.log.Error("failed to read program id from memory", zap.Error(err), @@ -91,7 +91,7 @@ func (i *Import) putFn(caller *program.Caller, id int64, key int64, value int64) return nil, err } - keyBytes, err := program.SmartPtr(key).Bytes(memory) + keyBytes, err := memory.Range(uint32(keyPtr), uint32(keyLen)) if err != nil { i.log.Error("failed to read key from memory", zap.Error(err), @@ -99,7 +99,7 @@ func (i *Import) putFn(caller *program.Caller, id int64, key int64, value int64) return nil, err } - valueBytes, err := program.SmartPtr(value).Bytes(memory) + valueBytes, err := memory.Range(uint32(valuePtr), uint32(valueLen)) if err != nil { i.log.Error("failed to read value from memory", zap.Error(err), @@ -116,10 +116,10 @@ func (i *Import) putFn(caller *program.Caller, id int64, key int64, value int64) return nil, err } - return types.ValI64(0), nil + return types.ValI32(0), nil } -func (i *Import) getFn(caller *program.Caller, id int64, key int64) (*types.Val, error) { +func (i *Import) getFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr int32, keyLen int32) (*types.Val, error) { memory, err := caller.Memory() if err != nil { i.log.Error("failed to get memory from caller", @@ -128,7 +128,7 @@ func (i *Import) getFn(caller *program.Caller, id int64, key int64) (*types.Val, return nil, err } - programIDBytes, err := program.SmartPtr(id).Bytes(memory) + programIDBytes, err := memory.Range(uint32(idPtr), uint32(idLen)) if err != nil { i.log.Error("failed to read program id from memory", zap.Error(err), @@ -136,7 +136,7 @@ func (i *Import) getFn(caller *program.Caller, id int64, key int64) (*types.Val, return nil, err } - keyBytes, err := program.SmartPtr(key).Bytes(memory) + keyBytes, err := memory.Range(uint32(keyPtr), uint32(keyLen)) if err != nil { i.log.Error("failed to read key from memory", zap.Error(err), @@ -148,7 +148,7 @@ func (i *Import) getFn(caller *program.Caller, id int64, key int64) (*types.Val, if err != nil { if errors.Is(err, database.ErrNotFound) { // TODO: return a more descriptive error - return types.ValI64(-1), nil + return types.ValI32(-1), nil } i.log.Error("failed to get value from storage", zap.Error(err), @@ -163,7 +163,7 @@ func (i *Import) getFn(caller *program.Caller, id int64, key int64) (*types.Val, return nil, err } - ptr, err := program.WriteBytes(memory, val) + valPtr, err := program.WriteBytes(memory, val) if err != nil { { i.log.Error("failed to write to memory", @@ -172,7 +172,7 @@ func (i *Import) getFn(caller *program.Caller, id int64, key int64) (*types.Val, } return nil, err } - argPtr, err := program.NewSmartPtr(ptr, len(val)) + _, err = memory.Range(valPtr, uint32(len(val))) if err != nil { i.log.Error("failed to convert ptr to argument", zap.Error(err), @@ -180,10 +180,10 @@ func (i *Import) getFn(caller *program.Caller, id int64, key int64) (*types.Val, return nil, err } - return types.ValI64(int64(argPtr)), nil + return types.ValI32(int32(valPtr)), nil } -func (i *Import) deleteFn(caller *program.Caller, id int64, key int64) (*types.Val, error) { +func (i *Import) deleteFn(caller *program.Caller, idPtr int32, idLen int32, keyPtr int32, keyLen int32) (*types.Val, error) { memory, err := caller.Memory() if err != nil { i.log.Error("failed to get memory from caller", @@ -192,7 +192,7 @@ func (i *Import) deleteFn(caller *program.Caller, id int64, key int64) (*types.V return nil, err } - programIDBytes, err := program.SmartPtr(id).Bytes(memory) + programIDBytes, err := memory.Range(uint32(idPtr), uint32(idLen)) if err != nil { i.log.Error("failed to read program id from memory", zap.Error(err), @@ -200,7 +200,7 @@ func (i *Import) deleteFn(caller *program.Caller, id int64, key int64) (*types.V return nil, err } - keyBytes, err := program.SmartPtr(key).Bytes(memory) + keyBytes, err := memory.Range(uint32(keyPtr), uint32(keyLen)) if err != nil { i.log.Error("failed to read key from memory", zap.Error(err), @@ -211,7 +211,7 @@ func (i *Import) deleteFn(caller *program.Caller, id int64, key int64) (*types.V k := storage.ProgramPrefixKey(programIDBytes, keyBytes) if err := i.mu.Remove(context.Background(), k); err != nil { i.log.Error("failed to remove from storage", zap.Error(err)) - return types.ValI64(-1), nil + return types.ValI32(-1), nil } - return types.ValI64(0), nil + return types.ValI32(0), nil } diff --git a/x/programs/examples/imports/wrap/wrap.go b/x/programs/examples/imports/wrap/wrap.go index deb8f71213..2b0eec99ac 100644 --- a/x/programs/examples/imports/wrap/wrap.go +++ b/x/programs/examples/imports/wrap/wrap.go @@ -25,13 +25,13 @@ type Wrap struct { link *host.Link } -// RegisterOneParamInt64Fn is a helper method for registering a function with one int64 parameter. +// RegisterAnyParamFn is a helper method for registering a function with one int64 parameter. func (w *Wrap) RegisterAnyParamFn(name, module string, paramCount int, fn AnyParamFn) error { return w.link.RegisterImportWrapFn(name, module, paramCount, NewImportFn[AnyParamFn](fn)) } // AnyParamFn is a generic type that satisfies AnyParamFnType -type AnyParamFn func(*program.Caller, ...int64) (*types.Val, error) +type AnyParamFn func(*program.Caller, ...int32) (*types.Val, error) // ImportFn is a generic type that satisfies ImportFnType type ImportFn[F AnyParamFn] struct { @@ -39,8 +39,8 @@ type ImportFn[F AnyParamFn] struct { } // Invoke calls the underlying function with the given arguments. Currently only -// supports int64 arguments and return values. -func (i ImportFn[F]) Invoke(c *program.Caller, args ...int64) (*types.Val, error) { +// supports int32 arguments and return values. +func (i ImportFn[F]) Invoke(c *program.Caller, args ...int32) (*types.Val, error) { switch fn := any(i.fn).(type) { case AnyParamFn: return fn.Call(c, args...) @@ -49,16 +49,16 @@ func (i ImportFn[F]) Invoke(c *program.Caller, args ...int64) (*types.Val, error } } -func (fn AnyParamFn) Call(c *program.Caller, args ...int64) (*types.Val, error) { +func (fn AnyParamFn) Call(c *program.Caller, args ...int32) (*types.Val, error) { return fn(c, args...) } func NewImportFn[F AnyParamFn](src F) func(caller *program.Caller, wargs ...wasmtime.Val) (*types.Val, error) { importFn := ImportFn[F]{fn: src} fn := func(c *program.Caller, wargs ...wasmtime.Val) (*types.Val, error) { - args := make([]int64, 0, len(wargs)) + args := make([]int32, 0, len(wargs)) for _, arg := range wargs { - args = append(args, arg.I64()) + args = append(args, arg.I32()) } return importFn.Invoke(c, args...) } diff --git a/x/programs/examples/token.go b/x/programs/examples/token.go index 7bc67c3bb7..6617e6d54f 100644 --- a/x/programs/examples/token.go +++ b/x/programs/examples/token.go @@ -386,7 +386,7 @@ func (t *Token) RunShort(ctx context.Context) error { return nil } -func (t *Token) GetUserBalanceFromState(ctx context.Context, programID ids.ID, userPublicKey ed25519.PublicKey) (res int64, err error) { +func (t *Token) GetUserBalanceFromState(ctx context.Context, programID ids.ID, userPublicKey ed25519.PublicKey) (res uint32, err error) { key := storage.ProgramPrefixKey(programID[:], append([]byte{uint8(Balance)}, userPublicKey[:]...)) b, err := t.db.GetValue(ctx, key) if err != nil { diff --git a/x/programs/examples/token_test.go b/x/programs/examples/token_test.go index 58a37358c5..4735c60a2c 100644 --- a/x/programs/examples/token_test.go +++ b/x/programs/examples/token_test.go @@ -26,7 +26,7 @@ func TestTokenProgram(t *testing.T) { t.Run("BurnUserTokens", func(t *testing.T) { wasmBytes := tests.ReadFixture(t, "../tests/fixture/token.wasm") require := require.New(t) - maxUnits := uint64(80000) + maxUnits := uint64(200000) eng := engine.New(engine.NewConfig()) program := newTokenProgram(maxUnits, eng, runtime.NewConfig(), wasmBytes) require.NoError(program.Run(context.Background())) @@ -76,7 +76,7 @@ func TestTokenProgram(t *testing.T) { // read alice balance from state db aliceBalance, err := program.GetUserBalanceFromState(ctx, programID, alicePublicKey) require.NoError(err) - require.Equal(int64(1000), aliceBalance) + require.Equal(uint32(1000), aliceBalance) alicePtr, err = writeToMem(alicePublicKey, mem) require.NoError(err) @@ -92,7 +92,7 @@ func TestTokenProgram(t *testing.T) { wasmBytes := tests.ReadFixture(t, "../tests/fixture/token.wasm") require := require.New(t) - maxUnits := uint64(80000) + maxUnits := uint64(200000) eng := engine.New(engine.NewConfig()) program := newTokenProgram(maxUnits, eng, runtime.NewConfig(), wasmBytes) require.NoError(program.Run(context.Background())) diff --git a/x/programs/host/link.go b/x/programs/host/link.go index 1c6105d3cc..9d183a0dfd 100644 --- a/x/programs/host/link.go +++ b/x/programs/host/link.go @@ -136,12 +136,12 @@ func (l *Link) RegisterImportWrapFn(module, name string, paramCount int, f func( // TODO: support other types? valType := make([]*wasmtime.ValType, paramCount) for i := 0; i < paramCount; i++ { - valType[i] = wasmtime.NewValType(wasmtime.KindI64) + valType[i] = wasmtime.NewValType(wasmtime.KindI32) } funcType := wasmtime.NewFuncType( valType, - []*wasmtime.ValType{wasmtime.NewValType(wasmtime.KindI64)}, + []*wasmtime.ValType{wasmtime.NewValType(wasmtime.KindI32)}, ) return l.wasmLink.FuncNew(module, name, funcType, fn) diff --git a/x/programs/program/function.go b/x/programs/program/function.go index 3c6ae3946d..09fd8faa12 100644 --- a/x/programs/program/function.go +++ b/x/programs/program/function.go @@ -45,7 +45,7 @@ func (f *Func) Call(context Context, params ...uint32) ([]int64, error) { return nil, err } - result, err := f.inner.Call(f.inst.GetStore(), append([]interface{}{int64(contextPtr)}, callParams...)...) + result, err := f.inner.Call(f.inst.GetStore(), append([]interface{}{int32(contextPtr)}, callParams...)...) if err != nil { return nil, HandleTrapError(err) } diff --git a/x/programs/program/memory.go b/x/programs/program/memory.go index b926fa75bd..32e4746b4a 100644 --- a/x/programs/program/memory.go +++ b/x/programs/program/memory.go @@ -4,7 +4,6 @@ package program import ( - "errors" "fmt" "math" "runtime" @@ -158,37 +157,7 @@ func WriteBytes(m *Memory, buf []byte) (uint32, error) { return offset, nil } -// SmartPtr is an int64 where the first 4 bytes represent the length of the bytes -// and the following 4 bytes represent a pointer to WASM memory where the bytes are stored. -type SmartPtr int64 - -// Get returns the int64 value of [s]. -func (s SmartPtr) Get() int64 { - return int64(s) -} - -// Len returns the length of the bytes stored in memory by [s]. -func (s SmartPtr) Len() uint32 { - return uint32(s >> 32) -} - -// PtrOffset returns the offset of the bytes stored in memory by [s]. -func (s SmartPtr) PtrOffset() uint32 { - return uint32(s) -} - -// Bytes returns the bytes stored in memory by [s]. -func (s SmartPtr) Bytes(memory *Memory) ([]byte, error) { - // read the range of PtrOffset + length from memory - bytes, err := memory.Range(s.PtrOffset(), s.Len()) - if err != nil { - return nil, err - } - - return bytes, nil -} - -// AllocateBytes writes [bytes] to memory and returns the resulting SmartPtr. +// AllocateBytes writes [bytes] to memory and returns the resulting pointer. func AllocateBytes(bytes []byte, memory *Memory) (uint32, error) { ptr, err := WriteBytes(memory, bytes) if err != nil { @@ -197,13 +166,3 @@ func AllocateBytes(bytes []byte, memory *Memory) (uint32, error) { return ptr, nil } - -// NewSmartPtr returns a SmartPtr from [ptr] and [byteLen]. -func NewSmartPtr(ptr uint32, byteLen int) (SmartPtr, error) { - // ensure length of bytes is not greater than int32 to prevent overflow - if !EnsureIntToInt32(byteLen) { - return 0, errors.New("length of bytes is greater than int32") - } - - return SmartPtr(int64(ptr)), nil -} diff --git a/x/programs/runtime/runtime_test.go b/x/programs/runtime/runtime_test.go index 01f4415211..e8f16e4e18 100644 --- a/x/programs/runtime/runtime_test.go +++ b/x/programs/runtime/runtime_test.go @@ -27,7 +27,7 @@ func TestStop(t *testing.T) { wasm, err := wasmtime.Wat2Wasm(` (module (memory 1) ;; 1 pages - (func $run (param i64) + (func $run (param i32) (loop br 0) ) (func $alloc (param i32) (result i32) @@ -68,12 +68,11 @@ func TestCallParams(t *testing.T) { defer cancel() // add param[0] + param[1] - //nolint: dupword wasm, err := wasmtime.Wat2Wasm(` (module (memory 1) ;; 1 pages ;; first argument is always the pointer to the context - (func $add (param i64 i64 i64) (result i64) + (func $add (param i32 i64 i64) (result i64) (i64.add local.get 1 local.get 2) ) (func $alloc (param i32) (result i32) @@ -120,7 +119,7 @@ func TestInfiniteLoop(t *testing.T) { wasm, err := wasmtime.Wat2Wasm(` (module (memory 1) ;; 1 pages - (func $run (param i64) + (func $run (param i32) (loop br 0) ) (func $alloc (param i32) (result i32) @@ -159,7 +158,7 @@ func TestMetering(t *testing.T) { wasm, err := wasmtime.Wat2Wasm(` (module (memory 1) ;; 1 pages - (func $get (param i64) (result i32) + (func $get (param i32) (result i32) i32.const 0 ) (func $alloc (param i32) (result i32) @@ -207,7 +206,7 @@ func TestMeterAfterStop(t *testing.T) { wasm, err := wasmtime.Wat2Wasm(` (module (memory 1) ;; 1 pages - (func $get (param i64) (result i32) + (func $get (param i32) (result i32) i32.const 0 ) (func $alloc (param i32) (result i32) @@ -358,7 +357,7 @@ func TestWithMaxWasmStack(t *testing.T) { wasm, err := wasmtime.Wat2Wasm(` (module (memory 1) ;; 1 pages - (func $get (param i64) (result i32) + (func $get (param i32) (result i32) i32.const 0 ) (func $alloc (param i32) (result i32) diff --git a/x/programs/rust/sdk_macros/src/lib.rs b/x/programs/rust/sdk_macros/src/lib.rs index c6c90441a3..930028c1cf 100644 --- a/x/programs/rust/sdk_macros/src/lib.rs +++ b/x/programs/rust/sdk_macros/src/lib.rs @@ -123,7 +123,7 @@ pub fn public(_: TokenStream, item: TokenStream) -> TokenStream { } }); - let param_types = std::iter::repeat(quote! { i64 }).take(param_names.len()); + let param_types = std::iter::repeat(quote! { *const u8 }).take(param_names.len()); // Extract the original function's return type. This must be a WASM supported type. let return_type = &input.sig.output; @@ -132,7 +132,7 @@ pub fn public(_: TokenStream, item: TokenStream) -> TokenStream { // Need to include the original function in the output, so contract can call itself #input #[no_mangle] - pub extern "C" fn #new_name(param_0: i64, #(#param_names: #param_types), *) #return_type { + pub extern "C" fn #new_name(param_0: *const u8, #(#param_names: #param_types), *) #return_type { let param_0: #context_type = unsafe { wasmlanche_sdk::from_host_ptr(param_0).expect("error serializing ptr") }; diff --git a/x/programs/rust/wasmlanche-sdk/src/lib.rs b/x/programs/rust/wasmlanche-sdk/src/lib.rs index 43d4509d95..68b176bd63 100644 --- a/x/programs/rust/wasmlanche-sdk/src/lib.rs +++ b/x/programs/rust/wasmlanche-sdk/src/lib.rs @@ -8,7 +8,7 @@ mod memory; mod program; pub use self::{ - memory::{from_host_ptr, HostPtr}, + memory::{from_host_ptr, CPointer}, params::{serialize_param, Params}, program::Program, }; diff --git a/x/programs/rust/wasmlanche-sdk/src/memory.rs b/x/programs/rust/wasmlanche-sdk/src/memory.rs index 5f491291ec..c800d2ce4a 100644 --- a/x/programs/rust/wasmlanche-sdk/src/memory.rs +++ b/x/programs/rust/wasmlanche-sdk/src/memory.rs @@ -8,33 +8,8 @@ use crate::state::Error as StateError; use borsh::{from_slice, BorshDeserialize}; use std::{alloc::Layout, cell::RefCell, collections::HashMap}; -/// Represents a pointer to a block of memory allocated by the global allocator. -#[derive(Clone, Copy)] -pub struct Pointer(*mut u8); - -impl From for Pointer { - fn from(v: i64) -> Self { - let ptr: *mut u8 = v as *mut u8; - Pointer(ptr) - } -} - -impl From for *const u8 { - fn from(pointer: Pointer) -> Self { - pointer.0.cast_const() - } -} - -impl From for *mut u8 { - fn from(pointer: Pointer) -> Self { - pointer.0 - } -} - -/// `HostPtr` is an i64 where the first 4 bytes represent the length of the bytes -/// and the following 4 bytes represent a pointer to WASM memeory where the bytes are stored. -// #[deprecated] TODO fix in a followup pr -pub type HostPtr = i64; +#[repr(C)] +pub struct CPointer(pub *const u8, pub usize); thread_local! { /// Map of pointer to the length of its content on the heap @@ -47,17 +22,16 @@ thread_local! { /// Returns an [`StateError`] if the pointer or length of `args` exceeds /// the maximum size of a u32. #[allow(clippy::cast_possible_truncation)] -pub fn to_host_ptr(arg: &[u8]) -> Result { - let ptr = arg.as_ptr() as usize; +pub fn to_ffi_ptr(arg: &[u8]) -> Result { + let ptr = arg.as_ptr(); let len = arg.len(); // Make sure the pointer and length fit into u32 - if ptr > u32::MAX as usize || len > u32::MAX as usize { + if ptr as usize > u32::MAX as usize || len > u32::MAX as usize { return Err(StateError::IntegerConversion); } - let host_ptr = i64::from(ptr as u32) | (i64::from(len as u32) << 32); - Ok(host_ptr) + Ok(CPointer(ptr, len)) } /// Converts a raw pointer to a deserialized value. @@ -69,7 +43,7 @@ pub fn to_host_ptr(arg: &[u8]) -> Result { /// This function is unsafe because it dereferences raw pointers. /// # Errors /// Returns an [`StateError`] if the bytes cannot be deserialized. -pub fn from_host_ptr(ptr: i64) -> Result +pub fn from_host_ptr(ptr: *const u8) -> Result where V: BorshDeserialize, { @@ -82,10 +56,10 @@ where /// Reconstructs the vec from the pointer with the length given by the store /// `host_ptr` is encoded using Big Endian as an i64. #[must_use] -fn into_bytes(ptr: HostPtr) -> Option> { +pub fn into_bytes(ptr: *const u8) -> Option> { GLOBAL_STORE - .with_borrow_mut(|s| s.remove(&(ptr as *const u8))) - .map(|len| unsafe { std::vec::Vec::from_raw_parts(ptr as *mut u8, len, len) }) + .with_borrow_mut(|s| s.remove(&ptr)) + .map(|len| unsafe { std::vec::Vec::from_raw_parts(ptr.cast_mut(), len, len) }) } /* memory functions ------------------------------------------- */ @@ -121,7 +95,7 @@ mod tests { let ptr = alloc(len); let vec = vec![1; len]; unsafe { std::ptr::copy(vec.as_ptr(), ptr, vec.len()) } - let val = into_bytes(ptr as i64).unwrap(); + let val = into_bytes(ptr).unwrap(); assert_eq!(val, vec); assert!(GLOBAL_STORE.with_borrow(|s| s.get(&(ptr.cast_const())).is_none())); } diff --git a/x/programs/rust/wasmlanche-sdk/src/params.rs b/x/programs/rust/wasmlanche-sdk/src/params.rs index fedcb05e24..b8d4b695cd 100644 --- a/x/programs/rust/wasmlanche-sdk/src/params.rs +++ b/x/programs/rust/wasmlanche-sdk/src/params.rs @@ -1,5 +1,5 @@ use crate::{ - memory::{to_host_ptr, HostPtr}, + memory::{to_ffi_ptr, CPointer}, state::Error as StateError, Error, }; @@ -23,8 +23,8 @@ pub struct Param(Vec); pub struct Params(Vec); impl Params { - pub(crate) fn into_host_ptr(self) -> Result { - to_host_ptr(&self.0) + pub(crate) fn into_ffi_ptr(self) -> Result { + to_ffi_ptr(&self.0) } } diff --git a/x/programs/rust/wasmlanche-sdk/src/program.rs b/x/programs/rust/wasmlanche-sdk/src/program.rs index 8462f9ae4d..8e67d78de2 100644 --- a/x/programs/rust/wasmlanche-sdk/src/program.rs +++ b/x/programs/rust/wasmlanche-sdk/src/program.rs @@ -1,9 +1,10 @@ -use std::hash::Hash; - +use crate::{ + memory::{to_ffi_ptr, CPointer}, + state::{Error as StateError, Key, State}, + Params, +}; use borsh::{BorshDeserialize, BorshSerialize}; - -use crate::state::Key; -use crate::{memory::to_host_ptr, state::Error as StateError, state::State, Params}; +use std::hash::Hash; /// Represents the current Program in the context of the caller. Or an external /// program that is being invoked. @@ -48,9 +49,9 @@ impl Program { max_units: i64, ) -> Result { // flatten the args into a single byte vector - let target = to_host_ptr(self.id())?; - let function = to_host_ptr(function_name.as_bytes())?; - let args = args.into_host_ptr()?; + let target = to_ffi_ptr(self.id())?; + let function = to_ffi_ptr(function_name.as_bytes())?; + let args = args.into_ffi_ptr()?; Ok(unsafe { _call_program(target, function, args, max_units) }) } @@ -59,5 +60,10 @@ impl Program { #[link(wasm_import_module = "program")] extern "C" { #[link_name = "call_program"] - fn _call_program(target_id: i64, function: i64, args_ptr: i64, max_units: i64) -> i64; + fn _call_program( + target_id: CPointer, + function: CPointer, + args_ptr: CPointer, + max_units: i64, + ) -> i64; } diff --git a/x/programs/rust/wasmlanche-sdk/src/state.rs b/x/programs/rust/wasmlanche-sdk/src/state.rs index f857128477..9bd98a08c0 100644 --- a/x/programs/rust/wasmlanche-sdk/src/state.rs +++ b/x/programs/rust/wasmlanche-sdk/src/state.rs @@ -1,4 +1,4 @@ -use crate::{from_host_ptr, program::Program, state::Error as StateError}; +use crate::{memory::into_bytes, program::Program, state::Error as StateError}; use borsh::{from_slice, to_vec, BorshDeserialize, BorshSerialize}; use std::{collections::HashMap, hash::Hash, ops::Deref}; @@ -102,13 +102,15 @@ where let val_bytes = if let Some(val) = self.cache.get(&key) { val } else { - let val_ptr = unsafe { host::get_bytes(&self.program, &key.clone().into())? }; - if val_ptr < 0 { + let val = unsafe { host::get_bytes(&self.program, &key.clone().into())? }; + let val_ptr = val as *const u8; + // TODO write a test for that + if val_ptr.is_null() { return Err(Error::Read); } // TODO Wrap in OK for now, change from_raw_ptr to return Result - let bytes = from_host_ptr(val_ptr)?; + let bytes = into_bytes(val_ptr).ok_or(Error::InvalidPointer)?; self.cache.entry(key).or_insert(bytes) }; @@ -162,23 +164,22 @@ macro_rules! ffi_linker { #[link(wasm_import_module = $mod)] extern "C" { #[link_name = $link] - fn ffi(caller: i64, key: i64) -> i64; + fn ffi(caller: CPointer, key: CPointer) -> i32; } - let $caller = to_host_ptr($caller.id())?; - let $key = to_host_ptr($key)?; + let $caller = to_ffi_ptr($caller.id())?; + let $key = to_ffi_ptr($key)?; }; ($mod:literal, $link:literal, $caller:ident, $key:ident, $value:ident) => { #[link(wasm_import_module = $mod)] extern "C" { #[link_name = $link] - fn ffi(caller: i64, key: i64, value: i64) -> i64; + fn ffi(caller: CPointer, key: CPointer, value: CPointer) -> i32; } - let $caller = to_host_ptr($caller.id())?; - let $key = to_host_ptr($key)?; - let value_bytes = borsh::to_vec($value).map_err(|_| Error::Serialization)?; - let $value = to_host_ptr(&value_bytes)?; + let $caller = to_ffi_ptr($caller.id())?; + let $key = to_ffi_ptr($key)?; + let $value = to_ffi_ptr($value)?; }; } @@ -205,18 +206,18 @@ macro_rules! call_host_fn { } mod host { - use super::{BorshSerialize, Key, Program}; - use crate::{memory::to_host_ptr, state::Error}; + use super::{Key, Program}; + use crate::{ + memory::{to_ffi_ptr, CPointer}, + state::Error, + }; - /// Persists the bytes at `value` at key on the host storage. - pub(super) unsafe fn put_bytes(caller: &Program, key: &Key, value: &V) -> Result<(), Error> - where - V: BorshSerialize, - { + /// Persists the bytes at key on the host storage. + pub(super) unsafe fn put_bytes(caller: &Program, key: &Key, bytes: &[u8]) -> Result<(), Error> { match call_host_fn! { wasm_import_module = "state" link_name = "put" - args = (caller, key, value) + args = (caller, key, bytes) } { 0 => Ok(()), _ => Err(Error::Write), @@ -224,11 +225,11 @@ mod host { } /// Gets the bytes associated with the key from the host. - pub(super) unsafe fn get_bytes(caller: &Program, key: &Key) -> Result { + pub(super) unsafe fn get_bytes(caller: &Program, key: &Key) -> Result { Ok(call_host_fn! { wasm_import_module = "state" - link_name = "get" - args = (caller, key) + link_name = "get" + args = (caller, key) }) } diff --git a/x/programs/rust/wasmlanche-sdk/tests/public_function.rs b/x/programs/rust/wasmlanche-sdk/tests/public_function.rs index 7e68339d86..4a3372414e 100644 --- a/x/programs/rust/wasmlanche-sdk/tests/public_function.rs +++ b/x/programs/rust/wasmlanche-sdk/tests/public_function.rs @@ -61,8 +61,8 @@ struct TestCrate { store: Store<()>, instance: Instance, allocate_func: TypedFunc, - always_true_func: TypedFunc, - combine_last_bit_of_each_id_byte_func: TypedFunc, + always_true_func: TypedFunc, + combine_last_bit_of_each_id_byte_func: TypedFunc, } impl TestCrate { @@ -75,11 +75,11 @@ impl TestCrate { .get_typed_func::(&mut store, "alloc") .expect("failed to find `alloc` function"); - let always_true_func = instance - .get_typed_func::(&mut store, "always_true_guest") + let always_true_func: TypedFunc = instance + .get_typed_func::(&mut store, "always_true_guest") .expect("failed to find `always_true` function"); let combine_last_bit_of_each_id_byte_func = instance - .get_typed_func::(&mut store, "combine_last_bit_of_each_id_byte_guest") + .get_typed_func::(&mut store, "combine_last_bit_of_each_id_byte_guest") .expect("combine_last_bit_of_each_id_byte should be a function"); Self { @@ -121,14 +121,14 @@ impl TestCrate { fn always_true(&mut self, ptr: i32) -> bool { self.always_true_func - .call(&mut self.store, ptr as i64) + .call(&mut self.store, ptr) .expect("failed to call `always_true` function") - == true as i64 + == true as i32 } fn combine_last_bit_of_each_id_byte(&mut self, ptr: i32) -> u32 { self.combine_last_bit_of_each_id_byte_func - .call(&mut self.store, ptr as i64) + .call(&mut self.store, ptr) .expect("failed to call `combine_last_bit_of_each_id_byte` function") } } diff --git a/x/programs/rust/wasmlanche-sdk/tests/test-crate/src/lib.rs b/x/programs/rust/wasmlanche-sdk/tests/test-crate/src/lib.rs index db0209ea67..caef966296 100644 --- a/x/programs/rust/wasmlanche-sdk/tests/test-crate/src/lib.rs +++ b/x/programs/rust/wasmlanche-sdk/tests/test-crate/src/lib.rs @@ -3,8 +3,8 @@ use wasmlanche_sdk::{public, Context}; #[public] -pub fn always_true(_: Context) -> i64 { - true as i64 +pub fn always_true(_: Context) -> i32 { + true as i32 } #[public] diff --git a/x/programs/tests/fixture/counter.wasm b/x/programs/tests/fixture/counter.wasm index 0f811555bf..243cae1aff 100755 Binary files a/x/programs/tests/fixture/counter.wasm and b/x/programs/tests/fixture/counter.wasm differ diff --git a/x/programs/tests/fixture/token.wasm b/x/programs/tests/fixture/token.wasm index 9d30138ce0..185ceec548 100755 Binary files a/x/programs/tests/fixture/token.wasm and b/x/programs/tests/fixture/token.wasm differ