Skip to content

Commit

Permalink
Safe unmarshal (#121)
Browse files Browse the repository at this point in the history
The deserialization layer panics on invalid input. Bad state can take
down the entire program currently.

This PR updates `(*coroutine.Context).Unmarshal` to catch these panics
and return an error value, allowing the caller to gracefully handle bad
input.

A better long term solution would be to abandon the `deserialize*()`
functions and instead build the deserialization layer on the scanner
introduced in #116, which
avoids panics in many cases and could be updated to avoid panics in all
cases.
  • Loading branch information
chriso authored Nov 30, 2023
2 parents 0f55583 + aa1ec10 commit b3b03b5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
14 changes: 10 additions & 4 deletions types/serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ func Serialize(x any) ([]byte, error) {
}

// Deserialize value from b. Return left over bytes.
func Deserialize(b []byte) (interface{}, error) {
func Deserialize(b []byte) (x interface{}, err error) {
defer func() {
// FIXME: the deserialize*() functions panic on invalid input
if e := recover(); e != nil {
err = fmt.Errorf("cannot deserialize state: %v", e)
}
}()

var state coroutinev1.State
if err := state.UnmarshalVT(b); err != nil {
return nil, err
Expand All @@ -77,16 +84,15 @@ func Deserialize(b []byte) (interface{}, error) {

d := newDeserializer(state.Root.Data, state.Types, state.Functions, state.Regions, state.Strings)

var x interface{}
px := &x
t := reflect.TypeOf(px).Elem()
p := unsafe.Pointer(px)
deserializeInterface(d, t, p)

if len(d.b) != 0 {
return nil, errors.New("trailing bytes")
err = errors.New("trailing bytes")
}
return x, nil
return
}

type Deserializer struct {
Expand Down
30 changes: 30 additions & 0 deletions types/serde_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"testing"
"time"
"unsafe"

coroutinev1 "github.com/stealthrocket/coroutine/gen/proto/go/coroutine/v1"
)

func TestSerdeTime(t *testing.T) {
Expand Down Expand Up @@ -936,3 +938,31 @@ func BenchmarkRoundtripString(b *testing.B) {
}
}
}

func TestUnmarshalInvalid(t *testing.T) {
for _, invalid := range [][]byte{
[]byte("foobar"),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
mustSerialize(&coroutinev1.State{
Build: buildInfo,
Root: &coroutinev1.Region{
Type: 0,
Data: []byte{1, 0},
},
}),
} {
t.Run(fmt.Sprintf("%v", invalid), func(t *testing.T) {
if _, err := Deserialize(invalid); err == nil {
t.Error("expected error, got nil")
}
})
}
}

func mustSerialize(s *coroutinev1.State) []byte {
b, err := s.MarshalVT()
if err != nil {
panic(err)
}
return b
}

0 comments on commit b3b03b5

Please sign in to comment.