Skip to content

Commit 7154635

Browse files
authored
expression: Add regexp support (#105)
1 parent 2a9c67d commit 7154635

File tree

4 files changed

+114
-4
lines changed

4 files changed

+114
-4
lines changed

sql/expression/comparison.go

+45-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package expression
22

33
import (
44
"fmt"
5+
"regexp"
56

67
"github.com/gitql/gitql/sql"
78
)
@@ -42,6 +43,50 @@ func (c *Equals) TransformUp(f func(sql.Expression) sql.Expression) sql.Expressi
4243
return f(NewEquals(lc, rc))
4344
}
4445

46+
func (e Equals) Name() string {
47+
return e.Left.Name() + "==" + e.Right.Name()
48+
}
49+
50+
type Regexp struct {
51+
Comparison
52+
}
53+
54+
func NewRegexp(left sql.Expression, right sql.Expression) *Regexp {
55+
// FIXME: enable this again
56+
// checkEqualTypes(left, right)
57+
return &Regexp{Comparison{BinaryExpression{left, right}, left.Type()}}
58+
}
59+
60+
func (e Regexp) Eval(row sql.Row) interface{} {
61+
l := e.Left.Eval(row)
62+
r := e.Right.Eval(row)
63+
64+
sl, okl := l.(string)
65+
sr, okr := r.(string)
66+
67+
if !okl || !okr {
68+
return e.ChildType.Compare(l, r) == 0
69+
}
70+
71+
reg, err := regexp.Compile(sr)
72+
if err != nil {
73+
return false
74+
}
75+
76+
return reg.MatchString(sl)
77+
}
78+
79+
func (c *Regexp) TransformUp(f func(sql.Expression) sql.Expression) sql.Expression {
80+
lc := c.BinaryExpression.Left.TransformUp(f)
81+
rc := c.BinaryExpression.Right.TransformUp(f)
82+
83+
return f(NewRegexp(lc, rc))
84+
}
85+
86+
func (e Regexp) Name() string {
87+
return e.Left.Name() + " REGEXP " + e.Right.Name()
88+
}
89+
4590
type GreaterThan struct {
4691
Comparison
4792
}
@@ -139,7 +184,3 @@ func checkEqualTypes(a sql.Expression, b sql.Expression) {
139184
panic(fmt.Errorf("both types should be equal: %v and %v\n", a, b))
140185
}
141186
}
142-
143-
func (e Equals) Name() string {
144-
return e.Left.Name() + "==" + e.Right.Name()
145-
}

sql/expression/comparison_test.go

+55
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ const (
1212
testEqual = 1
1313
testLess = 2
1414
testGreater = 3
15+
testRegexp = 4
16+
testNotRegexp = 5
1517
)
1618

1719
var comparisonCases = map[sql.Type]map[int][][]interface{}{
@@ -45,6 +47,33 @@ var comparisonCases = map[sql.Type]map[int][][]interface{}{
4547
},
4648
}
4749

50+
var likeComparisonCases = map[sql.Type]map[int][][]interface{}{
51+
sql.String: {
52+
testRegexp: {
53+
{"foobar", ".*bar"},
54+
{"foobarfoo", ".*bar.*"},
55+
{"bar", "bar"},
56+
{"barfoo", "bar.*"},
57+
},
58+
testNotRegexp: {
59+
{"foobara", ".*bar$"},
60+
{"foofoo", ".*bar.*"},
61+
{"bara", "bar$"},
62+
{"abarfoo", "^bar.*"},
63+
},
64+
},
65+
sql.Integer: {
66+
testRegexp: {
67+
{int32(1), int32(1)},
68+
{int32(0), int32(0)},
69+
},
70+
testNotRegexp: {
71+
{int32(-1), int32(0)},
72+
{int32(1), int32(2)},
73+
},
74+
},
75+
}
76+
4877
func TestComparisons_Equals(t *testing.T) {
4978
assert := require.New(t)
5079
for resultType, cmpCase := range comparisonCases {
@@ -122,3 +151,29 @@ func TestComparisons_GreaterThan(t *testing.T) {
122151
}
123152
}
124153
}
154+
155+
func TestComparisons_Regexp(t *testing.T) {
156+
assert := require.New(t)
157+
for resultType, cmpCase := range likeComparisonCases {
158+
get0 := NewGetField(0, resultType, "col1")
159+
assert.NotNil(get0)
160+
get1 := NewGetField(1, resultType, "col2")
161+
assert.NotNil(get1)
162+
eq := NewRegexp(get0, get1)
163+
assert.NotNil(eq)
164+
assert.Equal(sql.Boolean, eq.Type())
165+
for cmpResult, cases := range cmpCase {
166+
for _, pair := range cases {
167+
row := sql.NewRow(pair[0], pair[1])
168+
assert.NotNil(row)
169+
cmp := eq.Eval(row)
170+
assert.NotNil(cmp)
171+
if cmpResult == testRegexp {
172+
assert.Equal(true, cmp)
173+
} else {
174+
assert.Equal(false, cmp)
175+
}
176+
}
177+
}
178+
}
179+
}

sql/parse/parse.go

+2
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ func comparisonExprToExpression(c *sqlparser.ComparisonExpr) (sql.Expression,
306306
switch c.Operator {
307307
default:
308308
return nil, errUnsupportedFeature(c.Operator)
309+
case sqlparser.RegexpStr:
310+
return expression.NewRegexp(left, right), nil
309311
case sqlparser.EqualStr:
310312
return expression.NewEquals(left, right), nil
311313
case sqlparser.LessThanStr:

sql/parse/parse_test.go

+12
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,18 @@ var fixtures = map[string]sql.Node{
161161
[]sql.Expression{},
162162
plan.NewUnresolvedTable("t1"),
163163
),
164+
`SELECT a FROM t1 where a regexp '.*test.*';`: plan.NewProject(
165+
[]sql.Expression{
166+
expression.NewUnresolvedColumn("a"),
167+
},
168+
plan.NewFilter(
169+
expression.NewRegexp(
170+
expression.NewUnresolvedColumn("a"),
171+
expression.NewLiteral(".*test.*", sql.String),
172+
),
173+
plan.NewUnresolvedTable("t1"),
174+
),
175+
),
164176
}
165177

166178
func TestParse(t *testing.T) {

0 commit comments

Comments
 (0)