Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 5f48ea3

Browse files
authored
Merge pull request #750 from erizocosmico/feature/refactor-natural-join-rule
sql/analyzer: refactor resolve_natural_joins rule
2 parents 5336d8a + 4f6c4f8 commit 5f48ea3

File tree

4 files changed

+100
-235
lines changed

4 files changed

+100
-235
lines changed

Diff for: go.mod

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ require (
2323
github.com/stretchr/testify v1.2.2
2424
go.etcd.io/bbolt v1.3.2
2525
golang.org/x/net v0.0.0-20190227022144-312bce6e941f // indirect
26-
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b // indirect
2726
google.golang.org/grpc v1.19.0 // indirect
2827
gopkg.in/src-d/go-errors.v1 v1.0.0
2928
gopkg.in/yaml.v2 v2.2.2

Diff for: go.sum

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekf
4040
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
4141
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
4242
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
43+
github.com/golang/protobuf v1.3.0 h1:kbxbvI4Un1LUWKxufD+BiE6AEExYYgkQLQmLFqA1LFk=
44+
github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0=
4345
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
4446
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
4547
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw=
@@ -137,6 +139,7 @@ golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72
137139
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
138140
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
139141
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
142+
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
140143
golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
141144
golang.org/x/net v0.0.0-20190227022144-312bce6e941f h1:tbtX/qtlxzhZjgQue/7u7ygFwDEckd+DmS5+t8FgeKE=
142145
golang.org/x/net v0.0.0-20190227022144-312bce6e941f/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=

Diff for: sql/analyzer/resolve_natural_joins.go

+96-225
Original file line numberDiff line numberDiff line change
@@ -1,266 +1,137 @@
11
package analyzer
22

33
import (
4-
"reflect"
4+
"strings"
55

66
"github.com/src-d/go-mysql-server/sql"
77
"github.com/src-d/go-mysql-server/sql/expression"
88
"github.com/src-d/go-mysql-server/sql/plan"
99
)
1010

11-
type transformedJoin struct {
12-
node sql.Node
13-
condCols map[string]*transformedSource
14-
}
15-
16-
type transformedSource struct {
17-
correct string
18-
wrong []string
19-
}
20-
2111
func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
2212
span, _ := ctx.Span("resolve_natural_joins")
2313
defer span.Finish()
2414

25-
if n.Resolved() {
26-
return n, nil
27-
}
28-
29-
var transformed []*transformedJoin
30-
var aliasTables = map[string][]string{}
31-
var colsToUnresolve = map[string]*transformedSource{}
32-
a.Log("resolving natural joins, node of type %T", n)
33-
node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
34-
a.Log("transforming node of type: %T", n)
15+
var replacements = make(map[tableCol]tableCol)
16+
var tableAliases = make(map[string]string)
3517

36-
if alias, ok := n.(*plan.TableAlias); ok {
37-
table := alias.Child.(*plan.ResolvedTable).Name()
38-
aliasTables[alias.Name()] = append(aliasTables[alias.Name()], table)
18+
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
19+
switch n := n.(type) {
20+
case *plan.TableAlias:
21+
alias := n.Name()
22+
table := n.Child.(*plan.ResolvedTable).Name()
23+
tableAliases[strings.ToLower(alias)] = table
3924
return n, nil
40-
}
41-
42-
if n.Resolved() {
25+
case *plan.NaturalJoin:
26+
return resolveNaturalJoin(n, replacements)
27+
case sql.Expressioner:
28+
return replaceExpressions(n, replacements, tableAliases)
29+
default:
4330
return n, nil
4431
}
32+
})
33+
}
4534

