@@ -30,9 +30,31 @@ public Expression Optimize(Expression e, string? discriminator)
30
30
{
31
31
if ( discriminator is not null )
32
32
{
33
- Type ? sourceType = e . Type . GetGenericArguments ( ) [ 0 ] ;
34
- MethodInfo ? wrapInWhere = WrapInWhereGenericMethod . MakeGenericMethod ( new [ ] { sourceType } ) ;
35
- e = ( Expression ) wrapInWhere . Invoke ( null , new object [ ] { e , discriminator } ) ;
33
+ if ( e . Type . IsGenericType )
34
+ {
35
+ Type ? sourceType = e . Type . GetGenericArguments ( ) [ 0 ] ;
36
+ MethodInfo wrapInWhere = WrapInWhereGenericMethod . MakeGenericMethod ( sourceType ) ;
37
+ e = ( Expression ) wrapInWhere . Invoke ( null , new object [ ] { e , discriminator } ) ;
38
+ }
39
+ else
40
+ {
41
+ Type sourceType = e . Type ;
42
+ MethodInfo wrapInWhere = WrapInWhereGenericMethod . MakeGenericMethod ( sourceType ) ;
43
+
44
+ var rootMethodCallExpression = e as MethodCallExpression ;
45
+ Expression source = rootMethodCallExpression ! . Arguments [ 0 ] ;
46
+ var discriminatorWrap = ( MethodCallExpression ) wrapInWhere . Invoke ( null , new object [ ] { source , discriminator } ) ;
47
+
48
+
49
+ if ( rootMethodCallExpression . Arguments . Count == 1 )
50
+ {
51
+ e = Expression . Call ( rootMethodCallExpression . Method , discriminatorWrap ) ;
52
+ }
53
+ else
54
+ {
55
+ e = Expression . Call ( rootMethodCallExpression . Method , discriminatorWrap , rootMethodCallExpression . Arguments [ 1 ] ) ;
56
+ }
57
+ }
36
58
}
37
59
38
60
e = LocalExpressions . PartialEval ( e ) ;
@@ -163,7 +185,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
163
185
if ( genericDefinition == QueryableMethods . AnyWithPredicate )
164
186
{
165
187
return node
166
- . SubstituteWithWhere ( )
188
+ . SubstituteWithWhere ( this )
167
189
. WrapInTake ( 1 )
168
190
. WrapInMethodWithoutSelector ( QueryableMethods . AnyWithoutPredicate ) ;
169
191
}
@@ -172,7 +194,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
172
194
if ( genericDefinition == QueryableMethods . All )
173
195
{
174
196
return node
175
- . SubstituteWithWhere ( true )
197
+ . SubstituteWithWhere ( this , true )
176
198
. WrapInTake ( 1 )
177
199
. WrapInMethodWithoutSelector ( QueryableMethods . AnyWithoutPredicate ) ;
178
200
}
@@ -201,17 +223,17 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
201
223
if ( genericDefinition == QueryableMethods . SingleWithPredicate )
202
224
{
203
225
return node
204
- . SubstituteWithWhere ( )
205
- . SubstituteWithTake ( 2 )
226
+ . SubstituteWithWhere ( this )
227
+ . WrapInTake ( 2 )
206
228
. WrapInMethodWithoutSelector ( QueryableMethods . SingleWithoutPredicate ) ;
207
229
}
208
230
209
231
// SingleOrDefault(d => condition) == Where(d => condition).Take(2).SingleOrDefault()
210
232
if ( genericDefinition == QueryableMethods . SingleOrDefaultWithPredicate )
211
233
{
212
234
return node
213
- . SubstituteWithWhere ( )
214
- . SubstituteWithTake ( 2 )
235
+ . SubstituteWithWhere ( this )
236
+ . WrapInTake ( 2 )
215
237
. WrapInMethodWithoutSelector ( QueryableMethods . SingleOrDefaultWithoutPredicate ) ;
216
238
}
217
239
@@ -239,17 +261,17 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
239
261
if ( genericDefinition == QueryableMethods . FirstWithPredicate )
240
262
{
241
263
return node
242
- . SubstituteWithWhere ( )
243
- . SubstituteWithTake ( 1 )
264
+ . SubstituteWithWhere ( this )
265
+ . WrapInTake ( 1 )
244
266
. WrapInMethodWithoutSelector ( QueryableMethods . FirstWithoutPredicate ) ;
245
267
}
246
268
247
269
// FirstOrDefault(d => condition) == Where(d => condition).Take(1).FirstOrDefault()
248
270
if ( genericDefinition == QueryableMethods . FirstOrDefaultWithPredicate )
249
271
{
250
272
return node
251
- . SubstituteWithWhere ( )
252
- . SubstituteWithTake ( 1 )
273
+ . SubstituteWithWhere ( this )
274
+ . WrapInTake ( 1 )
253
275
. WrapInMethodWithoutSelector ( QueryableMethods . FirstOrDefaultWithoutPredicate ) ;
254
276
}
255
277
@@ -269,15 +291,15 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
269
291
if ( genericDefinition == QueryableMethods . LastWithPredicate )
270
292
{
271
293
return node
272
- . SubstituteWithWhere ( )
294
+ . SubstituteWithWhere ( this )
273
295
. WrapInMethodWithoutSelector ( QueryableMethods . LastWithoutPredicate ) ;
274
296
}
275
297
276
298
// LastOrDefault(d => condition) == Where(d => condition).LastOrDefault()
277
299
if ( genericDefinition == QueryableMethods . LastOrDefaultWithPredicate )
278
300
{
279
301
return node
280
- . SubstituteWithWhere ( )
302
+ . SubstituteWithWhere ( this )
281
303
. WrapInMethodWithoutSelector ( QueryableMethods . LastOrDefaultWithoutPredicate ) ;
282
304
}
283
305
0 commit comments