diff --git a/.gitignore b/.gitignore index f713869e..d8da873d 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ linux-s390x/sqlcmd # Build artifacts in root /sqlcmd /sqlcmd_binary +/modern # certificates used for local testing *.der diff --git a/cmd/modern/root/query_test.go b/cmd/modern/root/query_test.go index 496cead4..482e267c 100644 --- a/cmd/modern/root/query_test.go +++ b/cmd/modern/root/query_test.go @@ -28,7 +28,7 @@ func TestQueryWithNonDefaultDatabase(t *testing.T) { if runtime.GOOS != "windows" { t.Skip("stuartpa: This is failing in the pipeline (Login failed for user 'sa'.)") } - + cmdparser.TestSetup(t) setupContext(t) diff --git a/cmd/sqlcmd/pipe_detection_test.go b/cmd/sqlcmd/pipe_detection_test.go index b2bc8be3..02aee635 100644 --- a/cmd/sqlcmd/pipe_detection_test.go +++ b/cmd/sqlcmd/pipe_detection_test.go @@ -6,7 +6,7 @@ package sqlcmd import ( "os" "testing" - + "github.com/stretchr/testify/assert" ) @@ -14,15 +14,15 @@ func TestStdinPipeDetection(t *testing.T) { // Get stdin info fi, err := os.Stdin.Stat() assert.NoError(t, err, "os.Stdin.Stat()") - + // On most CI systems, stdin will be a pipe or file (not a terminal) // We're testing the logic, not expecting a specific result isPipe := false if fi != nil && (fi.Mode()&os.ModeCharDevice) == 0 { isPipe = true } - + // Just making sure the detection code doesn't crash // The actual value will depend on the environment t.Logf("Stdin detected as pipe: %v", isPipe) -} \ No newline at end of file +} diff --git a/internal/secret/generate.go b/internal/secret/generate.go index 62ea28f5..c5cddffd 100644 --- a/internal/secret/generate.go +++ b/internal/secret/generate.go @@ -17,7 +17,7 @@ const ( ) // Generate generates a random password of a specified length. The password -// will contain at least the specified number of special characters, +// will contain at least the specified number of special characters, // numeric digits, and upper-case letters. The remaining characters in the // password will be selected from a combination of lower-case letters, special // characters, and numeric digits. The special characters are chosen from diff --git a/pkg/console/console.go b/pkg/console/console.go index 4eebded7..687f42db 100644 --- a/pkg/console/console.go +++ b/pkg/console/console.go @@ -28,7 +28,7 @@ func NewConsole(historyFile string) sqlcmd.Console { historyFile: historyFile, stdinRedirected: isStdinRedirected(), } - + if c.stdinRedirected { c.stdinReader = bufio.NewReader(os.Stdin) } else { @@ -52,7 +52,7 @@ func (c *console) Close() { f.Close() } } - + if !c.stdinRedirected { c.impl.Close() } @@ -79,7 +79,7 @@ func (c *console) Readline() (string, error) { } return line, nil } - + // Interactive terminal mode with prompts s, err := c.impl.Prompt(c.prompt) if err == liner.ErrPromptAborted { diff --git a/pkg/console/console_redirect.go b/pkg/console/console_redirect.go index b09486d9..d342d987 100644 --- a/pkg/console/console_redirect.go +++ b/pkg/console/console_redirect.go @@ -5,6 +5,7 @@ package console import ( "os" + "golang.org/x/term" ) @@ -15,13 +16,13 @@ func isStdinRedirected() bool { // If we can't determine, assume it's not redirected return false } - + // If it's not a character device, it's coming from a pipe or redirection if (stat.Mode() & os.ModeCharDevice) == 0 { return true } - + // Double-check using term.IsTerminal fd := int(os.Stdin.Fd()) return !term.IsTerminal(fd) -} \ No newline at end of file +} diff --git a/pkg/console/console_redirect_test.go b/pkg/console/console_redirect_test.go index f26cab98..b1b3b2ca 100644 --- a/pkg/console/console_redirect_test.go +++ b/pkg/console/console_redirect_test.go @@ -51,4 +51,4 @@ func TestStdinRedirectionDetection(t *testing.T) { // Clean up console.Close() -} \ No newline at end of file +} diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..192c14ba 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -1,644 +1,660 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package sqlcmd - -import ( - "flag" - "fmt" - "os" - "regexp" - "sort" - "strconv" - "strings" - - "github.com/microsoft/go-sqlcmd/internal/color" - "golang.org/x/text/encoding/unicode" - "golang.org/x/text/transform" -) - -// Command defines a sqlcmd action which can be intermixed with the SQL batch -// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands -type Command struct { - // regex must include at least one group if it has parameters - // Will be matched using FindStringSubmatch - regex *regexp.Regexp - // The function that implements the command. Third parameter is the line number - action func(*Sqlcmd, []string, uint) error - // Name of the command - name string - // whether the command is a system command - isSystem bool -} - -// Commands is the set of sqlcmd command implementations -type Commands map[string]*Command - -func newCommands() Commands { - // Commands is the set of Command implementations - return map[string]*Command{ - "EXIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`), - action: exitCommand, - name: "EXIT", - }, - "QUIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`), - action: quitCommand, - name: "QUIT", - }, - "GO": { - regex: regexp.MustCompile(batchTerminatorRegex("GO")), - action: goCommand, - name: "GO", - }, - "OUT": { - regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`), - action: outCommand, - name: "OUT", - }, - "ERROR": { - regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`), - action: errorCommand, - name: "ERROR", - }, "READFILE": { - regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`), - action: readFileCommand, - name: "READFILE", - }, - "SETVAR": { - regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`), - action: setVarCommand, - name: "SETVAR", - }, - "LISTVAR": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`), - action: listVarCommand, - name: "LISTVAR", - }, - "RESET": { - regex: regexp.MustCompile(`(?im)^[ \t]*?:?RESET(?:[ \t]+(.*$)|$)`), - action: resetCommand, - name: "RESET", - }, - "LIST": { - regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`), - action: listCommand, - name: "LIST", - }, - "CONNECT": { - regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`), - action: connectCommand, - name: "CONNECT", - }, - "EXEC": { - regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`), - action: execCommand, - name: "EXEC", - isSystem: true, - }, - "EDIT": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`), - action: editCommand, - name: "EDIT", - isSystem: true, - }, - "ONERROR": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:?ON ERROR(?:[ \t]+(.*$)|$)`), - action: onerrorCommand, - name: "ONERROR", - }, - "XML": { - regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`), - action: xmlCommand, - name: "XML", - }, - } -} - -// DisableSysCommands disables the ED and :!! commands. -// When exitOnCall is true, running those commands will exit the process. -func (c Commands) DisableSysCommands(exitOnCall bool) { - f := warnDisabled - if exitOnCall { - f = errorDisabled - } - for _, cmd := range c { - if cmd.isSystem { - cmd.action = f - } - } -} - -func (c Commands) matchCommand(line string) (*Command, []string) { - for _, cmd := range c { - matchedCommand := cmd.regex.FindStringSubmatch(line) - if matchedCommand != nil { - return cmd, removeComments(matchedCommand[1:]) - } - } - return nil, nil -} - -func removeComments(args []string) []string { - var pos int - quote := false - for i := range args { - pos, quote = commentStart([]rune(args[i]), quote) - if pos > -1 { - out := make([]string, i+1) - if i > 0 { - copy(out, args[:i]) - } - out[i] = args[i][:pos] - return out - } - } - return args -} - -func commentStart(arg []rune, quote bool) (int, bool) { - var i int - space := true - for ; i < len(arg); i++ { - c, next := arg[i], grab(arg, i+1, len(arg)) - switch { - case quote && c == '"' && next != '"': - quote = false - case quote && c == '"' && next == '"': - i++ - case c == '\t' || c == ' ': - space = true - // Note we assume none of the regexes would split arguments on non-whitespace boundaries such that "text -- comment" would get split into "text -" and "- comment" - case !quote && space && c == '-' && next == '-': - return i, false - case !quote && c == '"': - quote = true - default: - space = false - } - } - return -1, quote -} - -func warnDisabled(s *Sqlcmd, args []string, line uint) error { - s.WriteError(s.GetError(), ErrCommandsDisabled) - return nil -} - -func errorDisabled(s *Sqlcmd, args []string, line uint) error { - s.WriteError(s.GetError(), ErrCommandsDisabled) - s.Exitcode = 1 - return ErrExitRequested -} - -func batchTerminatorRegex(terminator string) string { - return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator)) -} - -// SetBatchTerminator attempts to set the batch terminator to the given value -// Returns an error if the new value is not usable in the regex -func (c Commands) SetBatchTerminator(terminator string) error { - cmd := c["GO"] - regex, err := regexp.Compile(batchTerminatorRegex(terminator)) - if err != nil { - return err - } - cmd.regex = regex - return nil -} - -// exitCommand has 3 modes. -// With no (), it just exits without running any query -// With () it runs whatever batch is in the buffer then exits -// With any text between () it runs the text as a query then exits -func exitCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return ErrExitRequested - } - params := strings.TrimSpace(args[0]) - if params == "" { - return ErrExitRequested - } - if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { - return InvalidCommandError("EXIT", line) - } - // First we save the current batch - query1 := s.batch.String() - if len(query1) > 0 { - query1 = s.getRunnableQuery(query1) - } - // Now parse the params of EXIT as a batch without commands - cmd := s.batch.cmd - s.batch.cmd = nil - defer func() { - s.batch.cmd = cmd - }() - query2 := strings.TrimSpace(params[1 : len(params)-1]) - if len(query2) > 0 { - s.batch.Reset([]rune(query2)) - _, _, err := s.batch.Next() - if err != nil { - return err - } - query2 = s.batch.String() - if len(query2) > 0 { - query2 = s.getRunnableQuery(query2) - } - } - - if len(query1) > 0 || len(query2) > 0 { - query := query1 + SqlcmdEol + query2 - s.Exitcode, _ = s.runQuery(query) - } - return ErrExitRequested -} - -// quitCommand immediately exits the program without running any more batches -func quitCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("QUIT", line) - } - return ErrExitRequested -} - -// goCommand runs the current batch the number of times specified -func goCommand(s *Sqlcmd, args []string, line uint) error { - // default to 1 execution - n := 1 - var err error - if len(args) > 0 { - cnt := strings.TrimSpace(args[0]) - if cnt != "" { - if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil { - return err - } - _, err = fmt.Sscanf(cnt, "%d", &n) - } - } - if err != nil || n < 1 { - return InvalidCommandError("GO", line) - } - if s.EchoInput { - err = listCommand(s, []string{}, line) - } - if err != nil { - return InvalidCommandError("GO", line) - } - query := s.batch.String() - if query == "" { - return nil - } - query = s.getRunnableQuery(query) - for i := 0; i < n; i++ { - if retcode, err := s.runQuery(query); err != nil { - s.Exitcode = retcode - return err - } - } - s.batch.Reset(nil) - return nil -} - -// outCommand changes the output writer to use a file -func outCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("OUT", line) - } - filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) - if err != nil { - return err - } - - switch { - case strings.EqualFold(filePath, "stdout"): - s.SetOutput(os.Stdout) - case strings.EqualFold(filePath, "stderr"): - s.SetOutput(os.Stderr) - default: - o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return InvalidFileError(err, args[0]) - } - if s.UnicodeOutputFile { - // ODBC sqlcmd doesn't write a BOM but we will. - // Maybe the endian-ness should be configurable. - win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) - encoder := transform.NewWriter(o, win16le.NewEncoder()) - s.SetOutput(encoder) - } else { - s.SetOutput(o) - } - } - return nil -} - -// errorCommand changes the error writer to use a file -func errorCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("ERROR", line) - } - filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) - if err != nil { - return err - } - switch { - case strings.EqualFold(filePath, "stderr"): - s.SetError(os.Stderr) - case strings.EqualFold(filePath, "stdout"): - s.SetError(os.Stdout) - default: - o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return InvalidFileError(err, args[0]) - } - s.SetError(o) - } - return nil -} - -func readFileCommand(s *Sqlcmd, args []string, line uint) error { - if args == nil || len(args) != 1 { - return InvalidCommandError(":R", line) - } - fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false) - return s.IncludeFile(fileName, false) -} - -// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables -func setVarCommand(s *Sqlcmd, args []string, line uint) error { - if args == nil || len(args) != 1 || args[0] == "" { - return InvalidCommandError(":SETVAR", line) - } - - varname := args[0] - val := "" - // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value - // in some very unexpected cases. This version will require the space. - sp := strings.IndexRune(args[0], ' ') - if sp > -1 { - val = strings.TrimSpace(varname[sp:]) - varname = varname[:sp] - } - if err := s.vars.Setvar(varname, val); err != nil { - switch e := err.(type) { - case *VariableError: - return e - default: - return InvalidCommandError(":SETVAR", line) - } - } - return nil -} - -// listVarCommand prints the set of Sqlcmd scripting variables. -// Builtin values are printed first, followed by user-set values in sorted order. -func listVarCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("LISTVAR", line) - } - - vars := s.vars.All() - keys := make([]string, 0, len(vars)) - for k := range vars { - if !contains(builtinVariables, k) { - keys = append(keys, k) - } - } - sort.Strings(keys) - keys = append(builtinVariables, keys...) - for _, k := range keys { - fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol) - } - return nil -} - -// resetCommand resets the statement cache -func resetCommand(s *Sqlcmd, args []string, line uint) error { - if s.batch != nil { - s.batch.Reset(nil) - } - - return nil -} - -// listCommand displays statements currently in the statement cache -func listCommand(s *Sqlcmd, args []string, line uint) (err error) { - cmd := "" - if args != nil { - if len(args) > 0 { - cmd = strings.ToLower(strings.TrimSpace(args[0])) - if len(args) > 1 || (cmd != "color" && cmd != "") { - return InvalidCommandError("LIST", line) - } - } - } - output := s.GetOutput() - if cmd == "color" { - sample := "select 'literal' as literal, 100 as number from [sys].[tables]" - clr := color.TextTypeTSql - if s.Format.IsXmlMode() { - sample = `value` - clr = color.TextTypeXml - } - // ignoring errors since it's not critical output - for _, style := range s.colorizer.Styles() { - _, _ = output.Write([]byte(style + ": ")) - _ = s.colorizer.Write(output, sample, style, clr) - _, _ = output.Write([]byte(SqlcmdEol)) - } - return - } - if s.batch == nil || s.batch.String() == "" { - return - } - - if err = s.colorizer.Write(output, s.batch.String(), s.vars.ColorScheme(), color.TextTypeTSql); err == nil { - _, err = output.Write([]byte(SqlcmdEol)) - } - - return -} - -func connectCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return InvalidCommandError("CONNECT", line) - } - - commandArgs := strings.Fields(args[0]) - - // Parse flags - flags := flag.NewFlagSet("connect", flag.ContinueOnError) - database := flags.String("D", "", "database name") - username := flags.String("U", "", "user name") - password := flags.String("P", "", "password") - loginTimeout := flags.String("l", "", "login timeout") - authenticationMethod := flags.String("G", "", "authentication method") - - err := flags.Parse(commandArgs[1:]) - //err := flags.Parse(args[1:]) - if err != nil { - return InvalidCommandError("CONNECT", line) - } - - connect := *s.Connect - connect.UserName, _ = resolveArgumentVariables(s, []rune(*username), false) - connect.Password, _ = resolveArgumentVariables(s, []rune(*password), false) - connect.Database, _ = resolveArgumentVariables(s, []rune(*database), false) - - timeout, _ := resolveArgumentVariables(s, []rune(*loginTimeout), false) - if timeout != "" { - if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { - if timeoutSeconds < 0 { - return InvalidCommandError("CONNECT", line) - } - connect.LoginTimeoutSeconds = int(timeoutSeconds) - } - } - - connect.AuthenticationMethod = *authenticationMethod - - // Set server name as the first positional argument - if len(commandArgs) > 0 { - connect.ServerName, _ = resolveArgumentVariables(s, []rune(commandArgs[0]), false) - } - - // If no user name is provided we switch to integrated auth - _ = s.ConnectDb(&connect, s.lineIo == nil) - - // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option - return nil -} - -func execCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 { - return InvalidCommandError("EXEC", line) - } - cmdLine := strings.TrimSpace(args[0]) - if cmdLine == "" { - return InvalidCommandError("EXEC", line) - } - if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil { - return err - } else { - cmd := sysCommand(cmdLine) - cmd.Stderr = s.GetError() - cmd.Stdout = s.GetOutput() - _ = cmd.Run() - } - return nil -} - -func editCommand(s *Sqlcmd, args []string, line uint) error { - if args != nil && strings.TrimSpace(args[0]) != "" { - return InvalidCommandError("ED", line) - } - file, err := os.CreateTemp("", "sq*.sql") - if err != nil { - return err - } - fileName := file.Name() - defer os.Remove(fileName) - text := s.batch.String() - if s.batch.State() == "-" { - text = fmt.Sprintf("%s%s", text, SqlcmdEol) - } - _, err = file.WriteString(text) - if err != nil { - return err - } - file.Close() - cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`) - cmd.Stderr = s.GetError() - cmd.Stdout = s.GetOutput() - err = cmd.Run() - if err != nil { - return err - } - wasEcho := s.echoFileLines - s.echoFileLines = true - s.batch.Reset(nil) - _ = s.IncludeFile(fileName, false) - s.echoFileLines = wasEcho - return nil -} - -func onerrorCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) == 0 || args[0] == "" { - return InvalidCommandError("ON ERROR", line) - } - params := strings.TrimSpace(args[0]) - - if strings.EqualFold(strings.ToLower(params), "exit") { - s.Connect.ExitOnError = true - } else if strings.EqualFold(strings.ToLower(params), "ignore") { - s.Connect.IgnoreError = true - s.Connect.ExitOnError = false - } else { - return InvalidCommandError("ON ERROR", line) - } - return nil -} - -func xmlCommand(s *Sqlcmd, args []string, line uint) error { - if len(args) != 1 || args[0] == "" { - return InvalidCommandError("XML", line) - } - params := strings.TrimSpace(args[0]) - // "OFF" and "ON" are documented as the allowed values. - // ODBC sqlcmd treats any value other than "ON" the same as "OFF". - // So we will too. - if strings.EqualFold(params, "on") { - s.Format.XmlMode(true) - } else { - s.Format.XmlMode(false) - } - return nil -} - -func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { - var b *strings.Builder - end := len(arg) - for i := 0; i < end && !s.Connect.DisableVariableSubstitution; { - c, next := arg[i], grab(arg, i+1, end) - switch { - case c == '$' && next == '(': - vl, ok := readVariableReference(arg, i+2, end) - if ok { - varName := string(arg[i+2 : vl]) - val, ok := s.resolveVariable(varName) - if ok { - if b == nil { - b = new(strings.Builder) - b.Grow(len(arg)) - b.WriteString(string(arg[0:i])) - } - b.WriteString(val) - } else { - if failOnUnresolved { - return "", UndefinedVariable(varName) - } - s.WriteError(s.GetError(), UndefinedVariable(varName)) - if b != nil { - b.WriteString(string(arg[i : vl+1])) - } - } - i += ((vl - i) + 1) - } else { - if b != nil { - b.WriteString("$(") - } - i += 2 - } - default: - if b != nil { - b.WriteRune(c) - } - i++ - } - } - if b == nil { - return string(arg), nil - } - return b.String(), nil -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "flag" + "fmt" + "os" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/microsoft/go-sqlcmd/internal/color" + "golang.org/x/text/encoding/unicode" + "golang.org/x/text/transform" +) + +// Command defines a sqlcmd action which can be intermixed with the SQL batch +// Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands +type Command struct { + // regex must include at least one group if it has parameters + // Will be matched using FindStringSubmatch + regex *regexp.Regexp + // The function that implements the command. Third parameter is the line number + action func(*Sqlcmd, []string, uint) error + // Name of the command + name string + // whether the command is a system command + isSystem bool +} + +// Commands is the set of sqlcmd command implementations +type Commands map[string]*Command + +func newCommands() Commands { + // Commands is the set of Command implementations + return map[string]*Command{ + "EXIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT([\( \t]+.*\)*$|$)`), + action: exitCommand, + name: "EXIT", + }, + "QUIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`), + action: quitCommand, + name: "QUIT", + }, + "GO": { + regex: regexp.MustCompile(batchTerminatorRegex("GO")), + action: goCommand, + name: "GO", + }, + "OUT": { + regex: regexp.MustCompile(`(?im)^[ \t]*:OUT(?:[ \t]+(.*$)|$)`), + action: outCommand, + name: "OUT", + }, + "ERROR": { + regex: regexp.MustCompile(`(?im)^[ \t]*:ERROR(?:[ \t]+(.*$)|$)`), + action: errorCommand, + name: "ERROR", + }, "READFILE": { + regex: regexp.MustCompile(`(?im)^[ \t]*:R(?:[ \t]+(.*$)|$)`), + action: readFileCommand, + name: "READFILE", + }, + "SETVAR": { + regex: regexp.MustCompile(`(?im)^[ \t]*:SETVAR(?:[ \t]+(.*$)|$)`), + action: setVarCommand, + name: "SETVAR", + }, + "LISTVAR": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:LISTVAR(?:[ \t]+(.*$)|$)`), + action: listVarCommand, + name: "LISTVAR", + }, + "RESET": { + regex: regexp.MustCompile(`(?im)^[ \t]*?:?RESET(?:[ \t]+(.*$)|$)`), + action: resetCommand, + name: "RESET", + }, + "LIST": { + regex: regexp.MustCompile(`(?im)^[ \t]*:LIST(?:[ \t]+(.*$)|$)`), + action: listCommand, + name: "LIST", + }, + "CONNECT": { + regex: regexp.MustCompile(`(?im)^[ \t]*:CONNECT(?:[ \t]+(.*$)|$)`), + action: connectCommand, + name: "CONNECT", + }, + "EXEC": { + regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(.*$)`), + action: execCommand, + name: "EXEC", + isSystem: true, + }, + "EDIT": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?ED(?:[ \t]+(.*$)|$)`), + action: editCommand, + name: "EDIT", + isSystem: true, + }, + "ONERROR": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:?ON ERROR(?:[ \t]+(.*$)|$)`), + action: onerrorCommand, + name: "ONERROR", + }, + "XML": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:XML(?:[ \t]+(.*$)|$)`), + action: xmlCommand, + name: "XML", + }, + } +} + +// DisableSysCommands disables the ED and :!! commands. +// When exitOnCall is true, running those commands will exit the process. +func (c Commands) DisableSysCommands(exitOnCall bool) { + f := warnDisabled + if exitOnCall { + f = errorDisabled + } + for _, cmd := range c { + if cmd.isSystem { + cmd.action = f + } + } +} + +func (c Commands) matchCommand(line string) (*Command, []string) { + for _, cmd := range c { + matchedCommand := cmd.regex.FindStringSubmatch(line) + if matchedCommand != nil { + return cmd, removeComments(matchedCommand[1:]) + } + } + return nil, nil +} + +func removeComments(args []string) []string { + var pos int + quote := false + for i := range args { + pos, quote = commentStart([]rune(args[i]), quote) + if pos > -1 { + out := make([]string, i+1) + if i > 0 { + copy(out, args[:i]) + } + out[i] = args[i][:pos] + return out + } + } + return args +} + +func commentStart(arg []rune, quote bool) (int, bool) { + var i int + space := true + for ; i < len(arg); i++ { + c, next := arg[i], grab(arg, i+1, len(arg)) + switch { + case quote && c == '"' && next != '"': + quote = false + case quote && c == '"' && next == '"': + i++ + case c == '\t' || c == ' ': + space = true + // Note we assume none of the regexes would split arguments on non-whitespace boundaries such that "text -- comment" would get split into "text -" and "- comment" + case !quote && space && c == '-' && next == '-': + return i, false + case !quote && c == '"': + quote = true + default: + space = false + } + } + return -1, quote +} + +func warnDisabled(s *Sqlcmd, args []string, line uint) error { + s.WriteError(s.GetError(), ErrCommandsDisabled) + return nil +} + +func errorDisabled(s *Sqlcmd, args []string, line uint) error { + s.WriteError(s.GetError(), ErrCommandsDisabled) + s.Exitcode = 1 + return ErrExitRequested +} + +func batchTerminatorRegex(terminator string) string { + return fmt.Sprintf(`(?im)^[\t ]*?%s(?:[ ]+(.*$)|$)`, regexp.QuoteMeta(terminator)) +} + +// SetBatchTerminator attempts to set the batch terminator to the given value +// Returns an error if the new value is not usable in the regex +func (c Commands) SetBatchTerminator(terminator string) error { + cmd := c["GO"] + regex, err := regexp.Compile(batchTerminatorRegex(terminator)) + if err != nil { + return err + } + cmd.regex = regex + return nil +} + +// exitCommand has 3 modes. +// With no (), it just exits without running any query +// With () it runs whatever batch is in the buffer then exits +// With any text between () it runs the text as a query then exits +func exitCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return ErrExitRequested + } + params := strings.TrimSpace(args[0]) + if params == "" { + return ErrExitRequested + } + if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { + return InvalidCommandError("EXIT", line) + } + // First we save the current batch + query1 := s.batch.String() + if len(query1) > 0 { + query1 = s.getRunnableQuery(query1) + } + // Now parse the params of EXIT as a batch without commands + cmd := s.batch.cmd + s.batch.cmd = nil + defer func() { + s.batch.cmd = cmd + }() + query2 := strings.TrimSpace(params[1 : len(params)-1]) + if len(query2) > 0 { + s.batch.Reset([]rune(query2)) + _, _, err := s.batch.Next() + if err != nil { + return err + } + query2 = s.batch.String() + if len(query2) > 0 { + query2 = s.getRunnableQuery(query2) + } + } + + if len(query1) > 0 || len(query2) > 0 { + query := query1 + SqlcmdEol + query2 + s.Exitcode, _ = s.runQuery(query) + } + return ErrExitRequested +} + +// quitCommand immediately exits the program without running any more batches +func quitCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("QUIT", line) + } + return ErrExitRequested +} + +// goCommand runs the current batch the number of times specified +func goCommand(s *Sqlcmd, args []string, line uint) error { + // default to 1 execution + n := 1 + var err error + if len(args) > 0 { + cnt := strings.TrimSpace(args[0]) + if cnt != "" { + if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil { + return err + } + _, err = fmt.Sscanf(cnt, "%d", &n) + } + } + if err != nil || n < 1 { + return InvalidCommandError("GO", line) + } + if s.EchoInput { + err = listCommand(s, []string{}, line) + } + if err != nil { + return InvalidCommandError("GO", line) + } + query := s.batch.String() + if query == "" { + return nil + } + query = s.getRunnableQuery(query) + for i := 0; i < n; i++ { + if retcode, err := s.runQuery(query); err != nil { + s.Exitcode = retcode + return err + } + } + s.batch.Reset(nil) + return nil +} + +// outCommand changes the output writer to use a file +func outCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("OUT", line) + } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err + } + + switch { + case strings.EqualFold(filePath, "stdout"): + s.SetOutput(os.Stdout) + case strings.EqualFold(filePath, "stderr"): + s.SetOutput(os.Stderr) + default: + o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return InvalidFileError(err, args[0]) + } + if s.UnicodeOutputFile { + // ODBC sqlcmd doesn't write a BOM but we will. + // Maybe the endian-ness should be configurable. + win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) + encoder := transform.NewWriter(o, win16le.NewEncoder()) + s.SetOutput(encoder) + } else { + s.SetOutput(o) + } + } + return nil +} + +// errorCommand changes the error writer to use a file +func errorCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("ERROR", line) + } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err + } + switch { + case strings.EqualFold(filePath, "stderr"): + s.SetError(os.Stderr) + case strings.EqualFold(filePath, "stdout"): + s.SetError(os.Stdout) + default: + o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return InvalidFileError(err, args[0]) + } + s.SetError(o) + } + return nil +} + +func readFileCommand(s *Sqlcmd, args []string, line uint) error { + if args == nil || len(args) != 1 { + return InvalidCommandError(":R", line) + } + fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false) + return s.IncludeFile(fileName, false) +} + +// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables +func setVarCommand(s *Sqlcmd, args []string, line uint) error { + if args == nil || len(args) != 1 || args[0] == "" { + return InvalidCommandError(":SETVAR", line) + } + + varname := args[0] + val := "" + // The prior incarnation of sqlcmd doesn't require a space between the variable name and its value + // in some very unexpected cases. This version will require the space. + sp := strings.IndexRune(args[0], ' ') + if sp > -1 { + val = strings.TrimSpace(varname[sp:]) + varname = varname[:sp] + } + if err := s.vars.Setvar(varname, val); err != nil { + switch e := err.(type) { + case *VariableError: + return e + default: + return InvalidCommandError(":SETVAR", line) + } + } + return nil +} + +// listVarCommand prints the set of Sqlcmd scripting variables. +// Builtin values are printed first, followed by user-set values in sorted order. +func listVarCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("LISTVAR", line) + } + + vars := s.vars.All() + keys := make([]string, 0, len(vars)) + for k := range vars { + if !contains(builtinVariables, k) { + keys = append(keys, k) + } + } + sort.Strings(keys) + keys = append(builtinVariables, keys...) + for _, k := range keys { + if _, err := fmt.Fprintf(s.GetOutput(), `%s = "%s"%s`, k, vars[k], SqlcmdEol); err != nil { + return err + } + } + return nil +} + +// resetCommand resets the statement cache +func resetCommand(s *Sqlcmd, args []string, line uint) error { + if s.batch != nil { + s.batch.Reset(nil) + } + + return nil +} + +// listCommand displays statements currently in the statement cache +func listCommand(s *Sqlcmd, args []string, line uint) (err error) { + cmd := "" + if args != nil { + if len(args) > 0 { + cmd = strings.ToLower(strings.TrimSpace(args[0])) + if len(args) > 1 || (cmd != "color" && cmd != "") { + return InvalidCommandError("LIST", line) + } + } + } + output := s.GetOutput() + if cmd == "color" { + sample := "select 'literal' as literal, 100 as number from [sys].[tables]" + clr := color.TextTypeTSql + if s.Format.IsXmlMode() { + sample = `value` + clr = color.TextTypeXml + } + // ignoring errors since it's not critical output + for _, style := range s.colorizer.Styles() { + _, _ = output.Write([]byte(style + ": ")) + _ = s.colorizer.Write(output, sample, style, clr) + _, _ = output.Write([]byte(SqlcmdEol)) + } + return + } + if s.batch == nil || s.batch.String() == "" { + return + } + + if err = s.colorizer.Write(output, s.batch.String(), s.vars.ColorScheme(), color.TextTypeTSql); err == nil { + _, err = output.Write([]byte(SqlcmdEol)) + } + + return +} + +func connectCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return InvalidCommandError("CONNECT", line) + } + + commandArgs := strings.Fields(args[0]) + + // Require at least the server name parameter + if len(commandArgs) == 0 { + return InvalidCommandError("CONNECT", line) + } + + // Parse flags + flags := flag.NewFlagSet("connect", flag.ContinueOnError) + database := flags.String("D", "", "database name") + username := flags.String("U", "", "user name") + password := flags.String("P", "", "password") + loginTimeout := flags.String("l", "", "login timeout") + authenticationMethod := flags.String("G", "", "authentication method") + + err := flags.Parse(commandArgs[1:]) + if err != nil { + return InvalidCommandError("CONNECT", line) + } + + connect := *s.Connect + connect.UserName, _ = resolveArgumentVariables(s, []rune(*username), false) + connect.Password, _ = resolveArgumentVariables(s, []rune(*password), false) + connect.Database, _ = resolveArgumentVariables(s, []rune(*database), false) + + timeout, _ := resolveArgumentVariables(s, []rune(*loginTimeout), false) + if timeout != "" { + if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil { + if timeoutSeconds < 0 { + return InvalidCommandError("CONNECT", line) + } + connect.LoginTimeoutSeconds = int(timeoutSeconds) + } + } + + connect.AuthenticationMethod = *authenticationMethod + + // Set server name as the first positional argument + if len(commandArgs) > 0 { + connect.ServerName, _ = resolveArgumentVariables(s, []rune(commandArgs[0]), false) + } + + // If no user name is provided we switch to integrated auth + _ = s.ConnectDb(&connect, s.lineIo == nil) + + // ConnectDb prints connection errors already, and failure to connect is not fatal even with -b option + return nil +} + +func execCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 { + return InvalidCommandError("EXEC", line) + } + cmdLine := strings.TrimSpace(args[0]) + if cmdLine == "" { + return InvalidCommandError("EXEC", line) + } + if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil { + return err + } else { + cmd := sysCommand(cmdLine) + cmd.Stderr = s.GetError() + cmd.Stdout = s.GetOutput() + _ = cmd.Run() + } + return nil +} + +func editCommand(s *Sqlcmd, args []string, line uint) error { + if args != nil && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError("ED", line) + } + file, err := os.CreateTemp("", "sq*.sql") + if err != nil { + return err + } + fileName := file.Name() + defer func() { + // Best-effort cleanup - ignore errors + _ = os.Remove(fileName) + }() + defer func() { + // Ensure file is closed on all paths + _ = file.Close() + }() + text := s.batch.String() + if s.batch.State() == "-" { + text = fmt.Sprintf("%s%s", text, SqlcmdEol) + } + _, err = file.WriteString(text) + if err != nil { + return err + } + // Explicitly close before launching editor (defer will close again but that's safe) + if err := file.Close(); err != nil { + return err + } + cmd := sysCommand(s.vars.TextEditor() + " " + `"` + fileName + `"`) + cmd.Stderr = s.GetError() + cmd.Stdout = s.GetOutput() + err = cmd.Run() + if err != nil { + return err + } + wasEcho := s.echoFileLines + s.echoFileLines = true + s.batch.Reset(nil) + _ = s.IncludeFile(fileName, false) + s.echoFileLines = wasEcho + return nil +} + +func onerrorCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError("ON ERROR", line) + } + params := strings.TrimSpace(args[0]) + + if strings.EqualFold(strings.ToLower(params), "exit") { + s.Connect.ExitOnError = true + } else if strings.EqualFold(strings.ToLower(params), "ignore") { + s.Connect.IgnoreError = true + s.Connect.ExitOnError = false + } else { + return InvalidCommandError("ON ERROR", line) + } + return nil +} + +func xmlCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) != 1 || args[0] == "" { + return InvalidCommandError("XML", line) + } + params := strings.TrimSpace(args[0]) + // "OFF" and "ON" are documented as the allowed values. + // ODBC sqlcmd treats any value other than "ON" the same as "OFF". + // So we will too. + if strings.EqualFold(params, "on") { + s.Format.XmlMode(true) + } else { + s.Format.XmlMode(false) + } + return nil +} + +func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) { + var b *strings.Builder + end := len(arg) + for i := 0; i < end && !s.Connect.DisableVariableSubstitution; { + c, next := arg[i], grab(arg, i+1, end) + switch { + case c == '$' && next == '(': + vl, ok := readVariableReference(arg, i+2, end) + if ok { + varName := string(arg[i+2 : vl]) + val, ok := s.resolveVariable(varName) + if ok { + if b == nil { + b = new(strings.Builder) + b.Grow(len(arg)) + b.WriteString(string(arg[0:i])) + } + b.WriteString(val) + } else { + if failOnUnresolved { + return "", UndefinedVariable(varName) + } + s.WriteError(s.GetError(), UndefinedVariable(varName)) + if b != nil { + b.WriteString(string(arg[i : vl+1])) + } + } + i += ((vl - i) + 1) + } else { + if b != nil { + b.WriteString("$(") + } + i += 2 + } + default: + if b != nil { + b.WriteRune(c) + } + i++ + } + } + if b == nil { + return string(arg), nil + } + return b.String(), nil +} diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go index 55bd2e25..fb6c057b 100644 --- a/pkg/sqlcmd/format.go +++ b/pkg/sqlcmd/format.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "strconv" "strings" "time" @@ -59,6 +60,12 @@ const ( ControlReplaceConsecutive ) +const ( + // Default display widths for float types + realDefaultWidth int64 = 14 // For REAL and SMALLMONEY + floatDefaultWidth int64 = 24 // For FLOAT and MONEY +) + type columnDetail struct { displayWidth int64 leftJustify bool @@ -371,11 +378,11 @@ func calcColumnDetails(cols []*sql.ColumnType, fixed int64, variable int64) ([]c columnDetails[i].displayWidth = max64(21, nameLen) case "REAL", "SMALLMONEY": columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(14, nameLen) + columnDetails[i].displayWidth = max64(realDefaultWidth, nameLen) columnDetails[i].zeroesAfterDecimal = true case "FLOAT", "MONEY": columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(24, nameLen) + columnDetails[i].displayWidth = max64(floatDefaultWidth, nameLen) columnDetails[i].zeroesAfterDecimal = true case "DECIMAL": columnDetails[i].leftJustify = false @@ -530,6 +537,56 @@ func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) { } else { row[n] = "0" } + case float64: + // Format float64 to match ODBC sqlcmd behavior + // Use 'f' format with -1 precision to avoid scientific notation for typical values + // Fall back to 'g' format if the result would exceed the column display width + + // Use appropriate bitSize based on the SQL type (REAL=32, FLOAT=64) + // REAL columns should use 32-bit precision even though the value is scanned as float64 + bitSize := 64 + typeName := f.columnDetails[n].col.DatabaseTypeName() + if typeName == "REAL" || typeName == "SMALLMONEY" { + bitSize = 32 + } + + formatted := strconv.FormatFloat(x, 'f', -1, bitSize) + displayWidth := f.columnDetails[n].displayWidth + + // Use the type's default display width when displayWidth is 0 (unlimited) + // to avoid extremely long strings for extreme values + widthThreshold := displayWidth + if widthThreshold == 0 { + if typeName == "REAL" || typeName == "SMALLMONEY" { + widthThreshold = realDefaultWidth + } else { + widthThreshold = floatDefaultWidth + } + } + + if int64(len(formatted)) > widthThreshold { + // Use 'g' format for very large/small values to avoid truncation issues + formatted = strconv.FormatFloat(x, 'g', -1, bitSize) + } + row[n] = formatted + case float32: + // Format float32 to match ODBC sqlcmd behavior + // float32 values are rare (database/sql typically normalizes to float64) + // Use bitSize 32 to maintain precision appropriate for the original float32 value + formatted := strconv.FormatFloat(float64(x), 'f', -1, 32) + displayWidth := f.columnDetails[n].displayWidth + + // Use default REAL display width when displayWidth is 0 + widthThreshold := displayWidth + if widthThreshold == 0 { + widthThreshold = realDefaultWidth + } + + if int64(len(formatted)) > widthThreshold { + // Use 'g' format for very large/small values to avoid truncation issues + formatted = strconv.FormatFloat(float64(x), 'g', -1, 32) + } + row[n] = formatted default: var err error if row[n], err = fmt.Sprintf("%v", x), nil; err != nil { diff --git a/pkg/sqlcmd/format_test.go b/pkg/sqlcmd/format_test.go index f4bee464..c7caf98c 100644 --- a/pkg/sqlcmd/format_test.go +++ b/pkg/sqlcmd/format_test.go @@ -162,3 +162,147 @@ func TestFormatterXmlMode(t *testing.T) { assert.NoError(t, err, "runSqlCmd returned error") assert.Equal(t, ``+SqlcmdEol, buf.buf.String()) } + +func TestFormatterFloatFormatting(t *testing.T) { + // Test that float formatting matches ODBC sqlcmd behavior + // This addresses the issue where go-sqlcmd was using scientific notation + // while ODBC sqlcmd uses decimal notation + s, buf := setupSqlCmdWithMemoryOutput(t) + defer func() { _ = buf.Close() }() + + // Set SQLCMDMAXVARTYPEWIDTH to a non-zero value so FLOAT columns use the 24-char display width + // This enables the width-based fallback logic to be tested properly + s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256") + + // Test query with float values from the issue + query := `SELECT + CAST(788991.19988463481 AS FLOAT) as Longitude1, + CAST(4713347.3103808956 AS FLOAT) as Latitude1, + CAST(789288.40771771886 AS FLOAT) as Longitude2, + CAST(4712632.075629076 AS FLOAT) as Latitude2, + CAST(788569.36558582436 AS FLOAT) as Longitude3, + CAST(4714608.0418091472 AS FLOAT) as Latitude3` + + err := runSqlCmd(t, s, []string{query, "GO"}) + assert.NoError(t, err, "runSqlCmd returned error") + + output := buf.buf.String() + + // Verify that the output contains decimal notation, not scientific notation + // Scientific notation would look like "4.713347310380896e+06" + // Decimal notation should look like "4713347.3103808956" + assert.NotContains(t, output, "e+", "Output should not contain scientific notation (e+)") + assert.NotContains(t, output, "E+", "Output should not contain scientific notation (E+)") + + // Verify that specific expected values are present (allowing for precision differences) + assert.Contains(t, output, "788991.1998846", "Output should contain decimal representation of Longitude1") + assert.Contains(t, output, "4713347.310380", "Output should contain decimal representation of Latitude1") + assert.Contains(t, output, "789288.4077177", "Output should contain decimal representation of Longitude2") + assert.Contains(t, output, "4712632.075629", "Output should contain decimal representation of Latitude2") + assert.Contains(t, output, "788569.3655858", "Output should contain decimal representation of Longitude3") + assert.Contains(t, output, "4714608.041809", "Output should contain decimal representation of Latitude3") +} + +func TestFormatterFloatFormattingExtremeValues(t *testing.T) { + // Test that extreme float values fall back to scientific notation + // to avoid truncation issues with very large or very small numbers + s, buf := setupSqlCmdWithMemoryOutput(t) + defer func() { _ = buf.Close() }() + + // Set SQLCMDMAXVARTYPEWIDTH to a non-zero value so FLOAT columns use the 24-char display width + // This allows the fallback behavior to be tested + s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256") + + // Test query with extreme float values that would exceed the 24-char display width + query := `SELECT + CAST(1e100 AS FLOAT) as VeryLarge, + CAST(1e-100 AS FLOAT) as VerySmall` + + err := runSqlCmd(t, s, []string{query, "GO"}) + assert.NoError(t, err, "runSqlCmd returned error") + + output := buf.buf.String() + + // Verify that extremely large values use scientific notation with positive exponent + // (because decimal format would exceed the 24-char column width) + assert.Contains(t, output, "e+", "Output should contain scientific notation (e+) for very large values") + + // Verify that extremely small values use scientific notation with negative exponent + assert.Contains(t, output, "e-", "Output should contain scientific notation (e-) for very small values") +} + +func TestFormatterFloatFormattingExtremeValuesUnlimitedWidth(t *testing.T) { + // Test that extreme float values fall back to scientific notation even when + // displayWidth is 0 (unlimited), using type default widths as threshold + s, buf := setupSqlCmdWithMemoryOutput(t) + defer func() { _ = buf.Close() }() + + // Leave SQLCMDMAXVARTYPEWIDTH at 0 (setupSqlCmdWithMemoryOutput default) + // This sets displayWidth to 0 for all columns, testing the fallback logic + // that uses type default widths (24 for FLOAT, 14 for REAL) + + // Test query with extreme float values + query := `SELECT + CAST(1e100 AS FLOAT) as VeryLarge, + CAST(1e-100 AS FLOAT) as VerySmall` + + err := runSqlCmd(t, s, []string{query, "GO"}) + assert.NoError(t, err, "runSqlCmd returned error") + + output := buf.buf.String() + + // Verify that extreme values still use scientific notation even with unlimited width + // (fallback should use type default widths to prevent unbounded output) + assert.Contains(t, output, "e+", "Output should contain scientific notation (e+) for very large values even with unlimited width") + assert.Contains(t, output, "e-", "Output should contain scientific notation (e-) for very small values even with unlimited width") +} + +func TestFormatterRealFormatting(t *testing.T) { + // Test that REAL (float32) values use decimal notation for typical values + // and fall back to scientific notation for extreme values + s, buf := setupSqlCmdWithMemoryOutput(t) + defer func() { _ = buf.Close() }() + + // Set SQLCMDMAXVARTYPEWIDTH to a non-zero value so REAL columns use the 14-char display width + s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256") + + // Test query with REAL values (both typical and extreme) + query := `SELECT + CAST(123.456789 AS REAL) as TypicalValue, + CAST(1e30 AS REAL) as ExtremeValue` + + err := runSqlCmd(t, s, []string{query, "GO"}) + assert.NoError(t, err, "runSqlCmd returned error") + + output := buf.buf.String() + + // Split output into lines to examine the data row separately from headers + lines := strings.Split(output, SqlcmdEol) + var dataLine string + for _, line := range lines { + // Find the data line (contains actual values, not headers or separators) + if strings.Contains(line, "123.") { + dataLine = line + break + } + } + + // Verify that typical REAL values use decimal notation (not scientific) + assert.Contains(t, dataLine, "123.456", "Output should contain decimal representation of typical REAL value") + + // Check that the typical value portion doesn't use scientific notation + // Parse columns using whitespace (the default column separator) + fields := strings.Fields(dataLine) + if len(fields) > 0 { + // Find the field containing the typical value + for _, field := range fields { + if strings.Contains(field, "123.") { + assert.NotContains(t, field, "e", "Typical REAL value should not use scientific notation") + break + } + } + } + + // Verify that extreme REAL values use scientific notation + assert.Contains(t, output, "e+", "Output should contain scientific notation for extreme REAL value") +}