@@ -84,13 +84,19 @@ public virtual async Task<TEntity> GetAndIncludeAsync(TId id, string relationshi
84
84
85
85
public virtual async Task < TEntity > CreateAsync ( TEntity entity )
86
86
{
87
- AttachHasManyPointers ( ) ;
87
+ AttachRelationships ( ) ;
88
88
_dbSet . Add ( entity ) ;
89
89
90
90
await _context . SaveChangesAsync ( ) ;
91
91
return entity ;
92
92
}
93
93
94
+ protected virtual void AttachRelationships ( )
95
+ {
96
+ AttachHasManyPointers ( ) ;
97
+ AttachHasOnePointers ( ) ;
98
+ }
99
+
94
100
/// <summary>
95
101
/// This is used to allow creation of HasMany relationships when the
96
102
/// dependent side of the relationship already exists.
@@ -107,6 +113,18 @@ private void AttachHasManyPointers()
107
113
}
108
114
}
109
115
116
+ /// <summary>
117
+ /// This is used to allow creation of HasOne relationships when the
118
+ /// independent side of the relationship already exists.
119
+ /// </summary>
120
+ private void AttachHasOnePointers ( )
121
+ {
122
+ var relationships = _jsonApiContext . HasOneRelationshipPointers . Get ( ) ;
123
+ foreach ( var relationship in relationships )
124
+ if ( _context . Entry ( relationship . Value ) . State == EntityState . Detached && _context . EntityIsTracked ( relationship . Value ) == false )
125
+ _context . Entry ( relationship . Value ) . State = EntityState . Unchanged ;
126
+ }
127
+
110
128
public virtual async Task < TEntity > UpdateAsync ( TId id , TEntity entity )
111
129
{
112
130
var oldEntity = await GetAsync ( id ) ;
@@ -185,17 +203,23 @@ public virtual async Task<IEnumerable<TEntity>> PageAsync(IQueryable<TEntity> en
185
203
186
204
public async Task < int > CountAsync ( IQueryable < TEntity > entities )
187
205
{
188
- return await entities . CountAsync ( ) ;
206
+ return ( entities is IAsyncEnumerable < TEntity > )
207
+ ? await entities . CountAsync ( )
208
+ : entities . Count ( ) ;
189
209
}
190
210
191
- public Task < TEntity > FirstOrDefaultAsync ( IQueryable < TEntity > entities )
211
+ public async Task < TEntity > FirstOrDefaultAsync ( IQueryable < TEntity > entities )
192
212
{
193
- return entities . FirstOrDefaultAsync ( ) ;
213
+ return ( entities is IAsyncEnumerable < TEntity > )
214
+ ? await entities . FirstOrDefaultAsync ( )
215
+ : entities . FirstOrDefault ( ) ;
194
216
}
195
217
196
218
public async Task < IReadOnlyList < TEntity > > ToListAsync ( IQueryable < TEntity > entities )
197
219
{
198
- return await entities . ToListAsync ( ) ;
220
+ return ( entities is IAsyncEnumerable < TEntity > )
221
+ ? await entities . ToListAsync ( )
222
+ : entities . ToList ( ) ;
199
223
}
200
224
}
201
225
}
0 commit comments