Skip to content

Commit

Permalink
newHello/asHello (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelquigley committed Aug 18, 2021
1 parent d9dbc63 commit 2cef278
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 0 deletions.
109 changes: 109 additions & 0 deletions ack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package dilithium

import (
"github.com/openziti/dilithium/util"
"github.com/pkg/errors"
)

/*
* ACK Encoding Format
*
* If the high-bit of the first byte of the ACK region is low, then this is a single int32 containing a sequence number.
*
* If the high-bit of the first byte of the ACK region is high, then we know we're dealing with multiple ACKs (or ACK
* ranges) encoded in series. The remaining 7 bits contain the number of ACKs (or ACK ranges) encoded in the series.
*
* When decoding an ACK from a series, we use the high bit of the 4-byte int32 to determine if this is an ACK range.
* If the high-bit is set, then we know to expect that there are actually 2 int32s in a row, definining the lower and
* upper bounds of the range. If the high-bit is low, then we know this is a single sequence number.
*/

type Ack struct {
Start int32
End int32
}

const ackSeriesMarker = uint8(1 << 7)
const sequenceRangeMarker = uint32(1 << 31)
const sequenceRangeInvert = 0xFFFFFFFF ^ sequenceRangeMarker

func EncodeAcks(acks []Ack, data []byte) (n uint32, err error) {
if len(acks) < 1 {
return 0, nil
}
if len(acks) > 127 {
return 0, errors.Errorf("ack series too large [%d > 127]", len(acks))
}

dataSz := uint32(len(data))

if len(acks) == 1 && acks[0].Start == acks[0].End {
if dataSz < 4 {
return 0, errors.Errorf("insufficient buffer to encode ack [%d < 4]", dataSz)
}
util.WriteInt32(data, int32(uint32(acks[0].Start)&sequenceRangeInvert))
return 4, nil
}

i := uint32(0)
if (i + 1) > dataSz {
return i, errors.Errorf("insufficient buffer to encode ack series [%d < %d]", dataSz, i+1)
}
data[i] = ackSeriesMarker + uint8(len(acks))
i++

for _, a := range acks {
if a.Start == a.End {
if (i + 4) > dataSz {
return i, errors.Errorf("insufficient buffer to encode ack series [%d < %d]", dataSz, i)
}
util.WriteInt32(data[i:i+4], int32(uint32(a.Start)&sequenceRangeInvert))
i += 4

} else {
if (i + 4) > dataSz {
return i, errors.Errorf("insufficient buffer to encode ack series [%d < %d]", dataSz, i)
}
util.WriteInt32(data[i:i+4], int32(uint32(a.Start)|sequenceRangeMarker))
i += 4

if (i + 4) > dataSz {
return i, errors.Errorf("insufficient buffer to encode ack series [%d < %d]", dataSz, i)
}
util.WriteInt32(data[i:i+4], int32(uint32(a.End)&sequenceRangeInvert))
i += 4
}
}

return i, nil
}

func DecodeAcks(data []byte) (acks []Ack, sz uint32, err error) {
dataSz := uint32(len(data))
if dataSz < 4 {
return nil, 0, errors.Errorf("short ack buffer [%d < 4]", dataSz)
}

if data[0]&ackSeriesMarker == 0 {
seq := util.ReadInt32(data[0:4])
acks = append(acks, Ack{seq, seq})
return acks, 4, nil

} else {
seriesSz := int(data[0] ^ ackSeriesMarker)
sz = 1
for i := 0; i < seriesSz; i++ {
first := util.ReadInt32(data[sz : sz+4])
if uint32(first)&sequenceRangeMarker == sequenceRangeMarker {
sz += 4
second := util.ReadInt32(data[sz : sz+4])
acks = append(acks, Ack{int32(uint32(first) & sequenceRangeInvert), int32(uint32(second) & sequenceRangeInvert)})

} else {
acks = append(acks, Ack{int32(uint32(first) & sequenceRangeInvert), int32(uint32(first) & sequenceRangeInvert)})
}
sz += 4
}
}
return
}
55 changes: 55 additions & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,46 @@ func writeWireMessage(wm *WireMessage, t Transport) error {
return nil
}

func newHello(seq int32, h hello, a *Ack, p *Pool) (wm *WireMessage, err error) {
wm = &WireMessage{
Seq: seq,
Mt: HELLO,
buf: p.Get(),
}
var ackSize uint32
var helloSize uint32
if a != nil {
wm.setFlag(INLINE_ACK)
ackSize, err = EncodeAcks([]Ack{*a}, wm.buf.Data[dataStart:])
if err != nil {
return nil, errors.Wrap(err, "error encoding hello ack")
}
}
helloSize, err = encodeHello(h, wm.buf.Data[dataStart+ackSize:])
if err != nil {
return nil, errors.Wrap(err, "error encoding hello")
}
return wm.encodeHeader(uint16(ackSize + helloSize))
}

func (wm *WireMessage) asHello() (h hello, a []Ack, err error) {
if wm.messageType() != HELLO {
return hello{}, nil, errors.Errorf("unexpected message type [%d], expected HELLO", wm.messageType())
}
i := uint32(0)
if wm.hasFlag(INLINE_ACK) {
a, i, err = DecodeAcks(wm.buf.Data[dataStart:])
if err != nil {
return hello{}, nil, errors.Wrap(err, "error decoding acks")
}
}
h, _, err = decodeHello(wm.buf.Data[dataStart+i:])
if err != nil {
return hello{}, nil, errors.Wrap(err, "error decoding hello")
}
return
}

func (wm *WireMessage) encodeHeader(dataSize uint16) (*WireMessage, error) {
if wm.buf.Size < uint32(dataStart+dataSize) {
return nil, errors.Errorf("short buffer for encode [%d < %d]", wm.buf.Size, dataStart+dataSize)
Expand All @@ -90,3 +130,18 @@ func decodeHeader(buf *Buffer) (*WireMessage, error) {
}
return wm, nil
}

func (wm *WireMessage) messageType() messageType {
return messageType(byte(wm.Mt) & messageTypeMask)
}

func (wm *WireMessage) setFlag(flag messageFlag) {
wm.Mt = messageType(uint8(wm.Mt) | uint8(flag))
}

func (wm *WireMessage) hasFlag(flag messageFlag) bool {
if uint8(wm.Mt)&uint8(flag) > 0 {
return true
}
return false
}

0 comments on commit 2cef278

Please sign in to comment.