Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 18e2012

Browse files
committed
function: implement regexp_matches
Signed-off-by: Miguel Molina <[email protected]>
1 parent 550cc54 commit 18e2012

File tree

5 files changed

+309
-0
lines changed

5 files changed

+309
-0
lines changed

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ We support and actively test against certain third-party clients to ensure compa
103103
|`NOW()`|Returns the current timestamp.|
104104
|`NULLIF(expr1, expr2)`|Returns NULL if expr1 = expr2 is true, otherwise returns expr1.|
105105
|`POW(X, Y)`|Returns the value of X raised to the power of Y.|
106+
|`REGEXP_MATCHES(text, pattern, [flags])`|Returns an array with the matches of the pattern in the given text. Flags can be given to control certain behaviours of the regular expression. Currently, only the `i` flag is supported, to make the comparison case insensitive.|
106107
|`REPEAT(str, count)`|Returns a string consisting of the string str repeated count times.|
107108
|`REPLACE(str,from_str,to_str)`|Returns the string str with all occurrences of the string from_str replaced by the string to_str.|
108109
|`REVERSE(str)`|Returns the string str with the order of the characters reversed.|

Diff for: engine_test.go

+12
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,18 @@ var queries = []struct {
14541454
ORDER BY table_type, table_schema, table_name`,
14551455
[]sql.Row{{"mydb", "mytable", "TABLE"}},
14561456
},
1457+
{
1458+
`SELECT REGEXP_MATCHES("bopbeepbop", "bop")`,
1459+
[]sql.Row{{[]interface{}{"bop", "bop"}}},
1460+
},
1461+
{
1462+
`SELECT EXPLODE(REGEXP_MATCHES("bopbeepbop", "bop"))`,
1463+
[]sql.Row{{"bop"}, {"bop"}},
1464+
},
1465+
{
1466+
`SELECT EXPLODE(REGEXP_MATCHES("helloworld", "bop"))`,
1467+
[]sql.Row{},
1468+
},
14571469
}
14581470

14591471
func TestQueries(t *testing.T) {

Diff for: sql/expression/function/regexp_matches.go

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
package function
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
"strings"
7+
8+
"github.com/src-d/go-mysql-server/sql"
9+
"github.com/src-d/go-mysql-server/sql/expression"
10+
errors "gopkg.in/src-d/go-errors.v1"
11+
)
12+
13+
// RegexpMatches returns the matches of a regular expression.
14+
type RegexpMatches struct {
15+
Text sql.Expression
16+
Pattern sql.Expression
17+
Flags sql.Expression
18+
19+
cacheable bool
20+
re *regexp.Regexp
21+
}
22+
23+
// NewRegexpMatches creates a new RegexpMatches expression.
24+
func NewRegexpMatches(args ...sql.Expression) (sql.Expression, error) {
25+
var r RegexpMatches
26+
switch len(args) {
27+
case 3:
28+
r.Flags = args[2]
29+
fallthrough
30+
case 2:
31+
r.Text = args[0]
32+
r.Pattern = args[1]
33+
default:
34+
return nil, sql.ErrInvalidArgumentNumber.New("regexp_matches", "2 or 3", len(args))
35+
}
36+
37+
if canBeCached(r.Pattern) && (r.Flags == nil || canBeCached(r.Flags)) {
38+
r.cacheable = true
39+
}
40+
41+
return &r, nil
42+
}
43+
44+
// Type implements the sql.Expression interface.
45+
func (r *RegexpMatches) Type() sql.Type { return sql.Array(sql.Text) }
46+
47+
// IsNullable implements the sql.Expression interface.
48+
func (r *RegexpMatches) IsNullable() bool { return true }
49+
50+
// Children implements the sql.Expression interface.
51+
func (r *RegexpMatches) Children() []sql.Expression {
52+
var result = []sql.Expression{r.Text, r.Pattern}
53+
if r.Flags != nil {
54+
result = append(result, r.Flags)
55+
}
56+
return result
57+
}
58+
59+
// Resolved implements the sql.Expression interface.
60+
func (r *RegexpMatches) Resolved() bool {
61+
return r.Text.Resolved() && r.Pattern.Resolved() && (r.Flags == nil || r.Flags.Resolved())
62+
}
63+
64+
// WithChildren implements the sql.Expression interface.
65+
func (r *RegexpMatches) WithChildren(children ...sql.Expression) (sql.Expression, error) {
66+
required := 2
67+
if r.Flags != nil {
68+
required = 3
69+
}
70+
71+
if len(children) != required {
72+
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), required)
73+
}
74+
75+
return NewRegexpMatches(children...)
76+
}
77+
78+
func (r *RegexpMatches) String() string {
79+
var args []string
80+
for _, e := range r.Children() {
81+
args = append(args, e.String())
82+
}
83+
return fmt.Sprintf("regexp_matches(%s)", strings.Join(args, ", "))
84+
}
85+
86+
// Eval implements the sql.Expression interface.
87+
func (r *RegexpMatches) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
88+
span, ctx := ctx.Span("function.RegexpMatches")
89+
defer span.Finish()
90+
91+
var re *regexp.Regexp
92+
var err error
93+
if r.cacheable {
94+
if r.re == nil {
95+
r.re, err = r.compileRegex(ctx, nil)
96+
if err != nil {
97+
return nil, err
98+
}
99+
}
100+
re = r.re
101+
} else {
102+
re, err = r.compileRegex(ctx, row)
103+
if err != nil {
104+
return nil, err
105+
}
106+
}
107+
108+
text, err := r.Text.Eval(ctx, row)
109+
if err != nil {
110+
return nil, err
111+
}
112+
113+
text, err = sql.Text.Convert(text)
114+
if err != nil {
115+
return nil, err
116+
}
117+
118+
matches := re.FindAllStringSubmatch(text.(string), -1)
119+
if len(matches) == 0 {
120+
return nil, nil
121+
}
122+
123+
var result []interface{}
124+
for _, m := range matches {
125+
for _, sm := range m {
126+
result = append(result, sm)
127+
}
128+
}
129+
130+
return result, nil
131+
}
132+
133+
func (r *RegexpMatches) compileRegex(ctx *sql.Context, row sql.Row) (*regexp.Regexp, error) {
134+
pattern, err := r.Pattern.Eval(ctx, row)
135+
if err != nil {
136+
return nil, err
137+
}
138+
139+
pattern, err = sql.Text.Convert(pattern)
140+
if err != nil {
141+
return nil, err
142+
}
143+
144+
var flags string
145+
if r.Flags != nil {
146+
f, err := r.Flags.Eval(ctx, row)
147+
if err != nil {
148+
return nil, err
149+
}
150+
151+
f, err = sql.Text.Convert(f)
152+
if err != nil {
153+
return nil, err
154+
}
155+
156+
flags = f.(string)
157+
for _, f := range flags {
158+
if !validRegexpFlags[f] {
159+
return nil, errInvalidRegexpFlag.New(f)
160+
}
161+
}
162+
163+
flags = fmt.Sprintf("(?%s)", flags)
164+
}
165+
166+
return regexp.Compile(flags + pattern.(string))
167+
}
168+
169+
var errInvalidRegexpFlag = errors.NewKind("invalid regexp flag: %v")
170+
171+
var validRegexpFlags = map[rune]bool{
172+
'i': true,
173+
}
174+
175+
func canBeCached(e sql.Expression) bool {
176+
var hasCols bool
177+
expression.Inspect(e, func(e sql.Expression) bool {
178+
if _, ok := e.(*expression.GetField); ok {
179+
hasCols = true
180+
}
181+
return true
182+
})
183+
return !hasCols
184+
}

Diff for: sql/expression/function/regexp_matches_test.go

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package function
2+
3+
import (
4+
"testing"
5+
6+
"github.com/src-d/go-mysql-server/sql"
7+
"github.com/src-d/go-mysql-server/sql/expression"
8+
"github.com/stretchr/testify/require"
9+
10+
errors "gopkg.in/src-d/go-errors.v1"
11+
)
12+
13+
func TestRegexpMatches(t *testing.T) {
14+
testCases := []struct {
15+
pattern string
16+
text string
17+
flags string
18+
expected interface{}
19+
err *errors.Kind
20+
}{
21+
{
22+
`^foobar(.*)bye$`,
23+
"foobarhellobye",
24+
"",
25+
[]interface{}{"foobarhellobye", "hello"},
26+
nil,
27+
},
28+
{
29+
"bop",
30+
"bopbeepbop",
31+
"",
32+
[]interface{}{"bop", "bop"},
33+
nil,
34+
},
35+
{
36+
"bop",
37+
"bopbeepBop",
38+
"i",
39+
[]interface{}{"bop", "Bop"},
40+
nil,
41+
},
42+
{
43+
"bop",
44+
"helloworld",
45+
"",
46+
nil,
47+
nil,
48+
},
49+
{
50+
"bop",
51+
"bopbeepBop",
52+
"ix",
53+
nil,
54+
errInvalidRegexpFlag,
55+
},
56+
}
57+
58+
t.Run("cacheable", func(t *testing.T) {
59+
for _, tt := range testCases {
60+
var flags sql.Expression
61+
if tt.flags != "" {
62+
flags = expression.NewLiteral(tt.flags, sql.Text)
63+
}
64+
f, err := NewRegexpMatches(
65+
expression.NewLiteral(tt.text, sql.Text),
66+
expression.NewLiteral(tt.pattern, sql.Text),
67+
flags,
68+
)
69+
require.NoError(t, err)
70+
71+
t.Run(f.String(), func(t *testing.T) {
72+
require := require.New(t)
73+
result, err := f.Eval(sql.NewEmptyContext(), nil)
74+
if tt.err == nil {
75+
require.NoError(err)
76+
require.Equal(tt.expected, result)
77+
} else {
78+
require.Error(err)
79+
require.True(tt.err.Is(err))
80+
}
81+
})
82+
}
83+
})
84+
85+
t.Run("not cacheable", func(t *testing.T) {
86+
for _, tt := range testCases {
87+
var flags sql.Expression
88+
if tt.flags != "" {
89+
flags = expression.NewGetField(2, sql.Text, "x", false)
90+
}
91+
f, err := NewRegexpMatches(
92+
expression.NewGetField(0, sql.Text, "x", false),
93+
expression.NewGetField(1, sql.Text, "x", false),
94+
flags,
95+
)
96+
require.NoError(t, err)
97+
98+
t.Run(f.String(), func(t *testing.T) {
99+
require := require.New(t)
100+
result, err := f.Eval(sql.NewEmptyContext(), sql.Row{tt.text, tt.pattern, tt.flags})
101+
if tt.err == nil {
102+
require.NoError(err)
103+
require.Equal(tt.expected, result)
104+
} else {
105+
require.Error(err)
106+
require.True(tt.err.Is(err))
107+
}
108+
})
109+
}
110+
})
111+
}

Diff for: sql/expression/function/registry.go

+1
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,5 @@ var Defaults = []sql.Function{
9898
sql.Function1{Name: "char_length", Fn: NewCharLength},
9999
sql.Function1{Name: "character_length", Fn: NewCharLength},
100100
sql.Function1{Name: "explode", Fn: NewExplode},
101+
sql.FunctionN{Name: "regexp_matches", Fn: NewRegexpMatches},
101102
}

0 commit comments

Comments
 (0)