diff --git a/.gitignore b/.gitignore
index f713869e..d8da873d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -36,6 +36,7 @@ linux-s390x/sqlcmd
# Build artifacts in root
/sqlcmd
/sqlcmd_binary
+/modern
# certificates used for local testing
*.der
diff --git a/cmd/modern/root/query_test.go b/cmd/modern/root/query_test.go
index 496cead4..482e267c 100644
--- a/cmd/modern/root/query_test.go
+++ b/cmd/modern/root/query_test.go
@@ -28,7 +28,7 @@ func TestQueryWithNonDefaultDatabase(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("stuartpa: This is failing in the pipeline (Login failed for user 'sa'.)")
}
-
+
cmdparser.TestSetup(t)
setupContext(t)
diff --git a/cmd/sqlcmd/pipe_detection_test.go b/cmd/sqlcmd/pipe_detection_test.go
index b2bc8be3..02aee635 100644
--- a/cmd/sqlcmd/pipe_detection_test.go
+++ b/cmd/sqlcmd/pipe_detection_test.go
@@ -6,7 +6,7 @@ package sqlcmd
import (
"os"
"testing"
-
+
"github.com/stretchr/testify/assert"
)
@@ -14,15 +14,15 @@ func TestStdinPipeDetection(t *testing.T) {
// Get stdin info
fi, err := os.Stdin.Stat()
assert.NoError(t, err, "os.Stdin.Stat()")
-
+
// On most CI systems, stdin will be a pipe or file (not a terminal)
// We're testing the logic, not expecting a specific result
isPipe := false
if fi != nil && (fi.Mode()&os.ModeCharDevice) == 0 {
isPipe = true
}
-
+
// Just making sure the detection code doesn't crash
// The actual value will depend on the environment
t.Logf("Stdin detected as pipe: %v", isPipe)
-}
\ No newline at end of file
+}
diff --git a/internal/secret/generate.go b/internal/secret/generate.go
index 62ea28f5..c5cddffd 100644
--- a/internal/secret/generate.go
+++ b/internal/secret/generate.go
@@ -17,7 +17,7 @@ const (
)
// Generate generates a random password of a specified length. The password
-// will contain at least the specified number of special characters,
+// will contain at least the specified number of special characters,
// numeric digits, and upper-case letters. The remaining characters in the
// password will be selected from a combination of lower-case letters, special
// characters, and numeric digits. The special characters are chosen from
diff --git a/pkg/console/console.go b/pkg/console/console.go
index 4eebded7..687f42db 100644
--- a/pkg/console/console.go
+++ b/pkg/console/console.go
@@ -28,7 +28,7 @@ func NewConsole(historyFile string) sqlcmd.Console {
historyFile: historyFile,
stdinRedirected: isStdinRedirected(),
}
-
+
if c.stdinRedirected {
c.stdinReader = bufio.NewReader(os.Stdin)
} else {
@@ -52,7 +52,7 @@ func (c *console) Close() {
f.Close()
}
}
-
+
if !c.stdinRedirected {
c.impl.Close()
}
@@ -79,7 +79,7 @@ func (c *console) Readline() (string, error) {
}
return line, nil
}
-
+
// Interactive terminal mode with prompts
s, err := c.impl.Prompt(c.prompt)
if err == liner.ErrPromptAborted {
diff --git a/pkg/console/console_redirect.go b/pkg/console/console_redirect.go
index b09486d9..d342d987 100644
--- a/pkg/console/console_redirect.go
+++ b/pkg/console/console_redirect.go
@@ -5,6 +5,7 @@ package console
import (
"os"
+
"golang.org/x/term"
)
@@ -15,13 +16,13 @@ func isStdinRedirected() bool {
// If we can't determine, assume it's not redirected
return false
}
-
+
// If it's not a character device, it's coming from a pipe or redirection
if (stat.Mode() & os.ModeCharDevice) == 0 {
return true
}
-
+
// Double-check using term.IsTerminal
fd := int(os.Stdin.Fd())
return !term.IsTerminal(fd)
-}
\ No newline at end of file
+}
diff --git a/pkg/console/console_redirect_test.go b/pkg/console/console_redirect_test.go
index f26cab98..b1b3b2ca 100644
--- a/pkg/console/console_redirect_test.go
+++ b/pkg/console/console_redirect_test.go
@@ -51,4 +51,4 @@ func TestStdinRedirectionDetection(t *testing.T) {
// Clean up
console.Close()
-}
\ No newline at end of file
+}
diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go
index 66dd1dba..192c14ba 100644
--- a/pkg/sqlcmd/commands.go
+++ b/pkg/sqlcmd/commands.go
@@ -1,644 +1,660 @@
-// Copyright (c) Microsoft Corporation.
-// Licensed under the MIT license.
-
-package sqlcmd
-
-import (
- "flag"
- "fmt"
- "os"
- "regexp"
- "sort"
- "strconv"
- "strings"
-
- "github.com/microsoft/go-sqlcmd/internal/color"
- "golang.org/x/text/encoding/unicode"
- "golang.org/x/text/transform"
-)
-
-// Command defines a sqlcmd action which can be intermixed with the SQL batch
-// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands
-type Command struct {
- // regex must include at least one group if it has parameters
- // Will be matched using FindStringSubmatch
- regex *regexp.Regexp
- // The function that implements the command. Third parameter is the line number
- action func(*Sqlcmd, []string, uint) error
- // Name of the command
- name string
- // whether the command is a system command
- isSystem bool
-}
-
-// Commands is the set of sqlcmd command implementations
-type Commands map[string]*Command
-
-func newCommands() Commands {
- // Commands is the set of Command implementations
- return map[string]*Command{
- "EXIT": {
- regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`),
- action: exitCommand,
- name: "EXIT",
- },
- "QUIT": {
- regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`),
- action: quitCommand,
- name: "QUIT",
- },
- "GO": {
- regex: regexp.MustCompile(batchTerminatorRegex("GO")),
- action: goCommand,
- name: "GO",
- },
- "OUT": {
- regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`),
- action: outCommand,
- name: "OUT",
- },
- "ERROR": {
- regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`),
- action: errorCommand,
- name: "ERROR",
- }, "READFILE": {
- regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`),
- action: readFileCommand,
- name: "READFILE",
- },
- "SETVAR": {
- regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`),
- action: setVarCommand,
- name: "SETVAR",
- },
- "LISTVAR": {
- regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`),
- action: listVarCommand,
- name: "LISTVAR",
- },
- "RESET": {
- regex: regexp.MustCompile(`(?im)^[ \t]*?:?RESET(?:[ \t]+(.*$)|$)`),
- action: resetCommand,
- name: "RESET",
- },
- "LIST": {
- regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`),
- action: listCommand,
- name: "LIST",
- },
- "CONNECT": {
- regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`),
- action: connectCommand,
- name: "CONNECT",
- },
- "EXEC": {
- regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`),
- action: execCommand,
- name: "EXEC",
- isSystem: true,
- },
- "EDIT": {
- regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`),
- action: editCommand,
- name: "EDIT",
- isSystem: true,
- },
- "ONERROR": {
- regex: regexp.MustCompile(`(?im)^[\t ]*?:?ON ERROR(?:[ \t]+(.*$)|$)`),
- action: onerrorCommand,
- name: "ONERROR",
- },
- "XML": {
- regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`),
- action: xmlCommand,
- name: "XML",
- },
- }
-}
-
-// DisableSysCommands disables the ED and :!! commands.
-// When exitOnCall is true, running those commands will exit the process.
-func (c Commands) DisableSysCommands(exitOnCall bool) {
- f := warnDisabled
- if exitOnCall {
- f = errorDisabled
- }
- for _, cmd := range c {
- if cmd.isSystem {
- cmd.action = f
- }
- }
-}
-
-func (c Commands) matchCommand(line string) (*Command, []string) {
- for _, cmd := range c {
- matchedCommand := cmd.regex.FindStringSubmatch(line)
- if matchedCommand != nil {
- return cmd, removeComments(matchedCommand[1:])
- }
- }
- return nil, nil
-}
-
-func removeComments(args []string) []string {
- var pos int
- quote := false
- for i := range args {
- pos, quote = commentStart([]rune(args[i]), quote)
- if pos > -1 {
- out := make([]string, i+1)
- if i > 0 {
- copy(out, args[:i])
- }
- out[i] = args[i][:pos]
- return out
- }
- }
- return args
-}
-
-func commentStart(arg []rune, quote bool) (int, bool) {
- var i int
- space := true
- for ; i < len(arg); i++ {
- c, next := arg[i], grab(arg, i+1, len(arg))
- switch {
- case quote && c == '"' && next != '"':
- quote = false
- case quote && c == '"' && next == '"':
- i++
- case c == '\t' || c == ' ':
- space = true
- // Note we assume none of the regexes would split arguments on non-whitespace boundaries such that "text -- comment" would get split into "text -" and "- comment"
- case !quote && space && c == '-' && next == '-':
- return i, false
- case !quote && c == '"':
- quote = true
- default:
- space = false
- }
- }
- return -1, quote
-}
-
-func warnDisabled(s *Sqlcmd, args []string, line uint) error {
- s.WriteError(s.GetError(), ErrCommandsDisabled)
- return nil
-}
-
-func errorDisabled(s *Sqlcmd, args []string, line uint) error {
- s.WriteError(s.GetError(), ErrCommandsDisabled)
- s.Exitcode = 1
- return ErrExitRequested
-}
-
-func batchTerminatorRegex(terminator string) string {
- return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator))
-}
-
-// SetBatchTerminator attempts to set the batch terminator to the given value
-// Returns an error if the new value is not usable in the regex
-func (c Commands) SetBatchTerminator(terminator string) error {
- cmd := c["GO"]
- regex, err := regexp.Compile(batchTerminatorRegex(terminator))
- if err != nil {
- return err
- }
- cmd.regex = regex
- return nil
-}
-
-// exitCommand has 3 modes.
-// With no (), it just exits without running any query
-// With () it runs whatever batch is in the buffer then exits
-// With any text between () it runs the text as a query then exits
-func exitCommand(s *Sqlcmd, args []string, line uint) error {
- if len(args) == 0 {
- return ErrExitRequested
- }
- params := strings.TrimSpace(args[0])
- if params == "" {
- return ErrExitRequested
- }
- if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") {
- return InvalidCommandError("EXIT", line)
- }
- // First we save the current batch
- query1 := s.batch.String()
- if len(query1) > 0 {
- query1 = s.getRunnableQuery(query1)
- }
- // Now parse the params of EXIT as a batch without commands
- cmd := s.batch.cmd
- s.batch.cmd = nil
- defer func() {
- s.batch.cmd = cmd
- }()
- query2 := strings.TrimSpace(params[1 : len(params)-1])
- if len(query2) > 0 {
- s.batch.Reset([]rune(query2))
- _, _, err := s.batch.Next()
- if err != nil {
- return err
- }
- query2 = s.batch.String()
- if len(query2) > 0 {
- query2 = s.getRunnableQuery(query2)
- }
- }
-
- if len(query1) > 0 || len(query2) > 0 {
- query := query1 + SqlcmdEol + query2
- s.Exitcode, _ = s.runQuery(query)
- }
- return ErrExitRequested
-}
-
-// quitCommand immediately exits the program without running any more batches
-func quitCommand(s *Sqlcmd, args []string, line uint) error {
- if args != nil && strings.TrimSpace(args[0]) != "" {
- return InvalidCommandError("QUIT", line)
- }
- return ErrExitRequested
-}
-
-// goCommand runs the current batch the number of times specified
-func goCommand(s *Sqlcmd, args []string, line uint) error {
- // default to 1 execution
- n := 1
- var err error
- if len(args) > 0 {
- cnt := strings.TrimSpace(args[0])
- if cnt != "" {
- if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil {
- return err
- }
- _, err = fmt.Sscanf(cnt, "%d", &n)
- }
- }
- if err != nil || n < 1 {
- return InvalidCommandError("GO", line)
- }
- if s.EchoInput {
- err = listCommand(s, []string{}, line)
- }
- if err != nil {
- return InvalidCommandError("GO", line)
- }
- query := s.batch.String()
- if query == "" {
- return nil
- }
- query = s.getRunnableQuery(query)
- for i := 0; i < n; i++ {
- if retcode, err := s.runQuery(query); err != nil {
- s.Exitcode = retcode
- return err
- }
- }
- s.batch.Reset(nil)
- return nil
-}
-
-// outCommand changes the output writer to use a file
-func outCommand(s *Sqlcmd, args []string, line uint) error {
- if len(args) == 0 || args[0] == "" {
- return InvalidCommandError("OUT", line)
- }
- filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
- if err != nil {
- return err
- }
-
- switch {
- case strings.EqualFold(filePath, "stdout"):
- s.SetOutput(os.Stdout)
- case strings.EqualFold(filePath, "stderr"):
- s.SetOutput(os.Stderr)
- default:
- o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
- if err != nil {
- return InvalidFileError(err, args[0])
- }
- if s.UnicodeOutputFile {
- // ODBC sqlcmd doesn't write a BOM but we will.
- // Maybe the endian-ness should be configurable.
- win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM)
- encoder := transform.NewWriter(o, win16le.NewEncoder())
- s.SetOutput(encoder)
- } else {
- s.SetOutput(o)
- }
- }
- return nil
-}
-
-// errorCommand changes the error writer to use a file
-func errorCommand(s *Sqlcmd, args []string, line uint) error {
- if len(args) == 0 || args[0] == "" {
- return InvalidCommandError("ERROR", line)
- }
- filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
- if err != nil {
- return err
- }
- switch {
- case strings.EqualFold(filePath, "stderr"):
- s.SetError(os.Stderr)
- case strings.EqualFold(filePath, "stdout"):
- s.SetError(os.Stdout)
- default:
- o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
- if err != nil {
- return InvalidFileError(err, args[0])
- }
- s.SetError(o)
- }
- return nil
-}
-
-func readFileCommand(s *Sqlcmd, args []string, line uint) error {
- if args == nil || len(args) != 1 {
- return InvalidCommandError(":R", line)
- }
- fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false)
- return s.IncludeFile(fileName, false)
-}
-
-// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables
-func setVarCommand(s *Sqlcmd, args []string, line uint) error {
- if args == nil || len(args) != 1 || args[0] == "" {
- return InvalidCommandError(":SETVAR", line)
- }
-
- varname := args[0]
- val := ""
- // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value
- // in some very unexpected cases. This version will require the space.
- sp := strings.IndexRune(args[0], ' ')
- if sp > -1 {
- val = strings.TrimSpace(varname[sp:])
- varname = varname[:sp]
- }
- if err := s.vars.Setvar(varname, val); err != nil {
- switch e := err.(type) {
- case *VariableError:
- return e
- default:
- return InvalidCommandError(":SETVAR", line)
- }
- }
- return nil
-}
-
-// listVarCommand prints the set of Sqlcmd scripting variables.
-// Builtin values are printed first, followed by user-set values in sorted order.
-func listVarCommand(s *Sqlcmd, args []string, line uint) error {
- if args != nil && strings.TrimSpace(args[0]) != "" {
- return InvalidCommandError("LISTVAR", line)
- }
-
- vars := s.vars.All()
- keys := make([]string, 0, len(vars))
- for k := range vars {
- if !contains(builtinVariables, k) {
- keys = append(keys, k)
- }
- }
- sort.Strings(keys)
- keys = append(builtinVariables, keys...)
- for _, k := range keys {
- fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol)
- }
- return nil
-}
-
-// resetCommand resets the statement cache
-func resetCommand(s *Sqlcmd, args []string, line uint) error {
- if s.batch != nil {
- s.batch.Reset(nil)
- }
-
- return nil
-}
-
-// listCommand displays statements currently in the statement cache
-func listCommand(s *Sqlcmd, args []string, line uint) (err error) {
- cmd := ""
- if args != nil {
- if len(args) > 0 {
- cmd = strings.ToLower(strings.TrimSpace(args[0]))
- if len(args) > 1 || (cmd != "color" && cmd != "") {
- return InvalidCommandError("LIST", line)
- }
- }
- }
- output := s.GetOutput()
- if cmd == "color" {
- sample := "select 'literal' as literal, 100 as number from [sys].[tables]"
- clr := color.TextTypeTSql
- if s.Format.IsXmlMode() {
- sample = `value`
- clr = color.TextTypeXml
- }
- // ignoring errors since it's not critical output
- for _, style := range s.colorizer.Styles() {
- _, _ = output.Write([]byte(style + ": "))
- _ = s.colorizer.Write(output, sample, style, clr)
- _, _ = output.Write([]byte(SqlcmdEol))
- }
- return
- }
- if s.batch == nil || s.batch.String() == "" {
- return
- }
-
- if err = s.colorizer.Write(output, s.batch.String(), s.vars.ColorScheme(), color.TextTypeTSql); err == nil {
- _, err = output.Write([]byte(SqlcmdEol))
- }
-
- return
-}
-
-func connectCommand(s *Sqlcmd, args []string, line uint) error {
- if len(args) == 0 {
- return InvalidCommandError("CONNECT", line)
- }
-
- commandArgs := strings.Fields(args[0])
-
- // Parse flags
- flags := flag.NewFlagSet("connect", flag.ContinueOnError)
- database := flags.String("D", "", "database name")
- username := flags.String("U", "", "user name")
- password := flags.String("P", "", "password")
- loginTimeout := flags.String("l", "", "login timeout")
- authenticationMethod := flags.String("G", "", "authentication method")
-
- err := flags.Parse(commandArgs[1:])
- //err := flags.Parse(args[1:])
- if err != nil {
- return InvalidCommandError("CONNECT", line)
- }
-
- connect := *s.Connect
- connect.UserName, _ = resolveArgumentVariables(s, []rune(*username), false)
- connect.Password, _ = resolveArgumentVariables(s, []rune(*password), false)
- connect.Database, _ = resolveArgumentVariables(s, []rune(*database), false)
-
- timeout, _ := resolveArgumentVariables(s, []rune(*loginTimeout), false)
- if timeout != "" {
- if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil {
- if timeoutSeconds < 0 {
- return InvalidCommandError("CONNECT", line)
- }
- connect.LoginTimeoutSeconds = int(timeoutSeconds)
- }
- }
-
- connect.AuthenticationMethod = *authenticationMethod
-
- // Set server name as the first positional argument
- if len(commandArgs) > 0 {
- connect.ServerName, _ = resolveArgumentVariables(s, []rune(commandArgs[0]), false)
- }
-
- // If no user name is provided we switch to integrated auth
- _ = s.ConnectDb(&connect, s.lineIo == nil)
-
- // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option
- return nil
-}
-
-func execCommand(s *Sqlcmd, args []string, line uint) error {
- if len(args) == 0 {
- return InvalidCommandError("EXEC", line)
- }
- cmdLine := strings.TrimSpace(args[0])
- if cmdLine == "" {
- return InvalidCommandError("EXEC", line)
- }
- if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil {
- return err
- } else {
- cmd := sysCommand(cmdLine)
- cmd.Stderr = s.GetError()
- cmd.Stdout = s.GetOutput()
- _ = cmd.Run()
- }
- return nil
-}
-
-func editCommand(s *Sqlcmd, args []string, line uint) error {
- if args != nil && strings.TrimSpace(args[0]) != "" {
- return InvalidCommandError("ED", line)
- }
- file, err := os.CreateTemp("", "sq*.sql")
- if err != nil {
- return err
- }
- fileName := file.Name()
- defer os.Remove(fileName)
- text := s.batch.String()
- if s.batch.State() == "-" {
- text = fmt.Sprintf("%s%s", text, SqlcmdEol)
- }
- _, err = file.WriteString(text)
- if err != nil {
- return err
- }
- file.Close()
- cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`)
- cmd.Stderr = s.GetError()
- cmd.Stdout = s.GetOutput()
- err = cmd.Run()
- if err != nil {
- return err
- }
- wasEcho := s.echoFileLines
- s.echoFileLines = true
- s.batch.Reset(nil)
- _ = s.IncludeFile(fileName, false)
- s.echoFileLines = wasEcho
- return nil
-}
-
-func onerrorCommand(s *Sqlcmd, args []string, line uint) error {
- if len(args) == 0 || args[0] == "" {
- return InvalidCommandError("ON ERROR", line)
- }
- params := strings.TrimSpace(args[0])
-
- if strings.EqualFold(strings.ToLower(params), "exit") {
- s.Connect.ExitOnError = true
- } else if strings.EqualFold(strings.ToLower(params), "ignore") {
- s.Connect.IgnoreError = true
- s.Connect.ExitOnError = false
- } else {
- return InvalidCommandError("ON ERROR", line)
- }
- return nil
-}
-
-func xmlCommand(s *Sqlcmd, args []string, line uint) error {
- if len(args) != 1 || args[0] == "" {
- return InvalidCommandError("XML", line)
- }
- params := strings.TrimSpace(args[0])
- // "OFF" and "ON" are documented as the allowed values.
- // ODBC sqlcmd treats any value other than "ON" the same as "OFF".
- // So we will too.
- if strings.EqualFold(params, "on") {
- s.Format.XmlMode(true)
- } else {
- s.Format.XmlMode(false)
- }
- return nil
-}
-
-func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
- var b *strings.Builder
- end := len(arg)
- for i := 0; i < end && !s.Connect.DisableVariableSubstitution; {
- c, next := arg[i], grab(arg, i+1, end)
- switch {
- case c == '$' && next == '(':
- vl, ok := readVariableReference(arg, i+2, end)
- if ok {
- varName := string(arg[i+2 : vl])
- val, ok := s.resolveVariable(varName)
- if ok {
- if b == nil {
- b = new(strings.Builder)
- b.Grow(len(arg))
- b.WriteString(string(arg[0:i]))
- }
- b.WriteString(val)
- } else {
- if failOnUnresolved {
- return "", UndefinedVariable(varName)
- }
- s.WriteError(s.GetError(), UndefinedVariable(varName))
- if b != nil {
- b.WriteString(string(arg[i : vl+1]))
- }
- }
- i += ((vl - i) + 1)
- } else {
- if b != nil {
- b.WriteString("$(")
- }
- i += 2
- }
- default:
- if b != nil {
- b.WriteRune(c)
- }
- i++
- }
- }
- if b == nil {
- return string(arg), nil
- }
- return b.String(), nil
-}
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+package sqlcmd
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "regexp"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/microsoft/go-sqlcmd/internal/color"
+ "golang.org/x/text/encoding/unicode"
+ "golang.org/x/text/transform"
+)
+
+// Command defines a sqlcmd action which can be intermixed with the SQL batch
+// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands
+type Command struct {
+ // regex must include at least one group if it has parameters
+ // Will be matched using FindStringSubmatch
+ regex *regexp.Regexp
+ // The function that implements the command. Third parameter is the line number
+ action func(*Sqlcmd, []string, uint) error
+ // Name of the command
+ name string
+ // whether the command is a system command
+ isSystem bool
+}
+
+// Commands is the set of sqlcmd command implementations
+type Commands map[string]*Command
+
+func newCommands() Commands {
+ // Commands is the set of Command implementations
+ return map[string]*Command{
+ "EXIT": {
+ regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`),
+ action: exitCommand,
+ name: "EXIT",
+ },
+ "QUIT": {
+ regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`),
+ action: quitCommand,
+ name: "QUIT",
+ },
+ "GO": {
+ regex: regexp.MustCompile(batchTerminatorRegex("GO")),
+ action: goCommand,
+ name: "GO",
+ },
+ "OUT": {
+ regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`),
+ action: outCommand,
+ name: "OUT",
+ },
+ "ERROR": {
+ regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`),
+ action: errorCommand,
+ name: "ERROR",
+ }, "READFILE": {
+ regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`),
+ action: readFileCommand,
+ name: "READFILE",
+ },
+ "SETVAR": {
+ regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`),
+ action: setVarCommand,
+ name: "SETVAR",
+ },
+ "LISTVAR": {
+ regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`),
+ action: listVarCommand,
+ name: "LISTVAR",
+ },
+ "RESET": {
+ regex: regexp.MustCompile(`(?im)^[ \t]*?:?RESET(?:[ \t]+(.*$)|$)`),
+ action: resetCommand,
+ name: "RESET",
+ },
+ "LIST": {
+ regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`),
+ action: listCommand,
+ name: "LIST",
+ },
+ "CONNECT": {
+ regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`),
+ action: connectCommand,
+ name: "CONNECT",
+ },
+ "EXEC": {
+ regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`),
+ action: execCommand,
+ name: "EXEC",
+ isSystem: true,
+ },
+ "EDIT": {
+ regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`),
+ action: editCommand,
+ name: "EDIT",
+ isSystem: true,
+ },
+ "ONERROR": {
+ regex: regexp.MustCompile(`(?im)^[\t ]*?:?ON ERROR(?:[ \t]+(.*$)|$)`),
+ action: onerrorCommand,
+ name: "ONERROR",
+ },
+ "XML": {
+ regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`),
+ action: xmlCommand,
+ name: "XML",
+ },
+ }
+}
+
+// DisableSysCommands disables the ED and :!! commands.
+// When exitOnCall is true, running those commands will exit the process.
+func (c Commands) DisableSysCommands(exitOnCall bool) {
+ f := warnDisabled
+ if exitOnCall {
+ f = errorDisabled
+ }
+ for _, cmd := range c {
+ if cmd.isSystem {
+ cmd.action = f
+ }
+ }
+}
+
+func (c Commands) matchCommand(line string) (*Command, []string) {
+ for _, cmd := range c {
+ matchedCommand := cmd.regex.FindStringSubmatch(line)
+ if matchedCommand != nil {
+ return cmd, removeComments(matchedCommand[1:])
+ }
+ }
+ return nil, nil
+}
+
+func removeComments(args []string) []string {
+ var pos int
+ quote := false
+ for i := range args {
+ pos, quote = commentStart([]rune(args[i]), quote)
+ if pos > -1 {
+ out := make([]string, i+1)
+ if i > 0 {
+ copy(out, args[:i])
+ }
+ out[i] = args[i][:pos]
+ return out
+ }
+ }
+ return args
+}
+
+func commentStart(arg []rune, quote bool) (int, bool) {
+ var i int
+ space := true
+ for ; i < len(arg); i++ {
+ c, next := arg[i], grab(arg, i+1, len(arg))
+ switch {
+ case quote && c == '"' && next != '"':
+ quote = false
+ case quote && c == '"' && next == '"':
+ i++
+ case c == '\t' || c == ' ':
+ space = true
+ // Note we assume none of the regexes would split arguments on non-whitespace boundaries such that "text -- comment" would get split into "text -" and "- comment"
+ case !quote && space && c == '-' && next == '-':
+ return i, false
+ case !quote && c == '"':
+ quote = true
+ default:
+ space = false
+ }
+ }
+ return -1, quote
+}
+
+func warnDisabled(s *Sqlcmd, args []string, line uint) error {
+ s.WriteError(s.GetError(), ErrCommandsDisabled)
+ return nil
+}
+
+func errorDisabled(s *Sqlcmd, args []string, line uint) error {
+ s.WriteError(s.GetError(), ErrCommandsDisabled)
+ s.Exitcode = 1
+ return ErrExitRequested
+}
+
+func batchTerminatorRegex(terminator string) string {
+ return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator))
+}
+
+// SetBatchTerminator attempts to set the batch terminator to the given value
+// Returns an error if the new value is not usable in the regex
+func (c Commands) SetBatchTerminator(terminator string) error {
+ cmd := c["GO"]
+ regex, err := regexp.Compile(batchTerminatorRegex(terminator))
+ if err != nil {
+ return err
+ }
+ cmd.regex = regex
+ return nil
+}
+
+// exitCommand has 3 modes.
+// With no (), it just exits without running any query
+// With () it runs whatever batch is in the buffer then exits
+// With any text between () it runs the text as a query then exits
+func exitCommand(s *Sqlcmd, args []string, line uint) error {
+ if len(args) == 0 {
+ return ErrExitRequested
+ }
+ params := strings.TrimSpace(args[0])
+ if params == "" {
+ return ErrExitRequested
+ }
+ if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") {
+ return InvalidCommandError("EXIT", line)
+ }
+ // First we save the current batch
+ query1 := s.batch.String()
+ if len(query1) > 0 {
+ query1 = s.getRunnableQuery(query1)
+ }
+ // Now parse the params of EXIT as a batch without commands
+ cmd := s.batch.cmd
+ s.batch.cmd = nil
+ defer func() {
+ s.batch.cmd = cmd
+ }()
+ query2 := strings.TrimSpace(params[1 : len(params)-1])
+ if len(query2) > 0 {
+ s.batch.Reset([]rune(query2))
+ _, _, err := s.batch.Next()
+ if err != nil {
+ return err
+ }
+ query2 = s.batch.String()
+ if len(query2) > 0 {
+ query2 = s.getRunnableQuery(query2)
+ }
+ }
+
+ if len(query1) > 0 || len(query2) > 0 {
+ query := query1 + SqlcmdEol + query2
+ s.Exitcode, _ = s.runQuery(query)
+ }
+ return ErrExitRequested
+}
+
+// quitCommand immediately exits the program without running any more batches
+func quitCommand(s *Sqlcmd, args []string, line uint) error {
+ if args != nil && strings.TrimSpace(args[0]) != "" {
+ return InvalidCommandError("QUIT", line)
+ }
+ return ErrExitRequested
+}
+
+// goCommand runs the current batch the number of times specified
+func goCommand(s *Sqlcmd, args []string, line uint) error {
+ // default to 1 execution
+ n := 1
+ var err error
+ if len(args) > 0 {
+ cnt := strings.TrimSpace(args[0])
+ if cnt != "" {
+ if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil {
+ return err
+ }
+ _, err = fmt.Sscanf(cnt, "%d", &n)
+ }
+ }
+ if err != nil || n < 1 {
+ return InvalidCommandError("GO", line)
+ }
+ if s.EchoInput {
+ err = listCommand(s, []string{}, line)
+ }
+ if err != nil {
+ return InvalidCommandError("GO", line)
+ }
+ query := s.batch.String()
+ if query == "" {
+ return nil
+ }
+ query = s.getRunnableQuery(query)
+ for i := 0; i < n; i++ {
+ if retcode, err := s.runQuery(query); err != nil {
+ s.Exitcode = retcode
+ return err
+ }
+ }
+ s.batch.Reset(nil)
+ return nil
+}
+
+// outCommand changes the output writer to use a file
+func outCommand(s *Sqlcmd, args []string, line uint) error {
+ if len(args) == 0 || args[0] == "" {
+ return InvalidCommandError("OUT", line)
+ }
+ filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
+ if err != nil {
+ return err
+ }
+
+ switch {
+ case strings.EqualFold(filePath, "stdout"):
+ s.SetOutput(os.Stdout)
+ case strings.EqualFold(filePath, "stderr"):
+ s.SetOutput(os.Stderr)
+ default:
+ o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
+ if err != nil {
+ return InvalidFileError(err, args[0])
+ }
+ if s.UnicodeOutputFile {
+ // ODBC sqlcmd doesn't write a BOM but we will.
+ // Maybe the endian-ness should be configurable.
+ win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM)
+ encoder := transform.NewWriter(o, win16le.NewEncoder())
+ s.SetOutput(encoder)
+ } else {
+ s.SetOutput(o)
+ }
+ }
+ return nil
+}
+
+// errorCommand changes the error writer to use a file
+func errorCommand(s *Sqlcmd, args []string, line uint) error {
+ if len(args) == 0 || args[0] == "" {
+ return InvalidCommandError("ERROR", line)
+ }
+ filePath, err := resolveArgumentVariables(s, []rune(args[0]), true)
+ if err != nil {
+ return err
+ }
+ switch {
+ case strings.EqualFold(filePath, "stderr"):
+ s.SetError(os.Stderr)
+ case strings.EqualFold(filePath, "stdout"):
+ s.SetError(os.Stdout)
+ default:
+ o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644)
+ if err != nil {
+ return InvalidFileError(err, args[0])
+ }
+ s.SetError(o)
+ }
+ return nil
+}
+
+func readFileCommand(s *Sqlcmd, args []string, line uint) error {
+ if args == nil || len(args) != 1 {
+ return InvalidCommandError(":R", line)
+ }
+ fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false)
+ return s.IncludeFile(fileName, false)
+}
+
+// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables
+func setVarCommand(s *Sqlcmd, args []string, line uint) error {
+ if args == nil || len(args) != 1 || args[0] == "" {
+ return InvalidCommandError(":SETVAR", line)
+ }
+
+ varname := args[0]
+ val := ""
+ // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value
+ // in some very unexpected cases. This version will require the space.
+ sp := strings.IndexRune(args[0], ' ')
+ if sp > -1 {
+ val = strings.TrimSpace(varname[sp:])
+ varname = varname[:sp]
+ }
+ if err := s.vars.Setvar(varname, val); err != nil {
+ switch e := err.(type) {
+ case *VariableError:
+ return e
+ default:
+ return InvalidCommandError(":SETVAR", line)
+ }
+ }
+ return nil
+}
+
+// listVarCommand prints the set of Sqlcmd scripting variables.
+// Builtin values are printed first, followed by user-set values in sorted order.
+func listVarCommand(s *Sqlcmd, args []string, line uint) error {
+ if args != nil && strings.TrimSpace(args[0]) != "" {
+ return InvalidCommandError("LISTVAR", line)
+ }
+
+ vars := s.vars.All()
+ keys := make([]string, 0, len(vars))
+ for k := range vars {
+ if !contains(builtinVariables, k) {
+ keys = append(keys, k)
+ }
+ }
+ sort.Strings(keys)
+ keys = append(builtinVariables, keys...)
+ for _, k := range keys {
+ if _, err := fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// resetCommand resets the statement cache
+func resetCommand(s *Sqlcmd, args []string, line uint) error {
+ if s.batch != nil {
+ s.batch.Reset(nil)
+ }
+
+ return nil
+}
+
+// listCommand displays statements currently in the statement cache
+func listCommand(s *Sqlcmd, args []string, line uint) (err error) {
+ cmd := ""
+ if args != nil {
+ if len(args) > 0 {
+ cmd = strings.ToLower(strings.TrimSpace(args[0]))
+ if len(args) > 1 || (cmd != "color" && cmd != "") {
+ return InvalidCommandError("LIST", line)
+ }
+ }
+ }
+ output := s.GetOutput()
+ if cmd == "color" {
+ sample := "select 'literal' as literal, 100 as number from [sys].[tables]"
+ clr := color.TextTypeTSql
+ if s.Format.IsXmlMode() {
+ sample = `value`
+ clr = color.TextTypeXml
+ }
+ // ignoring errors since it's not critical output
+ for _, style := range s.colorizer.Styles() {
+ _, _ = output.Write([]byte(style + ": "))
+ _ = s.colorizer.Write(output, sample, style, clr)
+ _, _ = output.Write([]byte(SqlcmdEol))
+ }
+ return
+ }
+ if s.batch == nil || s.batch.String() == "" {
+ return
+ }
+
+ if err = s.colorizer.Write(output, s.batch.String(), s.vars.ColorScheme(), color.TextTypeTSql); err == nil {
+ _, err = output.Write([]byte(SqlcmdEol))
+ }
+
+ return
+}
+
+func connectCommand(s *Sqlcmd, args []string, line uint) error {
+ if len(args) == 0 {
+ return InvalidCommandError("CONNECT", line)
+ }
+
+ commandArgs := strings.Fields(args[0])
+
+ // Require at least the server name parameter
+ if len(commandArgs) == 0 {
+ return InvalidCommandError("CONNECT", line)
+ }
+
+ // Parse flags
+ flags := flag.NewFlagSet("connect", flag.ContinueOnError)
+ database := flags.String("D", "", "database name")
+ username := flags.String("U", "", "user name")
+ password := flags.String("P", "", "password")
+ loginTimeout := flags.String("l", "", "login timeout")
+ authenticationMethod := flags.String("G", "", "authentication method")
+
+ err := flags.Parse(commandArgs[1:])
+ if err != nil {
+ return InvalidCommandError("CONNECT", line)
+ }
+
+ connect := *s.Connect
+ connect.UserName, _ = resolveArgumentVariables(s, []rune(*username), false)
+ connect.Password, _ = resolveArgumentVariables(s, []rune(*password), false)
+ connect.Database, _ = resolveArgumentVariables(s, []rune(*database), false)
+
+ timeout, _ := resolveArgumentVariables(s, []rune(*loginTimeout), false)
+ if timeout != "" {
+ if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil {
+ if timeoutSeconds < 0 {
+ return InvalidCommandError("CONNECT", line)
+ }
+ connect.LoginTimeoutSeconds = int(timeoutSeconds)
+ }
+ }
+
+ connect.AuthenticationMethod = *authenticationMethod
+
+ // Set server name as the first positional argument
+ if len(commandArgs) > 0 {
+ connect.ServerName, _ = resolveArgumentVariables(s, []rune(commandArgs[0]), false)
+ }
+
+ // If no user name is provided we switch to integrated auth
+ _ = s.ConnectDb(&connect, s.lineIo == nil)
+
+ // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option
+ return nil
+}
+
+func execCommand(s *Sqlcmd, args []string, line uint) error {
+ if len(args) == 0 {
+ return InvalidCommandError("EXEC", line)
+ }
+ cmdLine := strings.TrimSpace(args[0])
+ if cmdLine == "" {
+ return InvalidCommandError("EXEC", line)
+ }
+ if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil {
+ return err
+ } else {
+ cmd := sysCommand(cmdLine)
+ cmd.Stderr = s.GetError()
+ cmd.Stdout = s.GetOutput()
+ _ = cmd.Run()
+ }
+ return nil
+}
+
+func editCommand(s *Sqlcmd, args []string, line uint) error {
+ if args != nil && strings.TrimSpace(args[0]) != "" {
+ return InvalidCommandError("ED", line)
+ }
+ file, err := os.CreateTemp("", "sq*.sql")
+ if err != nil {
+ return err
+ }
+ fileName := file.Name()
+ defer func() {
+ // Best-effort cleanup - ignore errors
+ _ = os.Remove(fileName)
+ }()
+ defer func() {
+ // Ensure file is closed on all paths
+ _ = file.Close()
+ }()
+ text := s.batch.String()
+ if s.batch.State() == "-" {
+ text = fmt.Sprintf("%s%s", text, SqlcmdEol)
+ }
+ _, err = file.WriteString(text)
+ if err != nil {
+ return err
+ }
+ // Explicitly close before launching editor (defer will close again but that's safe)
+ if err := file.Close(); err != nil {
+ return err
+ }
+ cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`)
+ cmd.Stderr = s.GetError()
+ cmd.Stdout = s.GetOutput()
+ err = cmd.Run()
+ if err != nil {
+ return err
+ }
+ wasEcho := s.echoFileLines
+ s.echoFileLines = true
+ s.batch.Reset(nil)
+ _ = s.IncludeFile(fileName, false)
+ s.echoFileLines = wasEcho
+ return nil
+}
+
+func onerrorCommand(s *Sqlcmd, args []string, line uint) error {
+ if len(args) == 0 || args[0] == "" {
+ return InvalidCommandError("ON ERROR", line)
+ }
+ params := strings.TrimSpace(args[0])
+
+ if strings.EqualFold(strings.ToLower(params), "exit") {
+ s.Connect.ExitOnError = true
+ } else if strings.EqualFold(strings.ToLower(params), "ignore") {
+ s.Connect.IgnoreError = true
+ s.Connect.ExitOnError = false
+ } else {
+ return InvalidCommandError("ON ERROR", line)
+ }
+ return nil
+}
+
+func xmlCommand(s *Sqlcmd, args []string, line uint) error {
+ if len(args) != 1 || args[0] == "" {
+ return InvalidCommandError("XML", line)
+ }
+ params := strings.TrimSpace(args[0])
+ // "OFF" and "ON" are documented as the allowed values.
+ // ODBC sqlcmd treats any value other than "ON" the same as "OFF".
+ // So we will too.
+ if strings.EqualFold(params, "on") {
+ s.Format.XmlMode(true)
+ } else {
+ s.Format.XmlMode(false)
+ }
+ return nil
+}
+
+func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
+ var b *strings.Builder
+ end := len(arg)
+ for i := 0; i < end && !s.Connect.DisableVariableSubstitution; {
+ c, next := arg[i], grab(arg, i+1, end)
+ switch {
+ case c == '$' && next == '(':
+ vl, ok := readVariableReference(arg, i+2, end)
+ if ok {
+ varName := string(arg[i+2 : vl])
+ val, ok := s.resolveVariable(varName)
+ if ok {
+ if b == nil {
+ b = new(strings.Builder)
+ b.Grow(len(arg))
+ b.WriteString(string(arg[0:i]))
+ }
+ b.WriteString(val)
+ } else {
+ if failOnUnresolved {
+ return "", UndefinedVariable(varName)
+ }
+ s.WriteError(s.GetError(), UndefinedVariable(varName))
+ if b != nil {
+ b.WriteString(string(arg[i : vl+1]))
+ }
+ }
+ i += ((vl - i) + 1)
+ } else {
+ if b != nil {
+ b.WriteString("$(")
+ }
+ i += 2
+ }
+ default:
+ if b != nil {
+ b.WriteRune(c)
+ }
+ i++
+ }
+ }
+ if b == nil {
+ return string(arg), nil
+ }
+ return b.String(), nil
+}
diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go
index 55bd2e25..fb6c057b 100644
--- a/pkg/sqlcmd/format.go
+++ b/pkg/sqlcmd/format.go
@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
+ "strconv"
"strings"
"time"
@@ -59,6 +60,12 @@ const (
ControlReplaceConsecutive
)
+const (
+ // Default display widths for float types
+ realDefaultWidth int64 = 14 // For REAL and SMALLMONEY
+ floatDefaultWidth int64 = 24 // For FLOAT and MONEY
+)
+
type columnDetail struct {
displayWidth int64
leftJustify bool
@@ -371,11 +378,11 @@ func calcColumnDetails(cols []*sql.ColumnType, fixed int64, variable int64) ([]c
columnDetails[i].displayWidth = max64(21, nameLen)
case "REAL", "SMALLMONEY":
columnDetails[i].leftJustify = false
- columnDetails[i].displayWidth = max64(14, nameLen)
+ columnDetails[i].displayWidth = max64(realDefaultWidth, nameLen)
columnDetails[i].zeroesAfterDecimal = true
case "FLOAT", "MONEY":
columnDetails[i].leftJustify = false
- columnDetails[i].displayWidth = max64(24, nameLen)
+ columnDetails[i].displayWidth = max64(floatDefaultWidth, nameLen)
columnDetails[i].zeroesAfterDecimal = true
case "DECIMAL":
columnDetails[i].leftJustify = false
@@ -530,6 +537,56 @@ func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) {
} else {
row[n] = "0"
}
+ case float64:
+ // Format float64 to match ODBC sqlcmd behavior
+ // Use 'f' format with -1 precision to avoid scientific notation for typical values
+ // Fall back to 'g' format if the result would exceed the column display width
+
+ // Use appropriate bitSize based on the SQL type (REAL=32, FLOAT=64)
+ // REAL columns should use 32-bit precision even though the value is scanned as float64
+ bitSize := 64
+ typeName := f.columnDetails[n].col.DatabaseTypeName()
+ if typeName == "REAL" || typeName == "SMALLMONEY" {
+ bitSize = 32
+ }
+
+ formatted := strconv.FormatFloat(x, 'f', -1, bitSize)
+ displayWidth := f.columnDetails[n].displayWidth
+
+ // Use the type's default display width when displayWidth is 0 (unlimited)
+ // to avoid extremely long strings for extreme values
+ widthThreshold := displayWidth
+ if widthThreshold == 0 {
+ if typeName == "REAL" || typeName == "SMALLMONEY" {
+ widthThreshold = realDefaultWidth
+ } else {
+ widthThreshold = floatDefaultWidth
+ }
+ }
+
+ if int64(len(formatted)) > widthThreshold {
+ // Use 'g' format for very large/small values to avoid truncation issues
+ formatted = strconv.FormatFloat(x, 'g', -1, bitSize)
+ }
+ row[n] = formatted
+ case float32:
+ // Format float32 to match ODBC sqlcmd behavior
+ // float32 values are rare (database/sql typically normalizes to float64)
+ // Use bitSize 32 to maintain precision appropriate for the original float32 value
+ formatted := strconv.FormatFloat(float64(x), 'f', -1, 32)
+ displayWidth := f.columnDetails[n].displayWidth
+
+ // Use default REAL display width when displayWidth is 0
+ widthThreshold := displayWidth
+ if widthThreshold == 0 {
+ widthThreshold = realDefaultWidth
+ }
+
+ if int64(len(formatted)) > widthThreshold {
+ // Use 'g' format for very large/small values to avoid truncation issues
+ formatted = strconv.FormatFloat(float64(x), 'g', -1, 32)
+ }
+ row[n] = formatted
default:
var err error
if row[n], err = fmt.Sprintf("%v", x), nil; err != nil {
diff --git a/pkg/sqlcmd/format_test.go b/pkg/sqlcmd/format_test.go
index f4bee464..c7caf98c 100644
--- a/pkg/sqlcmd/format_test.go
+++ b/pkg/sqlcmd/format_test.go
@@ -162,3 +162,147 @@ func TestFormatterXmlMode(t *testing.T) {
assert.NoError(t, err, "runSqlCmd returned error")
assert.Equal(t, ``+SqlcmdEol, buf.buf.String())
}
+
+func TestFormatterFloatFormatting(t *testing.T) {
+ // Test that float formatting matches ODBC sqlcmd behavior
+ // This addresses the issue where go-sqlcmd was using scientific notation
+ // while ODBC sqlcmd uses decimal notation
+ s, buf := setupSqlCmdWithMemoryOutput(t)
+ defer func() { _ = buf.Close() }()
+
+ // Set SQLCMDMAXVARTYPEWIDTH to a non-zero value so FLOAT columns use the 24-char display width
+ // This enables the width-based fallback logic to be tested properly
+ s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256")
+
+ // Test query with float values from the issue
+ query := `SELECT
+ CAST(788991.19988463481 AS FLOAT) as Longitude1,
+ CAST(4713347.3103808956 AS FLOAT) as Latitude1,
+ CAST(789288.40771771886 AS FLOAT) as Longitude2,
+ CAST(4712632.075629076 AS FLOAT) as Latitude2,
+ CAST(788569.36558582436 AS FLOAT) as Longitude3,
+ CAST(4714608.0418091472 AS FLOAT) as Latitude3`
+
+ err := runSqlCmd(t, s, []string{query, "GO"})
+ assert.NoError(t, err, "runSqlCmd returned error")
+
+ output := buf.buf.String()
+
+ // Verify that the output contains decimal notation, not scientific notation
+ // Scientific notation would look like "4.713347310380896e+06"
+ // Decimal notation should look like "4713347.3103808956"
+ assert.NotContains(t, output, "e+", "Output should not contain scientific notation (e+)")
+ assert.NotContains(t, output, "E+", "Output should not contain scientific notation (E+)")
+
+ // Verify that specific expected values are present (allowing for precision differences)
+ assert.Contains(t, output, "788991.1998846", "Output should contain decimal representation of Longitude1")
+ assert.Contains(t, output, "4713347.310380", "Output should contain decimal representation of Latitude1")
+ assert.Contains(t, output, "789288.4077177", "Output should contain decimal representation of Longitude2")
+ assert.Contains(t, output, "4712632.075629", "Output should contain decimal representation of Latitude2")
+ assert.Contains(t, output, "788569.3655858", "Output should contain decimal representation of Longitude3")
+ assert.Contains(t, output, "4714608.041809", "Output should contain decimal representation of Latitude3")
+}
+
+func TestFormatterFloatFormattingExtremeValues(t *testing.T) {
+ // Test that extreme float values fall back to scientific notation
+ // to avoid truncation issues with very large or very small numbers
+ s, buf := setupSqlCmdWithMemoryOutput(t)
+ defer func() { _ = buf.Close() }()
+
+ // Set SQLCMDMAXVARTYPEWIDTH to a non-zero value so FLOAT columns use the 24-char display width
+ // This allows the fallback behavior to be tested
+ s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256")
+
+ // Test query with extreme float values that would exceed the 24-char display width
+ query := `SELECT
+ CAST(1e100 AS FLOAT) as VeryLarge,
+ CAST(1e-100 AS FLOAT) as VerySmall`
+
+ err := runSqlCmd(t, s, []string{query, "GO"})
+ assert.NoError(t, err, "runSqlCmd returned error")
+
+ output := buf.buf.String()
+
+ // Verify that extremely large values use scientific notation with positive exponent
+ // (because decimal format would exceed the 24-char column width)
+ assert.Contains(t, output, "e+", "Output should contain scientific notation (e+) for very large values")
+
+ // Verify that extremely small values use scientific notation with negative exponent
+ assert.Contains(t, output, "e-", "Output should contain scientific notation (e-) for very small values")
+}
+
+func TestFormatterFloatFormattingExtremeValuesUnlimitedWidth(t *testing.T) {
+ // Test that extreme float values fall back to scientific notation even when
+ // displayWidth is 0 (unlimited), using type default widths as threshold
+ s, buf := setupSqlCmdWithMemoryOutput(t)
+ defer func() { _ = buf.Close() }()
+
+ // Leave SQLCMDMAXVARTYPEWIDTH at 0 (setupSqlCmdWithMemoryOutput default)
+ // This sets displayWidth to 0 for all columns, testing the fallback logic
+ // that uses type default widths (24 for FLOAT, 14 for REAL)
+
+ // Test query with extreme float values
+ query := `SELECT
+ CAST(1e100 AS FLOAT) as VeryLarge,
+ CAST(1e-100 AS FLOAT) as VerySmall`
+
+ err := runSqlCmd(t, s, []string{query, "GO"})
+ assert.NoError(t, err, "runSqlCmd returned error")
+
+ output := buf.buf.String()
+
+ // Verify that extreme values still use scientific notation even with unlimited width
+ // (fallback should use type default widths to prevent unbounded output)
+ assert.Contains(t, output, "e+", "Output should contain scientific notation (e+) for very large values even with unlimited width")
+ assert.Contains(t, output, "e-", "Output should contain scientific notation (e-) for very small values even with unlimited width")
+}
+
+func TestFormatterRealFormatting(t *testing.T) {
+ // Test that REAL (float32) values use decimal notation for typical values
+ // and fall back to scientific notation for extreme values
+ s, buf := setupSqlCmdWithMemoryOutput(t)
+ defer func() { _ = buf.Close() }()
+
+ // Set SQLCMDMAXVARTYPEWIDTH to a non-zero value so REAL columns use the 14-char display width
+ s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256")
+
+ // Test query with REAL values (both typical and extreme)
+ query := `SELECT
+ CAST(123.456789 AS REAL) as TypicalValue,
+ CAST(1e30 AS REAL) as ExtremeValue`
+
+ err := runSqlCmd(t, s, []string{query, "GO"})
+ assert.NoError(t, err, "runSqlCmd returned error")
+
+ output := buf.buf.String()
+
+ // Split output into lines to examine the data row separately from headers
+ lines := strings.Split(output, SqlcmdEol)
+ var dataLine string
+ for _, line := range lines {
+ // Find the data line (contains actual values, not headers or separators)
+ if strings.Contains(line, "123.") {
+ dataLine = line
+ break
+ }
+ }
+
+ // Verify that typical REAL values use decimal notation (not scientific)
+ assert.Contains(t, dataLine, "123.456", "Output should contain decimal representation of typical REAL value")
+
+ // Check that the typical value portion doesn't use scientific notation
+ // Parse columns using whitespace (the default column separator)
+ fields := strings.Fields(dataLine)
+ if len(fields) > 0 {
+ // Find the field containing the typical value
+ for _, field := range fields {
+ if strings.Contains(field, "123.") {
+ assert.NotContains(t, field, "e", "Typical REAL value should not use scientific notation")
+ break
+ }
+ }
+ }
+
+ // Verify that extreme REAL values use scientific notation
+ assert.Contains(t, output, "e+", "Output should contain scientific notation for extreme REAL value")
+}