-
Notifications
You must be signed in to change notification settings - Fork 82
feat(connect): add --server-name flag for tunneled connections #678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
611d165
7915010
eb57f2a
d3a58c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,10 @@ type ConnectSettings struct { | |
| HostNameInCertificate string | ||
| // ServerCertificate is the path to a certificate file to match against the server's TLS certificate | ||
| ServerCertificate string | ||
| // ServerNameOverride specifies the server name to use in the login packet. | ||
| // When set, the actual dial address comes from ServerName, but this value | ||
| // is sent in the TDS login packet for server validation. | ||
| ServerNameOverride string | ||
jimmystridh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| func (c ConnectSettings) authenticationMethod() string { | ||
|
|
@@ -100,6 +104,21 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err | |
| if err != nil { | ||
| return "", err | ||
| } | ||
|
|
||
| if connect.useServerNameOverride(protocol, connect.ServerName) { | ||
| overrideName, overrideInstance, _, _, err := splitServer(connect.ServerNameOverride) | ||
| if err != nil { | ||
| return "", err | ||
| } | ||
| if overrideName == "" { | ||
| overrideName = "." | ||
| } | ||
| serverName = overrideName | ||
| if overrideInstance != "" { | ||
| instance = overrideInstance | ||
| } | ||
| } | ||
|
|
||
| query := url.Values{} | ||
| connectionURL := &url.URL{ | ||
| Scheme: "sqlserver", | ||
|
|
@@ -176,3 +195,13 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err | |
| connectionURL.RawQuery = query.Encode() | ||
| return connectionURL.String(), nil | ||
| } | ||
|
|
||
| func (connect ConnectSettings) useServerNameOverride(protocol string, serverName string) bool { | ||
| if connect.ServerNameOverride == "" { | ||
| return false | ||
| } | ||
| if protocol == "np" || strings.HasPrefix(serverName, `\\`) { | ||
| return false | ||
| } | ||
| return true | ||
|
Comment on lines
+199
to
+206
|
||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,51 @@ | ||||||||||||||||||||||||
| // Copyright (c) Microsoft Corporation. | ||||||||||||||||||||||||
| // Licensed under the MIT license. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| package sqlcmd | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| import ( | ||||||||||||||||||||||||
| "context" | ||||||||||||||||||||||||
| "net" | ||||||||||||||||||||||||
| "strings" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // proxyDialer implements mssql.HostDialer to allow specifying a server name | ||||||||||||||||||||||||
| // for the TDS login packet that differs from the dial address. This enables | ||||||||||||||||||||||||
| // tunneling connections through localhost while authenticating to the real server. | ||||||||||||||||||||||||
| type proxyDialer struct { | ||||||||||||||||||||||||
| serverName string | ||||||||||||||||||||||||
| targetHost string | ||||||||||||||||||||||||
| targetPort string | ||||||||||||||||||||||||
| dialer *net.Dialer | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { | ||||||||||||||||||||||||
| if d.dialer == nil { | ||||||||||||||||||||||||
| d.dialer = &net.Dialer{} | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
Comment on lines
+19
to
+25
|
||||||||||||||||||||||||
| dialer *net.Dialer | |
| } | |
| func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { | |
| if d.dialer == nil { | |
| d.dialer = &net.Dialer{} | |
| } | |
| dialer net.Dialer | |
| } | |
| func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| // Copyright (c) Microsoft Corporation. | ||
| // Licensed under the MIT license. | ||
|
|
||
| package sqlcmd | ||
|
|
||
| import ( | ||
| "context" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| ) | ||
|
|
||
| func TestProxyDialerHostName(t *testing.T) { | ||
| d := &proxyDialer{serverName: "myserver.database.windows.net"} | ||
| assert.Equal(t, "myserver.database.windows.net", d.HostName()) | ||
| } | ||
|
|
||
| func TestProxyDialerHostNameEmpty(t *testing.T) { | ||
| d := &proxyDialer{} | ||
| assert.Equal(t, "", d.HostName()) | ||
| } | ||
|
|
||
| func TestProxyDialerInitializesNetDialer(t *testing.T) { | ||
| d := &proxyDialer{serverName: "test.server.net"} | ||
| assert.Nil(t, d.dialer) | ||
|
|
||
| // DialContext should fail with an invalid address, but that's fine for this test | ||
| // We just want to verify the dialer gets initialized | ||
| _, _ = d.DialContext(context.Background(), "tcp", "invalid:99999") | ||
| assert.NotNil(t, d.dialer) | ||
| } | ||
|
|
||
| func TestProxyDialerDialAddressOverridesHostAndPortForTCP(t *testing.T) { | ||
| d := &proxyDialer{ | ||
| targetHost: "proxy.local", | ||
| targetPort: "1444", | ||
| } | ||
|
|
||
| dialAddr := d.dialAddress("tcp", "server.example.com:1433") | ||
| assert.Equal(t, "proxy.local:1444", dialAddr) | ||
| } | ||
|
|
||
| func TestProxyDialerDialAddressKeepsPortForUDP(t *testing.T) { | ||
| d := &proxyDialer{ | ||
| targetHost: "proxy.local", | ||
| targetPort: "1444", | ||
| } | ||
|
|
||
| dialAddr := d.dialAddress("udp", "server.example.com:1434") | ||
| assert.Equal(t, "proxy.local:1434", dialAddr) | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -273,13 +273,44 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error { | |||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| if !useAad { | ||||||||||||||||
| connector, err = mssql.NewConnector(connstr) | ||||||||||||||||
| var c *mssql.Connector | ||||||||||||||||
| c, err = mssql.NewConnector(connstr) | ||||||||||||||||
| connector = c | ||||||||||||||||
| } else { | ||||||||||||||||
| connector, err = GetTokenBasedConnection(connstr, connect.authenticationMethod()) | ||||||||||||||||
| } | ||||||||||||||||
| if err != nil { | ||||||||||||||||
| return err | ||||||||||||||||
| } | ||||||||||||||||
| if connect.ServerNameOverride != "" { | ||||||||||||||||
| serverName, _, port, protocol, err := splitServer(connect.ServerName) | ||||||||||||||||
| if err != nil { | ||||||||||||||||
| return err | ||||||||||||||||
| } | ||||||||||||||||
| if serverName == "" { | ||||||||||||||||
| serverName = "." | ||||||||||||||||
| } | ||||||||||||||||
| if connect.useServerNameOverride(protocol, connect.ServerName) { | ||||||||||||||||
| overrideName, _, _, _, err := splitServer(connect.ServerNameOverride) | ||||||||||||||||
| if err != nil { | ||||||||||||||||
| return err | ||||||||||||||||
| } | ||||||||||||||||
| if overrideName == "" { | ||||||||||||||||
| overrideName = "." | ||||||||||||||||
| } | ||||||||||||||||
| targetPort := "" | ||||||||||||||||
| if port > 0 { | ||||||||||||||||
| targetPort = fmt.Sprintf("%d", port) | ||||||||||||||||
| } | ||||||||||||||||
| if mssqlConnector, ok := connector.(*mssql.Connector); ok { | ||||||||||||||||
| mssqlConnector.Dialer = &proxyDialer{ | ||||||||||||||||
| serverName: overrideName, | ||||||||||||||||
| targetHost: serverName, | ||||||||||||||||
| targetPort: targetPort, | ||||||||||||||||
| } | ||||||||||||||||
|
||||||||||||||||
| } | |
| } | |
| } else { | |
| // When ServerNameOverride is set we must be able to inject a dialer; | |
| // otherwise the connection string host is changed without rewriting | |
| // the actual dial target, which breaks tunneled/proxy scenarios. | |
| return localizer.Errorf("Server name override is not supported with the current authentication method") |
Uh oh!
There was an error while loading. Please reload this page.