|
1 | 1 | package analyzer
|
2 | 2 |
|
3 | 3 | import (
|
4 |
| - "reflect" |
| 4 | + "strings" |
5 | 5 |
|
6 | 6 | "github.com/src-d/go-mysql-server/sql"
|
7 | 7 | "github.com/src-d/go-mysql-server/sql/expression"
|
8 | 8 | "github.com/src-d/go-mysql-server/sql/plan"
|
9 | 9 | )
|
10 | 10 |
|
11 |
| -type transformedJoin struct { |
12 |
| - node sql.Node |
13 |
| - condCols map[string]*transformedSource |
14 |
| -} |
15 |
| - |
16 |
| -type transformedSource struct { |
17 |
| - correct string |
18 |
| - wrong []string |
19 |
| -} |
20 |
| - |
21 | 11 | func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
|
22 | 12 | span, _ := ctx.Span("resolve_natural_joins")
|
23 | 13 | defer span.Finish()
|
24 | 14 |
|
25 |
| - if n.Resolved() { |
26 |
| - return n, nil |
27 |
| - } |
28 |
| - |
29 |
| - var transformed []*transformedJoin |
30 |
| - var aliasTables = map[string][]string{} |
31 |
| - var colsToUnresolve = map[string]*transformedSource{} |
32 |
| - a.Log("resolving natural joins, node of type %T", n) |
33 |
| - node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) { |
34 |
| - a.Log("transforming node of type: %T", n) |
| 15 | + var replacements = make(map[tableCol]tableCol) |
| 16 | + var tableAliases = make(map[string]string) |
35 | 17 |
|
36 |
| - if alias, ok := n.(*plan.TableAlias); ok { |
37 |
| - table := alias.Child.(*plan.ResolvedTable).Name() |
38 |
| - aliasTables[alias.Name()] = append(aliasTables[alias.Name()], table) |
| 18 | + return n.TransformUp(func(n sql.Node) (sql.Node, error) { |
| 19 | + switch n := n.(type) { |
| 20 | + case *plan.TableAlias: |
| 21 | + alias := n.Name() |
| 22 | + table := n.Child.(*plan.ResolvedTable).Name() |
| 23 | + tableAliases[strings.ToLower(alias)] = table |
39 | 24 | return n, nil
|
40 |
| - } |
41 |
| - |
42 |
| - if n.Resolved() { |
| 25 | + case *plan.NaturalJoin: |
| 26 | + return resolveNaturalJoin(n, replacements) |
| 27 | + case sql.Expressioner: |
| 28 | + return replaceExpressions(n, replacements, tableAliases) |
| 29 | + default: |
43 | 30 | return n, nil
|
44 | 31 | }
|
| 32 | + }) |
| 33 | +} |
45 | 34 |
|
46 |
| - join, ok := n.(*plan.NaturalJoin) |
47 |
| - if !ok { |
48 |
| - return n, nil |
49 |
| - } |
50 |
| - |
51 |
| - // we need both leaves resolved before resolving this one |
52 |
| - if !join.Left.Resolved() || !join.Right.Resolved() { |
53 |
| - return n, nil |
54 |
| - } |
55 |
| - |
56 |
| - leftSchema, rightSchema := join.Left.Schema(), join.Right.Schema() |
57 |
| - |
58 |
| - var conditions, common, left, right []sql.Expression |
59 |
| - var seen = make(map[string]struct{}) |
60 |
| - |
61 |
| - for i, lcol := range leftSchema { |
62 |
| - var found bool |
63 |
| - leftCol := expression.NewGetFieldWithTable( |
64 |
| - i, |
65 |
| - lcol.Type, |
66 |
| - lcol.Source, |
67 |
| - lcol.Name, |
68 |
| - lcol.Nullable, |
69 |
| - ) |
70 |
| - |
71 |
| - for j, rcol := range rightSchema { |
72 |
| - if lcol.Name == rcol.Name { |
73 |
| - common = append(common, leftCol) |
74 |
| - |
75 |
| - conditions = append( |
76 |
| - conditions, |
77 |
| - expression.NewEquals( |
78 |
| - leftCol, |
79 |
| - expression.NewGetFieldWithTable( |
80 |
| - len(leftSchema)+j, |
81 |
| - rcol.Type, |
82 |
| - rcol.Source, |
83 |
| - rcol.Name, |
84 |
| - rcol.Nullable, |
85 |
| - ), |
86 |
| - ), |
87 |
| - ) |
88 |
| - |
89 |
| - found = true |
90 |
| - seen[lcol.Name] = struct{}{} |
91 |
| - if source, ok := colsToUnresolve[lcol.Name]; ok { |
92 |
| - source.correct = lcol.Source |
93 |
| - source.wrong = append(source.wrong, rcol.Source) |
94 |
| - } else { |
95 |
| - colsToUnresolve[lcol.Name] = &transformedSource{ |
96 |
| - correct: lcol.Source, |
97 |
| - wrong: []string{rcol.Source}, |
98 |
| - } |
99 |
| - } |
100 |
| - |
101 |
| - break |
102 |
| - } |
103 |
| - } |
| 35 | +func resolveNaturalJoin( |
| 36 | + n *plan.NaturalJoin, |
| 37 | + replacements map[tableCol]tableCol, |
| 38 | +) (sql.Node, error) { |
| 39 | + // Both sides of the natural join need to be resolved in order to resolve |
| 40 | + // the natural join itself. |
| 41 | + if !n.Left.Resolved() || !n.Right.Resolved() { |
| 42 | + return n, nil |
| 43 | + } |
104 | 44 |
|
105 |
| - if !found { |
106 |
| - left = append(left, leftCol) |
| 45 | + leftSchema := n.Left.Schema() |
| 46 | + rightSchema := n.Right.Schema() |
| 47 | + |
| 48 | + var conditions, common, left, right []sql.Expression |
| 49 | + for i, lcol := range leftSchema { |
| 50 | + leftCol := expression.NewGetFieldWithTable( |
| 51 | + i, |
| 52 | + lcol.Type, |
| 53 | + lcol.Source, |
| 54 | + lcol.Name, |
| 55 | + lcol.Nullable, |
| 56 | + ) |
| 57 | + if idx, rcol := findCol(rightSchema, lcol.Name); rcol != nil { |
| 58 | + common = append(common, leftCol) |
| 59 | + replacements[tableCol{strings.ToLower(rcol.Source), strings.ToLower(rcol.Name)}] = tableCol{ |
| 60 | + strings.ToLower(lcol.Source), strings.ToLower(lcol.Name), |
107 | 61 | }
|
108 |
| - } |
109 | 62 |
|
110 |
| - if len(conditions) == 0 { |
111 |
| - return plan.NewCrossJoin(join.Left, join.Right), nil |
112 |
| - } |
113 |
| - |
114 |
| - for i, col := range rightSchema { |
115 |
| - if _, ok := seen[col.Name]; !ok { |
116 |
| - right = append( |
117 |
| - right, |
| 63 | + conditions = append( |
| 64 | + conditions, |
| 65 | + expression.NewEquals( |
| 66 | + leftCol, |
118 | 67 | expression.NewGetFieldWithTable(
|
119 |
| - len(leftSchema)+i, |
120 |
| - col.Type, |
121 |
| - col.Source, |
122 |
| - col.Name, |
123 |
| - col.Nullable, |
| 68 | + len(leftSchema)+idx, |
| 69 | + rcol.Type, |
| 70 | + rcol.Source, |
| 71 | + rcol.Name, |
| 72 | + rcol.Nullable, |
124 | 73 | ),
|
125 |
| - ) |
126 |
| - } |
127 |
| - } |
128 |
| - |
129 |
| - projections := append(append(common, left...), right...) |
130 |
| - |
131 |
| - tj := &transformedJoin{ |
132 |
| - node: plan.NewProject( |
133 |
| - projections, |
134 |
| - plan.NewInnerJoin( |
135 |
| - join.Left, |
136 |
| - join.Right, |
137 |
| - expression.JoinAnd(conditions...), |
138 | 74 | ),
|
139 |
| - ), |
140 |
| - condCols: colsToUnresolve, |
| 75 | + ) |
| 76 | + } else { |
| 77 | + left = append(left, leftCol) |
141 | 78 | }
|
142 |
| - |
143 |
| - transformed = append(transformed, tj) |
144 |
| - |
145 |
| - return tj.node, nil |
146 |
| - }) |
147 |
| - |
148 |
| - if err != nil || len(transformed) == 0 { |
149 |
| - return node, err |
150 | 79 | }
|
151 | 80 |
|
152 |
| - var transformedSeen bool |
153 |
| - return node.TransformUp(func(node sql.Node) (sql.Node, error) { |
154 |
| - if ok, _ := isTransformedNode(node, transformed); ok { |
155 |
| - transformedSeen = true |
156 |
| - return node, nil |
157 |
| - } |
158 |
| - |
159 |
| - if !transformedSeen { |
160 |
| - return node, nil |
161 |
| - } |
162 |
| - |
163 |
| - expressioner, ok := node.(sql.Expressioner) |
164 |
| - if !ok { |
165 |
| - return node, nil |
166 |
| - } |
167 |
| - |
168 |
| - return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { |
169 |
| - var col, table string |
170 |
| - switch e := e.(type) { |
171 |
| - case *expression.GetField: |
172 |
| - col, table = e.Name(), e.Table() |
173 |
| - case *expression.UnresolvedColumn: |
174 |
| - col, table = e.Name(), e.Table() |
175 |
| - default: |
176 |
| - return e, nil |
177 |
| - } |
178 |
| - |
179 |
| - sources, ok := colsToUnresolve[col] |
180 |
| - if !ok { |
181 |
| - return e, nil |
182 |
| - } |
183 |
| - |
184 |
| - if !mustUnresolve(aliasTables, table, sources.wrong) { |
185 |
| - return e, nil |
186 |
| - } |
187 |
| - |
188 |
| - return expression.NewUnresolvedQualifiedColumn( |
189 |
| - sources.correct, |
190 |
| - col, |
191 |
| - ), nil |
192 |
| - }) |
193 |
| - }) |
194 |
| -} |
195 |
| - |
196 |
| -func isTransformedNode(node sql.Node, transformed []*transformedJoin) (is bool, colsToUnresolve map[string]*transformedSource) { |
197 |
| - var project *plan.Project |
198 |
| - var join *plan.InnerJoin |
199 |
| - switch n := node.(type) { |
200 |
| - case *plan.Project: |
201 |
| - var ok bool |
202 |
| - join, ok = n.Child.(*plan.InnerJoin) |
203 |
| - if !ok { |
204 |
| - return |
205 |
| - } |
206 |
| - |
207 |
| - project = n |
208 |
| - case *plan.InnerJoin: |
209 |
| - join = n |
210 |
| - |
211 |
| - default: |
212 |
| - return |
| 81 | + if len(conditions) == 0 { |
| 82 | + return plan.NewCrossJoin(n.Left, n.Right), nil |
213 | 83 | }
|
214 | 84 |
|
215 |
| - for _, t := range transformed { |
216 |
| - tproject, ok := t.node.(*plan.Project) |
217 |
| - if !ok { |
218 |
| - return |
219 |
| - } |
220 |
| - |
221 |
| - tjoin, ok := tproject.Child.(*plan.InnerJoin) |
222 |
| - if !ok { |
223 |
| - return |
224 |
| - } |
225 |
| - |
226 |
| - if project != nil && !reflect.DeepEqual(project.Projections, tproject.Projections) { |
227 |
| - continue |
228 |
| - } |
229 |
| - |
230 |
| - if reflect.DeepEqual(join.Cond, tjoin.Cond) { |
231 |
| - is = true |
232 |
| - colsToUnresolve = t.condCols |
| 85 | + for i, col := range rightSchema { |
| 86 | + source := strings.ToLower(col.Source) |
| 87 | + name := strings.ToLower(col.Name) |
| 88 | + if _, ok := replacements[tableCol{source, name}]; !ok { |
| 89 | + right = append( |
| 90 | + right, |
| 91 | + expression.NewGetFieldWithTable( |
| 92 | + len(leftSchema)+i, |
| 93 | + col.Type, |
| 94 | + col.Source, |
| 95 | + col.Name, |
| 96 | + col.Nullable, |
| 97 | + ), |
| 98 | + ) |
233 | 99 | }
|
234 | 100 | }
|
235 | 101 |
|
236 |
| - return |
| 102 | + return plan.NewProject( |
| 103 | + append(append(common, left...), right...), |
| 104 | + plan.NewInnerJoin(n.Left, n.Right, expression.JoinAnd(conditions...)), |
| 105 | + ), nil |
237 | 106 | }
|
238 | 107 |
|
239 |
| -func mustUnresolve(aliasTable map[string][]string, table string, wrongSources []string) bool { |
240 |
| - return isIn(table, wrongSources) || isAliasFor(aliasTable, table, wrongSources) |
241 |
| -} |
242 |
| - |
243 |
| -func isIn(s string, l []string) bool { |
244 |
| - for _, e := range l { |
245 |
| - if s == e { |
246 |
| - return true |
| 108 | +func findCol(s sql.Schema, name string) (int, *sql.Column) { |
| 109 | + for i, c := range s { |
| 110 | + if strings.ToLower(c.Name) == strings.ToLower(name) { |
| 111 | + return i, c |
247 | 112 | }
|
248 | 113 | }
|
249 |
| - |
250 |
| - return false |
| 114 | + return -1, nil |
251 | 115 | }
|
252 | 116 |
|
253 |
| -func isAliasFor(aliasTable map[string][]string, table string, wrongSources []string) bool { |
254 |
| - tables, ok := aliasTable[table] |
255 |
| - if !ok { |
256 |
| - return false |
257 |
| - } |
| 117 | +func replaceExpressions( |
| 118 | + n sql.Expressioner, |
| 119 | + replacements map[tableCol]tableCol, |
| 120 | + tableAliases map[string]string, |
| 121 | +) (sql.Node, error) { |
| 122 | + return n.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { |
| 123 | + switch e := e.(type) { |
| 124 | + case *expression.GetField, *expression.UnresolvedColumn: |
| 125 | + var tableName = e.(sql.Tableable).Table() |
| 126 | + if t, ok := tableAliases[strings.ToLower(tableName)]; ok { |
| 127 | + tableName = t |
| 128 | + } |
258 | 129 |
|
259 |
| - for _, t := range tables { |
260 |
| - if isIn(t, wrongSources) { |
261 |
| - return true |
| 130 | + name := e.(sql.Nameable).Name() |
| 131 | + if col, ok := replacements[tableCol{strings.ToLower(tableName), strings.ToLower(name)}]; ok { |
| 132 | + return expression.NewUnresolvedQualifiedColumn(col.table, col.col), nil |
| 133 | + } |
262 | 134 | }
|
263 |
| - } |
264 |
| - |
265 |
| - return false |
| 135 | + return e, nil |
| 136 | + }) |
266 | 137 | }
|
0 commit comments