23
23
import java .beans .FeatureDescriptor ;
24
24
import java .util .Collections ;
25
25
import java .util .List ;
26
+ import java .util .Map ;
26
27
import java .util .Optional ;
27
28
import java .util .function .BiFunction ;
29
+ import java .util .function .Function ;
28
30
import java .util .stream .Collectors ;
29
31
32
+ import org .reactivestreams .Publisher ;
33
+
30
34
import org .springframework .beans .BeansException ;
31
- import org .springframework .beans . factory . BeanFactory ;
32
- import org .springframework .beans . factory . BeanFactoryAware ;
35
+ import org .springframework .context . ApplicationContext ;
36
+ import org .springframework .context . ApplicationContextAware ;
33
37
import org .springframework .core .convert .ConversionService ;
34
38
import org .springframework .dao .DataAccessException ;
35
39
import org .springframework .dao .OptimisticLockingFailureException ;
36
40
import org .springframework .dao .TransientDataAccessResourceException ;
37
41
import org .springframework .data .mapping .IdentifierAccessor ;
38
42
import org .springframework .data .mapping .MappingException ;
39
43
import org .springframework .data .mapping .PersistentPropertyAccessor ;
44
+ import org .springframework .data .mapping .callback .ReactiveEntityCallbacks ;
40
45
import org .springframework .data .mapping .context .MappingContext ;
41
46
import org .springframework .data .projection .ProjectionInformation ;
42
47
import org .springframework .data .projection .SpelAwareProxyProjectionFactory ;
48
+ import org .springframework .data .r2dbc .mapping .OutboundRow ;
49
+ import org .springframework .data .r2dbc .mapping .SettableValue ;
50
+ import org .springframework .data .r2dbc .mapping .event .AfterConvertCallback ;
51
+ import org .springframework .data .r2dbc .mapping .event .AfterSaveCallback ;
52
+ import org .springframework .data .r2dbc .mapping .event .BeforeConvertCallback ;
53
+ import org .springframework .data .r2dbc .mapping .event .BeforeSaveCallback ;
43
54
import org .springframework .data .relational .core .mapping .RelationalPersistentEntity ;
44
55
import org .springframework .data .relational .core .mapping .RelationalPersistentProperty ;
45
56
import org .springframework .data .relational .core .query .Criteria ;
51
62
import org .springframework .data .relational .core .sql .SqlIdentifier ;
52
63
import org .springframework .data .relational .core .sql .Table ;
53
64
import org .springframework .data .util .ProxyUtils ;
65
+ import org .springframework .lang .Nullable ;
54
66
import org .springframework .util .Assert ;
55
67
56
68
/**
65
77
* @author Bogdan Ilchyshyn
66
78
* @since 1.1
67
79
*/
68
- public class R2dbcEntityTemplate implements R2dbcEntityOperations , BeanFactoryAware {
80
+ public class R2dbcEntityTemplate implements R2dbcEntityOperations , ApplicationContextAware {
69
81
70
82
private final DatabaseClient databaseClient ;
71
83
@@ -75,6 +87,8 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw
75
87
76
88
private final SpelAwareProxyProjectionFactory projectionFactory ;
77
89
90
+ private @ Nullable ReactiveEntityCallbacks entityCallbacks ;
91
+
78
92
/**
79
93
* Create a new {@link R2dbcEntityTemplate} given {@link DatabaseClient}.
80
94
*
@@ -111,11 +125,34 @@ public DatabaseClient getDatabaseClient() {
111
125
112
126
/*
113
127
* (non-Javadoc)
114
- * @see org.springframework.beans.factory.BeanFactoryAware#setBeanFactory (org.springframework.beans.factory.BeanFactory )
128
+ * @see org.springframework.context.ApplicationContextAware#setApplicationContext (org.springframework.context.ApplicationContext )
115
129
*/
116
130
@ Override
117
- public void setBeanFactory (BeanFactory beanFactory ) throws BeansException {
118
- this .projectionFactory .setBeanFactory (beanFactory );
131
+ public void setApplicationContext (ApplicationContext applicationContext ) throws BeansException {
132
+
133
+ if (entityCallbacks == null ) {
134
+ setEntityCallbacks (ReactiveEntityCallbacks .create (applicationContext ));
135
+ }
136
+
137
+ projectionFactory .setBeanFactory (applicationContext );
138
+ projectionFactory .setBeanClassLoader (applicationContext .getClassLoader ());
139
+ }
140
+
141
+ /**
142
+ * Set the {@link ReactiveEntityCallbacks} instance to use when invoking
143
+ * {@link org.springframework.data.mapping.callback.ReactiveEntityCallbacks callbacks} like the
144
+ * {@link BeforeSaveCallback}.
145
+ * <p />
146
+ * Overrides potentially existing {@link ReactiveEntityCallbacks}.
147
+ *
148
+ * @param entityCallbacks must not be {@literal null}.
149
+ * @throws IllegalArgumentException if the given instance is {@literal null}.
150
+ * @since 1.2
151
+ */
152
+ public void setEntityCallbacks (ReactiveEntityCallbacks entityCallbacks ) {
153
+
154
+ Assert .notNull (entityCallbacks , "EntityCallbacks must not be null!" );
155
+ this .entityCallbacks = entityCallbacks ;
119
156
}
120
157
121
158
// -------------------------------------------------------------------------
@@ -248,10 +285,27 @@ public <T> Flux<T> select(Query query, Class<T> entityClass) throws DataAccessEx
248
285
Assert .notNull (query , "Query must not be null" );
249
286
Assert .notNull (entityClass , "entity class must not be null" );
250
287
251
- return doSelect (query , entityClass , getTableName (entityClass ), entityClass ).all ();
288
+ SqlIdentifier tableName = getTableName (entityClass );
289
+ return doSelect (query , entityClass , tableName , entityClass , RowsFetchSpec ::all );
290
+ }
291
+
292
+ @ SuppressWarnings ("unchecked" )
293
+ <T , P extends Publisher <T >> P doSelect (Query query , Class <?> entityClass , SqlIdentifier tableName ,
294
+ Class <T > returnType , Function <RowsFetchSpec <T >, P > resultHandler ) {
295
+
296
+ RowsFetchSpec <T > fetchSpec = doSelect (query , entityClass , tableName , returnType );
297
+
298
+ P result = resultHandler .apply (fetchSpec );
299
+
300
+ if (result instanceof Mono ) {
301
+ return (P ) ((Mono <?>) result ).flatMap (it -> maybeCallAfterConvert (it , tableName ));
302
+ }
303
+
304
+ return (P ) ((Flux <?>) result ).flatMap (it -> maybeCallAfterConvert (it , tableName ));
252
305
}
253
306
254
- <T > RowsFetchSpec <T > doSelect (Query query , Class <?> entityClass , SqlIdentifier tableName , Class <T > returnType ) {
307
+ private <T > RowsFetchSpec <T > doSelect (Query query , Class <?> entityClass , SqlIdentifier tableName ,
308
+ Class <T > returnType ) {
255
309
256
310
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
257
311
@@ -295,7 +349,7 @@ <T> RowsFetchSpec<T> doSelect(Query query, Class<?> entityClass, SqlIdentifier t
295
349
*/
296
350
@ Override
297
351
public <T > Mono <T > selectOne (Query query , Class <T > entityClass ) throws DataAccessException {
298
- return doSelect (query .limit (2 ), entityClass , getTableName (entityClass ), entityClass ). one ( );
352
+ return doSelect (query .limit (2 ), entityClass , getTableName (entityClass ), entityClass , RowsFetchSpec :: one );
299
353
}
300
354
301
355
/*
@@ -377,14 +431,33 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
377
431
378
432
RelationalPersistentEntity <T > persistentEntity = getRequiredEntity (entity );
379
433
380
- T entityToInsert = setVersionIfNecessary (persistentEntity , entity );
434
+ return Mono .defer (() -> maybeCallBeforeConvert (setVersionIfNecessary (persistentEntity , entity ), tableName )
435
+ .flatMap (beforeConvert -> {
381
436
382
- return this .databaseClient .insert () //
383
- .into (persistentEntity .getType ()) //
384
- .table (tableName ).using (entityToInsert ) //
385
- .map (this .dataAccessStrategy .getConverter ().populateIdIfNecessary (entityToInsert )) //
386
- .first () //
387
- .defaultIfEmpty (entityToInsert );
437
+ OutboundRow outboundRow = dataAccessStrategy .getOutboundRow (beforeConvert );
438
+
439
+ return maybeCallBeforeSave (beforeConvert , outboundRow , tableName ).flatMap (entityToSave -> {
440
+
441
+ StatementMapper mapper = dataAccessStrategy .getStatementMapper ();
442
+ StatementMapper .InsertSpec insert = mapper .createInsert (tableName );
443
+
444
+ for (SqlIdentifier column : outboundRow .keySet ()) {
445
+ SettableValue settableValue = outboundRow .get (column );
446
+ if (settableValue .hasValue ()) {
447
+ insert = insert .withColumn (column , settableValue );
448
+ }
449
+ }
450
+
451
+ PreparedOperation <?> operation = mapper .getMappedObject (insert );
452
+
453
+ return this .databaseClient .execute (operation ) //
454
+ .filter (statement -> statement .returnGeneratedValues ())
455
+ .map (this .dataAccessStrategy .getConverter ().populateIdIfNecessary (entityToSave )) //
456
+ .first () //
457
+ .defaultIfEmpty (entityToSave ) //
458
+ .flatMap (saved -> maybeCallAfterSave (saved , outboundRow , tableName ));
459
+ });
460
+ }));
388
461
}
389
462
390
463
@ SuppressWarnings ("unchecked" )
@@ -413,37 +486,62 @@ public <T> Mono<T> update(T entity) throws DataAccessException {
413
486
414
487
Assert .notNull (entity , "Entity must not be null" );
415
488
489
+ return doUpdate (entity , getRequiredEntity (entity ).getTableName ());
490
+ }
491
+
492
+ private <T > Mono <T > doUpdate (T entity , SqlIdentifier tableName ) {
493
+
416
494
RelationalPersistentEntity <T > persistentEntity = getRequiredEntity (entity );
417
495
418
- DatabaseClient .TypedUpdateSpec <T > updateMatchingSpec = this .databaseClient .update () //
419
- .table (persistentEntity .getType ()) //
420
- .table (persistentEntity .getTableName ());
496
+ return maybeCallBeforeConvert (entity , tableName ).flatMap (beforeConvert -> {
421
497
422
- DatabaseClient .UpdateSpec matching ;
423
- T entityToUpdate ;
424
- if (persistentEntity .hasVersionProperty ()) {
498
+ OutboundRow outboundRow = dataAccessStrategy .getOutboundRow (entity );
425
499
426
- Criteria criteria = createMatchingVersionCriteria (entity , persistentEntity );
427
- entityToUpdate = incrementVersion (persistentEntity , entity );
428
- matching = updateMatchingSpec .using (entityToUpdate ).matching (criteria );
429
- } else {
430
- entityToUpdate = entity ;
431
- matching = updateMatchingSpec .using (entity );
432
- }
500
+ return maybeCallBeforeSave (beforeConvert , outboundRow , tableName ) //
501
+ .flatMap (entityToSave -> {
433
502
434
- return matching .fetch () //
435
- .rowsUpdated () //
436
- .flatMap (rowsUpdated -> rowsUpdated == 0 ? handleMissingUpdate (entityToUpdate , persistentEntity )
437
- : Mono .just (entityToUpdate ));
438
- }
503
+ SqlIdentifier idColumn = persistentEntity .getRequiredIdProperty ().getColumnName ();
504
+ SettableValue id = outboundRow .remove (idColumn );
505
+ Criteria criteria = Criteria .where (dataAccessStrategy .toSql (idColumn )).is (id );
506
+
507
+ T saved ;
508
+
509
+ if (persistentEntity .hasVersionProperty ()) {
510
+ criteria = criteria .and (createMatchingVersionCriteria (entity , persistentEntity ));
511
+ saved = incrementVersion (persistentEntity , entity , outboundRow );
512
+ } else {
513
+ saved = entityToSave ;
514
+ }
515
+
516
+ Update update = Update .from ((Map ) outboundRow );
439
517
440
- private <T > Mono <? extends T > handleMissingUpdate (T entity , RelationalPersistentEntity <T > persistentEntity ) {
518
+ StatementMapper mapper = dataAccessStrategy .getStatementMapper ();
519
+ StatementMapper .UpdateSpec updateSpec = mapper .createUpdate (tableName , update ).withCriteria (criteria );
441
520
442
- return Mono .error (persistentEntity .hasVersionProperty ()
443
- ? new OptimisticLockingFailureException (formatOptimisticLockingExceptionMessage (entity , persistentEntity ))
444
- : new TransientDataAccessResourceException (formatTransientEntityExceptionMessage (entity , persistentEntity )));
521
+ PreparedOperation <?> operation = mapper .getMappedObject (updateSpec );
522
+
523
+ return this .databaseClient .execute (operation ) //
524
+ .fetch () //
525
+ .rowsUpdated () //
526
+ .handle ((rowsUpdated , sink ) -> {
527
+
528
+ if (rowsUpdated != 0 ) {
529
+ return ;
530
+ }
531
+
532
+ if (persistentEntity .hasVersionProperty ()) {
533
+ sink .error (new OptimisticLockingFailureException (
534
+ formatOptimisticLockingExceptionMessage (saved , persistentEntity )));
535
+ } else {
536
+ sink .error (new TransientDataAccessResourceException (
537
+ formatTransientEntityExceptionMessage (saved , persistentEntity )));
538
+ }
539
+ }).then (maybeCallAfterSave (saved , outboundRow , tableName ));
540
+ });
541
+ });
445
542
}
446
543
544
+
447
545
private <T > String formatOptimisticLockingExceptionMessage (T entity , RelationalPersistentEntity <T > persistentEntity ) {
448
546
449
547
return String .format ("Failed to update table [%s]. Version does not match for row with Id [%s]." ,
@@ -457,7 +555,7 @@ private <T> String formatTransientEntityExceptionMessage(T entity, RelationalPer
457
555
}
458
556
459
557
@ SuppressWarnings ("unchecked" )
460
- private <T > T incrementVersion (RelationalPersistentEntity <T > persistentEntity , T entity ) {
558
+ private <T > T incrementVersion (RelationalPersistentEntity <T > persistentEntity , T entity , OutboundRow outboundRow ) {
461
559
462
560
PersistentPropertyAccessor <?> propertyAccessor = persistentEntity .getPropertyAccessor (entity );
463
561
RelationalPersistentProperty versionProperty = persistentEntity .getVersionProperty ();
@@ -471,6 +569,8 @@ private <T> T incrementVersion(RelationalPersistentEntity<T> persistentEntity, T
471
569
Class <?> versionPropertyType = versionProperty .getType ();
472
570
propertyAccessor .setProperty (versionProperty , conversionService .convert (newVersionValue , versionPropertyType ));
473
571
572
+ outboundRow .put (versionProperty .getColumnName (), SettableValue .from (newVersionValue ));
573
+
474
574
return (T ) propertyAccessor .getBean ();
475
575
}
476
576
@@ -502,6 +602,42 @@ public <T> Mono<T> delete(T entity) throws DataAccessException {
502
602
return delete (getByIdQuery (entity , persistentEntity ), persistentEntity .getType ()).thenReturn (entity );
503
603
}
504
604
605
+ protected <T > Mono <T > maybeCallBeforeConvert (T object , SqlIdentifier table ) {
606
+
607
+ if (entityCallbacks != null ) {
608
+ return entityCallbacks .callback (BeforeConvertCallback .class , object , table );
609
+ }
610
+
611
+ return Mono .just (object );
612
+ }
613
+
614
+ protected <T > Mono <T > maybeCallBeforeSave (T object , OutboundRow row , SqlIdentifier table ) {
615
+
616
+ if (entityCallbacks != null ) {
617
+ return entityCallbacks .callback (BeforeSaveCallback .class , object , row , table );
618
+ }
619
+
620
+ return Mono .just (object );
621
+ }
622
+
623
+ protected <T > Mono <T > maybeCallAfterSave (T object , OutboundRow row , SqlIdentifier table ) {
624
+
625
+ if (entityCallbacks != null ) {
626
+ return entityCallbacks .callback (AfterSaveCallback .class , object , row , table );
627
+ }
628
+
629
+ return Mono .just (object );
630
+ }
631
+
632
+ protected <T > Mono <T > maybeCallAfterConvert (T object , SqlIdentifier table ) {
633
+
634
+ if (entityCallbacks != null ) {
635
+ return entityCallbacks .callback (AfterConvertCallback .class , object , table );
636
+ }
637
+
638
+ return Mono .just (object );
639
+ }
640
+
505
641
private <T > Query getByIdQuery (T entity , RelationalPersistentEntity <?> persistentEntity ) {
506
642
if (!persistentEntity .hasIdProperty ()) {
507
643
throw new MappingException ("No id property found for object of type " + persistentEntity .getType () + "!" );
0 commit comments