diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index ffa0f2bfec73f..980c3af036f9f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineFilters; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineLimits; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineOrderBy; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownCompletion; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEnrich; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEval; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownRegexExtract; @@ -190,6 +191,7 @@ protected static Batch operators() { new PruneLiteralsInOrderBy(), new PushDownAndCombineLimits(), new PushDownAndCombineFilters(), + new PushDownCompletion(), new PushDownEval(), new PushDownRegexExtract(), new PushDownEnrich(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFilters.java index f1f139bc2b0f2..7577b106c4845 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFilters.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes; @@ -70,6 +71,10 @@ protected LogicalPlan rule(Filter filter) { // Push down filters that do not rely on attributes created by RegexExtract var attributes = AttributeSet.of(Expressions.asAttributes(re.extractedFields())); plan = maybePushDownPastUnary(filter, re, attributes::contains, NO_OP); + } else if (child instanceof Completion completion) { + // Push down filters that do not rely on attributes created by Cpmpletion + var attributes = AttributeSet.of(completion.generatedAttributes()); + plan = maybePushDownPastUnary(filter, completion, attributes::contains, NO_OP); } else if (child instanceof Enrich enrich) { // Push down filters that do not rely on attributes created by Enrich var attributes = AttributeSet.of(Expressions.asAttributes(enrich.enrichFields())); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java index dca4dfbd533df..67e2edef667b4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes; @@ -38,7 +39,11 @@ public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) { // We want to preserve the duplicated() value of the smaller limit, so we'll use replaceChild. return parentLimitValue < childLimitValue ? limit.replaceChild(childLimit.child()) : childLimit; } else if (limit.child() instanceof UnaryPlan unary) { - if (unary instanceof Eval || unary instanceof Project || unary instanceof RegexExtract || unary instanceof Enrich) { + if (unary instanceof Eval + || unary instanceof Project + || unary instanceof RegexExtract + || unary instanceof Enrich + || unary instanceof Completion) { return unary.replaceChild(limit.replaceChild(unary.child())); } else if (unary instanceof MvExpand) { // MV_EXPAND can increase the number of rows, so we cannot just push the limit down diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownCompletion.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownCompletion.java new file mode 100644 index 0000000000000..d74e90fb0569f --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownCompletion.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical; + +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; + +public final class PushDownCompletion extends OptimizerRules.OptimizerRule { + @Override + protected LogicalPlan rule(Completion p) { + return PushDownUtils.pushGeneratingPlanPastProjectAndOrderBy(p); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 879187c382f10..99de6b0df6d90 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -100,6 +100,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneRedundantOrderBy; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineLimits; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownCompletion; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEnrich; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEval; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownRegexExtract; @@ -123,6 +124,7 @@ import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate; import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig; @@ -162,6 +164,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; import static org.elasticsearch.xpack.esql.EsqlTestUtils.localSource; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.singleValue; import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; @@ -176,6 +179,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; +import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.EQ; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GT; import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.GTE; @@ -5556,6 +5560,17 @@ record PushdownShadowingGeneratingPlanTestCase( ) ), new PushDownEnrich() + ), + // | COMPLETION CONCAT(some text, x) WITH inferenceID AS y + new PushdownShadowingGeneratingPlanTestCase( + (plan, attr) -> new Completion( + EMPTY, + plan, + randomLiteral(TEXT), + new Concat(EMPTY, randomLiteral(TEXT), List.of(attr)), + new ReferenceAttribute(EMPTY, "y", KEYWORD) + ), + new PushDownCompletion() ) }; /** diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java index e30de7d818f9d..d8cae21257201 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java @@ -8,12 +8,15 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; import org.elasticsearch.index.IndexMode; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow; import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; @@ -28,6 +31,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import java.util.ArrayList; @@ -45,9 +49,12 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOrEqualOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.rlike; import static org.elasticsearch.xpack.esql.EsqlTestUtils.wildcardLike; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.mockito.Mockito.mock; public class PushDownAndCombineFiltersTests extends ESTestCase { @@ -238,6 +245,53 @@ public void testSelectivelyPushDownFilterPastFunctionAgg() { assertEquals(expected, new PushDownAndCombineFilters().apply(fb)); } + // from ... | where a > 1 | COMPLETION "some prompt" WITH reranker AS completion | where b < 2 and match(completion, some text) + // => ... | where a > 1 AND b < 2| COMPLETION "some prompt" WITH reranker AS completion | match(completion, some text) + public void testPushDownFilterPastCompletion() { + FieldAttribute a = getFieldAttribute("a"); + FieldAttribute b = getFieldAttribute("b"); + EsRelation relation = relation(List.of(a, b)); + + GreaterThan conditionA = greaterThanOf(getFieldAttribute("a"), ONE); + Filter filterA = new Filter(EMPTY, relation, conditionA); + + Completion completion = completion(filterA); + + LessThan conditionB = lessThanOf(getFieldAttribute("b"), TWO); + Match conditionCompletion = new Match( + EMPTY, + completion.targetField(), + randomLiteral(DataType.TEXT), + mock(Expression.class), + mock(QueryBuilder.class) + ); + Filter filterB = new Filter(EMPTY, completion, new And(EMPTY, conditionB, conditionCompletion)); + + LogicalPlan expectedOptimizedPlan = new Filter( + EMPTY, + new Completion( + EMPTY, + new Filter(EMPTY, relation, new And(EMPTY, conditionA, conditionB)), + completion.inferenceId(), + completion.prompt(), + completion.targetField() + ), + conditionCompletion + ); + + assertEquals(expectedOptimizedPlan, new PushDownAndCombineFilters().apply(filterB)); + } + + private static Completion completion(LogicalPlan child) { + return new Completion( + EMPTY, + child, + randomLiteral(DataType.TEXT), + randomLiteral(DataType.TEXT), + referenceAttribute(randomIdentifier(), DataType.TEXT) + ); + } + private static EsRelation relation() { return relation(List.of()); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java new file mode 100644 index 0000000000000..f5d88618b352d --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimitsTests.java @@ -0,0 +1,159 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; +import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; + +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT; +import static org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizerTests.relation; + +public class PushDownAndCombineLimitsTests extends ESTestCase { + + private static class PushDownLimitTestCase { + private final Class clazz; + private final BiFunction planBuilder; + private final BiConsumer planChecker; + + PushDownLimitTestCase( + Class clazz, + BiFunction planBuilder, + BiConsumer planChecker + ) { + this.clazz = clazz; + this.planBuilder = planBuilder; + this.planChecker = planChecker; + } + + public PlanType buildPlan(LogicalPlan child, Attribute attr) { + return planBuilder.apply(child, attr); + } + + public void checkOptimizedPlan(LogicalPlan basePlan, LogicalPlan optimizedPlan) { + planChecker.accept(as(basePlan, clazz), as(optimizedPlan, clazz)); + } + } + + private static final List> PUSHABLE_LIMIT_TEST_CASES = List.of( + new PushDownLimitTestCase<>( + Eval.class, + (plan, attr) -> new Eval(EMPTY, plan, List.of(new Alias(EMPTY, "y", new ToInteger(EMPTY, attr)))), + (basePlan, optimizedPlan) -> { + assertEquals(basePlan.source(), optimizedPlan.source()); + assertEquals(basePlan.fields(), optimizedPlan.fields()); + } + ), + new PushDownLimitTestCase<>( + Completion.class, + (plan, attr) -> new Completion(EMPTY, plan, randomLiteral(TEXT), randomLiteral(TEXT), attr), + (basePlan, optimizedPlan) -> { + assertEquals(basePlan.source(), optimizedPlan.source()); + assertEquals(basePlan.inferenceId(), optimizedPlan.inferenceId()); + assertEquals(basePlan.prompt(), optimizedPlan.prompt()); + assertEquals(basePlan.targetField(), optimizedPlan.targetField()); + } + ) + ); + + private static final List> NON_PUSHABLE_LIMIT_TEST_CASES = List.of( + new PushDownLimitTestCase<>( + Filter.class, + (plan, attr) -> new Filter(EMPTY, plan, new Equals(EMPTY, attr, new Literal(EMPTY, "right", TEXT))), + (basePlan, optimizedPlan) -> { + assertEquals(basePlan.source(), optimizedPlan.source()); + assertEquals(basePlan.condition(), optimizedPlan.condition()); + } + ), + new PushDownLimitTestCase<>( + OrderBy.class, + (plan, attr) -> new OrderBy(EMPTY, plan, List.of(new Order(EMPTY, attr, Order.OrderDirection.DESC, null))), + (basePlan, optimizedPlan) -> { + assertEquals(basePlan.source(), optimizedPlan.source()); + assertEquals(basePlan.order(), optimizedPlan.order()); + } + ) + ); + + public void testPushableLimit() { + FieldAttribute a = getFieldAttribute("a"); + FieldAttribute b = getFieldAttribute("b"); + EsRelation relation = relation().withAttributes(List.of(a, b)); + + for (PushDownLimitTestCase pushableLimitTestCase : PUSHABLE_LIMIT_TEST_CASES) { + int precedingLimitValue = randomIntBetween(1, 10_000); + Limit precedingLimit = new Limit(EMPTY, new Literal(EMPTY, precedingLimitValue, INTEGER), relation); + + LogicalPlan pushableLimitTestPlan = pushableLimitTestCase.buildPlan(precedingLimit, a); + + int pushableLimitValue = randomIntBetween(1, 10_000); + Limit pushableLimit = new Limit(EMPTY, new Literal(EMPTY, pushableLimitValue, INTEGER), pushableLimitTestPlan); + + LogicalPlan optimizedPlan = optimizePlan(pushableLimit); + + pushableLimitTestCase.checkOptimizedPlan(pushableLimitTestPlan, optimizedPlan); + + assertEquals( + as(optimizedPlan, UnaryPlan.class).child(), + new Limit(EMPTY, new Literal(EMPTY, Math.min(pushableLimitValue, precedingLimitValue), INTEGER), relation) + ); + } + } + + public void testNonPushableLimit() { + FieldAttribute a = getFieldAttribute("a"); + FieldAttribute b = getFieldAttribute("b"); + EsRelation relation = relation().withAttributes(List.of(a, b)); + + for (PushDownLimitTestCase nonPushableLimitTestCase : NON_PUSHABLE_LIMIT_TEST_CASES) { + int precedingLimitValue = randomIntBetween(1, 10_000); + Limit precedingLimit = new Limit(EMPTY, new Literal(EMPTY, precedingLimitValue, INTEGER), relation); + UnaryPlan nonPushableLimitTestPlan = nonPushableLimitTestCase.buildPlan(precedingLimit, a); + int nonPushableLimitValue = randomIntBetween(1, 10_000); + Limit nonPushableLimit = new Limit(EMPTY, new Literal(EMPTY, nonPushableLimitValue, INTEGER), nonPushableLimitTestPlan); + Limit optimizedPlan = as(optimizePlan(nonPushableLimit), Limit.class); + nonPushableLimitTestCase.checkOptimizedPlan(nonPushableLimitTestPlan, optimizedPlan.child()); + assertEquals( + optimizedPlan, + new Limit( + EMPTY, + new Literal(EMPTY, Math.min(nonPushableLimitValue, precedingLimitValue), INTEGER), + nonPushableLimitTestPlan + ) + ); + assertEquals(as(optimizedPlan.child(), UnaryPlan.class).child(), nonPushableLimitTestPlan.child()); + } + } + + private LogicalPlan optimizePlan(LogicalPlan plan) { + return new PushDownAndCombineLimits().apply(plan, unboundLogicalOptimizerContext()); + } +}