Skip to content

Commit

Permalink
Check serialized coroutine state is valid (#101)
Browse files Browse the repository at this point in the history
When serializing a durable coroutine, the serialized representation is
probably only valid for the build that generated it.

There are no checks to ensure that the serialized representation of a
coroutine is valid for a particular build. When the serialized
representation is no longer valid, the program crashes with cryptic
error messages.

This PR fixes the issue by prefixing the serialized representation with
a build hash. If there's a hash mismatch, `Unmarshal` returns an
`ErrInvalidState` error that the program can handle gracefully (for
example, the program may choose to restart the coroutine).

Rather than hashing the binary (which might be large), we extract the Go
Build ID and use it as the hash. The build ID changes when the source
changes, but also when the compiler and any other dependencies change
too (which both may affect the serialized representation of coroutines).
I adapted the ELF and Mach-O parsers from
https://github.com/golang/go/tree/3803c8588/src/cmd/internal/buildid,
but kept it simple and only added a parser for binaries generated by the
Go compiler (and not gccgo).
  • Loading branch information
chriso authored Oct 19, 2023
2 parents bca5511 + 1c881fd commit 26bd486
Show file tree
Hide file tree
Showing 12 changed files with 199 additions and 69 deletions.
12 changes: 9 additions & 3 deletions coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ func LoadContext[R, S any]() *Context[R, S] {
}
}

// ErrNotDurable is an error that occurs when attempting to
// serialize a coroutine that is not durable.
var ErrNotDurable = errors.New("only durable coroutines can be serialized")
var (
// ErrNotDurable is an error that occurs when attempting to
// serialize a coroutine that is not durable.
ErrNotDurable = errors.New("only durable coroutines can be serialized")

// ErrInvalidState is an error that occurs when attempting to
// deserialize a coroutine that was serialized in another build.
ErrInvalidState = errors.New("durable coroutine was serialized in another build")
)
9 changes: 8 additions & 1 deletion coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package coroutine