46-
join, ok := n.(*plan.NaturalJoin)
47-
if !ok {
48-
return n, nil
49-
}
50-
51-
// we need both leaves resolved before resolving this one
52-
if !join.Left.Resolved() || !join.Right.Resolved() {
53-
return n, nil
54-
}
55-
56-
leftSchema, rightSchema := join.Left.Schema(), join.Right.Schema()
57-
58-
var conditions, common, left, right []sql.Expression
59-
var seen = make(map[string]struct{})
60-
61-
for i, lcol := range leftSchema {
62-
var found bool
63-
leftCol := expression.NewGetFieldWithTable(
64-
i,
65-
lcol.Type,
66-
lcol.Source,
67-
lcol.Name,
68-
lcol.Nullable,
69-
)
70-
71-
for j, rcol := range rightSchema {
72-
if lcol.Name == rcol.Name {
73-
common = append(common, leftCol)
74-
75-
conditions = append(
76-
conditions,
77-
expression.NewEquals(
78-
leftCol,
79-
expression.NewGetFieldWithTable(
80-
len(leftSchema)+j,
81-
rcol.Type,
82-
rcol.Source,
83-
rcol.Name,
84-
rcol.Nullable,
85-
),
86-
),
87-
)
88-
89-
found = true
90-
seen[lcol.Name] = struct{}{}
91-
if source, ok := colsToUnresolve[lcol.Name]; ok {
92-
source.correct = lcol.Source
93-
source.wrong = append(source.wrong, rcol.Source)
94-
} else {
95-
colsToUnresolve[lcol.Name] = &transformedSource{
96-
correct: lcol.Source,
97-
wrong: []string{rcol.Source},
98-
}
99-
}
100-
101-
break
102-
}
103-
}
35+
func resolveNaturalJoin(
36+
n *plan.NaturalJoin,
37+
replacements map[tableCol]tableCol,
38+
) (sql.Node, error) {
39+
// Both sides of the natural join need to be resolved in order to resolve
40+
// the natural join itself.
41+
if !n.Left.Resolved() || !n.Right.Resolved() {
42+
return n, nil
43+
}
10444

105-
if !found {
106-
left = append(left, leftCol)
45+
leftSchema := n.Left.Schema()
46+
rightSchema := n.Right.Schema()
47+
48+
var conditions, common, left, right []sql.Expression
49+
for i, lcol := range leftSchema {
50+
leftCol := expression.NewGetFieldWithTable(
51+
i,
52+
lcol.Type,
53+
lcol.Source,
54+
lcol.Name,
55+
lcol.Nullable,
56+
)
57+
if idx, rcol := findCol(rightSchema, lcol.Name); rcol != nil {
58+
common = append(common, leftCol)
59+
replacements[tableCol{strings.ToLower(rcol.Source), strings.ToLower(rcol.Name)}] = tableCol{
60+
strings.ToLower(lcol.Source), strings.ToLower(lcol.Name),
10761
}
108-
}
10962

110-
if len(conditions) == 0 {
111-
return plan.NewCrossJoin(join.Left, join.Right), nil
112-
}
113-
114-
for i, col := range rightSchema {
115-
if _, ok := seen[col.Name]; !ok {
116-
right = append(
117-
right,
63+
conditions = append(
64+
conditions,
65+
expression.NewEquals(
66+
leftCol,
11867
expression.NewGetFieldWithTable(
119-
len(leftSchema)+i,
120-
col.Type,
121-
col.Source,
122-
col.Name,
123-
col.Nullable,
68+
len(leftSchema)+idx,
69+
rcol.Type,
70+
rcol.Source,
71+
rcol.Name,
72+
rcol.Nullable,
12473
),
125-
)
126-
}
127-
}
128-
129-
projections := append(append(common, left...), right...)
130-
131-
tj := &transformedJoin{
132-
node: plan.NewProject(
133-
projections,
134-
plan.NewInnerJoin(
135-
join.Left,
136-
join.Right,
137-
expression.JoinAnd(conditions...),
13874
),
139-
),
140-
condCols: colsToUnresolve,
75+
)
76+
} else {
77+
left = append(left, leftCol)
14178
}
142-
143-
transformed = append(transformed, tj)
144-
145-
return tj.node, nil
146-
})
147-
148-
if err != nil || len(transformed) == 0 {
149-
return node, err
15079
}
15180

152-
var transformedSeen bool
153-
return node.TransformUp(func(node sql.Node) (sql.Node, error) {
154-
if ok, _ := isTransformedNode(node, transformed); ok {
155-
transformedSeen = true
156-
return node, nil
157-
}
158-
159-
if !transformedSeen {
160-
return node, nil
161-
}
162-
163-
expressioner, ok := node.(sql.Expressioner)
164-
if !ok {
165-
return node, nil
166-
}
167-
168-
return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
169-
var col, table string
170-
switch e := e.(type) {
171-
case *expression.GetField:
172-
col, table = e.Name(), e.Table()
173-
case *expression.UnresolvedColumn:
174-
col, table = e.Name(), e.Table()
175-
default:
176-
return e, nil
177-
}
178-
179-
sources, ok := colsToUnresolve[col]
180-
if !ok {
181-
return e, nil
182-
}
183-
184-
if !mustUnresolve(aliasTables, table, sources.wrong) {
185-
return e, nil
186-
}
187-
188-
return expression.NewUnresolvedQualifiedColumn(
189-
sources.correct,
190-
col,
191-
), nil
192-
})
193-
})
194-
}
195-
196-
func isTransformedNode(node sql.Node, transformed []*transformedJoin) (is bool, colsToUnresolve map[string]*transformedSource) {
197-
var project *plan.Project
198-
var join *plan.InnerJoin
199-
switch n := node.(type) {
200-
case *plan.Project:
201-
var ok bool
202-
join, ok = n.Child.(*plan.InnerJoin)
203-
if !ok {
204-
return
205-
}
206-
207-
project = n
208-
case *plan.InnerJoin:
209-
join = n
210-
211-
default:
212-
return
81+
if len(conditions) == 0 {
82+
return plan.NewCrossJoin(n.Left, n.Right), nil
21383
}
21484

215-
for _, t := range transformed {
216-
tproject, ok := t.node.(*plan.Project)
217-
if !ok {
218-
return
219-
}
220-
221-
tjoin, ok := tproject.Child.(*plan.InnerJoin)
222-
if !ok {
223-
return
224-
}
225-
226-
if project != nil && !reflect.DeepEqual(project.Projections, tproject.Projections) {
227-
continue
228-
}
229-
230-
if reflect.DeepEqual(join.Cond, tjoin.Cond) {
231-
is = true
232-
colsToUnresolve = t.condCols
85+
for i, col := range rightSchema {
86+
source := strings.ToLower(col.Source)
87+
name := strings.ToLower(col.Name)
88+
if _, ok := replacements[tableCol{source, name}]; !ok {
89+
right = append(
90+
right,
91+
expression.NewGetFieldWithTable(
92+
len(leftSchema)+i,
93+
col.Type,
94+
col.Source,
95+
col.Name,
96+
col.Nullable,
97+
),
98+
)
23399
}
234100
}
235101

236-
return
102+
return plan.NewProject(
103+
append(append(common, left...), right...),
104+
plan.NewInnerJoin(n.Left, n.Right, expression.JoinAnd(conditions...)),
105+
), nil
237106
}
238107

239-
func mustUnresolve(aliasTable map[string][]string, table string, wrongSources []string) bool {
240-
return isIn(table, wrongSources) || isAliasFor(aliasTable, table, wrongSources)
241-
}
242-
243-
func isIn(s string, l []string) bool {
244-
for _, e := range l {
245-
if s == e {
246-
return true
108+
func findCol(s sql.Schema, name string) (int, *sql.Column) {
109+
for i, c := range s {
110+
if strings.ToLower(c.Name) == strings.ToLower(name) {
111+
return i, c
247112
}
248113
}
249-
250-
return false
114+
return -1, nil
251115
}
252116

253-
func isAliasFor(aliasTable map[string][]string, table string, wrongSources []string) bool {
254-
tables, ok := aliasTable[table]
255-
if !ok {
256-
return false
257-
}
117+
func replaceExpressions(
118+
n sql.Expressioner,
119+
replacements map[tableCol]tableCol,
120+
tableAliases map[string]string,
121+
) (sql.Node, error) {
122+
return n.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
123+
switch e := e.(type) {
124+
case *expression.GetField, *expression.UnresolvedColumn:
125+
var tableName = e.(sql.Tableable).Table()
126+
if t, ok := tableAliases[strings.ToLower(tableName)]; ok {
127+
tableName = t
128+
}
258129

259-
for _, t := range tables {
260-
if isIn(t, wrongSources) {
261-
return true
130+
name := e.(sql.Nameable).Name()
131+
if col, ok := replacements[tableCol{strings.ToLower(tableName), strings.ToLower(name)}]; ok {
132+
return expression.NewUnresolvedQualifiedColumn(col.table, col.col), nil
133+
}
262134
}
263-
}
264-
265-
return false
135+
return e, nil
136+
})
266137
}

0 commit comments

Comments
 (0)