Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add feat: acceptor template #662

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 0 additions & 85 deletions accepter_test.go

This file was deleted.

165 changes: 120 additions & 45 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,25 @@ import (

// Acceptor accepts connections from FIX clients and manages the associated sessions.
type Acceptor struct {
app Application
settings *Settings
logFactory LogFactory
storeFactory MessageStoreFactory
globalLog Log
sessions map[SessionID]*session
sessionGroup sync.WaitGroup
listenerShutdown sync.WaitGroup
dynamicSessions bool
dynamicQualifier bool
dynamicQualifierCount int
dynamicSessionChan chan *session
sessionAddr sync.Map
sessionHostPort map[SessionID]int
listeners map[string]net.Listener
connectionValidator ConnectionValidator
app Application
settings *Settings
logFactory LogFactory
storeFactory MessageStoreFactory
globalLog Log
sessions map[SessionID]*session
sessionsLock sync.RWMutex
sessionGroup sync.WaitGroup
listenerShutdown sync.WaitGroup
dynamicSessions bool
dynamicQualifier bool
dynamicQualifierCount int
dynamicSessionChan chan *session
sessionAddr sync.Map
sessionHostPort map[SessionID]int
listeners map[string]net.Listener
connectionValidator ConnectionValidator
templateIDProvider TemplateIDProvider
dynamicAcceptorSessionProvider *dynamicAcceptorSessionProvider
sessionFactory
}

Expand All @@ -60,6 +63,11 @@ type ConnectionValidator interface {

// Start accepting connections.
func (a *Acceptor) Start() (err error) {

if err = a.configureDyanmicSessionProvider(); err != nil {
return
}

socketAcceptHost := ""
if a.settings.GlobalSettings().HasSetting(config.SocketAcceptHost) {
if socketAcceptHost, err = a.settings.GlobalSettings().Setting(config.SocketAcceptHost); err != nil {
Expand Down Expand Up @@ -104,14 +112,8 @@ func (a *Acceptor) Start() (err error) {
a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]}
}
}
a.startSessions()

for _, s := range a.sessions {
a.sessionGroup.Add(1)
go func(s *session) {
s.run()
a.sessionGroup.Done()
}(s)
}
if a.dynamicSessions {
a.dynamicSessionChan = make(chan *session)
a.sessionGroup.Add(1)
Expand All @@ -127,6 +129,25 @@ func (a *Acceptor) Start() (err error) {
return
}

func (a *Acceptor) configureDyanmicSessionProvider() error {
if a.templateIDProvider == nil {
defaultTemplateIDProvider, err := NewDefaultTemplateIDProvider(a.settings)
if err != nil {
return err
}
if len(defaultTemplateIDProvider.templateMappings) == 0 {
// no templateMappings
return nil
}
a.templateIDProvider = defaultTemplateIDProvider
}
if setter, ok := a.storeFactory.(TemplateIDProviderSetter); ok {
setter.SetTemplateIDProvider(a.templateIDProvider)
}
a.dynamicAcceptorSessionProvider = newDynamicAcceptorSessionProvider(a.settings, a.storeFactory, a.logFactory, a.app, a.templateIDProvider)
return nil
}

// Stop logs out existing sessions, close their connections, and stop accepting new connections.
func (a *Acceptor) Stop() {
defer func() {
Expand All @@ -140,17 +161,7 @@ func (a *Acceptor) Stop() {
if a.dynamicSessions {
close(a.dynamicSessionChan)
}
for _, session := range a.sessions {
session.stop()
}
a.sessionGroup.Wait()

for sessionID := range a.sessions {
err := UnregisterSession(sessionID)
if err != nil {
return
}
}
a.stopSessions()
}

// RemoteAddr gets remote IP address for a given session.
Expand Down Expand Up @@ -191,6 +202,15 @@ func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Se
}

for sessionID, sessionSettings := range settings.SessionSettings() {
if sessionSettings.HasSetting(config.AcceptorTemplate) {
var acceptorTemplate bool
if acceptorTemplate, err = sessionSettings.BoolSetting(config.AcceptorTemplate); err != nil {
return
}
if acceptorTemplate {
continue
}
}
sessID := sessionID
sessID.Qualifier = ""

Expand Down Expand Up @@ -331,18 +351,27 @@ func (a *Acceptor) handleConnection(netConn net.Conn) {
}
session, ok := a.sessions[sessID]
if !ok {
if !a.dynamicSessions {
a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes)
return
}
dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app)
if err != nil {
a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err)
return
if a.dynamicAcceptorSessionProvider != nil {
session, err = a.dynamicAcceptorSessionProvider.GetSession(sessID)
if err != nil {
a.globalLog.OnEventf("Failed to get session %v from provider: %v", sessID, err)
return
}
a.addMngdDynamicSession(sessID, session)
} else {
if !a.dynamicSessions {
a.globalLog.OnEventf("Session %v not found for incoming message: %s", sessID, msgBytes)
return
}
dynamicSession, err := a.sessionFactory.createSession(sessID, a.storeFactory, a.settings.globalSettings.clone(), a.logFactory, a.app)
if err != nil {
a.globalLog.OnEventf("Dynamic session %v failed to create: %v", sessID, err)
return
}
a.dynamicSessionChan <- dynamicSession
session = dynamicSession
defer session.stop()
}
a.dynamicSessionChan <- dynamicSession
session = dynamicSession
defer session.stop()
}

a.sessionAddr.Store(sessID, netConn.RemoteAddr())
Expand Down Expand Up @@ -412,6 +441,46 @@ LOOP:
}
}

func (a *Acceptor) startSessions() {
a.sessionsLock.RLock()
defer a.sessionsLock.RUnlock()
for _, s := range a.sessions {
a.sessionGroup.Add(1)
go func(s *session) {
s.run()
a.sessionGroup.Done()
}(s)
}
}

func (a *Acceptor) stopSessions() {
a.sessionsLock.RLock()
defer a.sessionsLock.RUnlock()
for _, session := range a.sessions {
session.stop()
}
a.sessionGroup.Wait()

for sessionID := range a.sessions {
err := UnregisterSession(sessionID)
if err != nil {
return
}
}
}

func (a *Acceptor) addMngdDynamicSession(sessID SessionID, session *session) {
a.sessionsLock.Lock()
defer a.sessionsLock.Unlock()

a.sessions[sessID] = session
a.sessionGroup.Add(1)
go func() {
session.run()
a.sessionGroup.Done()
}()
}

// SetConnectionValidator sets an optional connection validator.
// Use it when you need a custom authentication logic that includes lower level interactions,
// like mTLS auth or IP whitelistening.
Expand All @@ -421,3 +490,9 @@ LOOP:
func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) {
a.connectionValidator = validator
}

// SetTemplateIDProvider sets an optional templateID provider.
// If not set and AcceptorTemplate=Y is configured for a session, the `DefaultTemplateIDProvider` will be used.
func (a *Acceptor) SetTemplateIDProvider(templateIDProvider TemplateIDProvider) {
a.templateIDProvider = templateIDProvider
}
Loading
Loading