import (
"errors"
"runtime"
"unsafe"

Expand Down Expand Up @@ -93,7 +94,13 @@ func (c *Context[R, S]) Marshal() ([]byte, error) {
// context.
func (c *Context[R, S]) Unmarshal(b []byte) (int, error) {
start := len(b)
v, b := types.Deserialize(b)
v, b, err := types.Deserialize(b)
if err != nil {
if errors.Is(err, types.ErrBuildIDMismatch) {
return 0, ErrInvalidState
}
return 0, err
}
s := v.(*serializedCoroutine)
c.entry = s.entry
c.Stack = s.stack
Expand Down
2 changes: 2 additions & 0 deletions examples/scrape/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ module scrape
go 1.21.0

require github.com/stealthrocket/coroutine v0.0.0-20230927150141-7c62a3508ce8

replace github.com/stealthrocket/coroutine => ../../
6 changes: 5 additions & 1 deletion examples/scrape/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ func main() {
log.Fatal(err)
}
} else if _, err := coro.Context().Unmarshal(state); err != nil {
log.Fatal(err)
if errors.Is(err, coroutine.ErrInvalidState) {
log.Println("warning: coroutine state is no longer valid. Starting fresh")
} else {
log.Fatal(err)
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions types/buildid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package types

// buildID is the build identifier for the binary.
var buildID string
6 changes: 5 additions & 1 deletion types/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package types

import (
"debug/gosym"
"errors"
"io"
"reflect"
"runtime"
Expand Down Expand Up @@ -150,7 +151,10 @@ func initFunctionTables(pclntab, symtab []byte) {
}
}

func readAll(r io.ReaderAt, size uint64) ([]byte, error) {
func readSection(r io.ReaderAt, size uint64) ([]byte, error) {
if r == nil {
return nil, errors.New("section missing")
}
b := make([]byte, size)
n, err := r.ReadAt(b, 0)
if err != nil && n < len(b) {
Expand Down
27 changes: 0 additions & 27 deletions types/func_darwin.go

This file was deleted.

27 changes: 0 additions & 27 deletions types/func_linux.go

This file was deleted.

65 changes: 65 additions & 0 deletions types/obj_darwin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package types

import (
"bytes"
"debug/macho"
"os"
"strconv"
)

func init() {
f, err := macho.Open(os.Args[0])
if err != nil {
panic("cannot read Mach-O binary: " + err.Error())
}
defer f.Close()

initMachOFunctionTables(f)
initMachOBuildID(f)
}

func initMachOFunctionTables(f *macho.File) {
pclntab := f.Section("__gopclntab")
pclntabData, err := readSection(pclntab, pclntab.Size)
if err != nil {
panic("cannot read pclntab: " + err.Error())
}
symtab := f.Section("__gosymtab")
symtabData, err := readSection(symtab, symtab.Size)
if err != nil {
panic("cannot read symtab: " + err.Error())
}
initFunctionTables(pclntabData, symtabData)
}

func initMachOBuildID(f *macho.File) {
text := f.Section("__text")

// Read up to 32KB from the text section.
// See https://github.com/golang/go/blob/3803c858/src/cmd/internal/buildid/note.go#L199
data, err := readSection(text, min(text.Size, 32*1024))
if err != nil {
panic("cannot read __text: " + err.Error())
}

// From https://github.com/golang/go/blob/3803c858/src/cmd/internal/buildid/buildid.go#L300
i := bytes.Index(data, buildIDPrefix)
if i < 0 {
panic("build ID not found")
}
j := bytes.Index(data[i+len(buildIDPrefix):], buildIDEnd)
if j < 0 {
panic("build ID not found")
}
quoted := data[i+len(buildIDPrefix)-1 : i+len(buildIDPrefix)+j+1]
id, err := strconv.Unquote(string(quoted))
if err != nil {
panic("build ID not found")
}
buildID = id
}

var (
buildIDPrefix = []byte("\xff Go build ID: \"")
buildIDEnd = []byte("\"\n \xff")
)
56 changes: 56 additions & 0 deletions types/obj_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package types

import (
"bytes"
"debug/elf"
"os"
)

func init() {
f, err := elf.Open(os.Args[0])
if err != nil {
panic("cannot read elf binary: " + err.Error())
}
defer f.Close()

initELFFunctionTables(f)
initELFBuildID(f)
}

func initELFFunctionTables(f *elf.File) {
pclntab := f.Section(".gopclntab")
pclntabData, err := readSection(pclntab, pclntab.Size)
if err != nil {
panic("cannot read pclntab: " + err.Error())
}
symtab := f.Section(".gosymtab")
symtabData, err := readSection(symtab, symtab.Size)
if err != nil {
panic("cannot read symtab: " + err.Error())
}
initFunctionTables(pclntabData, symtabData)
}

func initELFBuildID(f *elf.File) {
noteSection := f.Section(".note.go.buildid")
note, err := readSection(noteSection, noteSection.Size)
if err != nil {
panic("cannot read build ID: " + err.Error())
}

// See https://github.com/golang/go/blob/3803c858/src/cmd/internal/buildid/note.go#L135C3-L135C3
nameSize := f.ByteOrder.Uint32(note)
valSize := f.ByteOrder.Uint32(note[4:])
tag := f.ByteOrder.Uint32(note[8:])
nname := note[12:16]
if nameSize == 4 && 16+valSize <= uint32(len(note)) && tag == buildIDTag && bytes.Equal(nname, buildIDNote) {
buildID = string(note[16 : 16+valSize])
} else {
panic("build ID not found")
}
}

var (
buildIDNote = []byte("Go\x00\x00")
buildIDTag = uint32(4)
)
34 changes: 29 additions & 5 deletions types/serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package types

import (
"encoding/binary"
"errors"
"fmt"
"reflect"
"unsafe"
Expand All @@ -15,6 +16,10 @@ import (
// sID is the unique sID of a pointer or type in the serialized format.
type sID int64

// ErrBuildIDMismatch is an error that occurs when a program attempts
// to deserialize objects from another build.
var ErrBuildIDMismatch = errors.New("build ID mismatch")

// Serialize x.
//
// The output of Serialize can be reconstructed back to a Go value using
Expand All @@ -35,14 +40,17 @@ func Serialize(x any) []byte {
}

// Deserialize value from b. Return left over bytes.
func Deserialize(b []byte) (interface{}, []byte) {
d := newDeserializer(b)
func Deserialize(b []byte) (interface{}, []byte, error) {
d, err := newDeserializer(b)
if err != nil {
return nil, nil, err
}
var x interface{}
px := &x
t := reflect.TypeOf(px).Elem()
p := unsafe.Pointer(px)
deserializeInterface(d, t, p)
return x, d.b
return x, d.b, nil
}

type Deserializer struct {
Expand All @@ -54,11 +62,22 @@ type Deserializer struct {
b []byte
}

func newDeserializer(b []byte) *Deserializer {
func newDeserializer(b []byte) (*Deserializer, error) {
buildIDLength, n := binary.Varint(b)
if n <= 0 || buildIDLength <= 0 || buildIDLength > int64(len(buildID)) || int64(len(b)-n) < buildIDLength {
return nil, fmt.Errorf("missing or invalid build ID")
}
b = b[n:]
serializedBuildID := string(b[:buildIDLength])
b = b[buildIDLength:]
if serializedBuildID != buildID {
return nil, fmt.Errorf("%w: got %v, expect %v", ErrBuildIDMismatch, serializedBuildID, buildID)
}

return &Deserializer{
ptrs: make(map[sID]unsafe.Pointer),
b: b,
}
}, nil
}

func (d *Deserializer) readPtr() (unsafe.Pointer, sID) {
Expand Down Expand Up @@ -123,9 +142,14 @@ type Serializer struct {
}

func newSerializer() *Serializer {
b := make([]byte, 0, 128)
b = binary.AppendVarint(b, int64(len(buildID)))
b = append(b, buildID...)

return &Serializer{
ptrs: make(map[unsafe.Pointer]sID),
scanptrs: make(map[reflect.Value]struct{}),
b: b,
}
}

Expand Down
20 changes: 16 additions & 4 deletions types/serde_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ func TestSerdeTime(t *testing.T) {

func testSerdeTime(t *testing.T, x time.Time) {
b := Serialize(x)
out, _ := Deserialize(b)
out, _, err := Deserialize(b)
if err != nil {
t.Fatal(err)
}

if !x.Equal(out.(time.Time)) {
t.Errorf("expected %v, got %v", x, out)
Expand Down Expand Up @@ -120,7 +123,10 @@ func TestReflect(t *testing.T) {
typ := reflect.TypeOf(x)
t.Run(fmt.Sprintf("%d-%s", i, typ), func(t *testing.T) {
b := Serialize(x)
out, b := Deserialize(b)
out, b, err := Deserialize(b)
if err != nil {
t.Fatal(err)
}

assertEqual(t, x, out)

Expand Down Expand Up @@ -302,7 +308,10 @@ func TestReflectCustom(t *testing.T) {
// unserializable function in CheckRedirect.

b := Serialize(x)
out, b := Deserialize(b)
out, b, err := Deserialize(b)
if err != nil {
t.Fatal(err)
}

assertEqual(t, x.Timeout, out.(http.Client).Timeout)

Expand Down Expand Up @@ -615,7 +624,10 @@ func assertRoundTrip[T any](t *testing.T, orig T) T {
t.Helper()

b := Serialize(orig)
out, b := Deserialize(b)
out, b, err := Deserialize(b)
if err != nil {
t.Fatal(err)
}

assertEqual(t, orig, out)

Expand Down

0 comments on commit 26bd486

Please sign in to comment.