Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions internal/endtoend/fmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/debug"
"github.com/sqlc-dev/sqlc/internal/engine/dolphin"
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
"github.com/sqlc-dev/sqlc/internal/engine/sqlite"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/format"
)
Expand Down Expand Up @@ -79,6 +80,22 @@ func TestFormat(t *testing.T) {
}
return ast.Format(stmts[0].Raw, mysqlParser), nil
}
case config.EngineSQLite:
sqliteParser := sqlite.NewParser()
parse = sqliteParser
formatter = sqliteParser
// For SQLite, we use the same "round-trip" fingerprint strategy as MySQL:
// parse the SQL, format it, and return the formatted string.
fingerprint = func(sql string) (string, error) {
stmts, err := sqliteParser.Parse(strings.NewReader(sql))
if err != nil {
return "", err
}
if len(stmts) == 0 {
return "", nil
}
return strings.ToLower(ast.Format(stmts[0].Raw, sqliteParser)), nil
}
default:
// Skip unsupported engines
return
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions internal/engine/dolphin/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ func (p *Parser) Param(n int) string {
return "?"
}

// NamedParam returns the named parameter placeholder for the given name.
// MySQL doesn't have native named parameters, so we use ? (positional).
// The actual parameter names are handled by sqlc's rewrite phase.
func (p *Parser) NamedParam(name string) string {
return "?"
}

// Cast returns a type cast expression.
// MySQL uses CAST(expr AS type) syntax.
func (p *Parser) Cast(arg, typeName string) string {
Expand Down
6 changes: 6 additions & 0 deletions internal/engine/postgresql/reserved.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ func (p *Parser) Param(n int) string {
return fmt.Sprintf("$%d", n)
}

// NamedParam returns the named parameter placeholder for the given name.
// PostgreSQL/sqlc uses @name syntax.
func (p *Parser) NamedParam(name string) string {
return "@" + name
}

// Cast returns a type cast expression.
// PostgreSQL uses expr::type syntax.
func (p *Parser) Cast(arg, typeName string) string {
Expand Down
128 changes: 118 additions & 10 deletions internal/engine/sqlite/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,10 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No
limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt())
selectStmt.LimitCount = limitCount
selectStmt.LimitOffset = limitOffset
selectStmt.WithClause = &ast.WithClause{Ctes: &ctes}
// Only set WithClause if there are CTEs
if len(ctes.Items) > 0 {
selectStmt.WithClause = &ast.WithClause{Ctes: &ctes}
}
return selectStmt
}

Expand Down Expand Up @@ -759,6 +762,13 @@ func (c *cc) convertLiteral(n *parser.Expr_literalContext) ast.Node {
Location: n.GetStart().GetStart(),
}
}

if literal.NULL_() != nil {
return &ast.A_Const{
Val: &ast.Null{},
Location: n.GetStart().GetStart(),
}
}
}
return todo("convertLiteral", n)
}
Expand All @@ -776,8 +786,14 @@ func (c *cc) convertBinaryNode(n *parser.Expr_binaryContext) ast.Node {
}

func (c *cc) convertBoolNode(n *parser.Expr_boolContext) ast.Node {
var op ast.BoolExprType
if n.AND_() != nil {
op = ast.BoolExprTypeAnd
} else if n.OR_() != nil {
op = ast.BoolExprTypeOr
}
return &ast.BoolExpr{
// TODO: Set op
Boolop: op,
Args: &ast.List{
Items: []ast.Node{
c.convert(n.Expr(0)),
Expand All @@ -787,6 +803,49 @@ func (c *cc) convertBoolNode(n *parser.Expr_boolContext) ast.Node {
}
}

func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node {
op := n.Unary_operator()
if op == nil {
return c.convert(n.Expr())
}

// Get the inner expression
expr := c.convert(n.Expr())

// Check the operator type
if opCtx, ok := op.(*parser.Unary_operatorContext); ok {
if opCtx.NOT_() != nil {
// NOT expression
return &ast.BoolExpr{
Boolop: ast.BoolExprTypeNot,
Args: &ast.List{
Items: []ast.Node{expr},
},
}
}
if opCtx.MINUS() != nil {
// Negative number: -expr
return &ast.A_Expr{
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}},
Rexpr: expr,
}
}
if opCtx.PLUS() != nil {
// Positive number: +expr (just return expr)
return expr
}
if opCtx.TILDE() != nil {
// Bitwise NOT: ~expr
return &ast.A_Expr{
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}},
Rexpr: expr,
}
}
}

