Skip to content

Commit adfab5f

Browse files
authored
Fix exception propagation in Async API methods (#1479) (#1485)
- Resolve an issue where exceptions thrown during thenRun, thenSupply, and related operations in the asynchronous API were not properly propagated to the completion callback. This issue was addressed by replacing `unsafeFinish` with `finish`, ensuring that exceptions are caught and correctly passed to the completion callback when executed on different threads. - Update existing Async API tests to ensure they simulate separate async thread execution. - Modify the async callback to catch and handle exceptions locally. Exceptions are now directly processed and passed as an error argument to the callback function, avoiding propagation to the parent callback. - Move `callback.onResult` outside the catch block to ensure it's not invoked twice when an exception occurs. JAVA-5562
1 parent 39d1e9a commit adfab5f

9 files changed

+647
-321
lines changed

driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java

+26
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import com.mongodb.lang.Nullable;
2020

21+
import java.util.concurrent.atomic.AtomicBoolean;
22+
2123
/**
2224
* See {@link AsyncRunnable}
2325
* <p>
@@ -33,4 +35,28 @@ public interface AsyncFunction<T, R> {
3335
* @param callback the callback
3436
*/
3537
void unsafeFinish(T value, SingleResultCallback<R> callback);
38+
39+
/**
40+
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
41+
*
42+
* @param callback the callback provided by the method the chain is used in.
43+
*/
44+
default void finish(final T value, final SingleResultCallback<R> callback) {
45+
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
46+
try {
47+
this.unsafeFinish(value, (v, e) -> {
48+
if (!callbackInvoked.compareAndSet(false, true)) {
49+
throw new AssertionError(String.format("Callback has been already completed. It could happen "
50+
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
51+
}
52+
callback.onResult(v, e);
53+
});
54+
} catch (Throwable t) {
55+
if (!callbackInvoked.compareAndSet(false, true)) {
56+
throw t;
57+
} else {
58+
callback.completeExceptionally(t);
59+
}
60+
}
61+
}
3662
}

driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) {
170170
return (c) -> {
171171
this.unsafeFinish((r, e) -> {
172172
if (e == null) {
173-
runnable.unsafeFinish(c);
173+
/* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()',
174+
then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */
175+
runnable.finish(c);
174176
} else {
175177
c.completeExceptionally(e);
176178
}
@@ -199,7 +201,7 @@ default AsyncRunnable thenRunIf(final Supplier<Boolean> condition, final AsyncRu
199201
return;
200202
}
201203
if (matched) {
202-
runnable.unsafeFinish(callback);
204+
runnable.finish(callback);
203205
} else {
204206
callback.complete(callback);
205207
}
@@ -216,7 +218,7 @@ default <R> AsyncSupplier<R> thenSupply(final AsyncSupplier<R> supplier) {
216218
return (c) -> {
217219
this.unsafeFinish((r, e) -> {
218220
if (e == null) {
219-
supplier.unsafeFinish(c);
221+
supplier.finish(c);
220222
} else {
221223
c.completeExceptionally(e);
222224
}

driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java

+16-8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.mongodb.lang.Nullable;
2020

21+
import java.util.concurrent.atomic.AtomicBoolean;
2122
import java.util.function.Predicate;
2223

2324

@@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback
5455
}
5556

5657
/**
57-
* Must be invoked at end of async chain.
58+
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
59+
*
60+
* @see #thenApply(AsyncFunction)
61+
* @see #thenConsume(AsyncConsumer)
62+
* @see #onErrorIf(Predicate, AsyncFunction)
5863
* @param callback the callback provided by the method the chain is used in
5964
*/
6065
default void finish(final SingleResultCallback<T> callback) {
61-
final boolean[] callbackInvoked = {false};
66+
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
6267
try {
6368
this.unsafeFinish((v, e) -> {
64-
callbackInvoked[0] = true;
69+
if (!callbackInvoked.compareAndSet(false, true)) {
70+
throw new AssertionError(String.format("Callback has been already completed. It could happen "
71+
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
72+
}
6573
callback.onResult(v, e);
6674
});
6775
} catch (Throwable t) {
68-
if (callbackInvoked[0]) {
76+
if (!callbackInvoked.compareAndSet(false, true)) {
6977
throw t;
7078
} else {
7179
callback.completeExceptionally(t);
@@ -80,9 +88,9 @@ default void finish(final SingleResultCallback<T> callback) {
8088
*/
8189
default <R> AsyncSupplier<R> thenApply(final AsyncFunction<T, R> function) {
8290
return (c) -> {
83-
this.unsafeFinish((v, e) -> {
91+
this.finish((v, e) -> {
8492
if (e == null) {
85-
function.unsafeFinish(v, c);
93+
function.finish(v, c);
8694
} else {
8795
c.completeExceptionally(e);
8896
}
@@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer<T> consumer) {
99107
return (c) -> {
100108
this.unsafeFinish((v, e) -> {
101109
if (e == null) {
102-
consumer.unsafeFinish(v, c);
110+
consumer.finish(v, c);
103111
} else {
104112
c.completeExceptionally(e);
105113
}
@@ -131,7 +139,7 @@ default AsyncSupplier<T> onErrorIf(
131139
return;
132140
}
133141
if (errorMatched) {
134-
errorFunction.unsafeFinish(e, callback);
142+
errorFunction.finish(e, callback);
135143
} else {
136144
callback.completeExceptionally(e);
137145
}

driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
595595
return;
596596
}
597597
assertNotNull(responseBuffers);
598+
T commandResult;
598599
try {
599600
updateSessionContext(sessionContext, responseBuffers);
600601
boolean commandOk =
@@ -609,13 +610,14 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
609610
}
610611
commandEventSender.sendSucceededEvent(responseBuffers);
611612

612-
T result1 = getCommandResult(decoder, responseBuffers, messageId);
613-
callback.onResult(result1, null);
613+
commandResult = getCommandResult(decoder, responseBuffers, messageId);
614614
} catch (Throwable localThrowable) {
615615
callback.onResult(null, localThrowable);
616+
return;
616617
} finally {
617618
responseBuffers.close();
618619
}
620+
callback.onResult(commandResult, null);
619621
}));
620622
}
621623
});

driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection,
9898
callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t);
9999
} else {
100100
setSpeculativeAuthenticateResponse(helloResult);
101-
callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null);
101+
InternalConnectionInitializationDescription initializationDescription;
102+
try {
103+
initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime);
104+
} catch (Throwable localThrowable) {
105+
callback.onResult(null, localThrowable);
106+
return;
107+
}
108+
callback.onResult(initializationDescription, null);
102109
}
103110
});
104111
}

0 commit comments

Comments
 (0)