Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions pkg/sqlcmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ type ConnectSettings struct {
HostNameInCertificate string
// ServerCertificate is the path to a certificate file to match against the server's TLS certificate
ServerCertificate string
// ServerNameOverride specifies the server name to use in the login packet.
// When set, the actual dial address comes from ServerName, but this value
// is sent in the TDS login packet for server validation.
ServerNameOverride string
}

func (c ConnectSettings) authenticationMethod() string {
Expand Down Expand Up @@ -100,6 +104,21 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
if err != nil {
return "", err
}

if connect.useServerNameOverride(protocol, connect.ServerName) {
overrideName, overrideInstance, _, _, err := splitServer(connect.ServerNameOverride)
if err != nil {
return "", err
}
if overrideName == "" {
overrideName = "."
}
serverName = overrideName
if overrideInstance != "" {
instance = overrideInstance
}
}

query := url.Values{}
connectionURL := &url.URL{
Scheme: "sqlserver",
Expand Down Expand Up @@ -176,3 +195,13 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
connectionURL.RawQuery = query.Encode()
return connectionURL.String(), nil
}

func (connect ConnectSettings) useServerNameOverride(protocol string, serverName string) bool {
if connect.ServerNameOverride == "" {
return false
}
if protocol == "np" || strings.HasPrefix(serverName, `\\`) {
return false
}
return true
Comment on lines +199 to +206
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useServerNameOverride introduces logic to skip applying the override for named pipes (protocol == "np" or \\...), but there’s no test exercising the override+named-pipe combination. A regression here would be easy to miss since ConnectionString() behavior changes based on both ServerName and ServerNameOverride.

Add a unit test in TestConnectionStringFromSqlCmd (or a new test) that sets ServerName to a named-pipe value and also sets ServerNameOverride, asserting that the resulting connection string is unchanged (override ignored).

Copilot uses AI. Check for mistakes.
}
51 changes: 51 additions & 0 deletions pkg/sqlcmd/dialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package sqlcmd

import (
"context"
"net"
"strings"
)

// proxyDialer implements mssql.HostDialer to allow specifying a server name
// for the TDS login packet that differs from the dial address. This enables
// tunneling connections through localhost while authenticating to the real server.
type proxyDialer struct {
serverName string
targetHost string
targetPort string
dialer *net.Dialer
}

func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if d.dialer == nil {
d.dialer = &net.Dialer{}
}
Comment on lines +19 to +25
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proxyDialer lazily initializes d.dialer with an unsynchronized write. If DialContext is called concurrently on the same proxyDialer instance (which is plausible when a driver.Connector is used by database/sql), this introduces a data race.

Safer options: initialize dialer eagerly (remove the nil check), store it as a non-pointer net.Dialer value, or guard initialization with sync.Once/a mutex.

Suggested change
dialer *net.Dialer
}
func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if d.dialer == nil {
d.dialer = &net.Dialer{}
}
dialer net.Dialer
}
func (d *proxyDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {

Copilot uses AI. Check for mistakes.
return d.dialer.DialContext(ctx, network, d.dialAddress(network, addr))
}

func (d *proxyDialer) HostName() string {
return d.serverName
}

func (d *proxyDialer) dialAddress(network, addr string) string {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return addr
}

if d.targetHost != "" {
host = d.targetHost
}
if d.targetPort != "" && isTCPNetwork(network) {
port = d.targetPort
}

return net.JoinHostPort(host, port)
}

func isTCPNetwork(network string) bool {
return strings.HasPrefix(network, "tcp")
}
51 changes: 51 additions & 0 deletions pkg/sqlcmd/dialer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package sqlcmd

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestProxyDialerHostName(t *testing.T) {
d := &proxyDialer{serverName: "myserver.database.windows.net"}
assert.Equal(t, "myserver.database.windows.net", d.HostName())
}

func TestProxyDialerHostNameEmpty(t *testing.T) {
d := &proxyDialer{}
assert.Equal(t, "", d.HostName())
}

func TestProxyDialerInitializesNetDialer(t *testing.T) {
d := &proxyDialer{serverName: "test.server.net"}
assert.Nil(t, d.dialer)

// DialContext should fail with an invalid address, but that's fine for this test
// We just want to verify the dialer gets initialized
_, _ = d.DialContext(context.Background(), "tcp", "invalid:99999")
assert.NotNil(t, d.dialer)
}

