Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Acceptor struct {
listeners map[string]net.Listener
connectionValidator ConnectionValidator
tlsConfig *tls.Config
newListenerCallback NewListenerCallback
sessionFactory
}

Expand All @@ -59,6 +60,9 @@ type ConnectionValidator interface {
Validate(netConn net.Conn, session SessionID) error
}

// NewListenerCallback is a function that returns a net.Listener for the given address and tls.Config struct.
type NewListenerCallback func(address string, tlsConfig *tls.Config) (net.Listener, error)

// Start accepting connections.
func (a *Acceptor) Start() (err error) {
socketAcceptHost := ""
Expand Down Expand Up @@ -90,6 +94,15 @@ func (a *Acceptor) Start() (err error) {
a.tlsConfig = tlsConfig
}

if a.newListenerCallback == nil {
a.newListenerCallback = func(address string, tlsConfig *tls.Config) (net.Listener, error) {
if tlsConfig != nil {
return tls.Listen("tcp", address, a.tlsConfig)
}
return net.Listen("tcp", address)
}
}

var useTCPProxy bool
if a.settings.GlobalSettings().HasSetting(config.UseTCPProxy) {
if useTCPProxy, err = a.settings.GlobalSettings().BoolSetting(config.UseTCPProxy); err != nil {
Expand All @@ -98,11 +111,7 @@ func (a *Acceptor) Start() (err error) {
}

for address := range a.listeners {
if a.tlsConfig != nil {
if a.listeners[address], err = tls.Listen("tcp", address, a.tlsConfig); err != nil {
return
}
} else if a.listeners[address], err = net.Listen("tcp", address); err != nil {
if a.listeners[address], err = a.newListenerCallback(address, a.tlsConfig); err != nil {
return
} else if useTCPProxy {
a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]}
Expand Down Expand Up @@ -435,3 +444,9 @@ func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) {
func (a *Acceptor) SetTLSConfig(tlsConfig *tls.Config) {
a.tlsConfig = tlsConfig
}

// SetNewListenerCallback allows the creator of the Acceptor to specify the callback used to create each net.Listener
// which will be used in the Start() method.
func (a *Acceptor) SetNewListenerCallback(cb NewListenerCallback) {
a.newListenerCallback = cb
}
48 changes: 48 additions & 0 deletions acceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,51 @@ func TestAcceptor_SetTLSConfig(t *testing.T) {
assert.NotNil(t, conn)
defer conn.Close()
}

func TestAcceptor_SetCallback(t *testing.T) {
sessionSettings := NewSessionSettings()
sessionSettings.Set(config.BeginString, BeginStringFIX42)
sessionSettings.Set(config.SenderCompID, "sender")
sessionSettings.Set(config.TargetCompID, "target")

genericSettings := NewSettings()

genericSettings.GlobalSettings().Set("SocketAcceptPort", "5001")
_, err := genericSettings.AddSession(sessionSettings)
require.NoError(t, err)

logger, err := NewNullLogFactory().Create()
require.NoError(t, err)
acceptor := &Acceptor{settings: genericSettings, globalLog: logger}
defer acceptor.Stop()
// example of a customized tls.Config that loads the certificates dynamically by the `GetCertificate` function
// as opposed to the Certificates slice, that is static in nature, and is only populated once and needs application restart to reload the certs.
customizedTLSConfig := tls.Config{
Certificates: []tls.Certificate{},
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair("_test_data/localhost.crt", "_test_data/localhost.key")
if err != nil {
return nil, err
}
return &cert, nil
},
}

didUseCallback := false
acceptor.SetTLSConfig(&customizedTLSConfig)
acceptor.SetNewListenerCallback(func(address string, tlsConfig *tls.Config) (net.Listener, error) {
didUseCallback = true
assert.Equal(t, &customizedTLSConfig, tlsConfig)
return tls.Listen("tcp", address, tlsConfig)
})
assert.NoError(t, acceptor.Start())
assert.Len(t, acceptor.listeners, 1)

conn, err := tls.Dial("tcp", "localhost:5001", &tls.Config{
InsecureSkipVerify: true,
})
require.NoError(t, err)
assert.NotNil(t, conn)
assert.True(t, didUseCallback)
defer conn.Close()
}
Loading