diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpressionExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpressionExtensions.cs index 4f604ea50a4..1aaa5200dae 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpressionExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpressionExtensions.cs @@ -19,6 +19,9 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions { internal static class AstExpressionExtensions { + public static bool IsConstant(this AstExpression expression, BsonValue value) + => expression is AstConstantExpression constantExpression && constantExpression.Value.Equals(value); + public static bool IsInt32Constant(this AstExpression expression, out int value) { if (expression is AstConstantExpression constantExpression && diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs index 7098f6f4456..39ae78ba982 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs @@ -352,6 +352,15 @@ public override AstNode VisitFilterField(AstFilterField node) public override AstNode VisitGetFieldExpression(AstGetFieldExpression node) { + // { $getField : { field : , input : { $firstOrLast : "$_elements" } } } => { __agg0 : { $firstOrLast : } } + "$__agg0" + if (IsGetFieldChainOnFirstOrLastElement(node, out var firstOrLastOperator, out var rootFieldExpression)) + { + var unaryAccumulatorOperator = firstOrLastOperator == AstUnaryOperator.First ? AstUnaryAccumulatorOperator.First : AstUnaryAccumulatorOperator.Last; + var accumulatorExpression = AstExpression.UnaryAccumulator(unaryAccumulatorOperator, rootFieldExpression); + var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression); + return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName); + } + if (node.FieldName is AstConstantExpression constantFieldName && constantFieldName.Value.IsString && constantFieldName.Value.AsString == "_elements") @@ -360,6 +369,32 @@ public override AstNode VisitGetFieldExpression(AstGetFieldExpression node) } return base.VisitGetFieldExpression(node); + + bool IsGetFieldChainOnFirstOrLastElement(AstGetFieldExpression getFieldExpression, out AstUnaryOperator firstOrLastOperator, out AstExpression rootFieldExpression) + { + if (getFieldExpression.Input is AstGetFieldExpression innerGetFieldExpression && + IsGetFieldChainOnFirstOrLastElement(innerGetFieldExpression, out firstOrLastOperator, out rootFieldExpression)) + { + rootFieldExpression = AstExpression.GetField(rootFieldExpression, getFieldExpression.FieldName); + return true; + } + + if (getFieldExpression.Input is AstUnaryExpression unaryExpression && + unaryExpression.Operator is var unaryOperator && + (unaryOperator is AstUnaryOperator.First or AstUnaryOperator.Last) && + unaryExpression.Arg is AstGetFieldExpression innerMostGetFieldExpression && + innerMostGetFieldExpression.Input.IsRootVar() && + innerMostGetFieldExpression.FieldName.IsConstant("_elements")) + { + firstOrLastOperator = unaryOperator; + rootFieldExpression = AstExpression.GetField(AstExpression.RootVar, getFieldExpression.FieldName); + return true; + } + + firstOrLastOperator = default; + rootFieldExpression = null; + return false; + } } public override AstNode VisitMapExpression(AstMapExpression node) diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5529Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5529Tests.cs new file mode 100644 index 00000000000..928b88061b0 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5529Tests.cs @@ -0,0 +1,90 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5529Tests : LinqIntegrationTest +{ + public CSharp5529Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Theory] + [InlineData(1, 1, """{ $group: { _id : 1, __agg0 : { $first : "$X" } } }""", 1)] + [InlineData(1, 2, """{ $group: { _id : 1, __agg0 : { $last : "$X" } } }""", 2)] + [InlineData(2, 1, """{ $group: { _id : 1, __agg0 : { $first : "$D.Y" } } }""", 11)] + [InlineData(2, 2, """{ $group: { _id : 1, __agg0 : { $last : "$D.Y" } } }""", 22)] + [InlineData(3, 1, """{ $group: { _id : 1, __agg0 : { $first : "$D.E.Z" } } }""", 111)] + [InlineData(3, 2, """{ $group: { _id : 1, __agg0 : { $last : "$D.E.Z" } } }""", 222)] + public void First_or_Last_optimization_should_work(int level, int firstOrLast, string expectedGroupStage, int expectedResult) + { + var collection = Fixture.Collection; + + var queryable = (level, firstOrLast) switch + { + (1, 1) => collection.Aggregate().Group(x => 1, g => g.First().X), + (1, 2) => collection.Aggregate().Group(x => 1, g => g.Last().X), + (2, 1) => collection.Aggregate().Group(x => 1, g => g.First().D.Y), + (2, 2) => collection.Aggregate().Group(x => 1, g => g.Last().D.Y), + (3, 1) => collection.Aggregate().Group(x => 1, g => g.First().D.E.Z), + (3, 2) => collection.Aggregate().Group(x => 1, g => g.Last().D.E.Z), + _ => throw new ArgumentException() + }; + + var stages = Translate(collection,queryable); + AssertStages( + stages, + expectedGroupStage, + """{ $project : { _v : "$__agg0", _id : 0 } }"""); + + var result = queryable.Single(); + result.Should().Be(expectedResult); + } + public class C + { + public int Id { get; set; } + public int X { get; set; } + + public D D { get; set; } + } + + public class D + { + public E E { get; set; } + public int Y { get; set; } + } + + public class E + { + public int Z { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new C { Id = 1, X = 1, D = new D { E = new E { Z = 111 }, Y = 11 } }, + new C { Id = 2, X = 2, D = new D { E = new E { Z = 222 }, Y = 22 } }, + ]; + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs index fecd7fddc6e..8c7dec1d253 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/MongoQueryableTests.cs @@ -411,8 +411,8 @@ public void GroupBy_select_anonymous_type_method() Assert(query, 2, - "{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }", - "{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }"); + "{ $group: { _id: '$A', __agg0: { $first: '$B'} } }", + "{ $project: { Key: '$_id', FirstB: '$__agg0', _id: 0 } }"); query = CreateQuery() .GroupBy(x => x.A) @@ -434,8 +434,8 @@ group p by p.A into g Assert(query, 2, - "{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }", - "{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }"); + "{ $group: { _id: '$A', __agg0: { $first: '$B'} } }", + "{ $project: { Key: '$_id', FirstB: '$__agg0', _id: 0 } }"); query = from p in CreateQuery() group p by p.A into g @@ -484,9 +484,9 @@ public void GroupBy_where_select_anonymous_type_with_duplicate_accumulators_meth Assert(query, 1, - "{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }", + "{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'}, __agg1 : { $first : '$B' } } }", "{ $match: { '__agg0.B' : 'Balloon' } }", - "{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }"); + "{ $project: { Key: '$_id', FirstB: '$__agg1', _id: 0 } }"); query = CreateQuery() .GroupBy(x => x.A) @@ -511,9 +511,9 @@ where g.First().B == "Balloon" Assert(query, 1, - "{ $group: { _id: '$A', __agg0: { $first: '$$ROOT'} } }", + "{ $group: { _id: '$A', __agg0: { $first: '$$ROOT' }, __agg1 : { $first : '$B' } } }", "{ $match: { '__agg0.B' : 'Balloon' } }", - "{ $project: { Key: '$_id', FirstB: '$__agg0.B', _id: 0 } }"); + "{ $project: { Key: '$_id', FirstB: '$__agg1', _id: 0 } }"); } #endif @@ -525,8 +525,8 @@ public void GroupBy_with_resultSelector_anonymous_type_method() Assert(query, 2, - "{ $group: { _id : '$A', __agg0 : { $first: '$$ROOT'} } }", - "{ $project : { Key : '$_id', FirstB : '$__agg0.B', _id : 0 } }"); + "{ $group: { _id : '$A', __agg0 : { $first: '$B'} } }", + "{ $project : { Key : '$_id', FirstB : '$__agg0', _id : 0 } }"); query = CreateQuery() .GroupBy(x => x.A, (k, s) => new { Key = k, FirstB = s.Select(x => x.B).First() }); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs index acdfb1db2e0..45bbf7067af 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/AggregateGroupTranslatorTests.cs @@ -39,8 +39,8 @@ public void Should_translate_using_non_anonymous_type_with_default_constructor() AssertStages( result.Stages, - "{ $group : { _id : '$A', __agg0 : { $first : '$$ROOT' } } }", - "{ $project : { Property : '$_id', Field : '$__agg0.B', _id : 0 } }"); + "{ $group : { _id : '$A', __agg0 : { $first : '$B' } } }", + "{ $project : { Property : '$_id', Field : '$__agg0', _id : 0 } }"); result.Value.Property.Should().Be("Amazing"); result.Value.Field.Should().Be("Baby"); @@ -53,8 +53,8 @@ public void Should_translate_using_non_anonymous_type_with_parameterized_constru AssertStages( result.Stages, - "{ $group : { _id : '$A', __agg0 : { $first : '$$ROOT' } } }", - "{ $project : { Property : '$_id', Field : '$__agg0.B', _id : 0 } }"); + "{ $group : { _id : '$A', __agg0 : { $first : '$B' } } }", + "{ $project : { Property : '$_id', Field : '$__agg0', _id : 0 } }"); result.Value.Property.Should().Be("Amazing"); result.Value.Field.Should().Be("Baby"); @@ -236,8 +236,8 @@ public void Should_translate_first_with_normalization() AssertStages( result.Stages, - "{ $group : { _id : '$A', __agg0 : { $first : '$$ROOT' } } }", - "{ $project : { B : '$__agg0.B', _id : 0 } }"); + "{ $group : { _id : '$A', __agg0 : { $first : '$B' } } }", + "{ $project : { B : '$__agg0', _id : 0 } }"); result.Value.B.Should().Be("Baby"); } @@ -262,8 +262,8 @@ public void Should_translate_last_with_normalization() AssertStages( result.Stages, - "{ $group : { _id : '$A', __agg0 : { $last : '$$ROOT' } } }", - "{ $project : { B : '$__agg0.B', _id : 0 } }"); + "{ $group : { _id : '$A', __agg0 : { $last : '$B' } } }", + "{ $project : { B : '$__agg0', _id : 0 } }"); result.Value.B.Should().Be("Baby"); } @@ -492,8 +492,8 @@ public void Should_translate_complex_selector() _id : '$A', __agg0 : { $sum : 1 }, __agg1 : { $sum : { $add : ['$C.E.F', '$C.E.H'] } }, - __agg2 : { $first : '$$ROOT' }, - __agg3 : { $last : '$$ROOT' }, + __agg2 : { $first : '$B' }, + __agg3 : { $last : '$K' }, __agg4 : { $min : { $add : ['$C.E.F', '$C.E.H'] } }, __agg5 : { $max : { $add : ['$C.E.F', '$C.E.H'] } } } @@ -503,8 +503,8 @@ public void Should_translate_complex_selector() $project : { Count : '$__agg0', Sum : '$__agg1', - First : '$__agg2.B', - Last : '$__agg3.K', + First : '$__agg2', + Last : '$__agg3', Min : '$__agg4', Max : '$__agg5', _id : 0 diff --git a/tests/MongoDB.Driver.Tests/Samples/AggregationSample.cs b/tests/MongoDB.Driver.Tests/Samples/AggregationSample.cs index 459ca1de33f..ed26da058fe 100644 --- a/tests/MongoDB.Driver.Tests/Samples/AggregationSample.cs +++ b/tests/MongoDB.Driver.Tests/Samples/AggregationSample.cs @@ -108,7 +108,7 @@ where g.Sum(x => x.Population) > 20000 select new { State = g.Key, TotalPopulation = g.Sum(x => x.Population) }; var stages = Linq3TestHelpers.Translate(collection, queryable); - var expectedStages = + var expectedStages = new[] { "{ $group : { _id : '$state', __agg0 : { $sum : '$pop' } } }", @@ -173,13 +173,13 @@ public async Task Largest_and_smallest_cities_by_state() .SortBy(x => x.State); var pipelineTranslation = pipeline.ToString(); - var expectedTranslation = + var expectedTranslation = "aggregate([" + "{ \"$group\" : { \"_id\" : { \"State\" : \"$state\", \"City\" : \"$city\" }, \"__agg0\" : { \"$sum\" : \"$pop\" } } }, " + "{ \"$project\" : { \"StateAndCity\" : \"$_id\", \"Population\" : \"$__agg0\", \"_id\" : 0 } }, " + "{ \"$sort\" : { \"Population\" : 1 } }, " + - "{ \"$group\" : { \"_id\" : \"$StateAndCity.State\", \"__agg0\" : { \"$last\" : \"$$ROOT\" }, \"__agg1\" : { \"$first\" : \"$$ROOT\" } } }, " + - "{ \"$project\" : { \"State\" : \"$_id\", \"BiggestCity\" : \"$__agg0.StateAndCity.City\", \"BiggestPopulation\" : \"$__agg0.Population\", \"SmallestCity\" : \"$__agg1.StateAndCity.City\", \"SmallestPopulation\" : \"$__agg1.Population\", \"_id\" : 0 } }, " + + "{ \"$group\" : { \"_id\" : \"$StateAndCity.State\", \"__agg0\" : { \"$last\" : \"$StateAndCity.City\" }, \"__agg1\" : { \"$last\" : \"$Population\" }, \"__agg2\" : { \"$first\" : \"$StateAndCity.City\" }, \"__agg3\" : { \"$first\" : \"$Population\" } } }, " + + "{ \"$project\" : { \"State\" : \"$_id\", \"BiggestCity\" : \"$__agg0\", \"BiggestPopulation\" : \"$__agg1\", \"SmallestCity\" : \"$__agg2\", \"SmallestPopulation\" : \"$__agg3\", \"_id\" : 0 } }, " + "{ \"$project\" : { \"State\" : \"$State\", \"BiggestCity\" : { \"Name\" : \"$BiggestCity\", \"Population\" : \"$BiggestPopulation\" }, \"SmallestCity\" : { \"Name\" : \"$SmallestCity\", \"Population\" : \"$SmallestPopulation\" }, \"_id\" : 0 } }, " + "{ \"$sort\" : { \"State\" : 1 } }])"; pipelineTranslation.Should().Be(expectedTranslation);