func TestProxyDialerDialAddressOverridesHostAndPortForTCP(t *testing.T) {
d := &proxyDialer{
targetHost: "proxy.local",
targetPort: "1444",
}

dialAddr := d.dialAddress("tcp", "server.example.com:1433")
assert.Equal(t, "proxy.local:1444", dialAddr)
}

func TestProxyDialerDialAddressKeepsPortForUDP(t *testing.T) {
d := &proxyDialer{
targetHost: "proxy.local",
targetPort: "1444",
}

dialAddr := d.dialAddress("udp", "server.example.com:1434")
assert.Equal(t, "proxy.local:1434", dialAddr)
}
33 changes: 32 additions & 1 deletion pkg/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,44 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
}

if !useAad {
connector, err = mssql.NewConnector(connstr)
var c *mssql.Connector
c, err = mssql.NewConnector(connstr)
connector = c
} else {
connector, err = GetTokenBasedConnection(connstr, connect.authenticationMethod())
}
if err != nil {
return err
}
if connect.ServerNameOverride != "" {
serverName, _, port, protocol, err := splitServer(connect.ServerName)
if err != nil {
return err
}
if serverName == "" {
serverName = "."
}
if connect.useServerNameOverride(protocol, connect.ServerName) {
overrideName, _, _, _, err := splitServer(connect.ServerNameOverride)
if err != nil {
return err
}
if overrideName == "" {
overrideName = "."
}
targetPort := ""
if port > 0 {
targetPort = fmt.Sprintf("%d", port)
}
if mssqlConnector, ok := connector.(*mssql.Connector); ok {
mssqlConnector.Dialer = &proxyDialer{
serverName: overrideName,
targetHost: serverName,
targetPort: targetPort,
}
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When ServerNameOverride is set, the dial rewrite is only applied if connector can be type-asserted to *mssql.Connector. If that assertion ever fails (e.g., for non-mssql.Connector implementations returned by token/AAD connectors), ServerNameOverride will still change the connection string host but the dial target won’t be rewritten back to the original -S host/port, which breaks tunneled/proxy scenarios.

Consider either (1) ensuring GetTokenBasedConnection returns a *mssql.Connector (so this always works), (2) supporting dialer injection for the token connector type, or (3) returning an explicit error when the dialer can’t be applied while ServerNameOverride is set.

Suggested change
}
}
} else {
// When ServerNameOverride is set we must be able to inject a dialer;
// otherwise the connection string host is changed without rewriting
// the actual dial target, which breaks tunneled/proxy scenarios.
return localizer.Errorf("Server name override is not supported with the current authentication method")

Copilot uses AI. Check for mistakes.
}
}
}
db, err := sql.OpenDB(connector).Conn(context.Background())
if err != nil {
fmt.Fprintln(s.GetOutput(), err)
Expand Down
12 changes: 12 additions & 0 deletions pkg/sqlcmd/sqlcmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ func TestConnectionStringFromSqlCmd(t *testing.T) {
&ConnectSettings{ServerName: `tcp:someserver,1045`, Encrypt: "strict", HostNameInCertificate: "*.mydomain.com"},
"sqlserver://someserver:1045?encrypt=strict&hostnameincertificate=%2A.mydomain.com&protocol=tcp",
},
{
&ConnectSettings{ServerName: `tcp:proxyhost,1444`, ServerNameOverride: "realsql"},
"sqlserver://realsql:1444?protocol=tcp",
},
{
&ConnectSettings{ServerName: `proxyhost\instance`, ServerNameOverride: "realsql"},
"sqlserver://realsql/instance",
},
{
&ConnectSettings{ServerName: `proxyhost,1444`, ServerNameOverride: `realsql\inst`},
"sqlserver://realsql:1444/inst",
},
{
&ConnectSettings{ServerName: "someserver", AuthenticationMethod: azuread.ActiveDirectoryServicePrincipal, UserName: "myapp@mytenant", Password: pwd},
fmt.Sprintf("sqlserver://myapp%%40mytenant:%s@someserver", pwd),
Expand Down
Loading