Skip to content

Commit

Permalink
Add handshake validators
Browse files Browse the repository at this point in the history
Add handshake validators
  • Loading branch information
reinaldooli authored May 24, 2023
2 parents 98580b5 + 0a60f6a commit 04e8c01
Show file tree
Hide file tree
Showing 19 changed files with 902 additions and 539 deletions.
66 changes: 54 additions & 12 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ var (
// hbd contains the heartbeat packet data
hbd []byte
// hrd contains the handshake response data
hrd []byte
hrd []byte
// herd contains the handshake error response data
herd []byte
once sync.Once
)

Expand Down Expand Up @@ -110,6 +112,7 @@ type (
Handle()
IPVersion() string
SendHandshakeResponse() error
SendHandshakeErrorResponse() error
SendRequest(ctx context.Context, serverID, route string, v interface{}) (*protos.Response, error)
AnswerWithError(ctx context.Context, mid uint, err error)
}
Expand Down Expand Up @@ -180,6 +183,7 @@ func newAgent(

once.Do(func() {
hbdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName)
herdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName)
})

a := &agentImpl{
Expand Down Expand Up @@ -475,6 +479,14 @@ func (a *agentImpl) onSessionClosed(s session.Session) {
// SendHandshakeResponse sends a handshake response
func (a *agentImpl) SendHandshakeResponse() error {
_, err := a.conn.Write(hrd)

return err
}

func (a *agentImpl) SendHandshakeErrorResponse() error {
a.SetStatus(constants.StatusClosed)
_, err := a.conn.Write(herd)

return err
}

Expand Down Expand Up @@ -543,33 +555,63 @@ func hbdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder
"serializer": serializerName,
},
}
data, err := gojson.Marshal(hData)

data, err := encodeAndCompress(hData, dataCompression)
if err != nil {
panic(err)
}

if dataCompression {
compressedData, err := compression.DeflateData(data)
if err != nil {
panic(err)
}
hrd, err = packetEncoder.Encode(packet.Handshake, data)
if err != nil {
panic(err)
}

if len(compressedData) < len(data) {
data = compressedData
}
hbd, err = packetEncoder.Encode(packet.Heartbeat, nil)
if err != nil {
panic(err)
}
}

hrd, err = packetEncoder.Encode(packet.Handshake, data)
func herdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder, dataCompression bool, serializerName string) {
hErrData := map[string]interface{}{
"code": 400,
"sys": map[string]interface{}{
"heartbeat": heartbeatTimeout.Seconds(),
"dict": message.GetDictionary(),
"serializer": serializerName,
},
}

errData, err := encodeAndCompress(hErrData, dataCompression)
if err != nil {
panic(err)
}

hbd, err = packetEncoder.Encode(packet.Heartbeat, nil)
herd, err = packetEncoder.Encode(packet.Handshake, errData)
if err != nil {
panic(err)
}
}

func encodeAndCompress(data interface{}, dataCompression bool) ([]byte, error) {
encData, err := gojson.Marshal(data)
if err != nil {
return nil, err
}

if dataCompression {
compressedData, err := compression.DeflateData(encData)
if err != nil {
return nil, err
}

if len(compressedData) < len(encData) {
encData = compressedData
}
}
return encData, nil
}

func (a *agentImpl) reportChannelSize() {
chSendCapacity := a.messagesBufferSize - len(a.chSend)
if chSendCapacity == 0 {
Expand Down
2 changes: 1 addition & 1 deletion agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestNewAgent(t *testing.T) {
func(typ packet.Type, d []byte) {
// cannot compare inside the expect because they are equivalent but not equal
assert.EqualValues(t, packet.Handshake, typ)
})
}).Times(2)
mockEncoder.EXPECT().Encode(gomock.Any(), gomock.Nil()).Do(
func(typ packet.Type, d []byte) {
assert.EqualValues(t, packet.Heartbeat, typ)
Expand Down
Loading

0 comments on commit 04e8c01

Please sign in to comment.