return expr
}

func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node {
if n.NUMBERED_BIND_PARAMETER() != nil {
// Parameter numbers start at one
Expand Down Expand Up @@ -816,7 +875,52 @@ func (c *cc) convertParam(n *parser.Expr_bindContext) ast.Node {
}

func (c *cc) convertInSelectNode(n *parser.Expr_in_selectContext) ast.Node {
return c.convert(n.Select_stmt())
// Check if this is EXISTS or NOT EXISTS
if n.EXISTS_() != nil {
linkType := ast.EXISTS_SUBLINK
sublink := &ast.SubLink{
SubLinkType: linkType,
Subselect: c.convert(n.Select_stmt()),
}
if n.NOT_() != nil {
// NOT EXISTS is represented as a BoolExpr NOT wrapping the EXISTS
return &ast.BoolExpr{
Boolop: ast.BoolExprTypeNot,
Args: &ast.List{
Items: []ast.Node{sublink},
},
}
}
return sublink
}

// Check if this is an IN/NOT IN expression: expr IN (SELECT ...)
if n.IN_() != nil && len(n.AllExpr()) > 0 {
linkType := ast.ANY_SUBLINK
sublink := &ast.SubLink{
SubLinkType: linkType,
Testexpr: c.convert(n.Expr(0)),
Subselect: c.convert(n.Select_stmt()),
}
if n.NOT_() != nil {
return &ast.A_Expr{
Kind: ast.A_Expr_Kind_OP,
Name: &ast.List{Items: []ast.Node{&ast.String{Str: "NOT IN"}}},
Lexpr: c.convert(n.Expr(0)),
Rexpr: &ast.SubLink{
SubLinkType: ast.EXPR_SUBLINK,
Subselect: c.convert(n.Select_stmt()),
},
}
}
return sublink
}

// Plain subquery in parentheses (SELECT ...)
return &ast.SubLink{
SubLinkType: ast.EXPR_SUBLINK,
Subselect: c.convert(n.Select_stmt()),
}
}

func (c *cc) convertReturning_caluseContext(n parser.IReturning_clauseContext) *ast.List {
Expand Down Expand Up @@ -887,12 +991,8 @@ func (c *cc) convertInsert_stmtContext(n *parser.Insert_stmtContext) ast.Node {
}

if hasDefaultValues {
// For DEFAULT VALUES, create an empty select statement
insert.SelectStmt = &ast.SelectStmt{
FromClause: &ast.List{},
TargetList: &ast.List{},
ValuesLists: &ast.List{Items: []ast.Node{&ast.List{}}}, // Single empty values list
}
// For DEFAULT VALUES, set the flag instead of creating an empty values list
insert.DefaultValues = true
} else if n.Select_stmt() != nil {
if ss, ok := c.convert(n.Select_stmt()).(*ast.SelectStmt); ok {
ss.ValuesLists = &ast.List{}
Expand Down Expand Up @@ -976,6 +1076,11 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
tables = append(tables, rv)
} else if from.Table_function_name() != nil {
rel := from.Table_function_name().GetText()
// Convert function arguments
var args []ast.Node
for _, expr := range from.AllExpr() {
args = append(args, c.convert(expr))
}
rf := &ast.RangeFunction{
Functions: &ast.List{
Items: []ast.Node{
Expand All @@ -989,7 +1094,7 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
},
},
Args: &ast.List{
Items: []ast.Node{&ast.TODO{}},
Items: args,
},
Location: from.GetStart().GetStart(),
},
Expand Down Expand Up @@ -1189,6 +1294,9 @@ func (c *cc) convert(node node) ast.Node {
case *parser.Expr_binaryContext:
return c.convertBinaryNode(n)

case *parser.Expr_unaryContext:
return c.convertUnaryExpr(n)

case *parser.Expr_in_selectContext:
return c.convertInSelectNode(n)

Expand Down
35 changes: 35 additions & 0 deletions internal/engine/sqlite/format.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package sqlite

// QuoteIdent returns a quoted identifier if it needs quoting.
// SQLite uses double quotes for quoting identifiers (SQL standard),
// though backticks are also supported for MySQL compatibility.
func (p *Parser) QuoteIdent(s string) string {
// For now, don't quote - return as-is
return s
}

// TypeName returns the SQL type name for the given namespace and name.
func (p *Parser) TypeName(ns, name string) string {
if ns != "" {
return ns + "." + name
}
return name
}

// Param returns the parameter placeholder for the given number.
// SQLite uses ? for positional parameters.
func (p *Parser) Param(n int) string {
return "?"
}

// NamedParam returns the named parameter placeholder for the given name.
// SQLite uses :name syntax for named parameters.
func (p *Parser) NamedParam(name string) string {
return ":" + name
}

// Cast returns a type cast expression.
// SQLite uses CAST(expr AS type) syntax.
func (p *Parser) Cast(arg, typeName string) string {
return "CAST(" + arg + " AS " + typeName + ")"
}
52 changes: 27 additions & 25 deletions internal/sql/ast/a_expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,36 @@ func (n *A_Expr) Pos() int {
return n.Location
}

// isNamedParam returns true if this A_Expr represents a named parameter (@name)
// and extracts the parameter name if so.
func (n *A_Expr) isNamedParam() (string, bool) {
if n.Name == nil || len(n.Name.Items) != 1 {
return "", false
}
s, ok := n.Name.Items[0].(*String)
if !ok || s.Str != "@" {
return "", false
}
if set(n.Lexpr) || !set(n.Rexpr) {
return "", false
}
if nameStr, ok := n.Rexpr.(*String); ok {
return nameStr.Str, true
}
return "", false
}

func (n *A_Expr) Format(buf *TrackedBuffer) {
if n == nil {
return
}

// Check for named parameter first (works regardless of Kind)
if name, ok := n.isNamedParam(); ok {
buf.WriteString(buf.NamedParam(name))
return
}

switch n.Kind {
case A_Expr_Kind_IN:
buf.astFormat(n.Lexpr)
Expand Down Expand Up @@ -64,32 +90,8 @@ func (n *A_Expr) Format(buf *TrackedBuffer) {
buf.WriteString(", ")
buf.astFormat(n.Rexpr)
buf.WriteString(")")
case A_Expr_Kind_OP:
// Check if this is a named parameter (@name)
opName := ""
if n.Name != nil && len(n.Name.Items) == 1 {
if s, ok := n.Name.Items[0].(*String); ok {
opName = s.Str
}
}
if opName == "@" && !set(n.Lexpr) && set(n.Rexpr) {
// Named parameter: @name (no space after @)
buf.WriteString("@")
buf.astFormat(n.Rexpr)
} else {
// Standard binary operator
if set(n.Lexpr) {
buf.astFormat(n.Lexpr)
buf.WriteString(" ")
}
buf.astFormat(n.Name)
if set(n.Rexpr) {
buf.WriteString(" ")
buf.astFormat(n.Rexpr)
}
}
default:
// Fallback for other cases
// Standard operator (including A_Expr_Kind_OP)
if set(n.Lexpr) {
buf.astFormat(n.Lexpr)
buf.WriteString(" ")
Expand Down
Loading
Loading