-
Notifications
You must be signed in to change notification settings - Fork 25.2k
SQL: Fix incorrect parameter resolution #63710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
620d8b4
40254dc
4d0568e
db8ad10
03e9a75
7ae10b2
b61461d
94d734d
c205e9d
f2e46de
2cefcc6
fa1d31c
ecd3087
541b4c9
6bb184d
9cd3dc0
09cfe08
f7fead3
bee8a7d
62a2f90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -139,10 +139,10 @@ | |
|
||
abstract class ExpressionBuilder extends IdentifierBuilder { | ||
|
||
private final Map<Token, SqlTypedParamValue> params; | ||
private final Map<Token, SqlParser.SqlParameter> params; | ||
private final ZoneId zoneId; | ||
|
||
ExpressionBuilder(Map<Token, SqlTypedParamValue> params, ZoneId zoneId) { | ||
ExpressionBuilder(Map<Token, SqlParser.SqlParameter> params, ZoneId zoneId) { | ||
this.params = params; | ||
this.zoneId = zoneId; | ||
} | ||
|
@@ -165,7 +165,15 @@ public Expression visitSelectExpression(SelectExpressionContext ctx) { | |
Expression exp = expression(ctx.expression()); | ||
String alias = visitIdentifier(ctx.identifier()); | ||
Source source = source(ctx); | ||
return alias != null ? new Alias(source, alias, exp) : new UnresolvedAlias(source, exp); | ||
if (alias != null) { | ||
return new Alias(source, alias, exp); | ||
} | ||
if (exp instanceof Literal && "?".equals(exp.source().text())) { | ||
int paramIndex = param(ctx).getIndex(); | ||
// all indexes related to JDBC or databases usually start with 1 | ||
return new Alias(source, "?" + paramIndex, exp); | ||
} | ||
return new UnresolvedAlias(source, exp); | ||
} | ||
|
||
@Override | ||
|
@@ -700,7 +708,7 @@ public Literal visitIntegerLiteral(IntegerLiteralContext ctx) { | |
|
||
@Override | ||
public Literal visitParamLiteral(ParamLiteralContext ctx) { | ||
SqlTypedParamValue param = param(ctx.PARAM()); | ||
SqlTypedParamValue param = param(ctx.PARAM()).getValue(); | ||
DataType dataType = SqlDataTypes.fromTypeName(param.type); | ||
Source source = source(ctx); | ||
if (dataType == null) { | ||
|
@@ -745,15 +753,17 @@ String string(StringContext ctx) { | |
if (ctx == null) { | ||
return null; | ||
} | ||
SqlTypedParamValue param = param(ctx.PARAM()); | ||
if (param != null) { | ||
return param.value != null ? param.value.toString() : null; | ||
} else { | ||
return unquoteString(ctx.getText()); | ||
SqlParser.SqlParameter sqlParam = param(ctx.PARAM()); | ||
if (sqlParam != null) { | ||
SqlTypedParamValue param = sqlParam.getValue(); | ||
if (param != null) { | ||
return param.value != null ? param.value.toString() : null; | ||
} | ||
} | ||
return unquoteString(ctx.getText()); | ||
} | ||
|
||
private SqlTypedParamValue param(TerminalNode node) { | ||
private SqlParser.SqlParameter param(TerminalNode node) { | ||
if (node == null) { | ||
return null; | ||
} | ||
|
@@ -767,6 +777,16 @@ private SqlTypedParamValue param(TerminalNode node) { | |
return params.get(token); | ||
} | ||
|
||
private SqlParser.SqlParameter param(ParserRuleContext ctx) { | ||
if (!ctx.getStart().equals(ctx.getStop())) { | ||
throw new ParsingException(source(ctx), "Single PARAM literal expected"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Incorrect style (use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: any reason why this is not part of |
||
} | ||
if (params.containsKey(ctx.getStart()) == false) { | ||
throw new ParsingException(source(ctx), "Unexpected parameter"); | ||
} | ||
return params.get(ctx.getStart()); | ||
} | ||
|
||
@Override | ||
public Literal visitDateEscapedLiteral(DateEscapedLiteralContext ctx) { | ||
String string = string(ctx.string()); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,7 +104,7 @@ private <T> T invokeParser(String sql, | |
lexer.removeErrorListeners(); | ||
lexer.addErrorListener(ERROR_LISTENER); | ||
|
||
Map<Token, SqlTypedParamValue> paramTokens = new HashMap<>(); | ||
Map<Token, SqlParameter> paramTokens = new HashMap<>(); | ||
TokenSource tokenSource = new ParametrizedTokenSource(lexer, paramTokens, params); | ||
|
||
CommonTokenStream tokenStream = new CommonTokenStream(tokenSource); | ||
|
@@ -231,6 +231,27 @@ public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int | |
} | ||
}; | ||
|
||
public static class SqlParameter { | ||
public final Integer index; | ||
public final SqlTypedParamValue value; | ||
|
||
/** | ||
* @param index Index of the SQL parameter. Index of first parameter is 1 (same as in JDBC). | ||
*/ | ||
public SqlParameter(Integer index, SqlTypedParamValue value) { | ||
this.index = index; | ||
this.value = value; | ||
} | ||
|
||
public Integer getIndex() { | ||
return index; | ||
} | ||
|
||
public SqlTypedParamValue getValue() { | ||
return value; | ||
} | ||
} | ||
|
||
/** | ||
* Finds all parameter tokens (?) and associates them with actual parameter values | ||
* <p> | ||
|
@@ -240,26 +261,28 @@ public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int | |
private static class ParametrizedTokenSource implements TokenSource { | ||
|
||
private TokenSource delegate; | ||
private Map<Token, SqlTypedParamValue> paramTokens; | ||
private int param; | ||
private Map<Token, SqlParameter> paramTokens; | ||
private int paramIndex; | ||
private List<SqlTypedParamValue> params; | ||
|
||
ParametrizedTokenSource(TokenSource delegate, Map<Token, SqlTypedParamValue> paramTokens, List<SqlTypedParamValue> params) { | ||
ParametrizedTokenSource(TokenSource delegate, | ||
Map<Token, SqlParameter> paramTokens, | ||
List<SqlTypedParamValue> params) { | ||
this.delegate = delegate; | ||
this.paramTokens = paramTokens; | ||
this.params = params; | ||
param = 0; | ||
paramIndex = 0; | ||
} | ||
|
||
@Override | ||
public Token nextToken() { | ||
Token token = delegate.nextToken(); | ||
if (token.getType() == SqlBaseLexer.PARAM) { | ||
if (param >= params.size()) { | ||
if (paramIndex >= params.size()) { | ||
throw new ParsingException("Not enough actual parameters {} ", params.size()); | ||
} | ||
paramTokens.put(token, params.get(param)); | ||
param++; | ||
paramTokens.put(token, new SqlParameter(paramIndex+1, params.get(paramIndex))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. incorrect formatting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: will do, again would be great to add it to |
||
paramIndex++; | ||
} | ||
return token; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.sql.parser; | ||
|
||
import org.elasticsearch.test.ESTestCase; | ||
import org.elasticsearch.xpack.ql.expression.Alias; | ||
import org.elasticsearch.xpack.ql.expression.Literal; | ||
import org.elasticsearch.xpack.ql.expression.NamedExpression; | ||
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThan; | ||
import org.elasticsearch.xpack.ql.plan.logical.Filter; | ||
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; | ||
import org.elasticsearch.xpack.ql.plan.logical.Project; | ||
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue; | ||
|
||
import java.util.List; | ||
|
||
import static org.elasticsearch.xpack.ql.type.DateUtils.UTC; | ||
import static org.hamcrest.Matchers.equalTo; | ||
import static org.hamcrest.Matchers.everyItem; | ||
import static org.hamcrest.Matchers.isA; | ||
import static org.hamcrest.Matchers.startsWith; | ||
|
||
public class ParamLiteralTests extends ESTestCase { | ||
|
||
private final SqlParser parser = new SqlParser(); | ||
|
||
private LogicalPlan parse(String sql, SqlTypedParamValue... parameters) { | ||
return parser.createStatement(sql, List.of(parameters), UTC); | ||
} | ||
|
||
public void testMultipleParamLiteralsWithUnresolvedAliases() { | ||
LogicalPlan logicalPlan = parse("SELECT ?, ? FROM test", | ||
new SqlTypedParamValue("integer", 100), | ||
new SqlTypedParamValue("integer", 200) | ||
); | ||
List<? extends NamedExpression> projections = ((Project) logicalPlan.children().get(0)).projections(); | ||
assertThat(projections, everyItem(isA(Alias.class))); | ||
assertThat(projections.get(0).toString(), startsWith("100 AS ?1#")); | ||
assertThat(projections.get(1).toString(), startsWith("200 AS ?2#")); | ||
} | ||
|
||
public void testMultipleParamLiteralsWithUnresolvedAliasesAndWhereClause() { | ||
LogicalPlan logicalPlan = parse("SELECT ?, ? FROM test WHERE 1 < ?", | ||
new SqlTypedParamValue("integer", 100), | ||
new SqlTypedParamValue("integer", 200), | ||
new SqlTypedParamValue("integer", 300) | ||
); | ||
Project project = (Project) logicalPlan.children().get(0); | ||
List<? extends NamedExpression> projections = project.projections(); | ||
assertThat(projections, everyItem(isA(Alias.class))); | ||
assertThat(projections.get(0).toString(), startsWith("100 AS ?1#")); | ||
assertThat(projections.get(1).toString(), startsWith("200 AS ?2#")); | ||
assertThat(project.children().get(0), isA(Filter.class)); | ||
Filter filter = (Filter) project.children().get(0); | ||
assertThat(filter.condition(), isA(LessThan.class)); | ||
LessThan condition = (LessThan) filter.condition(); | ||
assertThat(condition.left(), isA(Literal.class)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In *QL (and Elasticsearch) codebase, the general style for these assertions is either There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or |
||
assertThat(condition.right(), isA(Literal.class)); | ||
assertThat(((Literal)condition.right()).value(), equalTo(300)); | ||
} | ||
|
||
public void testParamLiteralsWithUnresolvedAliasesAndMixedTypes() { | ||
LogicalPlan logicalPlan = parse("SELECT ?, ? FROM test", | ||
new SqlTypedParamValue("integer", 100), | ||
new SqlTypedParamValue("text", "200") | ||
); | ||
List<? extends NamedExpression> projections = ((Project) logicalPlan.children().get(0)).projections(); | ||
assertThat(projections, everyItem(isA(Alias.class))); | ||
assertThat(projections.get(0).toString(), startsWith("100 AS ?1#")); | ||
assertThat(projections.get(1).toString(), startsWith("200 AS ?2#")); | ||
} | ||
|
||
public void testParamLiteralsWithResolvedAndUnresolvedAliases() { | ||
LogicalPlan logicalPlan = parse("SELECT ?, ? as x, ? FROM test", | ||
new SqlTypedParamValue("integer", 100), | ||
new SqlTypedParamValue("integer", 200), | ||
new SqlTypedParamValue("integer", 300) | ||
); | ||
List<? extends NamedExpression> projections = ((Project) logicalPlan.children().get(0)).projections(); | ||
assertThat(projections, everyItem(isA(Alias.class))); | ||
assertThat(projections.get(0).toString(), startsWith("100 AS ?1#")); | ||
assertThat(projections.get(1).toString(), startsWith("200 AS x#"));; | ||
assertThat(projections.get(2).toString(), startsWith("300 AS ?3#"));; | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert. The PRs should contain the minimal number of changes needed and not affect the rest of the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, leftover from prev commits.