Skip to content

Commit 170a38a

Browse files
#136 - Fix FirstOrDefault() with discriminator
1 parent d6621c4 commit 170a38a

File tree

3 files changed

+83
-18
lines changed

3 files changed

+83
-18
lines changed

src/CouchDB.Driver/Helpers/MethodCallExpressionBuilder.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public static MethodCallExpression SubstituteWithSelect(this MethodCallExpressio
3939
return Expression.Call(typeof(Queryable), nameof(Queryable.Select), genericArgumentTypes, node.Arguments[0], selectorNode.Arguments[1]);
4040
}
4141

42-
public static MethodCallExpression SubstituteWithWhere(this MethodCallExpression node, bool negate = false)
42+
public static MethodCallExpression SubstituteWithWhere(this MethodCallExpression node, ExpressionVisitor optimizer, bool negate = false)
4343
{
4444
Check.NotNull(node, nameof(node));
4545

@@ -52,7 +52,8 @@ public static MethodCallExpression SubstituteWithWhere(this MethodCallExpression
5252
predicate = body.WrapInLambda(lambdaExpression.Parameters);
5353
}
5454

55-
return Expression.Call(typeof(Queryable), nameof(Queryable.Where), node.Method.GetGenericArguments(), node.Arguments[0], predicate);
55+
var e = Expression.Call(typeof(Queryable), nameof(Queryable.Where), node.Method.GetGenericArguments(), node.Arguments[0], predicate);
56+
return (MethodCallExpression)optimizer.Visit(e);
5657
}
5758

5859
#endregion

src/CouchDB.Driver/Query/QueryOptimizer.cs

