Skip to content

Commit 18af48f

Browse files
authored
Fix expression cloning when table changes in SelectExpression.VisitChildren (#32504)
Fixes #32234 (cherry picked from commit cf5ec40)
1 parent cdbc432 commit 18af48f

File tree

6 files changed

+160
-5
lines changed

6 files changed

+160
-5
lines changed

src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs

+47
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,53 @@ public TableReferenceUpdatingExpressionVisitor(SelectExpression oldSelect, Selec
370370
}
371371
}
372372

373+
// Note: this is conceptually the same as ColumnExpressionReplacingExpressionVisitor; I duplicated it since this is for a patch,
374+
// and we want to limit the potential risk (note that this calls the special SelectExpression.VisitChildren() with updateColumns: false,
375+
// to avoid infinite recursion).
376+
private sealed class ColumnTableReferenceUpdater : ExpressionVisitor
377+
{
378+
private readonly SelectExpression _oldSelect;
379+
private readonly SelectExpression _newSelect;
380+
381+
public ColumnTableReferenceUpdater(SelectExpression oldSelect, SelectExpression newSelect)
382+
{
383+
_oldSelect = oldSelect;
384+
_newSelect = newSelect;
385+
}
386+
387+
[return: NotNullIfNotNull("expression")]
388+
public override Expression? Visit(Expression? expression)
389+
{
390+
if (expression is ConcreteColumnExpression columnExpression
391+
&& _oldSelect._tableReferences.Find(t => ReferenceEquals(t.Table, columnExpression.Table)) is TableReferenceExpression
392+
oldTableReference
393+
&& _newSelect._tableReferences.Find(t => t.Alias == columnExpression.TableAlias) is TableReferenceExpression
394+
newTableReference
395+
&& newTableReference != oldTableReference)
396+
{
397+
return new ConcreteColumnExpression(
398+
columnExpression.Name,
399+
newTableReference,
400+
columnExpression.Type,
401+
columnExpression.TypeMapping!,
402+
columnExpression.IsNullable);
403+
}
404+
405+
return base.Visit(expression);
406+
}
407+
408+
protected override Expression VisitExtension(Expression node)
409+
{
410+
if (node is SelectExpression select)
411+
{
412+
Check.DebugAssert(!select._mutable, "Visiting mutable select expression in ColumnTableReferenceUpdater");
413+
return select.VisitChildren(this, updateColumns: false);
414+
}
415+
416+
return base.VisitExtension(node);
417+
}
418+
}
419+
373420
private sealed class IdentifierComparer : IEqualityComparer<(ColumnExpression Column, ValueComparer Comparer)>
374421
{
375422
public bool Equals((ColumnExpression Column, ValueComparer Comparer) x, (ColumnExpression Column, ValueComparer Comparer) y)

src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs

+35-5
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions;
2525
public sealed partial class SelectExpression : TableExpressionBase
2626
{
2727
private const string DiscriminatorColumnAlias = "Discriminator";
28+
2829
private static readonly bool UseOldBehavior31107 =
2930
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue31107", out var enabled31107) && enabled31107;
31+
private static readonly bool UseOldBehavior32234 =
32+
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue32234", out var enabled32234) && enabled32234;
3033

3134
private static readonly IdentifierComparer IdentifierComparerInstance = new();
3235

@@ -4612,6 +4615,9 @@ private static string GenerateUniqueAlias(HashSet<string> usedAliases, string cu
46124615

46134616
/// <inheritdoc />
46144617
protected override Expression VisitChildren(ExpressionVisitor visitor)
4618+
=> VisitChildren(visitor, updateColumns: true);
4619+
4620+
private Expression VisitChildren(ExpressionVisitor visitor, bool updateColumns)
46154621
{
46164622
if (_mutable)
46174623
{
@@ -4800,14 +4806,38 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
48004806
newSelectExpression._childIdentifiers.AddRange(
48014807
childIdentifier.Zip(_childIdentifiers).Select(e => (e.First, e.Second.Comparer)));
48024808

4803-
// Remap tableReferences in new select expression
4804-
foreach (var tableReference in newTableReferences)
4809+
// We duplicated the SelectExpression, and must therefore also update all table reference expressions to point to it.
4810+
// If any tables have changed, we must duplicate the TableReferenceExpressions and replace all ColumnExpressions to use
4811+
// them; otherwise we end up two SelectExpressions sharing the same TableReferenceExpression instance, and if that's later
4812+
// mutated, both SelectExpressions are affected (this happened in AliasUniquifier, see #32234).
4813+
4814+
// Otherwise, if no tables have changed, we mutate the TableReferenceExpressions (this was the previous code, left it for
4815+
// a more low-risk fix). Note that updateColumns is false only if we're already being called from
4816+
// ColumnTableReferenceUpdater to replace the ColumnExpressions, in which case we avoid infinite recursion.
4817+
if (tablesChanged && updateColumns && !UseOldBehavior32234)
48054818
{
4806-
tableReference.UpdateTableReference(this, newSelectExpression);
4819+
for (var i = 0; i < newTableReferences.Count; i++)
4820+
{
4821+
newTableReferences[i] = new TableReferenceExpression(newSelectExpression, _tableReferences[i].Alias);
4822+
}
4823+
4824+
var columnTableReferenceUpdater = new ColumnTableReferenceUpdater(this, newSelectExpression);
4825+
newSelectExpression = (SelectExpression)columnTableReferenceUpdater.Visit(newSelectExpression);
48074826
}
4827+
else
4828+
{
4829+
// Remap tableReferences in new select expression
4830+
foreach (var tableReference in newTableReferences)
4831+
{
4832+
tableReference.UpdateTableReference(this, newSelectExpression);
4833+
}
48084834

4809-
var tableReferenceUpdatingExpressionVisitor = new TableReferenceUpdatingExpressionVisitor(this, newSelectExpression);
4810-
tableReferenceUpdatingExpressionVisitor.Visit(newSelectExpression);
4835+
// TODO: Why does need to be done? We've already updated all table references on the new select just above, and
4836+
// no ColumnExpression in the query is every supposed to reference a TableReferenceExpression that isn't in the
4837+
// select's list... The same thing is done in all other places where TableReferenceUpdatingExpressionVisitor is used.
4838+
var tableReferenceUpdatingExpressionVisitor = new TableReferenceUpdatingExpressionVisitor(this, newSelectExpression);
4839+
tableReferenceUpdatingExpressionVisitor.Visit(newSelectExpression);
4840+
}
48114841

48124842
return newSelectExpression;
48134843
}

test/EFCore.Cosmos.FunctionalTests/Query/NorthwindMiscellaneousQueryCosmosTest.cs

+8
Original file line numberDiff line numberDiff line change
@@ -4678,6 +4678,14 @@ await AssertTranslationFailed(
46784678
AssertSql();
46794679
}
46804680

4681+
public override async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
4682+
{
4683+
await AssertTranslationFailed(
4684+
() => base.Parameter_collection_Contains_with_projection_and_ordering(async));
4685+
4686+
AssertSql();
4687+
}
4688+
46814689
private void AssertSql(params string[] expected)
46824690
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
46834691

test/EFCore.Specification.Tests/Query/NorthwindMiscellaneousQueryTestBase.cs

+16
Original file line numberDiff line numberDiff line change
@@ -5716,4 +5716,20 @@ public virtual Task Subquery_with_navigation_inside_inline_collection(bool async
57165716
=> AssertQuery(
57175717
async,
57185718
ss => ss.Set<Customer>().Where(c => new[] { 100, c.Orders.Count }.Sum() > 101));
5719+
5720+
[ConditionalTheory] // #32234
5721+
[MemberData(nameof(IsAsyncData))]
5722+
public virtual async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
5723+
{
5724+
var ids = new[] { 10248, 10249 };
5725+
5726+
await AssertQuery(
5727+
async,
5728+
ss => ss.Set<OrderDetail>()
5729+
.Where(e => ids.Contains(e.OrderID))
5730+
.GroupBy(e => e.Quantity)
5731+
.Select(g => new { g.Key, MaxTimestamp = g.Select(e => e.Order.OrderDate).Max() })
5732+
.OrderBy(x => x.MaxTimestamp)
5733+
.Select(x => x));
5734+
}
57195735
}

test/EFCore.SqlServer.FunctionalTests/Query/NorthwindMiscellaneousQuerySqlServerTest.cs

+41
Original file line numberDiff line numberDiff line change
@@ -7361,6 +7361,47 @@ FROM [Orders] AS [o]
73617361
""");
73627362
}
73637363

7364+
public override async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
7365+
{
7366+
#if DEBUG
7367+
// GroupBy debug assert. Issue #26104.
7368+
Assert.StartsWith(
7369+
"Missing alias in the list",
7370+
(await Assert.ThrowsAsync<InvalidOperationException>(
7371+
() => base.Parameter_collection_Contains_with_projection_and_ordering(async))).Message);
7372+
#else
7373+
await base.Parameter_collection_Contains_with_projection_and_ordering(async);
7374+
7375+
AssertSql(
7376+
"""
7377+
@__ids_0='[10248,10249]' (Size = 4000)
7378+
7379+
SELECT [o].[Quantity] AS [Key], (
7380+
SELECT MAX([o3].[OrderDate])
7381+
FROM [Order Details] AS [o2]
7382+
INNER JOIN [Orders] AS [o3] ON [o2].[OrderID] = [o3].[OrderID]
7383+
WHERE [o2].[OrderID] IN (
7384+
SELECT [i1].[value]
7385+
FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i1]
7386+
) AND [o].[Quantity] = [o2].[Quantity]) AS [MaxTimestamp]
7387+
FROM [Order Details] AS [o]
7388+
WHERE [o].[OrderID] IN (
7389+
SELECT [i].[value]
7390+
FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i]
7391+
)
7392+
GROUP BY [o].[Quantity]
7393+
ORDER BY (
7394+
SELECT MAX([o3].[OrderDate])
7395+
FROM [Order Details] AS [o2]
7396+
INNER JOIN [Orders] AS [o3] ON [o2].[OrderID] = [o3].[OrderID]
7397+
WHERE [o2].[OrderID] IN (
7398+
SELECT [i0].[value]
7399+
FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i0]
7400+
) AND [o].[Quantity] = [o2].[Quantity])
7401+
""");
7402+
#endif
7403+
}
7404+
73647405
private void AssertSql(params string[] expected)
73657406
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
73667407

test/EFCore.Sqlite.FunctionalTests/Query/NorthwindMiscellaneousQuerySqliteTest.cs

+13
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,19 @@ public override async Task Correlated_collection_with_distinct_without_default_i
438438
public override Task Max_on_empty_sequence_throws(bool async)
439439
=> Assert.ThrowsAsync<InvalidOperationException>(() => base.Max_on_empty_sequence_throws(async));
440440

441+
public override async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
442+
{
443+
#if DEBUG
444+
// GroupBy debug assert. Issue #26104.
445+
Assert.StartsWith(
446+
"Missing alias in the list",
447+
(await Assert.ThrowsAsync<InvalidOperationException>(
448+
() => base.Parameter_collection_Contains_with_projection_and_ordering(async))).Message);
449+
#else
450+
await base.Parameter_collection_Contains_with_projection_and_ordering(async);
451+
#endif
452+
}
453+
441454
[ConditionalFact]
442455
public async Task Single_Predicate_Cancellation()
443456
=> await Assert.ThrowsAnyAsync<OperationCanceledException>(

0 commit comments

Comments
 (0)