diff --git a/internal/format/expressions.go b/internal/format/expressions.go new file mode 100644 index 000000000..81e674dcd --- /dev/null +++ b/internal/format/expressions.go @@ -0,0 +1,166 @@ +package format + +import ( + "fmt" + "strings" + + "github.com/sqlc-dev/doubleclick/ast" +) + +// Expression formats an expression. +func Expression(sb *strings.Builder, expr ast.Expression) { + if expr == nil { + return + } + + switch e := expr.(type) { + case *ast.Literal: + formatLiteral(sb, e) + case *ast.Identifier: + formatIdentifier(sb, e) + case *ast.TableIdentifier: + formatTableIdentifier(sb, e) + case *ast.FunctionCall: + formatFunctionCall(sb, e) + case *ast.BinaryExpr: + formatBinaryExpr(sb, e) + case *ast.UnaryExpr: + formatUnaryExpr(sb, e) + case *ast.Asterisk: + formatAsterisk(sb, e) + case *ast.AliasedExpr: + formatAliasedExpr(sb, e) + default: + // Fallback for unhandled expressions + sb.WriteString(fmt.Sprintf("%v", expr)) + } +} + +// formatLiteral formats a literal value. +func formatLiteral(sb *strings.Builder, lit *ast.Literal) { + switch lit.Type { + case ast.LiteralString: + sb.WriteString("'") + // Escape single quotes in the string + s := lit.Value.(string) + s = strings.ReplaceAll(s, "'", "''") + sb.WriteString(s) + sb.WriteString("'") + case ast.LiteralInteger: + switch v := lit.Value.(type) { + case int64: + sb.WriteString(fmt.Sprintf("%d", v)) + case uint64: + sb.WriteString(fmt.Sprintf("%d", v)) + default: + sb.WriteString(fmt.Sprintf("%v", lit.Value)) + } + case ast.LiteralFloat: + sb.WriteString(fmt.Sprintf("%v", lit.Value)) + case ast.LiteralBoolean: + if lit.Value.(bool) { + sb.WriteString("true") + } else { + sb.WriteString("false") + } + case ast.LiteralNull: + sb.WriteString("NULL") + case ast.LiteralArray: + formatArrayLiteral(sb, lit.Value) + case ast.LiteralTuple: + formatTupleLiteral(sb, lit.Value) + default: + sb.WriteString(fmt.Sprintf("%v", lit.Value)) + } +} + +// formatArrayLiteral formats an array literal. +func formatArrayLiteral(sb *strings.Builder, val interface{}) { + sb.WriteString("[") + exprs, ok := val.([]ast.Expression) + if ok { + for i, e := range exprs { + if i > 0 { + sb.WriteString(", ") + } + Expression(sb, e) + } + } + sb.WriteString("]") +} + +// formatTupleLiteral formats a tuple literal. +func formatTupleLiteral(sb *strings.Builder, val interface{}) { + sb.WriteString("(") + exprs, ok := val.([]ast.Expression) + if ok { + for i, e := range exprs { + if i > 0 { + sb.WriteString(", ") + } + Expression(sb, e) + } + } + sb.WriteString(")") +} + +// formatIdentifier formats an identifier. +func formatIdentifier(sb *strings.Builder, id *ast.Identifier) { + sb.WriteString(id.Name()) +} + +// formatTableIdentifier formats a table identifier. +func formatTableIdentifier(sb *strings.Builder, t *ast.TableIdentifier) { + if t.Database != "" { + sb.WriteString(t.Database) + sb.WriteString(".") + } + sb.WriteString(t.Table) +} + +// formatFunctionCall formats a function call. +func formatFunctionCall(sb *strings.Builder, fn *ast.FunctionCall) { + sb.WriteString(fn.Name) + sb.WriteString("(") + if fn.Distinct { + sb.WriteString("DISTINCT ") + } + for i, arg := range fn.Arguments { + if i > 0 { + sb.WriteString(", ") + } + Expression(sb, arg) + } + sb.WriteString(")") +} + +// formatBinaryExpr formats a binary expression. +func formatBinaryExpr(sb *strings.Builder, expr *ast.BinaryExpr) { + Expression(sb, expr.Left) + sb.WriteString(" ") + sb.WriteString(expr.Op) + sb.WriteString(" ") + Expression(sb, expr.Right) +} + +// formatUnaryExpr formats a unary expression. +func formatUnaryExpr(sb *strings.Builder, expr *ast.UnaryExpr) { + sb.WriteString(expr.Op) + Expression(sb, expr.Operand) +} + +// formatAsterisk formats an asterisk. +func formatAsterisk(sb *strings.Builder, a *ast.Asterisk) { + if a.Table != "" { + sb.WriteString(a.Table) + sb.WriteString(".") + } + sb.WriteString("*") +} + +// formatAliasedExpr formats an aliased expression. +func formatAliasedExpr(sb *strings.Builder, a *ast.AliasedExpr) { + Expression(sb, a.Expr) + sb.WriteString(" AS ") + sb.WriteString(a.Alias) +} diff --git a/internal/format/format.go b/internal/format/format.go new file mode 100644 index 000000000..ad3859eaf --- /dev/null +++ b/internal/format/format.go @@ -0,0 +1,37 @@ +// Package format provides SQL formatting for ClickHouse AST. +package format + +import ( + "strings" + + "github.com/sqlc-dev/doubleclick/ast" +) + +// Format returns the SQL string representation of the statements. +func Format(stmts []ast.Statement) string { + var sb strings.Builder + for i, stmt := range stmts { + if i > 0 { + sb.WriteString("\n") + } + Statement(&sb, stmt) + sb.WriteString(";") + } + return sb.String() +} + +// Statement formats a single statement. +func Statement(sb *strings.Builder, stmt ast.Statement) { + if stmt == nil { + return + } + + switch s := stmt.(type) { + case *ast.SelectWithUnionQuery: + formatSelectWithUnionQuery(sb, s) + case *ast.SelectQuery: + formatSelectQuery(sb, s) + default: + // For now, only handle SELECT statements + } +} diff --git a/internal/format/statements.go b/internal/format/statements.go new file mode 100644 index 000000000..0078d8c9e --- /dev/null +++ b/internal/format/statements.go @@ -0,0 +1,120 @@ +package format + +import ( + "strings" + + "github.com/sqlc-dev/doubleclick/ast" +) + +// formatSelectWithUnionQuery formats a SELECT with UNION query. +func formatSelectWithUnionQuery(sb *strings.Builder, q *ast.SelectWithUnionQuery) { + for i, sel := range q.Selects { + if i > 0 { + sb.WriteString(" UNION ") + if len(q.UnionModes) > i-1 && q.UnionModes[i-1] == "ALL" { + sb.WriteString("ALL ") + } else if len(q.UnionModes) > i-1 && q.UnionModes[i-1] == "DISTINCT" { + sb.WriteString("DISTINCT ") + } + } + Statement(sb, sel) + } +} + +// formatSelectQuery formats a SELECT query. +func formatSelectQuery(sb *strings.Builder, q *ast.SelectQuery) { + sb.WriteString("SELECT ") + + if q.Distinct { + sb.WriteString("DISTINCT ") + } + + // Format columns + for i, col := range q.Columns { + if i > 0 { + sb.WriteString(", ") + } + Expression(sb, col) + } + + // Format FROM clause + if q.From != nil { + sb.WriteString(" FROM ") + formatTablesInSelectQuery(sb, q.From) + } + + // Format WHERE clause + if q.Where != nil { + sb.WriteString(" WHERE ") + Expression(sb, q.Where) + } + + // Format GROUP BY clause + if len(q.GroupBy) > 0 { + sb.WriteString(" GROUP BY ") + for i, expr := range q.GroupBy { + if i > 0 { + sb.WriteString(", ") + } + Expression(sb, expr) + } + } + + // Format HAVING clause + if q.Having != nil { + sb.WriteString(" HAVING ") + Expression(sb, q.Having) + } + + // Format ORDER BY clause + if len(q.OrderBy) > 0 { + sb.WriteString(" ORDER BY ") + for i, elem := range q.OrderBy { + if i > 0 { + sb.WriteString(", ") + } + formatOrderByElement(sb, elem) + } + } + + // Format LIMIT clause + if q.Limit != nil { + sb.WriteString(" LIMIT ") + Expression(sb, q.Limit) + } +} + +// formatTablesInSelectQuery formats the FROM clause tables. +func formatTablesInSelectQuery(sb *strings.Builder, t *ast.TablesInSelectQuery) { + for i, elem := range t.Tables { + if i > 0 { + // TODO: Handle JOINs properly + sb.WriteString(", ") + } + formatTablesInSelectQueryElement(sb, elem) + } +} + +// formatTablesInSelectQueryElement formats a single table element. +func formatTablesInSelectQueryElement(sb *strings.Builder, t *ast.TablesInSelectQueryElement) { + if t.Table != nil { + formatTableExpression(sb, t.Table) + } +} + +// formatTableExpression formats a table expression. +func formatTableExpression(sb *strings.Builder, t *ast.TableExpression) { + Expression(sb, t.Table) + if t.Alias != "" { + sb.WriteString(" AS ") + sb.WriteString(t.Alias) + } +} + +// formatOrderByElement formats an ORDER BY element. +func formatOrderByElement(sb *strings.Builder, o *ast.OrderByElement) { + Expression(sb, o.Expression) + if o.Descending { + sb.WriteString(" DESC") + } +} diff --git a/parser/format.go b/parser/format.go new file mode 100644 index 000000000..de2502263 --- /dev/null +++ b/parser/format.go @@ -0,0 +1,11 @@ +package parser + +import ( + "github.com/sqlc-dev/doubleclick/ast" + "github.com/sqlc-dev/doubleclick/internal/format" +) + +// Format returns the SQL string representation of the statements. +func Format(stmts []ast.Statement) string { + return format.Format(stmts) +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 295b26ce8..aa0bfeba6 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -202,6 +202,14 @@ func TestParser(t *testing.T) { } } + // Check Format output for 00007_array test + if entry.Name() == "00007_array" { + formatted := parser.Format(stmts) + if formatted != query { + t.Errorf("Format output mismatch\nQuery: %s\nFormatted: %s", query, formatted) + } + } + // If we get here with a todo test and -check-skipped is set, the test passes! // Automatically remove the todo flag from metadata.json if metadata.Todo && *checkSkipped {