forked from src-d/go-mysql-server
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresolve_columns.go
654 lines (562 loc) · 17 KB
/
resolve_columns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
package analyzer
import (
"fmt"
"sort"
"strings"
"gopkg.in/src-d/go-errors.v1"
"gopkg.in/src-d/go-mysql-server.v0/internal/similartext"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
"gopkg.in/src-d/go-vitess.v1/vt/sqlparser"
)
func checkAliases(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, _ := ctx.Span("check_aliases")
defer span.Finish()
a.Log("check aliases")
var err error
plan.Inspect(n, func(node sql.Node) bool {
p, ok := node.(*plan.Project)
if !ok {
return true
}
aliases := lookForAliasDeclarations(p)
for alias := range aliases {
if isAliasUsed(p, alias) {
err = ErrMisusedAlias.New(alias)
}
}
return true
})
return n, err
}
func lookForAliasDeclarations(node sql.Expressioner) map[string]struct{} {
var (
aliases = map[string]struct{}{}
in = struct{}{}
)
for _, e := range node.Expressions() {
expression.Inspect(e, func(expr sql.Expression) bool {
if alias, ok := expr.(*expression.Alias); ok {
aliases[alias.Name()] = in
}
return true
})
}
return aliases
}
func isAliasUsed(node sql.Expressioner, alias string) bool {
var found bool
for _, e := range node.Expressions() {
expression.Inspect(e, func(expr sql.Expression) bool {
if a, ok := expr.(*expression.Alias); ok {
if a.Name() == alias {
return false
}
return true
}
if n, ok := expr.(sql.Nameable); ok && n.Name() == alias {
found = true
return false
}
return true
})
if found {
break
}
}
return found
}
// deferredColumn is a wrapper on UnresolvedColumn used only to defer the
// resolution of the column because it may require some work done by
// other analyzer phases.
type deferredColumn struct {
*expression.UnresolvedColumn
}
// IsNullable implements the Expression interface.
func (deferredColumn) IsNullable() bool {
return true
}
func (e deferredColumn) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
return fn(e)
}
// column is the common interface that groups UnresolvedColumn and deferredColumn.
type column interface {
sql.Nameable
sql.Tableable
sql.Expression
}
func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, _ := ctx.Span("qualify_columns")
defer span.Finish()
a.Log("qualify columns")
tables := make(map[string]sql.Node)
tableAliases := make(map[string]string)
colIndex := make(map[string][]string)
indexCols := func(table string, schema sql.Schema) {
for _, col := range schema {
name := strings.ToLower(col.Name)
colIndex[name] = append(colIndex[name], strings.ToLower(table))
}
}
var projects, seenProjects int
plan.Inspect(n, func(n sql.Node) bool {
if _, ok := n.(*plan.Project); ok {
projects++
}
return true
})
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
a.Log("transforming node of type: %T", n)
switch n := n.(type) {
case *plan.TableAlias:
switch t := n.Child.(type) {
case *plan.ResolvedTable, *plan.UnresolvedTable:
name := strings.ToLower(t.(sql.Nameable).Name())
tableAliases[strings.ToLower(n.Name())] = name
default:
tables[strings.ToLower(n.Name())] = n.Child
indexCols(n.Name(), n.Schema())
}
case *plan.ResolvedTable, *plan.SubqueryAlias:
name := strings.ToLower(n.(sql.Nameable).Name())
tables[name] = n
indexCols(name, n.Schema())
}
exp, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}
result, err := exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
a.Log("transforming expression of type: %T", e)
switch col := e.(type) {
case *expression.UnresolvedColumn:
// Skip this step for global and session variables
if isGlobalOrSessionColumn(col) {
return col, nil
}
col = expression.NewUnresolvedQualifiedColumn(col.Table(), col.Name())
name := strings.ToLower(col.Name())
table := strings.ToLower(col.Table())
if table == "" {
// If a column has no table, it might be an alias
// defined in a child projection, so check that instead
// of incorrectly qualify it.
if isDefinedInChildProject(n, col) {
return col, nil
}
tables := dedupStrings(colIndex[name])
switch len(tables) {
case 0:
// If there are no tables that have any column with the column
// name let's just return it as it is. This may be an alias, so
// we'll wait for the reorder of the projection.
return col, nil
case 1:
col = expression.NewUnresolvedQualifiedColumn(
tables[0],
col.Name(),
)
default:
if _, ok := n.(*plan.GroupBy); ok {
return expression.NewUnresolvedColumn(col.Name()), nil
}
return nil, ErrAmbiguousColumnName.New(col.Name(), strings.Join(tables, ", "))
}
} else {
if real, ok := tableAliases[table]; ok {
col = expression.NewUnresolvedQualifiedColumn(
real,
col.Name(),
)
}
if _, ok := tables[col.Table()]; !ok {
if len(tables) == 0 {
return nil, sql.ErrTableNotFound.New(col.Table())
}
similar := similartext.FindFromMap(tables, col.Table())
return nil, sql.ErrTableNotFound.New(col.Table() + similar)
}
}
a.Log("column %q was qualified with table %q", col.Name(), col.Table())
return col, nil
case *expression.Star:
if col.Table != "" {
if real, ok := tableAliases[strings.ToLower(col.Table)]; ok {
col = expression.NewQualifiedStar(real)
}
if _, ok := tables[strings.ToLower(col.Table)]; !ok {
return nil, sql.ErrTableNotFound.New(col.Table)
}
return col, nil
}
default:
// If any other kind of expression has a star, just replace it
// with an unqualified star because it cannot be expanded.
return e.TransformUp(func(e sql.Expression) (sql.Expression, error) {
if _, ok := e.(*expression.Star); ok {
return expression.NewStar(), nil
}
return e, nil
})
}
return e, nil
})
if err != nil {
return nil, err
}
// We should ignore the topmost project, because some nodes are
// reordered, such as Sort, and they would not be resolved well.
if n, ok := result.(*plan.Project); ok && projects-seenProjects > 1 {
seenProjects++
// We need to modify the indexed columns to only contain what is
// projected in this project. If the column is not qualified by any
// table, just keep the ones that are currently in the index.
// If it is, then just make those tables available for the column.
// If we don't do this, columns that are not projected will be
// available in this step and may cause false errors or unintended
// results.
var projected = make(map[string][]string)
for _, p := range n.Projections {
var table, col string
switch p := p.(type) {
case column:
table = p.Table()
col = p.Name()
default:
continue
}
col = strings.ToLower(col)
table = strings.ToLower(table)
if table != "" {
projected[col] = append(projected[col], table)
} else {
projected[col] = append(projected[col], colIndex[col]...)
}
}
colIndex = make(map[string][]string)
for col, tables := range projected {
colIndex[col] = dedupStrings(tables)
}
}
return result, nil
})
}
func isDefinedInChildProject(n sql.Node, col *expression.UnresolvedColumn) bool {
var x sql.Node
for _, child := range n.Children() {
plan.Inspect(child, func(n sql.Node) bool {
switch n := n.(type) {
case *plan.SubqueryAlias:
return false
case *plan.Project, *plan.GroupBy:
if x == nil {
x = n
}
return false
default:
return true
}
})
if x != nil {
break
}
}
if x == nil {
return false
}
var found bool
for _, expr := range x.(sql.Expressioner).Expressions() {
alias, ok := expr.(*expression.Alias)
if ok && strings.ToLower(alias.Name()) == strings.ToLower(col.Name()) {
found = true
break
}
}
return found
}
var errGlobalVariablesNotSupported = errors.NewKind("can't resolve global variable, %s was requested")
const (
sessionTable = "@@" + sqlparser.SessionStr
sessionPrefix = sqlparser.SessionStr + "."
globalPrefix = sqlparser.GlobalStr + "."
)
func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, ctx := ctx.Span("resolve_columns")
defer span.Finish()
a.Log("resolve columns, node of type: %T", n)
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
a.Log("transforming node of type: %T", n)
if n.Resolved() {
return n, nil
}
var childSchema sql.Schema
colMap := make(map[string][]*sql.Column)
for _, child := range n.Children() {
if !child.Resolved() {
return n, nil
}
for _, col := range child.Schema() {
name := strings.ToLower(col.Name)
colMap[name] = append(colMap[name], col)
childSchema = append(childSchema, col)
}
}
expressioner, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}
// make sure all children are resolved before resolving a node
for _, c := range n.Children() {
if !c.Resolved() {
a.Log("a children with type %T of node %T were not resolved, skipping", c, n)
return n, nil
}
}
return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
a.Log("transforming expression of type: %T", e)
if e.Resolved() {
return e, nil
}
uc, ok := e.(column)
if !ok {
return e, nil
}
name := strings.ToLower(uc.Name())
table := strings.ToLower(uc.Table())
// First of all, try to find the field in the child schema, which
// will resolve aliases.
if idx := childSchema.IndexOf(name, table); idx >= 0 {
col := childSchema[idx]
return expression.NewGetFieldWithTable(idx, col.Type, col.Source, col.Name, col.Nullable), nil
}
columns, ok := colMap[name]
if !ok {
switch uc := uc.(type) {
case *expression.UnresolvedColumn:
if isGlobalOrSessionColumn(uc) {
if table != "" && table != sessionTable {
return nil, errGlobalVariablesNotSupported.New(uc)
}
name := strings.TrimLeft(uc.Name(), "@")
name = strings.TrimPrefix(strings.TrimPrefix(name, globalPrefix), sessionPrefix)
typ, value := ctx.Get(name)
return expression.NewGetSessionField(name, typ, value), nil
}
a.Log("evaluation of column %q was deferred", uc.Name())
return &deferredColumn{uc}, nil
default:
if len(colMap) == 0 {
return nil, ErrColumnNotFound.New(uc.Name())
}
if table != "" {
return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name())
}
similar := similartext.FindFromMap(colMap, uc.Name())
return nil, ErrColumnNotFound.New(uc.Name() + similar)
}
}
var col *sql.Column
var found bool
for _, c := range columns {
_, ok := n.(*plan.GroupBy)
if ok || (strings.ToLower(c.Source) == table) {
col = c
found = true
break
}
}
if !found {
if table != "" {
return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name())
}
switch uc := uc.(type) {
case *expression.UnresolvedColumn:
return &deferredColumn{uc}, nil
default:
return nil, ErrColumnNotFound.New(uc.Name())
}
}
var schema sql.Schema
// If expressioner and unary node we must take the
// child's schema to correctly select the indexes
// in the row is going to be evaluated in this node
if plan.IsUnary(n) {
schema = n.Children()[0].Schema()
} else {
schema = n.Schema()
}
idx := schema.IndexOf(col.Name, col.Source)
if idx < 0 {
return nil, ErrColumnNotFound.New(col.Name)
}
a.Log("column resolved to %q.%q", col.Source, col.Name)
return expression.NewGetFieldWithTable(
idx,
col.Type,
col.Source,
col.Name,
col.Nullable,
), nil
})
})
}
// resolveGroupingColumns reorders the aggregation in a groupby so aliases
// defined in it can be resolved in the grouping of the groupby. To do so,
// all aliases are pushed down to a projection node under the group by.
func resolveGroupingColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
a.Log("resoving group columns")
if n.Resolved() {
return n, nil
}
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
g, ok := n.(*plan.GroupBy)
if n.Resolved() || !ok || len(g.Grouping) == 0 {
return n, nil
}
// The reason we have two sets of columns, one for grouping and
// one for aggregate is because an alias can redefine a column name
// of the child schema. In the grouping, if that column is referenced
// it refers to the alias, and not the one in the child. However,
// in the aggregate, aliases in that same aggregate cannot be used,
// so it refers to the column in the child node.
var groupingColumns = make(map[string]struct{})
for _, g := range g.Grouping {
for _, n := range findAllColumns(g) {
groupingColumns[strings.ToLower(n)] = struct{}{}
}
}
var aggregateColumns = make(map[string]struct{})
for _, agg := range g.Aggregate {
// This alias is going to be pushed down, so don't bother gathering
// its requirements.
if alias, ok := agg.(*expression.Alias); ok {
if _, ok := groupingColumns[strings.ToLower(alias.Name())]; ok {
continue
}
}
for _, n := range findAllColumns(agg) {
aggregateColumns[strings.ToLower(n)] = struct{}{}
}
}
var newAggregate []sql.Expression
var projection []sql.Expression
// Aliases will keep the aliases that have been pushed down and their
// index in the new aggregate.
var aliases = make(map[string]int)
var needsReorder bool
for _, a := range g.Aggregate {
alias, ok := a.(*expression.Alias)
// Note that aliases of aggregations cannot be used in the grouping
// because the grouping is needed before computing the aggregation.
if !ok || containsAggregation(alias) {
newAggregate = append(newAggregate, a)
continue
}
name := strings.ToLower(alias.Name())
// Only if the alias is required in the grouping set needsReorder
// to true. If it's not required, there's no need for a reorder if
// no other alias is required.
_, ok = groupingColumns[name]
if ok {
aliases[name] = len(newAggregate)
needsReorder = true
delete(groupingColumns, name)
projection = append(projection, a)
newAggregate = append(newAggregate, expression.NewUnresolvedColumn(alias.Name()))
} else {
newAggregate = append(newAggregate, a)
}
}
if !needsReorder {
return n, nil
}
// Instead of iterating columns directly, we want them sorted so the
// executions of the rule are consistent.
var missingCols = make([]string, 0, len(aggregateColumns)+len(groupingColumns))
for col := range aggregateColumns {
missingCols = append(missingCols, col)
}
for col := range groupingColumns {
missingCols = append(missingCols, col)
}
sort.Strings(missingCols)
var renames = make(map[string]string)
// All columns required by expressions in both grouping and aggregation
// must also be projected in the new projection node or they will not
// be able to resolve.
for _, col := range missingCols {
name := col
// If an alias has been pushed down with the same name as a missing
// column, there will be a conflict of names. We must find an unique name
// for the missing column.
if _, ok := aliases[col]; ok {
for i := 1; ; i++ {
name = fmt.Sprintf("%s_%02d", col, i)
if !stringContains(missingCols, name) {
break
}
}
}
if name == col {
projection = append(projection, expression.NewUnresolvedColumn(col))
} else {
renames[col] = name
projection = append(projection, expression.NewAlias(
expression.NewUnresolvedColumn(col),
name,
))
}
}
// If there is any name conflict between columns we need to rename every
// usage inside the aggregate.
if len(renames) > 0 {
for i, expr := range newAggregate {
var err error
newAggregate[i], err = expr.TransformUp(func(e sql.Expression) (sql.Expression, error) {
col, ok := e.(*expression.UnresolvedColumn)
if ok {
// We need to make sure we don't rename the reference to the
// pushed down alias.
if to, ok := renames[col.Name()]; ok && aliases[col.Name()] != i {
return expression.NewUnresolvedColumn(to), nil
}
}
return e, nil
})
if err != nil {
return nil, err
}
}
}
return plan.NewGroupBy(
newAggregate, g.Grouping,
plan.NewProject(projection, g.Child),
), nil
})
}
func findAllColumns(e sql.Expression) []string {
var cols []string
expression.Inspect(e, func(e sql.Expression) bool {
col, ok := e.(*expression.UnresolvedColumn)
if ok {
cols = append(cols, col.Name())
}
return true
})
return cols
}
func dedupStrings(in []string) []string {
var seen = make(map[string]struct{})
var result []string
for _, s := range in {
if _, ok := seen[s]; !ok {
seen[s] = struct{}{}
result = append(result, s)
}
}
return result
}
func isGlobalOrSessionColumn(col *expression.UnresolvedColumn) bool {
return strings.HasPrefix(col.Name(), "@@") || strings.HasPrefix(col.Table(), "@@")
}