Skip to content

Commit 55dba9f

Browse files
feat: improve MySQL mathematical expression type inference
1 parent b807fe9 commit 55dba9f

File tree

2 files changed

+101
-1
lines changed

2 files changed

+101
-1
lines changed

internal/compiler/output_columns.go

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
155155
// TODO: Generate a name for these operations
156156
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
157157
case lang.IsMathematicalOperator(op):
158-
cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
158+
// Improve type inference for mathematical expressions
159+
dataType, notNull := c.inferMathExpressionType(n, tables, op)
160+
cols = append(cols, &Column{Name: name, DataType: dataType, NotNull: notNull})
159161
default:
160162
cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
161163
}
@@ -770,3 +772,100 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List)
770772

771773
return nil
772774
}
775+
776+
// inferMathExpressionType attempts to infer the data type of a mathematical expression
777+
// by analyzing its operands and the operation being performed
778+
func (c *Compiler) inferMathExpressionType(expr *ast.A_Expr, tables []*Table, op string) (string, bool) {
779+
// Try to infer types from left and right operands
780+
leftType := c.inferOperandType(expr.Lexpr, tables)
781+
rightType := c.inferOperandType(expr.Rexpr, tables)
782+
783+
// Debug logging to understand what's happening
784+
// fmt.Printf("DEBUG: Math expression %s: left=%s, right=%s\n", op, leftType, rightType)
785+
786+
// Determine the result type based on operands and operation
787+
resultType := c.combineTypes(leftType, rightType, op)
788+
789+
// For now, assume nullable since we're dealing with database columns
790+
// In a more sophisticated implementation, we could track nullability through the expression
791+
notNull := false
792+
793+
return resultType, notNull
794+
}
795+
796+
// inferOperandType tries to determine the type of an operand in an expression
797+
func (c *Compiler) inferOperandType(operand ast.Node, tables []*Table) string {
798+
switch n := operand.(type) {
799+
case *ast.ColumnRef:
800+
// Look up the column in the available tables
801+
parts := stringSlice(n.Fields)
802+
var name string
803+
if len(parts) >= 1 {
804+
name = parts[len(parts)-1] // Get the column name (last part)
805+
}
806+
807+
for _, table := range tables {
808+
for _, col := range table.Columns {
809+
if col.Name == name {
810+
return col.DataType
811+
}
812+
}
813+
}
814+
return "any"
815+
case *ast.A_Const:
816+
// Determine type based on constant value
817+
switch n.Val.(type) {
818+
case *ast.Integer:
819+
return "int"
820+
case *ast.Float:
821+
return "float"
822+
case *ast.String:
823+
return "text"
824+
default:
825+
return "any"
826+
}
827+
case *ast.A_Expr:
828+
// Recursive case for nested expressions
829+
if n.Name != nil {
830+
nestedOp := astutils.Join(n.Name, "")
831+
if lang.IsMathematicalOperator(nestedOp) {
832+
resultType, _ := c.inferMathExpressionType(n, tables, nestedOp)
833+
return resultType
834+
}
835+
}
836+
return "any"
837+
default:
838+
return "any"
839+
}
840+
}
841+
842+
// combineTypes determines the result type when combining two operand types with an operation
843+
func (c *Compiler) combineTypes(leftType, rightType, op string) string {
844+
// Handle division specially - division operations typically result in float
845+
if op == "/" {
846+
// If either operand is float, result is float
847+
if leftType == "float" || rightType == "float" {
848+
return "float"
849+
}
850+
// Even integer division might want to be float in many cases
851+
// For safety, return float for division unless both operands are clearly non-numeric
852+
if leftType != "text" && rightType != "text" {
853+
return "float"
854+
}
855+
}
856+
857+
// For other mathematical operations
858+
switch {
859+
case leftType == "float" || rightType == "float":
860+
return "float"
861+
case leftType == "int" && rightType == "int":
862+
return "int"
863+
case leftType == "int" && rightType == "any":
864+
return "int"
865+
case leftType == "any" && rightType == "int":
866+
return "int"
867+
default:
868+
// Default fallback
869+
return "any"
870+
}
871+
}

internal/sql/lang/operator.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ func IsMathematicalOperator(s string) bool {
2323
case "-":
2424
case "*":
2525
case "/":
26+
case "div": // MySQL division operator
2627
case "%":
2728
case "^":
2829
case "|/":

0 commit comments

Comments
 (0)