diff --git a/acceptor.go b/acceptor.go index 2f9f6c48b..f58ef01f7 100644 --- a/acceptor.go +++ b/acceptor.go @@ -49,6 +49,7 @@ type Acceptor struct { listeners map[string]net.Listener connectionValidator ConnectionValidator tlsConfig *tls.Config + newListenerCallback NewListenerCallback sessionFactory } @@ -59,6 +60,9 @@ type ConnectionValidator interface { Validate(netConn net.Conn, session SessionID) error } +// NewListenerCallback is a function that returns a net.Listener for the given address and tls.Config struct. +type NewListenerCallback func(address string, tlsConfig *tls.Config) (net.Listener, error) + // Start accepting connections. func (a *Acceptor) Start() (err error) { socketAcceptHost := "" @@ -90,6 +94,15 @@ func (a *Acceptor) Start() (err error) { a.tlsConfig = tlsConfig } + if a.newListenerCallback == nil { + a.newListenerCallback = func(address string, tlsConfig *tls.Config) (net.Listener, error) { + if tlsConfig != nil { + return tls.Listen("tcp", address, a.tlsConfig) + } + return net.Listen("tcp", address) + } + } + var useTCPProxy bool if a.settings.GlobalSettings().HasSetting(config.UseTCPProxy) { if useTCPProxy, err = a.settings.GlobalSettings().BoolSetting(config.UseTCPProxy); err != nil { @@ -98,11 +111,7 @@ func (a *Acceptor) Start() (err error) { } for address := range a.listeners { - if a.tlsConfig != nil { - if a.listeners[address], err = tls.Listen("tcp", address, a.tlsConfig); err != nil { - return - } - } else if a.listeners[address], err = net.Listen("tcp", address); err != nil { + if a.listeners[address], err = a.newListenerCallback(address, a.tlsConfig); err != nil { return } else if useTCPProxy { a.listeners[address] = &proxyproto.Listener{Listener: a.listeners[address]} @@ -435,3 +444,9 @@ func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) { func (a *Acceptor) SetTLSConfig(tlsConfig *tls.Config) { a.tlsConfig = tlsConfig } + +// SetNewListenerCallback allows the creator of the Acceptor to specify the callback used to create each net.Listener +// which will be used in the Start() method. +func (a *Acceptor) SetNewListenerCallback(cb NewListenerCallback) { + a.newListenerCallback = cb +} diff --git a/acceptor_test.go b/acceptor_test.go index 059c42715..593eb841d 100644 --- a/acceptor_test.go +++ b/acceptor_test.go @@ -126,3 +126,51 @@ func TestAcceptor_SetTLSConfig(t *testing.T) { assert.NotNil(t, conn) defer conn.Close() } + +func TestAcceptor_SetCallback(t *testing.T) { + sessionSettings := NewSessionSettings() + sessionSettings.Set(config.BeginString, BeginStringFIX42) + sessionSettings.Set(config.SenderCompID, "sender") + sessionSettings.Set(config.TargetCompID, "target") + + genericSettings := NewSettings() + + genericSettings.GlobalSettings().Set("SocketAcceptPort", "5001") + _, err := genericSettings.AddSession(sessionSettings) + require.NoError(t, err) + + logger, err := NewNullLogFactory().Create() + require.NoError(t, err) + acceptor := &Acceptor{settings: genericSettings, globalLog: logger} + defer acceptor.Stop() + // example of a customized tls.Config that loads the certificates dynamically by the `GetCertificate` function + // as opposed to the Certificates slice, that is static in nature, and is only populated once and needs application restart to reload the certs. + customizedTLSConfig := tls.Config{ + Certificates: []tls.Certificate{}, + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair("_test_data/localhost.crt", "_test_data/localhost.key") + if err != nil { + return nil, err + } + return &cert, nil + }, + } + + didUseCallback := false + acceptor.SetTLSConfig(&customizedTLSConfig) + acceptor.SetNewListenerCallback(func(address string, tlsConfig *tls.Config) (net.Listener, error) { + didUseCallback = true + assert.Equal(t, &customizedTLSConfig, tlsConfig) + return tls.Listen("tcp", address, tlsConfig) + }) + assert.NoError(t, acceptor.Start()) + assert.Len(t, acceptor.listeners, 1) + + conn, err := tls.Dial("tcp", "localhost:5001", &tls.Config{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + assert.NotNil(t, conn) + assert.True(t, didUseCallback) + defer conn.Close() +}