Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions internal/format/expressions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package format

import (
"fmt"
"strings"

"github.com/sqlc-dev/doubleclick/ast"
)

// Expression formats an expression.
func Expression(sb *strings.Builder, expr ast.Expression) {
if expr == nil {
return
}

switch e := expr.(type) {
case *ast.Literal:
formatLiteral(sb, e)
case *ast.Identifier:
formatIdentifier(sb, e)
case *ast.TableIdentifier:
formatTableIdentifier(sb, e)
case *ast.FunctionCall:
formatFunctionCall(sb, e)
case *ast.BinaryExpr:
formatBinaryExpr(sb, e)
case *ast.UnaryExpr:
formatUnaryExpr(sb, e)
case *ast.Asterisk:
formatAsterisk(sb, e)
case *ast.AliasedExpr:
formatAliasedExpr(sb, e)
default:
// Fallback for unhandled expressions
sb.WriteString(fmt.Sprintf("%v", expr))
}
}

// formatLiteral formats a literal value.
func formatLiteral(sb *strings.Builder, lit *ast.Literal) {
switch lit.Type {
case ast.LiteralString:
sb.WriteString("'")
// Escape single quotes in the string
s := lit.Value.(string)
s = strings.ReplaceAll(s, "'", "''")
sb.WriteString(s)
sb.WriteString("'")
case ast.LiteralInteger:
switch v := lit.Value.(type) {
case int64:
sb.WriteString(fmt.Sprintf("%d", v))
case uint64:
sb.WriteString(fmt.Sprintf("%d", v))
default:
sb.WriteString(fmt.Sprintf("%v", lit.Value))
}
case ast.LiteralFloat:
sb.WriteString(fmt.Sprintf("%v", lit.Value))
case ast.LiteralBoolean:
if lit.Value.(bool) {
sb.WriteString("true")
} else {
sb.WriteString("false")
}
case ast.LiteralNull:
sb.WriteString("NULL")
case ast.LiteralArray:
formatArrayLiteral(sb, lit.Value)
case ast.LiteralTuple:
formatTupleLiteral(sb, lit.Value)
default:
sb.WriteString(fmt.Sprintf("%v", lit.Value))
}
}

// formatArrayLiteral formats an array literal.
func formatArrayLiteral(sb *strings.Builder, val interface{}) {
sb.WriteString("[")
exprs, ok := val.([]ast.Expression)
if ok {
for i, e := range exprs {
if i > 0 {
sb.WriteString(", ")
}
Expression(sb, e)
}
}
sb.WriteString("]")
}

// formatTupleLiteral formats a tuple literal.
func formatTupleLiteral(sb *strings.Builder, val interface{}) {
sb.WriteString("(")
exprs, ok := val.([]ast.Expression)
if ok {
for i, e := range exprs {
if i > 0 {
sb.WriteString(", ")
}
Expression(sb, e)
}
}
sb.WriteString(")")
}

// formatIdentifier formats an identifier.
func formatIdentifier(sb *strings.Builder, id *ast.Identifier) {
sb.WriteString(id.Name())
}

// formatTableIdentifier formats a table identifier.
func formatTableIdentifier(sb *strings.Builder, t *ast.TableIdentifier) {
if t.Database != "" {
sb.WriteString(t.Database)
sb.WriteString(".")
}
sb.WriteString(t.Table)
}

// formatFunctionCall formats a function call.
func formatFunctionCall(sb *strings.Builder, fn *ast.FunctionCall) {
sb.WriteString(fn.Name)
sb.WriteString("(")
if fn.Distinct {
sb.WriteString("DISTINCT ")
}
for i, arg := range fn.Arguments {
if i > 0 {
sb.WriteString(", ")
}
Expression(sb, arg)
}
sb.WriteString(")")
}

// formatBinaryExpr formats a binary expression.
func formatBinaryExpr(sb *strings.Builder, expr *ast.BinaryExpr) {
Expression(sb, expr.Left)
sb.WriteString(" ")
sb.WriteString(expr.Op)
sb.WriteString(" ")
Expression(sb, expr.Right)
}

// formatUnaryExpr formats a unary expression.
func formatUnaryExpr(sb *strings.Builder, expr *ast.UnaryExpr) {
sb.WriteString(expr.Op)
Expression(sb, expr.Operand)
}

// formatAsterisk formats an asterisk.
func formatAsterisk(sb *strings.Builder, a *ast.Asterisk) {
if a.Table != "" {
sb.WriteString(a.Table)
sb.WriteString(".")
}
sb.WriteString("*")
}

// formatAliasedExpr formats an aliased expression.
func formatAliasedExpr(sb *strings.Builder, a *ast.AliasedExpr) {
Expression(sb, a.Expr)
sb.WriteString(" AS ")
sb.WriteString(a.Alias)
}
37 changes: 37 additions & 0 deletions internal/format/format.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Package format provides SQL formatting for ClickHouse AST.
package format