+37-15
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,31 @@ public Expression Optimize(Expression e, string? discriminator)
3030
{
3131
if (discriminator is not null)
3232
{
33-
Type? sourceType = e.Type.GetGenericArguments()[0];
34-
MethodInfo? wrapInWhere = WrapInWhereGenericMethod.MakeGenericMethod(new[] { sourceType });
35-
e = (Expression)wrapInWhere.Invoke(null, new object[] { e, discriminator });
33+
if (e.Type.IsGenericType)
34+
{
35+
Type? sourceType = e.Type.GetGenericArguments()[0];
36+
MethodInfo wrapInWhere = WrapInWhereGenericMethod.MakeGenericMethod(sourceType);
37+
e = (Expression)wrapInWhere.Invoke(null, new object[] { e, discriminator });
38+
}
39+
else
40+
{
41+
Type sourceType = e.Type;
42+
MethodInfo wrapInWhere = WrapInWhereGenericMethod.MakeGenericMethod(sourceType);
43+
44+
var rootMethodCallExpression = e as MethodCallExpression;
45+
Expression source = rootMethodCallExpression!.Arguments[0];
46+
var discriminatorWrap = (MethodCallExpression)wrapInWhere.Invoke(null, new object[] { source, discriminator });
47+
48+
49+
if (rootMethodCallExpression.Arguments.Count == 1)
50+
{
51+
e = Expression.Call(rootMethodCallExpression.Method, discriminatorWrap);
52+
}
53+
else
54+
{
55+
e = Expression.Call(rootMethodCallExpression.Method, discriminatorWrap, rootMethodCallExpression.Arguments[1]);
56+
}
57+
}
3658
}
3759

3860
e = LocalExpressions.PartialEval(e);
@@ -163,7 +185,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
163185
if (genericDefinition == QueryableMethods.AnyWithPredicate)
164186
{
165187
return node
166-
.SubstituteWithWhere()
188+
.SubstituteWithWhere(this)
167189
.WrapInTake(1)
168190
.WrapInMethodWithoutSelector(QueryableMethods.AnyWithoutPredicate);
169191
}
@@ -172,7 +194,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
172194
if (genericDefinition == QueryableMethods.All)
173195
{
174196
return node
175-
.SubstituteWithWhere(true)
197+
.SubstituteWithWhere(this, true)
176198
.WrapInTake(1)
177199
.WrapInMethodWithoutSelector(QueryableMethods.AnyWithoutPredicate);
178200
}
@@ -201,17 +223,17 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
201223
if (genericDefinition == QueryableMethods.SingleWithPredicate)
202224
{
203225
return node
204-
.SubstituteWithWhere()
205-
.SubstituteWithTake(2)
226+
.SubstituteWithWhere(this)
227+
.WrapInTake(2)
206228
.WrapInMethodWithoutSelector(QueryableMethods.SingleWithoutPredicate);
207229
}
208230

209231
// SingleOrDefault(d => condition) == Where(d => condition).Take(2).SingleOrDefault()
210232
if (genericDefinition == QueryableMethods.SingleOrDefaultWithPredicate)
211233
{
212234
return node
213-
.SubstituteWithWhere()
214-
.SubstituteWithTake(2)
235+
.SubstituteWithWhere(this)
236+
.WrapInTake(2)
215237
.WrapInMethodWithoutSelector(QueryableMethods.SingleOrDefaultWithoutPredicate);
216238
}
217239

@@ -239,17 +261,17 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
239261
if (genericDefinition == QueryableMethods.FirstWithPredicate)
240262
{
241263
return node
242-
.SubstituteWithWhere()
243-
.SubstituteWithTake(1)
264+
.SubstituteWithWhere(this)
265+
.WrapInTake(1)
244266
.WrapInMethodWithoutSelector(QueryableMethods.FirstWithoutPredicate);
245267
}
246268

247269
// FirstOrDefault(d => condition) == Where(d => condition).Take(1).FirstOrDefault()
248270
if (genericDefinition == QueryableMethods.FirstOrDefaultWithPredicate)
249271
{
250272
return node
251-
.SubstituteWithWhere()
252-
.SubstituteWithTake(1)
273+
.SubstituteWithWhere(this)
274+
.WrapInTake(1)
253275
.WrapInMethodWithoutSelector(QueryableMethods.FirstOrDefaultWithoutPredicate);
254276
}
255277

@@ -269,15 +291,15 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
269291
if (genericDefinition == QueryableMethods.LastWithPredicate)
270292
{
271293
return node
272-
.SubstituteWithWhere()
294+
.SubstituteWithWhere(this)
273295
.WrapInMethodWithoutSelector(QueryableMethods.LastWithoutPredicate);
274296
}
275297

276298
// LastOrDefault(d => condition) == Where(d => condition).LastOrDefault()
277299
if (genericDefinition == QueryableMethods.LastOrDefaultWithPredicate)
278300
{
279301
return node
280-
.SubstituteWithWhere()
302+
.SubstituteWithWhere(this)
281303
.WrapInMethodWithoutSelector(QueryableMethods.LastOrDefaultWithoutPredicate);
282304
}
283305

tests/CouchDB.Driver.UnitTests/Find/Find_Discriminator.cs

+43-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
using CouchDB.UnitTests.Models;
1+
using System;
2+
using System.Collections.Generic;
3+
using CouchDB.UnitTests.Models;
24
using System.Linq;
5+
using Flurl.Http.Testing;
36
using Xunit;
47

58
namespace CouchDB.Driver.UnitTests.Find
@@ -9,12 +12,29 @@ public class Find_Discriminator
912
private const string _databaseName = "allrebels";
1013
private readonly ICouchDatabase<Rebel> _rebels;
1114
private readonly ICouchDatabase<SimpleRebel> _simpleRebels;
15+
private readonly object _response;
1216

1317
public Find_Discriminator()
1418
{
1519
var client = new CouchClient("http://localhost");
1620
_rebels = client.GetDatabase<Rebel>(_databaseName, nameof(Rebel));
1721
_simpleRebels = client.GetDatabase<SimpleRebel>(_databaseName, nameof(SimpleRebel));
22+
23+
var mainRebel = new Rebel
24+
{
25+
Id = Guid.NewGuid().ToString(),
26+
Name = "Luke",
27+
Age = 19,
28+
Skills = new List<string> { "Force" }
29+
};
30+
var rebelsList = new List<Rebel>
31+
{
32+
mainRebel
33+
};
34+
_response = new
35+
{
36+
Docs = rebelsList
37+
};
1838
}
1939

2040
[Fact]
@@ -34,5 +54,27 @@ public void Discriminator_WithFilter()
3454
Assert.Equal(@"{""selector"":{""$and"":[{""age"":19},{""split_discriminator"":""Rebel""}]}}", json1);
3555
Assert.Equal(@"{""selector"":{""$and"":[{""age"":19},{""split_discriminator"":""SimpleRebel""}]}}", json2);
3656
}
57+
58+
[Fact]
59+
public void Discriminator_FirstOrDefault()
60+
{
61+
using var httpTest = new HttpTest();
62+
httpTest.RespondWithJson(_response);
63+
_rebels.FirstOrDefault();
64+
_simpleRebels.FirstOrDefault();
65+
Assert.Equal(@"{""selector"":{""split_discriminator"":""Rebel""},""limit"":1}", httpTest.CallLog[0].RequestBody);
66+
Assert.Equal(@"{""selector"":{""split_discriminator"":""SimpleRebel""},""limit"":1}", httpTest.CallLog[1].RequestBody);
67+
}
68+
69+
[Fact]
70+
public void Discriminator_FirstOrDefault_WithExpression()
71+
{
72+
using var httpTest = new HttpTest();
73+
httpTest.RespondWithJson(_response);
74+
_rebels.FirstOrDefault(c => c.Age == 19);
75+
_simpleRebels.FirstOrDefault(c => c.Age == 19);
76+
Assert.Equal(@"{""selector"":{""$and"":[{""split_discriminator"":""Rebel""},{""age"":19}]},""limit"":1}", httpTest.CallLog[0].RequestBody);
77+
Assert.Equal(@"{""selector"":{""$and"":[{""split_discriminator"":""SimpleRebel""},{""age"":19}]},""limit"":1}", httpTest.CallLog[1].RequestBody);
78+
}
3779
}
3880
}

0 commit comments

Comments
 (0)