Skip to content

Commit

Permalink
Filter out ephemeral pub key messages by session ID
Browse files Browse the repository at this point in the history
Added message pre-processing step ensuring messages with a session ID
different from the current one are filtered out. This protects against
the replay attacks on the protocol level.
  • Loading branch information
pdyraga committed Oct 23, 2024
1 parent 12980f7 commit 02c828c
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 15 deletions.
2 changes: 1 addition & 1 deletion gjkr/member.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type memberIndex uint16
// phase of the protocol.
type member struct {
memberIndex memberIndex
sessionID string
group *group

evidenceLog evidenceLog

logger Logger
Expand Down
7 changes: 6 additions & 1 deletion gjkr/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ import "threshold.network/roast/ephemeral"
// within the group.
type ephemeralPublicKeyMessage struct {
senderIndex memberIndex // i
sessionID string

ephemeralPublicKeys map[memberIndex]*ephemeral.PublicKey // j -> Y_ij
}

func (m *ephemeralPublicKeyMessage) senderIdx() memberIndex {
func (m *ephemeralPublicKeyMessage) getSenderIndex() memberIndex {
return m.senderIndex
}

func (m *ephemeralPublicKeyMessage) getSessionID() string {
return m.sessionID
}
42 changes: 30 additions & 12 deletions gjkr/message_filter.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
package gjkr

// filterForSession goes through the messages passed as a parameter and finds
// all messages sent for the given session ID.
func filterForSession[T interface{ getSessionID() string }](
sessionID string,
list []T,
) []T {
result := make([]T, 0)

for _, msg := range list {
if msg.getSessionID() == sessionID {
result = append(result, msg)
}
}

return result
}

// findInactive goes through the messages passed as a parameter and finds all
// inactive members for this set of messages. The function does not care if
// the given member was already marked as inactive before. The function makes no
// assumptions about the ordering of the list elements.
func findInactive[T interface{ senderIdx() memberIndex }](
groupSize uint16, list []T,
func findInactive[T interface{ getSenderIndex() memberIndex }](
groupSize uint16,
list []T,
) []memberIndex {
senders := make(map[memberIndex]bool)
for _, item := range list {
senders[item.senderIdx()] = true
senders[item.getSenderIndex()] = true
}

inactive := make([]memberIndex, 0)
Expand All @@ -25,16 +43,16 @@ func findInactive[T interface{ senderIdx() memberIndex }](
// deduplicateBySender removes duplicated items for the given sender. It always
// takes the first item that occurs for the given sender and ignores the
// subsequent ones.
func deduplicateBySender[T interface{ senderIdx() memberIndex }](
func deduplicateBySender[T interface{ getSenderIndex() memberIndex }](
list []T,
) []T {
senders := make(map[memberIndex]bool)
result := make([]T, 0)

for _, item := range list {
if _, exists := senders[item.senderIdx()]; !exists {
senders[item.senderIdx()] = true
result = append(result, item)
for _, msg := range list {
if _, exists := senders[msg.getSenderIndex()]; !exists {
senders[msg.getSenderIndex()] = true
result = append(result, msg)
}
}

Expand All @@ -44,12 +62,12 @@ func deduplicateBySender[T interface{ senderIdx() memberIndex }](
func (m *symmetricKeyGeneratingMember) preProcessMessages(
ephemeralPubKeyMessages []*ephemeralPublicKeyMessage,
) []*ephemeralPublicKeyMessage {
inactiveMembers := findInactive(m.group.groupSize, ephemeralPubKeyMessages)
forThisSession := filterForSession(m.sessionID, ephemeralPubKeyMessages)

inactiveMembers := findInactive(m.group.groupSize, forThisSession)
for _, ia := range inactiveMembers {
m.group.markMemberAsInactive(ia)
}

// TODO: validate session ID

return deduplicateBySender(ephemeralPubKeyMessages)
return deduplicateBySender(forThisSession)
}
20 changes: 19 additions & 1 deletion gjkr/message_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@ import (
"threshold.network/roast/internal/testutils"
)

func TestFilterForSession(t *testing.T) {
msg0 := &ephemeralPublicKeyMessage{sessionID: "session-2", senderIndex: 1}
msg1 := &ephemeralPublicKeyMessage{sessionID: "session-1", senderIndex: 1}
msg2 := &ephemeralPublicKeyMessage{sessionID: "session-1", senderIndex: 2}
msg3 := &ephemeralPublicKeyMessage{sessionID: "session-2", senderIndex: 3}

filtered := filterForSession("session-1", []*ephemeralPublicKeyMessage{
msg0, msg1, msg2, msg3,
})

testutils.AssertDeepEqual(
t,
"filtered messages",
[]*ephemeralPublicKeyMessage{msg1, msg2},
filtered,
)
}

func TestFindInactive(t *testing.T) {
var tests = map[string]struct {
groupSize uint16
Expand Down Expand Up @@ -80,7 +98,7 @@ func TestDeduplicateBySender(t *testing.T) {

deduplicatedSenders := make([]memberIndex, 0)
for _, msg := range deduplicateBySender(messages) {
deduplicatedSenders = append(deduplicatedSenders, msg.senderIdx())
deduplicatedSenders = append(deduplicatedSenders, msg.getSenderIndex())
}
testutils.AssertUint16SlicesEqual(
t,
Expand Down
17 changes: 17 additions & 0 deletions internal/testutils/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package testutils
import (
"fmt"
"math/big"
"reflect"
"testing"

"golang.org/x/exp/slices"
Expand Down Expand Up @@ -134,3 +135,19 @@ func AssertUint16SlicesEqual[T ~uint16](
)
}
}

func AssertDeepEqual(
t *testing.T,
description string,
expected any,
actual any,
) {
if !reflect.DeepEqual(expected, actual) {
t.Errorf(
"unexpected %s\nexpected: %v\nactual: %v\n",
description,
expected,
actual,
)
}
}

0 comments on commit 02c828c

Please sign in to comment.