@@ -155,7 +155,9 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
155155 // TODO: Generate a name for these operations
156156 cols = append (cols , & Column {Name : name , DataType : "bool" , NotNull : true })
157157 case lang .IsMathematicalOperator (op ):
158- cols = append (cols , & Column {Name : name , DataType : "int" , NotNull : true })
158+ // Improve type inference for mathematical expressions
159+ dataType , notNull := c .inferMathExpressionType (n , tables , op )
160+ cols = append (cols , & Column {Name : name , DataType : dataType , NotNull : notNull })
159161 default :
160162 cols = append (cols , & Column {Name : name , DataType : "any" , NotNull : false })
161163 }
@@ -770,3 +772,100 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List)
770772
771773 return nil
772774}
775+
776+ // inferMathExpressionType attempts to infer the data type of a mathematical expression
777+ // by analyzing its operands and the operation being performed
778+ func (c * Compiler ) inferMathExpressionType (expr * ast.A_Expr , tables []* Table , op string ) (string , bool ) {
779+ // Try to infer types from left and right operands
780+ leftType := c .inferOperandType (expr .Lexpr , tables )
781+ rightType := c .inferOperandType (expr .Rexpr , tables )
782+
783+ // Debug logging to understand what's happening
784+ // fmt.Printf("DEBUG: Math expression %s: left=%s, right=%s\n", op, leftType, rightType)
785+
786+ // Determine the result type based on operands and operation
787+ resultType := c .combineTypes (leftType , rightType , op )
788+
789+ // For now, assume nullable since we're dealing with database columns
790+ // In a more sophisticated implementation, we could track nullability through the expression
791+ notNull := false
792+
793+ return resultType , notNull
794+ }
795+
796+ // inferOperandType tries to determine the type of an operand in an expression
797+ func (c * Compiler ) inferOperandType (operand ast.Node , tables []* Table ) string {
798+ switch n := operand .(type ) {
799+ case * ast.ColumnRef :
800+ // Look up the column in the available tables
801+ parts := stringSlice (n .Fields )
802+ var name string
803+ if len (parts ) >= 1 {
804+ name = parts [len (parts )- 1 ] // Get the column name (last part)
805+ }
806+
807+ for _ , table := range tables {
808+ for _ , col := range table .Columns {
809+ if col .Name == name {
810+ return col .DataType
811+ }
812+ }
813+ }
814+ return "any"
815+ case * ast.A_Const :
816+ // Determine type based on constant value
817+ switch n .Val .(type ) {
818+ case * ast.Integer :
819+ return "int"
820+ case * ast.Float :
821+ return "float"
822+ case * ast.String :
823+ return "text"
824+ default :
825+ return "any"
826+ }
827+ case * ast.A_Expr :
828+ // Recursive case for nested expressions
829+ if n .Name != nil {
830+ nestedOp := astutils .Join (n .Name , "" )
831+ if lang .IsMathematicalOperator (nestedOp ) {
832+ resultType , _ := c .inferMathExpressionType (n , tables , nestedOp )
833+ return resultType
834+ }
835+ }
836+ return "any"
837+ default :
838+ return "any"
839+ }
840+ }
841+
842+ // combineTypes determines the result type when combining two operand types with an operation
843+ func (c * Compiler ) combineTypes (leftType , rightType , op string ) string {
844+ // Handle division specially - division operations typically result in float
845+ if op == "/" {
846+ // If either operand is float, result is float
847+ if leftType == "float" || rightType == "float" {
848+ return "float"
849+ }
850+ // Even integer division might want to be float in many cases
851+ // For safety, return float for division unless both operands are clearly non-numeric
852+ if leftType != "text" && rightType != "text" {
853+ return "float"
854+ }
855+ }
856+
857+ // For other mathematical operations
858+ switch {
859+ case leftType == "float" || rightType == "float" :
860+ return "float"
861+ case leftType == "int" && rightType == "int" :
862+ return "int"
863+ case leftType == "int" && rightType == "any" :
864+ return "int"
865+ case leftType == "any" && rightType == "int" :
866+ return "int"
867+ default :
868+ // Default fallback
869+ return "any"
870+ }
871+ }
0 commit comments