import (
"strings"

"github.com/sqlc-dev/doubleclick/ast"
)

// Format returns the SQL string representation of the statements.
func Format(stmts []ast.Statement) string {
var sb strings.Builder
for i, stmt := range stmts {
if i > 0 {
sb.WriteString("\n")
}
Statement(&sb, stmt)
sb.WriteString(";")
}
return sb.String()
}

// Statement formats a single statement.
func Statement(sb *strings.Builder, stmt ast.Statement) {
if stmt == nil {
return
}

switch s := stmt.(type) {
case *ast.SelectWithUnionQuery:
formatSelectWithUnionQuery(sb, s)
case *ast.SelectQuery:
formatSelectQuery(sb, s)
default:
// For now, only handle SELECT statements
}
}
120 changes: 120 additions & 0 deletions internal/format/statements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package format

import (
"strings"

"github.com/sqlc-dev/doubleclick/ast"
)

// formatSelectWithUnionQuery formats a SELECT with UNION query.
func formatSelectWithUnionQuery(sb *strings.Builder, q *ast.SelectWithUnionQuery) {
for i, sel := range q.Selects {
if i > 0 {
sb.WriteString(" UNION ")
if len(q.UnionModes) > i-1 && q.UnionModes[i-1] == "ALL" {
sb.WriteString("ALL ")
} else if len(q.UnionModes) > i-1 && q.UnionModes[i-1] == "DISTINCT" {
sb.WriteString("DISTINCT ")
}
}
Statement(sb, sel)
}
}

// formatSelectQuery formats a SELECT query.
func formatSelectQuery(sb *strings.Builder, q *ast.SelectQuery) {
sb.WriteString("SELECT ")

if q.Distinct {
sb.WriteString("DISTINCT ")
}

// Format columns
for i, col := range q.Columns {
if i > 0 {
sb.WriteString(", ")
}
Expression(sb, col)
}

// Format FROM clause
if q.From != nil {
sb.WriteString(" FROM ")
formatTablesInSelectQuery(sb, q.From)
}

// Format WHERE clause
if q.Where != nil {
sb.WriteString(" WHERE ")
Expression(sb, q.Where)
}

// Format GROUP BY clause
if len(q.GroupBy) > 0 {
sb.WriteString(" GROUP BY ")
for i, expr := range q.GroupBy {
if i > 0 {
sb.WriteString(", ")
}
Expression(sb, expr)
}
}

// Format HAVING clause
if q.Having != nil {
sb.WriteString(" HAVING ")
Expression(sb, q.Having)
}

// Format ORDER BY clause
if len(q.OrderBy) > 0 {
sb.WriteString(" ORDER BY ")
for i, elem := range q.OrderBy {
if i > 0 {
sb.WriteString(", ")
}
formatOrderByElement(sb, elem)
}
}

// Format LIMIT clause
if q.Limit != nil {
sb.WriteString(" LIMIT ")
Expression(sb, q.Limit)
}
}

// formatTablesInSelectQuery formats the FROM clause tables.
func formatTablesInSelectQuery(sb *strings.Builder, t *ast.TablesInSelectQuery) {
for i, elem := range t.Tables {
if i > 0 {
// TODO: Handle JOINs properly
sb.WriteString(", ")
}
formatTablesInSelectQueryElement(sb, elem)
}
}

// formatTablesInSelectQueryElement formats a single table element.
func formatTablesInSelectQueryElement(sb *strings.Builder, t *ast.TablesInSelectQueryElement) {
if t.Table != nil {
formatTableExpression(sb, t.Table)
}
}

// formatTableExpression formats a table expression.
func formatTableExpression(sb *strings.Builder, t *ast.TableExpression) {
Expression(sb, t.Table)
if t.Alias != "" {
sb.WriteString(" AS ")
sb.WriteString(t.Alias)
}
}

// formatOrderByElement formats an ORDER BY element.
func formatOrderByElement(sb *strings.Builder, o *ast.OrderByElement) {
Expression(sb, o.Expression)
if o.Descending {
sb.WriteString(" DESC")
}
}
11 changes: 11 additions & 0 deletions parser/format.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package parser

import (
"github.com/sqlc-dev/doubleclick/ast"
"github.com/sqlc-dev/doubleclick/internal/format"
)

// Format returns the SQL string representation of the statements.
func Format(stmts []ast.Statement) string {
return format.Format(stmts)
}
8 changes: 8 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,14 @@ func TestParser(t *testing.T) {
}
}

// Check Format output for 00007_array test
if entry.Name() == "00007_array" {
formatted := parser.Format(stmts)
if formatted != query {
t.Errorf("Format output mismatch\nQuery: %s\nFormatted: %s", query, formatted)
}
}

// If we get here with a todo test and -check-skipped is set, the test passes!
// Automatically remove the todo flag from metadata.json
if metadata.Todo && *checkSkipped {
Expand Down