diff --git a/internal/endtoend/fmt_test.go b/internal/endtoend/fmt_test.go index db4aaee747..550033de49 100644 --- a/internal/endtoend/fmt_test.go +++ b/internal/endtoend/fmt_test.go @@ -13,6 +13,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/debug" "github.com/sqlc-dev/sqlc/internal/engine/dolphin" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/engine/sqlite" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/format" ) @@ -79,6 +80,22 @@ func TestFormat(t *testing.T) { } return ast.Format(stmts[0].Raw, mysqlParser), nil } + case config.EngineSQLite: + sqliteParser := sqlite.NewParser() + parse = sqliteParser + formatter = sqliteParser + // For SQLite, we use the same "round-trip" fingerprint strategy as MySQL: + // parse the SQL, format it, and return the formatted string. + fingerprint = func(sql string) (string, error) { + stmts, err := sqliteParser.Parse(strings.NewReader(sql)) + if err != nil { + return "", err + } + if len(stmts) == 0 { + return "", nil + } + return strings.ToLower(ast.Format(stmts[0].Raw, sqliteParser)), nil + } default: // Skip unsupported engines return diff --git a/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go b/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go index e22e5b6f33..b30fa7d95a 100644 --- a/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/select_exists/sqlite/go/query.sql.go @@ -21,9 +21,9 @@ SELECT ) ` -func (q *Queries) BarExists(ctx context.Context, id int64) (int64, error) { +func (q *Queries) BarExists(ctx context.Context, id int64) (bool, error) { row := q.db.QueryRowContext(ctx, barExists, id) - var column_1 int64 - err := row.Scan(&column_1) - return column_1, err + var exists bool + err := row.Scan(&exists) + return exists, err } diff --git a/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go b/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go index 6da4636da8..91dea13570 100644 --- a/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/select_not_exists/sqlite/go/query.sql.go @@ -21,9 +21,9 @@ SELECT ) ` -func (q *Queries) BarNotExists(ctx context.Context, dollar_1 interface{}) (interface{}, error) { - row := q.db.QueryRowContext(ctx, barNotExists, dollar_1) - var column_1 interface{} - err := row.Scan(&column_1) - return column_1, err +func (q *Queries) BarNotExists(ctx context.Context, id int64) (bool, error) { + row := q.db.QueryRowContext(ctx, barNotExists, id) + var not_exists bool + err := row.Scan(¬_exists) + return not_exists, err } diff --git a/internal/engine/dolphin/format.go b/internal/engine/dolphin/format.go index 458ae02363..9c6346756c 100644 --- a/internal/engine/dolphin/format.go +++ b/internal/engine/dolphin/format.go @@ -29,6 +29,13 @@ func (p *Parser) Param(n int) string { return "?" } +// NamedParam returns the named parameter placeholder for the given name. +// MySQL doesn't have native named parameters, so we use ? (positional). +// The actual parameter names are handled by sqlc's rewrite phase. +func (p *Parser) NamedParam(name string) string { + return "?" +} + // Cast returns a type cast expression. // MySQL uses CAST(expr AS type) syntax. func (p *Parser) Cast(arg, typeName string) string { diff --git a/internal/engine/postgresql/reserved.go b/internal/engine/postgresql/reserved.go index 9254bfdb82..b9ccc76d30 100644 --- a/internal/engine/postgresql/reserved.go +++ b/internal/engine/postgresql/reserved.go @@ -64,6 +64,12 @@ func (p *Parser) Param(n int) string { return fmt.Sprintf("$%d", n) } +// NamedParam returns the named parameter placeholder for the given name. +// PostgreSQL/sqlc uses @name syntax. +func (p *Parser) NamedParam(name string) string { + return "@" + name +} + // Cast returns a type cast expression. // PostgreSQL uses expr::type syntax. func (p *Parser) Cast(arg, typeName string) string { diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index 658a9d7f33..e9868f5be6 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -514,7 +514,10 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt()) selectStmt.LimitCount = limitCount selectStmt.LimitOffset = limitOffset - selectStmt.WithClause = &ast.WithClause{Ctes: &ctes} + // Only set WithClause if there are CTEs + if len(ctes.Items) > 0 { + selectStmt.WithClause = &ast.WithClause{Ctes: &ctes} + } return selectStmt } @@ -759,6 +762,13 @@ func (c *cc) convertLiteral(n *parser.Expr_literalContext) ast.Node { Location: n.GetStart().GetStart(), } } + + if literal.NULL_() != nil { + return &ast.A_Const{ + Val: &ast.Null{}, + Location: n.GetStart().GetStart(), + } + } } return todo("convertLiteral", n) } @@ -776,8 +786,14 @@ func (c *cc) convertBinaryNode(n *parser.Expr_binaryContext) ast.Node { } func (c *cc) convertBoolNode(n *parser.Expr_boolContext) ast.Node { + var op ast.BoolExprType + if n.AND_() != nil { + op = ast.BoolExprTypeAnd + } else if n.OR_() != nil { + op = ast.BoolExprTypeOr + } return &ast.BoolExpr{ - // TODO: Set op + Boolop: op, Args: &ast.List{ Items: []ast.Node{ c.convert(n.Expr(0)), @@ -787,6 +803,49 @@ func (c *cc) convertBoolNode(n *parser.Expr_boolContext) ast.Node { } } +func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { + op := n.Unary_operator() + if op == nil { + return c.convert(n.Expr()) + } + + // Get the inner expression + expr := c.convert(n.Expr()) + + // Check the operator type + if opCtx, ok := op.(*parser.Unary_operatorContext); ok { + if opCtx.NOT_() != nil { + // NOT expression + return &ast.BoolExpr{ + Boolop: ast.BoolExprTypeNot, + Args: &ast.List{ + Items: []ast.Node{expr}, + }, + } + } + if opCtx.MINUS() != nil { + // Negative number: -expr + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, + Rexpr: expr, + } + } + if opCtx.PLUS() != nil { + // Positive number: +expr (just return expr) + return expr + } + if opCtx.TILDE() != nil { + // Bitwise NOT: ~expr + return &ast.A_Expr{ + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, + Rexpr: expr, + } + } + } + + return expr +} + func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node { if n.NUMBERED_BIND_PARAMETER() != nil { // Parameter numbers start at one @@ -816,7 +875,52 @@ func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node { } func (c *cc) convertInSelectNode(n *parser.Expr_in_selectContext) ast.Node { - return c.convert(n.Select_stmt()) + // Check if this is EXISTS or NOT EXISTS + if n.EXISTS_() != nil { + linkType := ast.EXISTS_SUBLINK + sublink := &ast.SubLink{ + SubLinkType: linkType, + Subselect: c.convert(n.Select_stmt()), + } + if n.NOT_() != nil { + // NOT EXISTS is represented as a BoolExpr NOT wrapping the EXISTS + return &ast.BoolExpr{ + Boolop: ast.BoolExprTypeNot, + Args: &ast.List{ + Items: []ast.Node{sublink}, + }, + } + } + return sublink + } + + // Check if this is an IN/NOT IN expression: expr IN (SELECT ...) + if n.IN_() != nil && len(n.AllExpr()) > 0 { + linkType := ast.ANY_SUBLINK + sublink := &ast.SubLink{ + SubLinkType: linkType, + Testexpr: c.convert(n.Expr(0)), + Subselect: c.convert(n.Select_stmt()), + } + if n.NOT_() != nil { + return &ast.A_Expr{ + Kind: ast.A_Expr_Kind_OP, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "NOT IN"}}}, + Lexpr: c.convert(n.Expr(0)), + Rexpr: &ast.SubLink{ + SubLinkType: ast.EXPR_SUBLINK, + Subselect: c.convert(n.Select_stmt()), + }, + } + } + return sublink + } + + // Plain subquery in parentheses (SELECT ...) + return &ast.SubLink{ + SubLinkType: ast.EXPR_SUBLINK, + Subselect: c.convert(n.Select_stmt()), + } } func (c *cc) convertReturning_caluseContext(n parser.IReturning_clauseContext) *ast.List { @@ -887,12 +991,8 @@ func (c *cc) convertInsert_stmtContext(n *parser.Insert_stmtContext) ast.Node { } if hasDefaultValues { - // For DEFAULT VALUES, create an empty select statement - insert.SelectStmt = &ast.SelectStmt{ - FromClause: &ast.List{}, - TargetList: &ast.List{}, - ValuesLists: &ast.List{Items: []ast.Node{&ast.List{}}}, // Single empty values list - } + // For DEFAULT VALUES, set the flag instead of creating an empty values list + insert.DefaultValues = true } else if n.Select_stmt() != nil { if ss, ok := c.convert(n.Select_stmt()).(*ast.SelectStmt); ok { ss.ValuesLists = &ast.List{} @@ -976,6 +1076,11 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast tables = append(tables, rv) } else if from.Table_function_name() != nil { rel := from.Table_function_name().GetText() + // Convert function arguments + var args []ast.Node + for _, expr := range from.AllExpr() { + args = append(args, c.convert(expr)) + } rf := &ast.RangeFunction{ Functions: &ast.List{ Items: []ast.Node{ @@ -989,7 +1094,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast }, }, Args: &ast.List{ - Items: []ast.Node{&ast.TODO{}}, + Items: args, }, Location: from.GetStart().GetStart(), }, @@ -1189,6 +1294,9 @@ func (c *cc) convert(node node) ast.Node { case *parser.Expr_binaryContext: return c.convertBinaryNode(n) + case *parser.Expr_unaryContext: + return c.convertUnaryExpr(n) + case *parser.Expr_in_selectContext: return c.convertInSelectNode(n) diff --git a/internal/engine/sqlite/format.go b/internal/engine/sqlite/format.go new file mode 100644 index 0000000000..39ac859ca5 --- /dev/null +++ b/internal/engine/sqlite/format.go @@ -0,0 +1,35 @@ +package sqlite + +// QuoteIdent returns a quoted identifier if it needs quoting. +// SQLite uses double quotes for quoting identifiers (SQL standard), +// though backticks are also supported for MySQL compatibility. +func (p *Parser) QuoteIdent(s string) string { + // For now, don't quote - return as-is + return s +} + +// TypeName returns the SQL type name for the given namespace and name. +func (p *Parser) TypeName(ns, name string) string { + if ns != "" { + return ns + "." + name + } + return name +} + +// Param returns the parameter placeholder for the given number. +// SQLite uses ? for positional parameters. +func (p *Parser) Param(n int) string { + return "?" +} + +// NamedParam returns the named parameter placeholder for the given name. +// SQLite uses :name syntax for named parameters. +func (p *Parser) NamedParam(name string) string { + return ":" + name +} + +// Cast returns a type cast expression. +// SQLite uses CAST(expr AS type) syntax. +func (p *Parser) Cast(arg, typeName string) string { + return "CAST(" + arg + " AS " + typeName + ")" +} diff --git a/internal/sql/ast/a_expr.go b/internal/sql/ast/a_expr.go index 3b73d66d37..fc795a77ce 100644 --- a/internal/sql/ast/a_expr.go +++ b/internal/sql/ast/a_expr.go @@ -12,10 +12,36 @@ func (n *A_Expr) Pos() int { return n.Location } +// isNamedParam returns true if this A_Expr represents a named parameter (@name) +// and extracts the parameter name if so. +func (n *A_Expr) isNamedParam() (string, bool) { + if n.Name == nil || len(n.Name.Items) != 1 { + return "", false + } + s, ok := n.Name.Items[0].(*String) + if !ok || s.Str != "@" { + return "", false + } + if set(n.Lexpr) || !set(n.Rexpr) { + return "", false + } + if nameStr, ok := n.Rexpr.(*String); ok { + return nameStr.Str, true + } + return "", false +} + func (n *A_Expr) Format(buf *TrackedBuffer) { if n == nil { return } + + // Check for named parameter first (works regardless of Kind) + if name, ok := n.isNamedParam(); ok { + buf.WriteString(buf.NamedParam(name)) + return + } + switch n.Kind { case A_Expr_Kind_IN: buf.astFormat(n.Lexpr) @@ -64,32 +90,8 @@ func (n *A_Expr) Format(buf *TrackedBuffer) { buf.WriteString(", ") buf.astFormat(n.Rexpr) buf.WriteString(")") - case A_Expr_Kind_OP: - // Check if this is a named parameter (@name) - opName := "" - if n.Name != nil && len(n.Name.Items) == 1 { - if s, ok := n.Name.Items[0].(*String); ok { - opName = s.Str - } - } - if opName == "@" && !set(n.Lexpr) && set(n.Rexpr) { - // Named parameter: @name (no space after @) - buf.WriteString("@") - buf.astFormat(n.Rexpr) - } else { - // Standard binary operator - if set(n.Lexpr) { - buf.astFormat(n.Lexpr) - buf.WriteString(" ") - } - buf.astFormat(n.Name) - if set(n.Rexpr) { - buf.WriteString(" ") - buf.astFormat(n.Rexpr) - } - } default: - // Fallback for other cases + // Standard operator (including A_Expr_Kind_OP) if set(n.Lexpr) { buf.astFormat(n.Lexpr) buf.WriteString(" ") diff --git a/internal/sql/ast/bool_expr.go b/internal/sql/ast/bool_expr.go index 9bbddfd7dc..0241503a06 100644 --- a/internal/sql/ast/bool_expr.go +++ b/internal/sql/ast/bool_expr.go @@ -26,6 +26,12 @@ func (n *BoolExpr) Format(buf *TrackedBuffer) { buf.astFormat(n.Args.Items[0]) } buf.WriteString(" IS NOT NULL") + case BoolExprTypeNot: + // NOT expression: format as NOT + buf.WriteString("NOT ") + if items(n.Args) && len(n.Args.Items) > 0 { + buf.astFormat(n.Args.Items[0]) + } default: buf.WriteString("(") if items(n.Args) { @@ -34,9 +40,6 @@ func (n *BoolExpr) Format(buf *TrackedBuffer) { buf.join(n.Args, " AND ") case BoolExprTypeOr: buf.join(n.Args, " OR ") - case BoolExprTypeNot: - buf.WriteString(" NOT ") - buf.astFormat(n.Args) } } buf.WriteString(")") diff --git a/internal/sql/ast/collate_expr.go b/internal/sql/ast/collate_expr.go index 6c32eece77..fd9a891e08 100644 --- a/internal/sql/ast/collate_expr.go +++ b/internal/sql/ast/collate_expr.go @@ -10,3 +10,12 @@ type CollateExpr struct { func (n *CollateExpr) Pos() int { return n.Location } + +func (n *CollateExpr) Format(buf *TrackedBuffer) { + if n == nil { + return + } + buf.astFormat(n.Xpr) + buf.WriteString(" COLLATE ") + buf.astFormat(n.Arg) +} diff --git a/internal/sql/ast/insert_stmt.go b/internal/sql/ast/insert_stmt.go index f287df4ae7..75ef44863a 100644 --- a/internal/sql/ast/insert_stmt.go +++ b/internal/sql/ast/insert_stmt.go @@ -9,6 +9,7 @@ type InsertStmt struct { ReturningList *List WithClause *WithClause Override OverridingKind + DefaultValues bool // SQLite-specific: INSERT INTO ... DEFAULT VALUES } func (n *InsertStmt) Pos() int { @@ -35,7 +36,9 @@ func (n *InsertStmt) Format(buf *TrackedBuffer) { buf.WriteString(")") } - if set(n.SelectStmt) { + if n.DefaultValues { + buf.WriteString(" DEFAULT VALUES") + } else if set(n.SelectStmt) { buf.WriteString(" ") buf.astFormat(n.SelectStmt) } diff --git a/internal/sql/ast/print.go b/internal/sql/ast/print.go index c4390a15c5..6335846946 100644 --- a/internal/sql/ast/print.go +++ b/internal/sql/ast/print.go @@ -65,6 +65,15 @@ func (t *TrackedBuffer) Cast(arg, typeName string) string { return arg + "::" + typeName } +// NamedParam returns the named parameter placeholder for the given name. +// If no formatter is set, it returns PostgreSQL-style @name. +func (t *TrackedBuffer) NamedParam(name string) string { + if t.formatter != nil { + return t.formatter.NamedParam(name) + } + return "@" + name +} + func (t *TrackedBuffer) astFormat(n Node) { if ft, ok := n.(nodeFormatter); ok { ft.Format(t) diff --git a/internal/sql/format/format.go b/internal/sql/format/format.go index 922b02b61c..02140757f7 100644 --- a/internal/sql/format/format.go +++ b/internal/sql/format/format.go @@ -14,6 +14,10 @@ type Formatter interface { // PostgreSQL uses $1, $2, etc. MySQL uses ? Param(n int) string + // NamedParam returns the named parameter placeholder for the given name. + // PostgreSQL uses @name, SQLite uses :name + NamedParam(name string) string + // Cast formats a type cast expression. // PostgreSQL uses expr::type, MySQL uses CAST(expr AS type) Cast(arg, typeName string) string