diff --git a/filestore.go b/filestore.go index 120ba83c0..893466cd7 100644 --- a/filestore.go +++ b/filestore.go @@ -55,6 +55,10 @@ type fileStore struct { fileSync bool } +func (store *fileStore) SaveMessagesAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error { + return errors.New("not implemented") +} + // NewFileStoreFactory returns a file-based implementation of MessageStoreFactory. func NewFileStoreFactory(settings *Settings) MessageStoreFactory { return fileStoreFactory{settings: settings} diff --git a/mongostore.go b/mongostore.go index 290cdc236..0c9233723 100644 --- a/mongostore.go +++ b/mongostore.go @@ -327,6 +327,10 @@ func (store *mongoStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg [ return store.cache.SetNextSenderMsgSeqNum(next) } +func (store *mongoStore) SaveMessagesAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error { + return errors.New("not implemented") +} + func (store *mongoStore) GetMessages(beginSeqNum, endSeqNum int) (msgs [][]byte, err error) { msgFilter := generateMessageFilter(&store.sessionID) // Marshal into database form. diff --git a/registry.go b/registry.go index 291d44bee..8c0c99aab 100644 --- a/registry.go +++ b/registry.go @@ -64,6 +64,25 @@ func SendToTarget(m Messagable, sessionID SessionID) error { return session.queueForSend(msg) } +// SendAppToTarget is similar to SendToTarget, but it sends application messages in batch to the sessionID. +// The entire batch would fail if: +// - any message in the batch fails ToApp() validation +// - any message in the batch is an admin message +// This is more efficient compare to SendToTarget in the case of sending a burst of application messages, +// especially when using a persistent store like SQLStore, because it allows batching at the storage layer. +func SendAppToTarget(m []Messagable, sessionID SessionID) error { + session, ok := lookupSession(sessionID) + if !ok { + return errUnknownSession + } + msg := make([]*Message, len(m)) + for i, v := range m { + msg[i] = v.ToMessage() + } + + return session.queueBatchAppsForSend(msg) +} + // UnregisterSession removes a session from the set of known sessions. func UnregisterSession(sessionID SessionID) error { sessionsLock.Lock() diff --git a/session.go b/session.go index b359245e2..57adaf52d 100644 --- a/session.go +++ b/session.go @@ -223,6 +223,27 @@ func (s *session) resend(msg *Message) bool { return s.application.ToApp(msg, s.sessionID) == nil } +// queueBatchAppsForSend will validate, persist, and queue the messages for send. +func (s *session) queueBatchAppsForSend(msg []*Message) error { + s.sendMutex.Lock() + defer s.sendMutex.Unlock() + + msgBytes, err := s.prepBatchAppMessagesForSend(msg) + if err != nil { + return err + } + + for _, mb := range msgBytes { + s.toSend = append(s.toSend, mb) + select { + case s.messageEvent <- true: + default: + } + } + + return nil +} + // queueForSend will validate, persist, and queue the message for send. func (s *session) queueForSend(msg *Message) error { s.sendMutex.Lock() @@ -295,6 +316,30 @@ func (s *session) dropAndSendInReplyTo(msg *Message, inReplyTo *Message) error { return nil } +func (s *session) prepBatchAppMessagesForSend(msg []*Message) (msgBytes [][]byte, err error) { + seqNum := s.store.NextSenderMsgSeqNum() + for i, m := range msg { + s.fillDefaultHeader(m, nil) + m.Header.SetField(tagMsgSeqNum, FIXInt(seqNum+i)) + msgType, err := m.Header.GetBytes(tagMsgType) + if err != nil { + return nil, err + } + if isAdminMessageType(msgType) { + return nil, fmt.Errorf("cannot send admin messages in batch") + } + if errToApp := s.application.ToApp(m, s.sessionID); errToApp != nil { + return nil, errToApp + } + msgBytes = append(msgBytes, m.build()) + } + err = s.persistBatch(seqNum, msgBytes) + if err != nil { + return nil, err + } + return msgBytes, nil +} + func (s *session) prepMessageForSend(msg *Message, inReplyTo *Message) (msgBytes []byte, err error) { s.fillDefaultHeader(msg, inReplyTo) seqNum := s.store.NextSenderMsgSeqNum() @@ -338,6 +383,14 @@ func (s *session) prepMessageForSend(msg *Message, inReplyTo *Message) (msgBytes return } +func (s *session) persistBatch(seqNum int, msgBytes [][]byte) error { + if !s.DisableMessagePersist { + return s.store.SaveMessagesAndIncrNextSenderMsgSeqNum(seqNum, msgBytes) + } + + return s.store.SetNextSenderMsgSeqNum(seqNum + len(msgBytes)) +} + func (s *session) persist(seqNum int, msgBytes []byte) error { if !s.DisableMessagePersist { return s.store.SaveMessageAndIncrNextSenderMsgSeqNum(seqNum, msgBytes) diff --git a/sqlstore.go b/sqlstore.go index 5e492be2c..d536cb3ca 100644 --- a/sqlstore.go +++ b/sqlstore.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" "regexp" + "strings" "time" "github.com/pkg/errors" @@ -297,6 +298,56 @@ func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error { return err } +func (store *sqlStore) SaveMessagesAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error { + s := store.sessionID + + tx, err := store.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + const values = "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" + placeholders := make([]string, 0, len(msg)) + params := make([]interface{}, 0, len(msg)*10) + for offset, m := range msg { + placeholders = append(placeholders, values) + params = append(params, seqNum+offset, string(m), + s.BeginString, s.Qualifier, + s.SenderCompID, s.SenderSubID, s.SenderLocationID, + s.TargetCompID, s.TargetSubID, s.TargetLocationID) + } + _, err = tx.Exec(sqlString(`INSERT INTO messages ( + msgseqnum, message, + beginstring, session_qualifier, + sendercompid, sendersubid, senderlocid, + targetcompid, targetsubid, targetlocid) + VALUES`+strings.Join(placeholders, ","), store.placeholder), + params...) + if err != nil { + return err + } + + next := store.cache.NextSenderMsgSeqNum() + len(msg) + _, err = tx.Exec(sqlString(`UPDATE sessions SET outgoing_seqnum = ? + WHERE beginstring=? AND session_qualifier=? + AND sendercompid=? AND sendersubid=? AND senderlocid=? + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), + next, s.BeginString, s.Qualifier, + s.SenderCompID, s.SenderSubID, s.SenderLocationID, + s.TargetCompID, s.TargetSubID, s.TargetLocationID) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + return store.cache.SetNextSenderMsgSeqNum(next) +} + func (store *sqlStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error { s := store.sessionID diff --git a/store.go b/store.go index c56e1896e..ce0588ea3 100644 --- a/store.go +++ b/store.go @@ -36,6 +36,7 @@ type MessageStore interface { SaveMessage(seqNum int, msg []byte) error SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error + SaveMessagesAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) Refresh() error @@ -121,6 +122,15 @@ func (store *memoryStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg return store.IncrNextSenderMsgSeqNum() } +func (store *memoryStore) SaveMessagesAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error { + for offset, m := range msg { + if err := store.SaveMessageAndIncrNextSenderMsgSeqNum(seqNum+offset, m); err != nil { + return err + } + } + return nil +} + func (store *memoryStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { var msgs [][]byte for seqNum := beginSeqNum; seqNum <= endSeqNum; seqNum++ { diff --git a/store_test.go b/store_test.go index 075b3fd6c..3351f70bc 100644 --- a/store_test.go +++ b/store_test.go @@ -16,6 +16,7 @@ package quickfix import ( + "strings" "testing" "time" @@ -168,6 +169,56 @@ func (s *MessageStoreTestSuite) TestMessageStore_SaveMessage_AndIncrement_GetMes s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) } +func (s *MessageStoreTestSuite) TestMessageStore_SaveMessages_AndIncrement_GetMessage() { + if !strings.Contains(s.T().Name(), "TestSqlStoreTestSuite") { + s.T().Skip("Only SQL store implemented this method for now") + } + s.Require().Nil(s.msgStore.SetNextSenderMsgSeqNum(420)) + + // Given the following saved messages + const ( + m1 = "In the frozen land of Nador" + m2 = "they were forced to eat Robin's minstrels" + m3 = "and there was much rejoicing" + ) + expectedMsgsBySeqNum := map[int]string{ + 1: m1, + 2: m2, + 3: m3, + } + s.Require().Nil(s.msgStore.SaveMessagesAndIncrNextSenderMsgSeqNum(1, [][]byte{ + []byte(m1), + []byte(m2), + []byte(m3), + })) + s.Equal(423, s.msgStore.NextSenderMsgSeqNum()) + + // When the messages are retrieved from the MessageStore + actualMsgs, err := s.msgStore.GetMessages(1, 3) + s.Require().Nil(err) + + // Then the messages should be + s.Require().Len(actualMsgs, 3) + s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0])) + s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1])) + s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) + + // When the store is refreshed from its backing store + s.Require().Nil(s.msgStore.Refresh()) + + // And the messages are retrieved from the MessageStore + actualMsgs, err = s.msgStore.GetMessages(1, 3) + s.Require().Nil(err) + + s.Equal(423, s.msgStore.NextSenderMsgSeqNum()) + + // Then the messages should still be + s.Require().Len(actualMsgs, 3) + s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0])) + s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1])) + s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) +} + func (s *MessageStoreTestSuite) TestMessageStore_GetMessages_EmptyStore() { // When messages are retrieved from an empty store messages, err := s.msgStore.GetMessages(1, 2)