Skip to content

SQL: Fix incorrect parameter resolution #63710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -956,8 +956,7 @@ private boolean hasUnresolvedAliases(List<? extends NamedExpression> expressions

private List<NamedExpression> assignAliases(List<? extends NamedExpression> exprs) {
List<NamedExpression> newExpr = new ArrayList<>(exprs.size());
for (int i = 0; i < exprs.size(); i++) {
NamedExpression expr = exprs.get(i);
for (NamedExpression expr : exprs) {
NamedExpression transformed = (NamedExpression) expr.transformUp(ua -> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert. The PRs should contain the minimal number of changes needed and not affect the rest of the code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, leftover from prev commits.

Expression child = ua.child();
if (child instanceof NamedExpression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
import org.antlr.v4.runtime.Token;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.SingleStatementContext;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;

import java.time.ZoneId;
import java.util.Map;

class AstBuilder extends CommandBuilder {
/**
* Create AST Builder
* @param params a map between '?' tokens that represent parameters and the actual parameter values
* @param params a map between '?' tokens that represent parameters
* and the parameter indexes and values
* @param zoneId user specified timezone in the session
*/
AstBuilder(Map<Token, SqlTypedParamValue> params, ZoneId zoneId) {
AstBuilder(Map<Token, SqlParser.SqlParameter> params, ZoneId zoneId) {
super(params, zoneId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.elasticsearch.xpack.sql.plan.logical.command.sys.SysColumns;
import org.elasticsearch.xpack.sql.plan.logical.command.sys.SysTables;
import org.elasticsearch.xpack.sql.plan.logical.command.sys.SysTypes;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;

import java.time.ZoneId;
import java.util.ArrayList;
Expand All @@ -45,7 +44,7 @@

abstract class CommandBuilder extends LogicalPlanBuilder {

protected CommandBuilder(Map<Token, SqlTypedParamValue> params, ZoneId zoneId) {
protected CommandBuilder(Map<Token, SqlParser.SqlParameter> params, ZoneId zoneId) {
super(params, zoneId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@

abstract class ExpressionBuilder extends IdentifierBuilder {

private final Map<Token, SqlTypedParamValue> params;
private final Map<Token, SqlParser.SqlParameter> params;
private final ZoneId zoneId;

ExpressionBuilder(Map<Token, SqlTypedParamValue> params, ZoneId zoneId) {
ExpressionBuilder(Map<Token, SqlParser.SqlParameter> params, ZoneId zoneId) {
this.params = params;
this.zoneId = zoneId;
}
Expand All @@ -165,7 +165,15 @@ public Expression visitSelectExpression(SelectExpressionContext ctx) {
Expression exp = expression(ctx.expression());
String alias = visitIdentifier(ctx.identifier());
Source source = source(ctx);
return alias != null ? new Alias(source, alias, exp) : new UnresolvedAlias(source, exp);
if (alias != null) {
return new Alias(source, alias, exp);
}
if (exp instanceof Literal && "?".equals(exp.source().text())) {
int paramIndex = param(ctx).getIndex();
// all indexes related to JDBC or databases usually start with 1
return new Alias(source, "?" + paramIndex, exp);
}
return new UnresolvedAlias(source, exp);
}

@Override
Expand Down Expand Up @@ -700,7 +708,7 @@ public Literal visitIntegerLiteral(IntegerLiteralContext ctx) {

@Override
public Literal visitParamLiteral(ParamLiteralContext ctx) {
SqlTypedParamValue param = param(ctx.PARAM());
SqlTypedParamValue param = param(ctx.PARAM()).getValue();
DataType dataType = SqlDataTypes.fromTypeName(param.type);
Source source = source(ctx);
if (dataType == null) {
Expand Down Expand Up @@ -745,15 +753,17 @@ String string(StringContext ctx) {
if (ctx == null) {
return null;
}
SqlTypedParamValue param = param(ctx.PARAM());
if (param != null) {
return param.value != null ? param.value.toString() : null;
} else {
return unquoteString(ctx.getText());
SqlParser.SqlParameter sqlParam = param(ctx.PARAM());
if (sqlParam != null) {
SqlTypedParamValue param = sqlParam.getValue();
if (param != null) {
return param.value != null ? param.value.toString() : null;
}
}
return unquoteString(ctx.getText());
}

private SqlTypedParamValue param(TerminalNode node) {
private SqlParser.SqlParameter param(TerminalNode node) {
if (node == null) {
return null;
}
Expand All @@ -767,6 +777,16 @@ private SqlTypedParamValue param(TerminalNode node) {
return params.get(token);
}

private SqlParser.SqlParameter param(ParserRuleContext ctx) {
if (!ctx.getStart().equals(ctx.getStop())) {
throw new ParsingException(source(ctx), "Single PARAM literal expected");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect style (use == false instead of !).
I'm not sure what this check tries to prevent...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: any reason why this is not part of checkStyle and the precommit checks?
I will remove this check, it was more like a sanity check (during unit & integ test) for me to see if there are any possibility where the start and stop token won't be the same. This should never happen.

}
if (params.containsKey(ctx.getStart()) == false) {
throw new ParsingException(source(ctx), "Unexpected parameter");
}
return params.get(ctx.getStart());
}

@Override
public Literal visitDateEscapedLiteral(DateEscapedLiteralContext ctx) {
String string = string(ctx.string());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import org.elasticsearch.xpack.sql.plan.logical.Pivot;
import org.elasticsearch.xpack.sql.plan.logical.SubQueryAlias;
import org.elasticsearch.xpack.sql.plan.logical.With;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
import org.elasticsearch.xpack.sql.session.SingletonExecutable;

import java.time.ZoneId;
Expand All @@ -64,7 +63,7 @@

abstract class LogicalPlanBuilder extends ExpressionBuilder {

protected LogicalPlanBuilder(Map<Token, SqlTypedParamValue> params, ZoneId zoneId) {
protected LogicalPlanBuilder(Map<Token, SqlParser.SqlParameter> params, ZoneId zoneId) {
super(params, zoneId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private <T> T invokeParser(String sql,
lexer.removeErrorListeners();
lexer.addErrorListener(ERROR_LISTENER);

Map<Token, SqlTypedParamValue> paramTokens = new HashMap<>();
Map<Token, SqlParameter> paramTokens = new HashMap<>();
TokenSource tokenSource = new ParametrizedTokenSource(lexer, paramTokens, params);

CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
Expand Down Expand Up @@ -231,6 +231,27 @@ public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int
}
};

public static class SqlParameter {
public final Integer index;
public final SqlTypedParamValue value;

/**
* @param index Index of the SQL parameter. Index of first parameter is 1 (same as in JDBC).
*/
public SqlParameter(Integer index, SqlTypedParamValue value) {
this.index = index;
this.value = value;
}

public Integer getIndex() {
return index;
}

public SqlTypedParamValue getValue() {
return value;
}
}

/**
* Finds all parameter tokens (?) and associates them with actual parameter values
* <p>
Expand All @@ -240,26 +261,28 @@ public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int
private static class ParametrizedTokenSource implements TokenSource {

private TokenSource delegate;
private Map<Token, SqlTypedParamValue> paramTokens;
private int param;
private Map<Token, SqlParameter> paramTokens;
private int paramIndex;
private List<SqlTypedParamValue> params;

ParametrizedTokenSource(TokenSource delegate, Map<Token, SqlTypedParamValue> paramTokens, List<SqlTypedParamValue> params) {
ParametrizedTokenSource(TokenSource delegate,
Map<Token, SqlParameter> paramTokens,
List<SqlTypedParamValue> params) {
this.delegate = delegate;
this.paramTokens = paramTokens;
this.params = params;
param = 0;
paramIndex = 0;
}

@Override
public Token nextToken() {
Token token = delegate.nextToken();
if (token.getType() == SqlBaseLexer.PARAM) {
if (param >= params.size()) {
if (paramIndex >= params.size()) {
throw new ParsingException("Not enough actual parameters {} ", params.size());
}
paramTokens.put(token, params.get(param));
param++;
paramTokens.put(token, new SqlParameter(paramIndex+1, params.get(paramIndex)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incorrect formatting paramIndex + 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: will do, again would be great to add it to checkStyle

paramIndex++;
}
return token;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.sql.parser;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.Literal;
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.ql.plan.logical.Filter;
import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.ql.plan.logical.Project;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;

import java.util.List;

import static org.elasticsearch.xpack.ql.type.DateUtils.UTC;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.everyItem;
import static org.hamcrest.Matchers.isA;
import static org.hamcrest.Matchers.startsWith;

public class ParamLiteralTests extends ESTestCase {

private final SqlParser parser = new SqlParser();

private LogicalPlan parse(String sql, SqlTypedParamValue... parameters) {
return parser.createStatement(sql, List.of(parameters), UTC);
}

public void testMultipleParamLiteralsWithUnresolvedAliases() {
LogicalPlan logicalPlan = parse("SELECT ?, ? FROM test",
new SqlTypedParamValue("integer", 100),
new SqlTypedParamValue("integer", 200)
);
List<? extends NamedExpression> projections = ((Project) logicalPlan.children().get(0)).projections();
assertThat(projections, everyItem(isA(Alias.class)));
assertThat(projections.get(0).toString(), startsWith("100 AS ?1#"));
assertThat(projections.get(1).toString(), startsWith("200 AS ?2#"));
}

public void testMultipleParamLiteralsWithUnresolvedAliasesAndWhereClause() {
LogicalPlan logicalPlan = parse("SELECT ?, ? FROM test WHERE 1 < ?",
new SqlTypedParamValue("integer", 100),
new SqlTypedParamValue("integer", 200),
new SqlTypedParamValue("integer", 300)
);
Project project = (Project) logicalPlan.children().get(0);
List<? extends NamedExpression> projections = project.projections();
assertThat(projections, everyItem(isA(Alias.class)));
assertThat(projections.get(0).toString(), startsWith("100 AS ?1#"));
assertThat(projections.get(1).toString(), startsWith("200 AS ?2#"));
assertThat(project.children().get(0), isA(Filter.class));
Filter filter = (Filter) project.children().get(0);
assertThat(filter.condition(), isA(LessThan.class));
LessThan condition = (LessThan) filter.condition();
assertThat(condition.left(), isA(Literal.class));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In *QL (and Elasticsearch) codebase, the general style for these assertions is either assertThat(X, instanceof(Y)) (more in ES) or assertTrue(X instanceof Y) (more in *QL). Mockito isA seems to be used in the x-pack Security plugin only. I'd suggest going with a *QL-wide consistent approach and use the assertTrue variant, if it's not too much trouble.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or assertEquals(Literal.class, condition.left().getClass()

assertThat(condition.right(), isA(Literal.class));
assertThat(((Literal)condition.right()).value(), equalTo(300));
}

public void testParamLiteralsWithUnresolvedAliasesAndMixedTypes() {
LogicalPlan logicalPlan = parse("SELECT ?, ? FROM test",
new SqlTypedParamValue("integer", 100),
new SqlTypedParamValue("text", "200")
);
List<? extends NamedExpression> projections = ((Project) logicalPlan.children().get(0)).projections();
assertThat(projections, everyItem(isA(Alias.class)));
assertThat(projections.get(0).toString(), startsWith("100 AS ?1#"));
assertThat(projections.get(1).toString(), startsWith("200 AS ?2#"));
}

public void testParamLiteralsWithResolvedAndUnresolvedAliases() {
LogicalPlan logicalPlan = parse("SELECT ?, ? as x, ? FROM test",
new SqlTypedParamValue("integer", 100),
new SqlTypedParamValue("integer", 200),
new SqlTypedParamValue("integer", 300)
);
List<? extends NamedExpression> projections = ((Project) logicalPlan.children().get(0)).projections();
assertThat(projections, everyItem(isA(Alias.class)));
assertThat(projections.get(0).toString(), startsWith("100 AS ?1#"));
assertThat(projections.get(1).toString(), startsWith("200 AS x#"));;
assertThat(projections.get(2).toString(), startsWith("300 AS ?3#"));;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.FieldAttribute;
import org.elasticsearch.xpack.ql.expression.Literal;
import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
import org.elasticsearch.xpack.ql.expression.function.FunctionDefinition;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.ql.expression.function.aggregate.Count;
Expand Down Expand Up @@ -68,6 +69,7 @@
import org.elasticsearch.xpack.sql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.sql.planner.QueryFolder.FoldAggregate.GroupingContext;
import org.elasticsearch.xpack.sql.planner.QueryTranslator.QueryTranslation;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
import org.elasticsearch.xpack.sql.querydsl.agg.AggFilter;
import org.elasticsearch.xpack.sql.querydsl.agg.GroupByDateHistogram;
import org.elasticsearch.xpack.sql.querydsl.container.MetricAggRef;
Expand Down Expand Up @@ -100,7 +102,9 @@
import static org.elasticsearch.xpack.sql.util.DateUtils.UTC;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.everyItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.isA;
import static org.hamcrest.Matchers.startsWith;

public class QueryTranslatorTests extends ESTestCase {
Expand Down Expand Up @@ -133,7 +137,11 @@ private LogicalPlan plan(String sql, ZoneId zoneId) {
}

private PhysicalPlan optimizeAndPlan(String sql) {
return planner.plan(optimizer.optimize(plan(sql)), true);
return optimizeAndPlan(plan(sql));
}

private PhysicalPlan optimizeAndPlan(LogicalPlan plan) {
return planner.plan(optimizer.optimize(plan),true);
}

private QueryTranslation translate(Expression condition) {
Expand All @@ -144,6 +152,10 @@ private QueryTranslation translateWithAggs(Expression condition) {
return QueryTranslator.toQuery(condition, true);
}

private LogicalPlan parameterizedSql(String sql, SqlTypedParamValue... params) {
return analyzer.analyze(parser.createStatement(sql, Arrays.asList(params), org.elasticsearch.xpack.ql.type.DateUtils.UTC), true);
}

public void testTermEqualityAnalyzer() {
LogicalPlan p = plan("SELECT some.string FROM test WHERE some.string = 'value'");
assertTrue(p instanceof Project);
Expand Down Expand Up @@ -2239,4 +2251,22 @@ public void testScriptsInsideAggregateFunctions_WithDateField_AndExtendedStats()
+ "InternalSqlScriptUtils.asDateTime(params.a0),InternalSqlScriptUtils.asDateTime(params.v0)))\",\"lang\":\"painless\","
+ "\"params\":{\"v0\":\"2020-05-03T00:00:00.000Z\"}},\"gap_policy\":\"skip\"}}}}}}"));
}

public void testFoldingWithParamsWithoutIndex() {
PhysicalPlan p = optimizeAndPlan(parameterizedSql("SELECT ?, ? FROM test",
new SqlTypedParamValue("integer", 100),
new SqlTypedParamValue("integer", 200)));
assertThat(p.output(), everyItem(isA(ReferenceAttribute.class)));
assertThat(p.output().get(0).toString(), startsWith("?1{r}#"));
assertThat(p.output().get(1).toString(), startsWith("?2{r}#"));
}

public void testFoldingWithMixedParamsWithoutAlias() {
PhysicalPlan p = optimizeAndPlan(parameterizedSql("SELECT ?, ? FROM test",
new SqlTypedParamValue("integer", 100),
new SqlTypedParamValue("text", "200")));
assertThat(p.output(), everyItem(isA(ReferenceAttribute.class)));
assertThat(p.output().get(0).toString(), startsWith("?1{r}#"));
assertThat(p.output().get(1).toString(), startsWith("?2{r}#"));
}
}