Skip to content

Commit 773d4c2

Browse files
committed
Add runAfter and notifyOnce wrapper to ActionListener (#37331)
Relates #37291
1 parent 764da16 commit 773d4c2

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

server/src/main/java/org/elasticsearch/action/ActionListener.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,48 @@ static <Response> void onFailure(Iterable<ActionListener<Response>> listeners, E
142142
}
143143
ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptionList);
144144
}
145+
146+
/**
147+
* Wraps a given listener and returns a new listener which executes the provided {@code runAfter}
148+
* callback when the listener is notified via either {@code #onResponse} or {@code #onFailure}.
149+
*/
150+
static <Response> ActionListener<Response> runAfter(ActionListener<Response> delegate, Runnable runAfter) {
151+
return new ActionListener<Response>() {
152+
@Override
153+
public void onResponse(Response response) {
154+
try {
155+
delegate.onResponse(response);
156+
} finally {
157+
runAfter.run();
158+
}
159+
}
160+
161+
@Override
162+
public void onFailure(Exception e) {
163+
try {
164+
delegate.onFailure(e);
165+
} finally {
166+
runAfter.run();
167+
}
168+
}
169+
};
170+
}
171+
172+
/**
173+
* Wraps a given listener and returns a new listener which makes sure {@link #onResponse(Object)}
174+
* and {@link #onFailure(Exception)} of the provided listener will be called at most once.
175+
*/
176+
static <Response> ActionListener<Response> notifyOnce(ActionListener<Response> delegate) {
177+
return new NotifyOnceListener<Response>() {
178+
@Override
179+
protected void innerOnResponse(Response response) {
180+
delegate.onResponse(response);
181+
}
182+
183+
@Override
184+
protected void innerOnFailure(Exception e) {
185+
delegate.onFailure(e);
186+
}
187+
};
188+
}
145189
}

server/src/test/java/org/elasticsearch/action/ActionListenerTests.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323

2424
import java.util.ArrayList;
2525
import java.util.List;
26+
import java.util.concurrent.atomic.AtomicBoolean;
2627
import java.util.concurrent.atomic.AtomicInteger;
2728
import java.util.concurrent.atomic.AtomicReference;
2829

30+
import static org.hamcrest.Matchers.equalTo;
31+
2932
public class ActionListenerTests extends ESTestCase {
3033

3134
public void testWrap() {
@@ -148,4 +151,54 @@ public void testOnFailure() {
148151
assertEquals("listener index " + i, "booom", excList.get(i).get().getMessage());
149152
}
150153
}
154+
155+
public void testRunAfter() {
156+
{
157+
AtomicBoolean afterSuccess = new AtomicBoolean();
158+
ActionListener<Object> listener = ActionListener.runAfter(ActionListener.wrap(r -> {}, e -> {}), () -> afterSuccess.set(true));
159+
listener.onResponse(null);
160+
assertThat(afterSuccess.get(), equalTo(true));
161+
}
162+
{
163+
AtomicBoolean afterFailure = new AtomicBoolean();
164+
ActionListener<Object> listener = ActionListener.runAfter(ActionListener.wrap(r -> {}, e -> {}), () -> afterFailure.set(true));
165+
listener.onFailure(null);
166+
assertThat(afterFailure.get(), equalTo(true));
167+
}
168+
}
169+
170+
public void testNotifyOnce() {
171+
AtomicInteger onResponseTimes = new AtomicInteger();
172+
AtomicInteger onFailureTimes = new AtomicInteger();
173+
ActionListener<Object> listener = ActionListener.notifyOnce(new ActionListener<Object>() {
174+
@Override
175+
public void onResponse(Object o) {
176+
onResponseTimes.getAndIncrement();
177+
}
178+
@Override
179+
public void onFailure(Exception e) {
180+
onFailureTimes.getAndIncrement();
181+
}
182+
});
183+
boolean success = randomBoolean();
184+
if (success) {
185+
listener.onResponse(null);
186+
} else {
187+
listener.onFailure(new RuntimeException("test"));
188+
}
189+
for (int iters = between(0, 10), i = 0; i < iters; i++) {
190+
if (randomBoolean()) {
191+
listener.onResponse(null);
192+
} else {
193+
listener.onFailure(new RuntimeException("test"));
194+
}
195+
}
196+
if (success) {
197+
assertThat(onResponseTimes.get(), equalTo(1));
198+
assertThat(onFailureTimes.get(), equalTo(0));
199+
} else {
200+
assertThat(onResponseTimes.get(), equalTo(0));
201+
assertThat(onFailureTimes.get(), equalTo(1));
202+
}
203+
}
151204
}

0 commit comments

Comments
 (0)