@@ -20,7 +20,8 @@ type QueryValue struct {
2020
2121 // Column is kept so late in the generation process around to differentiate
2222 // between mysql slices and pg arrays
23- Column *plugin.Column
23+ Column *plugin.Column
24+ QueryText string
2425}
2526
2627func (v QueryValue) EmitStruct() bool {
@@ -85,6 +86,9 @@ func (v QueryValue) SlicePair() string {
8586
8687func (v QueryValue) Type() string {
8788 if v.Typ != "" {
89+ if v.isUsedWithArrayComparison() {
90+ return strings.Trim(v.Typ, "[]") // Return single type if used in array comparison.
91+ }
8892 return v.Typ
8993 }
9094 if v.Struct != nil {
@@ -113,6 +117,9 @@ func (v QueryValue) UniqueFields() []Field {
113117 fields := make([]Field, 0, len(v.Struct.Fields))
114118
115119 for _, field := range v.Struct.Fields {
120+ if v.isUsedWithArrayComparison() {
121+ field.Type = strings.Trim(field.Type, "[]")
122+ }
116123 if _, found := seen[field.Name]; found {
117124 continue
118125 }
@@ -129,14 +136,14 @@ func (v QueryValue) Params() string {
129136 }
130137 var out []string
131138 if v.Struct == nil {
132- if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() {
139+ if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() {
133140 out = append(out, "pq.Array("+escape(v.Name)+")")
134141 } else {
135142 out = append(out, escape(v.Name))
136143 }
137144 } else {
138145 for _, f := range v.Struct.Fields {
139- if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
146+ if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() {
140147 out = append(out, "pq.Array("+escape(v.VariableForField(f))+")")
141148 } else {
142149 out = append(out, escape(v.VariableForField(f)))
@@ -254,6 +261,22 @@ func (v QueryValue) VariableForField(f Field) string {
254261 return v.Name + "." + f.Name
255262}
256263
264+ // isUsedWithArrayComparison returns true if the parameter is used with the ANY/SOME/ALL keyword in query.
265+ func (v QueryValue) isUsedWithArrayComparison() bool {
266+ if v.Struct != nil {
267+ for _, f := range v.Struct.Fields {
268+ if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", f.DBName)) {
269+ return true
270+ }
271+ }
272+ } else {
273+ if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", v.DBName)) {
274+ return true
275+ }
276+ }
277+ return false
278+ }
279+
257280// A struct used to generate methods and fields on the Queries struct
258281type Query struct {
259282 Cmd string
0 commit comments