diff --git a/gen/proto/go/coroutine/v1/coroutine.pb.go b/gen/proto/go/coroutine/v1/coroutine.pb.go index 223f467..063de0c 100644 --- a/gen/proto/go/coroutine/v1/coroutine.pb.go +++ b/gen/proto/go/coroutine/v1/coroutine.pb.go @@ -197,10 +197,16 @@ type Region struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - // Type is the type of the region. + // Type is the type of the region, shifted left by one. + // + // The least significant bit indicates that this region represents + // an array, and that the type is the array element type rather + // than the object that's encoded in this region. Type uint32 `protobuf:"varint,1,opt,name=type,proto3" json:"type,omitempty"` + // Array length, for regions that are arrays. + ArrayLength uint32 `protobuf:"varint,2,opt,name=array_length,json=arrayLength,proto3" json:"array_length,omitempty"` // Data is the encoded contents of the memory region. - Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + Data []byte `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"` } func (x *Region) Reset() { @@ -242,6 +248,13 @@ func (x *Region) GetType() uint32 { return 0 } +func (x *Region) GetArrayLength() uint32 { + if x != nil { + return x.ArrayLength + } + return 0 +} + func (x *Region) GetData() []byte { if x != nil { return x.Data @@ -279,23 +292,25 @@ var file_coroutine_v1_coroutine_proto_rawDesc = []byte{ 0x73, 0x22, 0x3b, 0x0a, 0x05, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x0e, 0x0a, 0x02, 0x6f, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x6f, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x61, 0x72, - 0x63, 0x68, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x61, 0x72, 0x63, 0x68, 0x22, 0x30, + 0x63, 0x68, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x61, 0x72, 0x63, 0x68, 0x22, 0x53, 0x0a, 0x06, 0x52, 0x65, 0x67, 0x69, 0x6f, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, - 0x42, 0xbd, 0x01, 0x0a, 0x10, 0x63, 0x6f, 0x6d, 0x2e, 0x63, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, - 0x6e, 0x65, 0x2e, 0x76, 0x31, 0x42, 0x0e, 0x43, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x65, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x48, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x74, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x72, 0x6f, 0x63, 0x6b, 0x65, - 0x74, 0x2f, 0x63, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x65, 0x2f, 0x67, 0x65, 0x6e, 0x2f, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x2f, 0x63, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, - 0x6e, 0x65, 0x2f, 0x76, 0x31, 0x3b, 0x63, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x65, 0x76, - 0x31, 0xa2, 0x02, 0x03, 0x43, 0x58, 0x58, 0xaa, 0x02, 0x0c, 0x43, 0x6f, 0x72, 0x6f, 0x75, 0x74, - 0x69, 0x6e, 0x65, 0x2e, 0x56, 0x31, 0xca, 0x02, 0x0c, 0x43, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, - 0x6e, 0x65, 0x5c, 0x56, 0x31, 0xe2, 0x02, 0x18, 0x43, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, - 0x65, 0x5c, 0x56, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, - 0xea, 0x02, 0x0d, 0x43, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x65, 0x3a, 0x3a, 0x56, 0x31, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x21, 0x0a, 0x0c, + 0x61, 0x72, 0x72, 0x61, 0x79, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0d, 0x52, 0x0b, 0x61, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x12, + 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, + 0x61, 0x74, 0x61, 0x42, 0xbd, 0x01, 0x0a, 0x10, 0x63, 0x6f, 0x6d, 0x2e, 0x63, 0x6f, 0x72, 0x6f, + 0x75, 0x74, 0x69, 0x6e, 0x65, 0x2e, 0x76, 0x31, 0x42, 0x0e, 0x43, 0x6f, 0x72, 0x6f, 0x75, 0x74, + 0x69, 0x6e, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x48, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x74, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x72, 0x6f, + 0x63, 0x6b, 0x65, 0x74, 0x2f, 0x63, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x65, 0x2f, 0x67, + 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x2f, 0x63, 0x6f, 0x72, 0x6f, + 0x75, 0x74, 0x69, 0x6e, 0x65, 0x2f, 0x76, 0x31, 0x3b, 0x63, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, + 0x6e, 0x65, 0x76, 0x31, 0xa2, 0x02, 0x03, 0x43, 0x58, 0x58, 0xaa, 0x02, 0x0c, 0x43, 0x6f, 0x72, + 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x65, 0x2e, 0x56, 0x31, 0xca, 0x02, 0x0c, 0x43, 0x6f, 0x72, 0x6f, + 0x75, 0x74, 0x69, 0x6e, 0x65, 0x5c, 0x56, 0x31, 0xe2, 0x02, 0x18, 0x43, 0x6f, 0x72, 0x6f, 0x75, + 0x74, 0x69, 0x6e, 0x65, 0x5c, 0x56, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, + 0x61, 0x74, 0x61, 0xea, 0x02, 0x0d, 0x43, 0x6f, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x65, 0x3a, + 0x3a, 0x56, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/gen/proto/go/coroutine/v1/coroutine_vtproto.pb.go b/gen/proto/go/coroutine/v1/coroutine_vtproto.pb.go index cac79dd..e12fc18 100644 --- a/gen/proto/go/coroutine/v1/coroutine_vtproto.pb.go +++ b/gen/proto/go/coroutine/v1/coroutine_vtproto.pb.go @@ -211,7 +211,12 @@ func (m *Region) MarshalToSizedBufferVT(dAtA []byte) (int, error) { copy(dAtA[i:], m.Data) i = encodeVarint(dAtA, i, uint64(len(m.Data))) i-- - dAtA[i] = 0x12 + dAtA[i] = 0x1a + } + if m.ArrayLength != 0 { + i = encodeVarint(dAtA, i, uint64(m.ArrayLength)) + i-- + dAtA[i] = 0x10 } if m.Type != 0 { i = encodeVarint(dAtA, i, uint64(m.Type)) @@ -298,6 +303,9 @@ func (m *Region) SizeVT() (n int) { if m.Type != 0 { n += 1 + sov(uint64(m.Type)) } + if m.ArrayLength != 0 { + n += 1 + sov(uint64(m.ArrayLength)) + } l = len(m.Data) if l > 0 { n += 1 + l + sov(uint64(l)) @@ -793,6 +801,25 @@ func (m *Region) UnmarshalVT(dAtA []byte) error { } } case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field ArrayLength", wireType) + } + m.ArrayLength = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.ArrayLength |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 3: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) } diff --git a/proto/coroutine/v1/coroutine.proto b/proto/coroutine/v1/coroutine.proto index 3205334..c5acd64 100644 --- a/proto/coroutine/v1/coroutine.proto +++ b/proto/coroutine/v1/coroutine.proto @@ -42,9 +42,16 @@ message Build { // Region is an encoded region of memory. message Region { - // Type is the type of the region. + // Type is the type of the region, shifted left by one. + // + // The least significant bit indicates that this region represents + // an array, and that the type is the array element type rather + // than the object that's encoded in this region. uint32 type = 1; + // Array length, for regions that are arrays. + uint32 array_length = 2; + // Data is the encoded contents of the memory region. - bytes data = 2; + bytes data = 3; } diff --git a/types/inspect.go b/types/inspect.go index 272490d..51b7c3a 100644 --- a/types/inspect.go +++ b/types/inspect.go @@ -124,7 +124,8 @@ type Type struct { index int } -// Index is the index of the type in the serialized state. +// Index is the index of the type in the serialized state, or -1 +// if the type is derived from a serialized type. func (t *Type) Index() int { return t.index } @@ -597,7 +598,27 @@ func (t *Region) Index() int { // Type is the type of the region. func (r *Region) Type() *Type { - return r.state.Type(int(r.region.Type - 1)) + t := r.state.Type(int((r.region.Type >> 1) - 1)) + if r.region.Type&1 == 1 { + t = newArrayType(r.state, int64(r.region.ArrayLength), t) + } + return t +} + +func newArrayType(state *State, length int64, t *Type) *Type { + idx := t.Index() + if idx < 0 { + panic("BUG") + } + return &Type{ + state: state, + typ: &coroutinev1.Type{ + Kind: coroutinev1.Kind_KIND_ARRAY, + Length: int64(length), + Elem: uint32(idx + 1), + }, + index: -1, // aka. a derived type + } } // Size is the size of the region in bytes. @@ -727,7 +748,10 @@ func (s *Scanner) Next() bool { case scancustom: if uint64(s.pos) < last.customtil { - return s.readTypedAny(len(s.stack)) + if !s.readType() { + return false + } + return s.readAny(s.typ, len(s.stack)) } if uint64(s.pos) > last.customtil { s.err = fmt.Errorf("invalid custom object size") @@ -951,13 +975,22 @@ func (s *Scanner) readAny(t *Type, depth int) (ok bool) { } } -func (s *Scanner) readTypedAny(depth int) (ok bool) { +func (s *Scanner) readType() (ok bool) { id, ok := s.getVarint() if !ok { return false } t := s.state.Type(int(id - 1)) - return s.readAny(t, depth) + + len, ok := s.getVarint() + if !ok { + return false + } + if len >= 0 { + t = newArrayType(s.state, len, t) + } + s.typ = t + return true } func (s *Scanner) readUint8() (ok bool) { @@ -1110,13 +1143,9 @@ func (s *Scanner) readInterface() (ok bool) { s.nil = true return true } - - typeid, ok := s.getVarint() - if !ok { + if !s.readType() { return false } - s.typ = s.state.Type(int(typeid - 1)) - return s.readRegionPointer() } diff --git a/types/reflect.go b/types/reflect.go index 67d8752..98b313a 100644 --- a/types/reflect.go +++ b/types/reflect.go @@ -11,13 +11,23 @@ import ( ) func serializeType(s *Serializer, t reflect.Type) { - x := s.types.ToType(t) - serializeVarint(s, int(x)) + if t != nil && t.Kind() == reflect.Array { + id := s.types.ToType(t.Elem()) + serializeVarint(s, int(id)) + serializeVarint(s, t.Len()) + } else { + id := s.types.ToType(t) + serializeVarint(s, int(id)) + serializeVarint(s, -1) + } + } -func deserializeType(d *Deserializer) reflect.Type { +func deserializeType(d *Deserializer) (reflect.Type, int) { id := deserializeVarint(d) - return d.types.ToReflect(typeid(id)) + length := deserializeVarint(d) + t := d.types.ToReflect(typeid(id)) + return t, length } func serializeAny(s *Serializer, t reflect.Type, p unsafe.Pointer) { @@ -105,7 +115,15 @@ func deserializeAny(d *Deserializer, t reflect.Type, p unsafe.Pointer) { switch t { case reflectValueType: - rt := deserializeType(d) + rt, length := deserializeType(d) + if length >= 0 { + // We can't avoid the ArrayOf call here. We need to build a + // reflect.Type in order to return a reflect.Value. The only + // time this path is taken is if the user has explicitly serialized + // a reflect.Value, or some other data type that contains or points + // to a reflect.Value. + rt = reflect.ArrayOf(length, rt) + } v := deserializeReflectValue(d, rt) reflect.NewAt(reflectValueType, p).Elem().Set(reflect.ValueOf(v)) return @@ -237,7 +255,7 @@ func serializeReflectValue(s *Serializer, t reflect.Type, v reflect.Value) { serializeFunc(s, t, unsafe.Pointer(&addr)) } case reflect.Pointer: - serializePointedAt(s, t.Elem(), v.UnsafePointer()) + serializePointedAt(s, t.Elem(), -1, v.UnsafePointer()) default: panic(fmt.Sprintf("not implemented: serializing reflect.Value with type %s (%s)", t, t.Kind())) } @@ -338,16 +356,16 @@ func deserializeReflectValue(d *Deserializer, t reflect.Type) (v reflect.Value) *(*unsafe.Pointer)(p) = unsafe.Pointer(&fn.Addr) } case reflect.Pointer: - ep := deserializePointedAt(d, t.Elem()) + ep := deserializePointedAt(d, t.Elem(), -1) v = reflect.New(t).Elem() - v.Set(ep) + v.Set(reflect.NewAt(t.Elem(), ep)) default: panic(fmt.Sprintf("not implemented: deserializing reflect.Value with type %s", t)) } return } -func serializePointedAt(s *Serializer, t reflect.Type, p unsafe.Pointer) { +func serializePointedAt(s *Serializer, et reflect.Type, length int, p unsafe.Pointer) { // If this is a nil pointer, write it as such. if p == nil { serializeVarint(s, 0) @@ -373,14 +391,15 @@ func serializePointedAt(s *Serializer, t reflect.Type, p unsafe.Pointer) { // to the serializer). Scanning here might cause known regions to // expand, invalidating those that have already been encoded. if !r.valid() { - if t == nil { + if et == nil { panic("cannot serialize unsafe.Pointer pointing to region of unknown size") } r.addr = p - r.typ = t + r.typ = et + r.len = length } - if r.typ.Kind() == reflect.Map { + if r.len < 0 && r.typ.Kind() == reflect.Map { serializeMap(s, r.typ, r.addr) return } @@ -396,46 +415,57 @@ func serializePointedAt(s *Serializer, t reflect.Type, p unsafe.Pointer) { } region := &coroutinev1.Region{ - Type: s.types.ToType(r.typ), + Type: s.types.ToType(r.typ) << 1, + } + if r.len >= 0 { + region.Type |= 1 + region.ArrayLength = uint32(r.len) } s.regions = append(s.regions, region) - if r.typ.Kind() == reflect.Array && r.typ.Elem().Kind() == reflect.Uint8 { - // Fast path for byte arrays. - if n := r.typ.Len(); n > 0 { - region.Data = unsafe.Slice((*byte)(r.addr), n) + // Fast path for byte arrays. + if r.len >= 0 && r.typ.Kind() == reflect.Uint8 { + if r.len > 0 { + region.Data = unsafe.Slice((*byte)(r.addr), r.len) + } + return + } + + regionSer := s.fork() + if r.len >= 0 { // array + es := int(r.typ.Size()) + for i := 0; i < r.len; i++ { + serializeAny(regionSer, r.typ, unsafe.Add(r.addr, i*es)) } } else { - regionSer := s.fork() serializeAny(regionSer, r.typ, r.addr) - region.Data = regionSer.b } + region.Data = regionSer.b } -func deserializePointedAt(d *Deserializer, t reflect.Type) reflect.Value { +func deserializePointedAt(d *Deserializer, t reflect.Type, length int) unsafe.Pointer { // This function is a bit different than the other deserialize* ones // because it deserializes into an unknown location. As a result, - // instead of taking an unsafe.Pointer as an input, it returns a - // reflect.Value that contains a *T (where T is given by the argument - // t). + // instead of taking an unsafe.Pointer as an input, it returns an + // unsafe.Pointer to a deserialized object. - if t.Kind() == reflect.Map { + if length < 0 && t.Kind() == reflect.Map { m := reflect.New(t) + p := m.UnsafePointer() deserializeMapReflect(d, t, m.Elem(), m.UnsafePointer()) - return m + return p } id := deserializeVarint(d) if id == 0 { // Nil pointer. - return reflect.NewAt(t, unsafe.Pointer(nil)) + return unsafe.Pointer(nil) } offset := deserializeVarint(d) if id == -1 { // Pointer into static uint64 table. - p := staticPointer(offset) - return reflect.NewAt(t, p) + return staticPointer(offset) } p := d.ptrs[sID(id)] @@ -446,25 +476,38 @@ func deserializePointedAt(d *Deserializer, t reflect.Type) reflect.Value { } region := d.regions[id-1] - regionType := d.types.ToReflect(typeid(region.Type)) + regionType := d.types.ToReflect(typeid(region.Type >> 1)) - container := reflect.New(regionType) - p = container.UnsafePointer() - d.store(sID(id), p) + if region.Type&1 == 1 { + elemSize := int(regionType.Size()) + length := int(region.ArrayLength) + data := make([]byte, elemSize*length) + p = unsafe.Pointer(unsafe.SliceData(data)) + d.store(sID(id), p) - if regionType.Kind() == reflect.Array && regionType.Elem().Kind() == reflect.Uint8 { // Fast path for byte arrays. - if n := regionType.Len(); n > 0 { - copy(unsafe.Slice((*byte)(p), n), region.Data) + if regionType.Kind() == reflect.Uint8 { + if length > 0 { + copy(unsafe.Slice((*byte)(p), length), region.Data) + } + } else { + regionDeser := d.fork(region.Data) + for i := 0; i < length; i++ { + deserializeAny(regionDeser, regionType, unsafe.Add(p, elemSize*i)) + } } } else { + container := reflect.New(regionType) + p = container.UnsafePointer() + d.store(sID(id), p) regionDeser := d.fork(region.Data) deserializeAny(regionDeser, regionType, p) } + } // Create the pointer with an offset into the container. - return reflect.NewAt(t, unsafe.Add(p, offset)) + return unsafe.Add(p, offset) } func serializeMap(s *Serializer, t reflect.Type, p unsafe.Pointer) { @@ -491,7 +534,7 @@ func serializeMapReflect(s *Serializer, t reflect.Type, r reflect.Value) { size := r.Len() region := &coroutinev1.Region{ - Type: s.types.ToType(t), + Type: s.types.ToType(t) << 1, } s.regions = append(s.regions, region) @@ -562,26 +605,20 @@ func serializeSlice(s *Serializer, t reflect.Type, p unsafe.Pointer) { serializeVarint(s, r.Len()) serializeVarint(s, r.Cap()) - - at := reflect.ArrayOf(r.Cap(), t.Elem()) - ap := r.UnsafePointer() - - serializePointedAt(s, at, ap) + serializePointedAt(s, t.Elem(), r.Cap(), r.UnsafePointer()) } func deserializeSlice(d *Deserializer, t reflect.Type, p unsafe.Pointer) { l := deserializeVarint(d) c := deserializeVarint(d) - at := reflect.ArrayOf(c, t.Elem()) - - ar := deserializePointedAt(d, at) - if ar.IsNil() { + ar := deserializePointedAt(d, t.Elem(), c) + if ar == nil { return } s := (*slice)(p) - s.data = ar.UnsafePointer() + s.data = ar s.cap = c s.len = l } @@ -608,20 +645,20 @@ func deserializeArray(d *Deserializer, t reflect.Type, p unsafe.Pointer) { func serializePointer(s *Serializer, t reflect.Type, p unsafe.Pointer) { r := reflect.NewAt(t, p).Elem() x := r.UnsafePointer() - serializePointedAt(s, t.Elem(), x) + serializePointedAt(s, t.Elem(), -1, x) } func deserializePointer(d *Deserializer, t reflect.Type, p unsafe.Pointer) { - ep := deserializePointedAt(d, t.Elem()) + ep := deserializePointedAt(d, t.Elem(), -1) r := reflect.NewAt(t, p) - r.Elem().Set(ep) + r.Elem().Set(reflect.NewAt(t.Elem(), ep)) } func serializeUnsafePointer(s *Serializer, p unsafe.Pointer) { if p == nil { - serializePointedAt(s, nil, nil) + serializePointedAt(s, nil, -1, nil) } else { - serializePointedAt(s, nil, *(*unsafe.Pointer)(p)) + serializePointedAt(s, nil, -1, *(*unsafe.Pointer)(p)) } } @@ -630,10 +667,9 @@ var unsafePointerType = reflect.TypeOf(unsafe.Pointer(nil)) func deserializeUnsafePointer(d *Deserializer, p unsafe.Pointer) { r := reflect.NewAt(unsafePointerType, p) - ep := deserializePointedAt(d, unsafePointerType) - if !ep.IsNil() { - up := ep.UnsafePointer() - r.Elem().Set(reflect.ValueOf(up)) + ep := deserializePointedAt(d, unsafePointerType, -1) + if ep != nil { + r.Elem().Set(reflect.ValueOf(ep)) } } @@ -735,7 +771,11 @@ func serializeInterface(s *Serializer, t reflect.Type, p unsafe.Pointer) { // noescape? } - serializePointedAt(s, et, eptr) + if et.Kind() == reflect.Array { + serializePointedAt(s, et.Elem(), et.Len(), eptr) + } else { + serializePointedAt(s, et, -1, eptr) + } } func deserializeInterface(d *Deserializer, t reflect.Type, p unsafe.Pointer) { @@ -746,15 +786,24 @@ func deserializeInterface(d *Deserializer, t reflect.Type, p unsafe.Pointer) { } // Deserialize the type - et := deserializeType(d) + et, length := deserializeType(d) + if et == nil { + return + } // Deserialize the pointer - ep := deserializePointedAt(d, et) + ep := deserializePointedAt(d, et, length) // Store the result in the interface r := reflect.NewAt(t, p) - if !ep.IsNil() { - r.Elem().Set(ep.Elem()) + if ep != nil { + // FIXME: is there a way to avoid ArrayOf+NewAt here? We can + // access the iface via p. We can set the ptr, but not the typ. + if length >= 0 { + et = reflect.ArrayOf(length, et) + } + x := reflect.NewAt(et, ep) + r.Elem().Set(x.Elem()) } else { r.Elem().Set(reflect.Zero(et)) } @@ -770,10 +819,9 @@ func serializeString(s *Serializer, x *string) { return } - at := reflect.ArrayOf(l, byteT) - ap := unsafe.Pointer(unsafe.StringData(*x)) + p := unsafe.Pointer(unsafe.StringData(*x)) - serializePointedAt(s, at, ap) + serializePointedAt(s, byteT, l, p) } func deserializeString(d *Deserializer, x *string) { @@ -783,10 +831,9 @@ func deserializeString(d *Deserializer, x *string) { return } - at := reflect.ArrayOf(l, byteT) - ar := deserializePointedAt(d, at) + ar := deserializePointedAt(d, byteT, l) - *x = unsafe.String((*byte)(ar.UnsafePointer()), l) + *x = unsafe.String((*byte)(ar), l) } func serializeBool(s *Serializer, x bool) { diff --git a/types/scan.go b/types/scan.go index 5baab9a..3b484a2 100644 --- a/types/scan.go +++ b/types/scan.go @@ -11,6 +11,7 @@ import ( type container struct { addr unsafe.Pointer typ reflect.Type + len int // >=0 for arrays, -1 for other types } // Returns true iff at least one byte of the address space is shared between c @@ -37,15 +38,18 @@ func (c container) after(x container) bool { // Size in bytes of c. func (c container) size() uintptr { + if c.len >= 0 { + return uintptr(c.len) * c.typ.Size() + } return c.typ.Size() } func (c container) isStruct() bool { - return c.typ.Kind() == reflect.Struct + return !c.isArray() && c.typ.Kind() == reflect.Struct } func (c container) isArray() bool { - return c.typ.Kind() == reflect.Array + return c.len >= 0 } func (c container) valid() bool { @@ -71,7 +75,7 @@ func (c container) compare(p unsafe.Pointer) int { } func (c container) String() string { - return fmt.Sprintf("[%d-%d[ %d %s", c.addr, uintptr(c.addr)+c.size(), c.size(), c.typ) + return fmt.Sprintf("[%d-%d] %d %s(%d)", c.addr, uintptr(c.addr)+c.size(), c.size(), c.typ, c.len) } type containers []container @@ -97,7 +101,10 @@ func (c *containers) of(p unsafe.Pointer) container { return s[i] } -func (c *containers) add(t reflect.Type, p unsafe.Pointer) { +func (c *containers) add(t reflect.Type, length int, p unsafe.Pointer) { + if length == 0 { + return + } if t.Size() == 0 { return } @@ -105,11 +112,6 @@ func (c *containers) add(t reflect.Type, p unsafe.Pointer) { if p == nil { panic("tried to add nil pointer") } - switch t.Kind() { - case reflect.Struct, reflect.Array: - default: - panic(fmt.Errorf("tried to add non struct or array container: %s (%s)", t, t.Kind())) - } defer func() { r := recover() @@ -119,7 +121,7 @@ func (c *containers) add(t reflect.Type, p unsafe.Pointer) { } }() - x := container{addr: p, typ: t} + x := container{addr: p, typ: t, len: length} i := c.insert(x) c.fixup(i) if i > 0 { @@ -156,7 +158,7 @@ func (c *containers) fixup(i int) { // There is some overlap. The only thing we accept to merge are arrays // of the same type. - if !x.isArray() || !next.isArray() || x.typ.Elem() != next.typ.Elem() { + if !x.isArray() || !next.isArray() || x.typ != next.typ { panic(fmt.Errorf("only support merging arrays of same type (%s, %s)", x.typ, next.typ)) } @@ -171,16 +173,14 @@ func (c *containers) merge(i int) { a := s[i] b := s[i+1] - elemSize := a.typ.Elem().Size() + elemSize := a.typ.Size() // sanity check alignment if (uintptr(b.addr)-uintptr(a.addr))%uintptr(elemSize) != 0 { panic("overlapping arrays aren't aligned") } - // new element count of the array - newlen := int((uintptr(b.addr)-uintptr(a.addr))/elemSize) + b.typ.Len() - s[i].typ = reflect.ArrayOf(newlen, a.typ.Elem()) + s[i].len = int((uintptr(b.addr)-uintptr(a.addr))/elemSize) + b.len c.remove(i + 1) } @@ -267,7 +267,7 @@ func (s *Serializer) scan1(t reflect.Type, p unsafe.Pointer, seen map[reflect.Va case reflect.Invalid: panic("handling invalid reflect.Type") case reflect.Array: - s.containers.add(t, p) + s.containers.add(t.Elem(), t.Len(), p) et := t.Elem() es := int(et.Size()) for i := 0; i < t.Len(); i++ { @@ -287,9 +287,7 @@ func (s *Serializer) scan1(t reflect.Type, p unsafe.Pointer, seen map[reflect.Va et := t.Elem() es := int(et.Size()) - // Create a new type for the backing array. - xt := reflect.ArrayOf(sr.Cap(), t.Elem()) - s.containers.add(xt, ep) + s.containers.add(et, sr.Cap(), ep) for i := 0; i < sr.Len(); i++ { ep := unsafe.Add(ep, es*i) s.scan1(et, ep, seen) @@ -314,7 +312,7 @@ func (s *Serializer) scan1(t reflect.Type, p unsafe.Pointer, seen map[reflect.Va s.scan1(et, eptr, seen) case reflect.Struct: - s.containers.add(t, p) + s.containers.add(t, -1, p) n := t.NumField() for i := 0; i < n; i++ { f := t.Field(i) @@ -335,8 +333,7 @@ func (s *Serializer) scan1(t reflect.Type, p unsafe.Pointer, seen map[reflect.Va // empty strings are represented as nil pointers. return } - xt := reflect.ArrayOf(len(str), byteT) - s.containers.add(xt, unsafe.Pointer(sp)) + s.containers.add(byteT, len(str), unsafe.Pointer(sp)) case reflect.Map: m := r.Elem() if m.IsNil() || m.Len() == 0 { diff --git a/types/serde.go b/types/serde.go index 6db80ef..73f87f8 100644 --- a/types/serde.go +++ b/types/serde.go @@ -58,7 +58,7 @@ func Serialize(x any) ([]byte, error) { Strings: s.strings.strings, Regions: s.regions, Root: &coroutinev1.Region{ - Type: s.types.ToType(t), + Type: s.types.ToType(t) << 1, Data: s.b, }, } @@ -240,9 +240,13 @@ func DeserializeTo[T any](d *Deserializer, x *T) { r := reflect.ValueOf(x) t := r.Type().Elem() p := r.UnsafePointer() - actualType := deserializeType(d) - if actualType != t { - panic(fmt.Sprintf("cannot deserialize %s as %s", actualType, t)) + actualType, length := deserializeType(d) + if length < 0 { + if t != actualType { + panic(fmt.Sprintf("cannot deserialize %s as %s", actualType, t)) + } + } else if t.Kind() != reflect.Array || t.Len() != length || t != actualType.Elem() { + panic(fmt.Sprintf("cannot deserialize [%d]%s as %s", length, actualType, t)) } deserializeAny(d, t, p) } diff --git a/types/serde_test.go b/types/serde_test.go index 6b72edd..3be1d69 100644 --- a/types/serde_test.go +++ b/types/serde_test.go @@ -134,6 +134,8 @@ func TestReflect(t *testing.T) { struct{ a, b int }{a: 1, b: 2}, [1][2]int{{1, 2}}, + [1][2][3][4]byte{{{{1}}}}, + []any{(*int)(nil)}, funcType(nil),