Skip to content

Commit 2c5951a

Browse files
authored
Use TransportChannel in TransportHandshaker (#54684)
Currently the TransportHandshaker has a specialized codepath for sending a response. In other work, we are going to start having handshakes contribute to circuit breaking (while not being breakable). This commit moves in that direction by allowing the handshaker to responding using a standard TcpTransportChannel similar to other requests.
1 parent 9f9ade7 commit 2c5951a

File tree

10 files changed

+104
-146
lines changed

10 files changed

+104
-146
lines changed

server/src/main/java/org/elasticsearch/transport/InboundHandler.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
155155
try {
156156
messageListener.onRequestReceived(requestId, action);
157157
if (header.isHandshake()) {
158-
handshaker.handleHandshake(version, channel, requestId, stream);
158+
// Handshakes are not currently circuit broken
159+
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version,
160+
circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
161+
handshaker.handleHandshake(transportChannel, requestId, stream);
159162
} else {
160163
final RequestHandlerRegistry<T> reg = getRequestHandler(action);
161164
if (reg == null) {
@@ -168,7 +171,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
168171
breaker.addWithoutBreaking(messageLengthBytes);
169172
}
170173
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version,
171-
circuitBreakerService, messageLengthBytes, header.isCompressed());
174+
circuitBreakerService, messageLengthBytes, header.isCompressed(), header.isHandshake());
172175
final T request = reg.newRequest(stream);
173176
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
174177
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
@@ -184,7 +187,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
184187
// the circuit breaker tripped
185188
if (transportChannel == null) {
186189
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version,
187-
circuitBreakerService, 0, header.isCompressed());
190+
circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
188191
}
189192
try {
190193
transportChannel.sendResponse(e);

server/src/main/java/org/elasticsearch/transport/TcpTransport.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P
144144
this.handshaker = new TransportHandshaker(version, threadPool,
145145
(node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
146146
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
147-
TransportRequestOptions.EMPTY, v, false, true),
148-
(v, channel, response, requestId) -> outboundHandler.sendResponse(v, channel, requestId,
149-
TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true));
147+
TransportRequestOptions.EMPTY, v, false, true));
150148
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
151149
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker,
152150
keepAlive);

server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ public final class TcpTransportChannel implements TransportChannel {
3737
private final CircuitBreakerService breakerService;
3838
private final long reservedBytes;
3939
private final boolean compressResponse;
40+
private final boolean isHandshake;
4041

4142
TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
42-
CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) {
43+
CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse, boolean isHandshake) {
4344
this.version = version;
4445
this.channel = channel;
4546
this.outboundHandler = outboundHandler;
@@ -48,6 +49,7 @@ public final class TcpTransportChannel implements TransportChannel {
4849
this.breakerService = breakerService;
4950
this.reservedBytes = reservedBytes;
5051
this.compressResponse = compressResponse;
52+
this.isHandshake = isHandshake;
5153
}
5254

5355
@Override
@@ -58,7 +60,7 @@ public String getProfileName() {
5860
@Override
5961
public void sendResponse(TransportResponse response) throws IOException {
6062
try {
61-
outboundHandler.sendResponse(version, channel, requestId, action, response, compressResponse, false);
63+
outboundHandler.sendResponse(version, channel, requestId, action, response, compressResponse, isHandshake);
6264
} finally {
6365
release(false);
6466
}
@@ -99,6 +101,5 @@ public Version getVersion() {
99101
public TcpChannel getChannel() {
100102
return channel;
101103
}
102-
103104
}
104105

server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,11 @@ final class TransportHandshaker {
4949
private final Version version;
5050
private final ThreadPool threadPool;
5151
private final HandshakeRequestSender handshakeRequestSender;
52-
private final HandshakeResponseSender handshakeResponseSender;
5352

54-
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
55-
HandshakeResponseSender handshakeResponseSender) {
53+
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender) {
5654
this.version = version;
5755
this.threadPool = threadPool;
5856
this.handshakeRequestSender = handshakeRequestSender;
59-
this.handshakeResponseSender = handshakeResponseSender;
6057
}
6158

6259
void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener<Version> listener) {
@@ -88,16 +85,15 @@ void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeV
8885
}
8986
}
9087

91-
void handleHandshake(Version version, TcpChannel channel, long requestId, StreamInput stream) throws IOException {
88+
void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
9289
// Must read the handshake request to exhaust the stream
9390
HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
9491
final int nextByte = stream.read();
9592
if (nextByte != -1) {
9693
throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
9794
+ TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
9895
}
99-
HandshakeResponse response = new HandshakeResponse(this.version);
100-
handshakeResponseSender.sendResponse(version, channel, response, requestId);
96+
channel.sendResponse(new HandshakeResponse(this.version));
10197
}
10298

10399
TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
@@ -228,12 +224,4 @@ interface HandshakeRequestSender {
228224

229225
void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException;
230226
}
231-
232-
@FunctionalInterface
233-
interface HandshakeResponseSender {
234-
235-
void sendResponse(Version version, TcpChannel channel, TransportResponse response, long requestId) throws IOException;
236-
237-
}
238-
239227
}

server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
import org.elasticsearch.test.transport.CapturingTransport;
5858
import org.elasticsearch.threadpool.TestThreadPool;
5959
import org.elasticsearch.threadpool.ThreadPool;
60-
import org.elasticsearch.transport.TransportChannel;
60+
import org.elasticsearch.transport.TestTransportChannel;
6161
import org.elasticsearch.transport.TransportResponse;
6262
import org.elasticsearch.transport.TransportService;
6363
import org.junit.After;
@@ -366,14 +366,15 @@ public void testOperationExecution() throws Exception {
366366
final TransportBroadcastByNodeAction.BroadcastByNodeTransportRequestHandler handler =
367367
action.new BroadcastByNodeTransportRequestHandler();
368368

369-
TestTransportChannel channel = new TestTransportChannel();
369+
final PlainActionFuture<TransportResponse> future = PlainActionFuture.newFuture();
370+
TestTransportChannel channel = new TestTransportChannel(future);
370371

371372
handler.messageReceived(action.new NodeRequest(nodeId, new Request(), new ArrayList<>(shards)), channel, null);
372373

373374
// check the operation was executed only on the expected shards
374375
assertEquals(shards, action.getResults().keySet());
375376

376-
TransportResponse response = channel.getCapturedResponse();
377+
TransportResponse response = future.actionGet();
377378
assertTrue(response instanceof TransportBroadcastByNodeAction.NodeResponse);
378379
TransportBroadcastByNodeAction.NodeResponse nodeResponse = (TransportBroadcastByNodeAction.NodeResponse) response;
379380

@@ -469,32 +470,4 @@ public void testResultAggregation() throws ExecutionException, InterruptedExcept
469470
assertEquals("failed shards", totalFailedShards, response.getFailedShards());
470471
assertEquals("accumulated exceptions", totalFailedShards, response.getShardFailures().length);
471472
}
472-
473-
public class TestTransportChannel implements TransportChannel {
474-
private TransportResponse capturedResponse;
475-
476-
public TransportResponse getCapturedResponse() {
477-
return capturedResponse;
478-
}
479-
480-
@Override
481-
public String getProfileName() {
482-
return "";
483-
}
484-
485-
@Override
486-
public void sendResponse(TransportResponse response) throws IOException {
487-
capturedResponse = response;
488-
}
489-
490-
@Override
491-
public void sendResponse(Exception exception) throws IOException {
492-
}
493-
494-
@Override
495-
public String getChannelType() {
496-
return "test";
497-
}
498-
499-
}
500473
}

server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
import org.elasticsearch.test.transport.MockTransportService;
7979
import org.elasticsearch.threadpool.TestThreadPool;
8080
import org.elasticsearch.threadpool.ThreadPool;
81+
import org.elasticsearch.transport.TestTransportChannel;
8182
import org.elasticsearch.transport.Transport;
8283
import org.elasticsearch.transport.TransportChannel;
8384
import org.elasticsearch.transport.TransportException;
@@ -817,7 +818,7 @@ public void testSeqNoIsSetOnPrimary() {
817818
Request request = new Request(shardId);
818819
TransportReplicationAction.ConcreteShardRequest<Request> concreteShardRequest =
819820
new TransportReplicationAction.ConcreteShardRequest<>(request, routingEntry.allocationId().getId(), primaryTerm);
820-
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
821+
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
821822

822823

823824
final IndexShard shard = mockIndexShard(shardId, clusterService);
@@ -981,7 +982,7 @@ public void testPrimaryActionRejectsWrongAidOrWrongTerm() throws Exception {
981982
setState(clusterService, state(index, true, ShardRoutingState.STARTED));
982983
final ShardRouting primary = clusterService.state().routingTable().shardRoutingTable(shardId).primaryShard();
983984
final long primaryTerm = clusterService.state().metadata().index(shardId.getIndexName()).primaryTerm(shardId.id());
984-
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
985+
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
985986
final boolean wrongAllocationId = randomBoolean();
986987
final long requestTerm = wrongAllocationId && randomBoolean() ? primaryTerm : primaryTerm + randomIntBetween(1, 10);
987988
Request request = new Request(shardId).timeout("1ms");
@@ -1018,7 +1019,7 @@ public void testReplicaActionRejectsWrongAid() throws Exception {
10181019
state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build();
10191020
setState(clusterService, state);
10201021

1021-
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
1022+
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
10221023
Request request = new Request(shardId).timeout("1ms");
10231024
action.handleReplicaRequest(
10241025
new TransportReplicationAction.ConcreteReplicaRequest<>(request, "_not_a_valid_aid_", randomNonNegativeLong(),
@@ -1062,7 +1063,7 @@ protected ReplicaResult shardOperationOnReplica(Request request, IndexShard repl
10621063
return new ReplicaResult();
10631064
}
10641065
};
1065-
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
1066+
final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
10661067
final Request request = new Request(shardId);
10671068
final long checkpoint = randomNonNegativeLong();
10681069
final long maxSeqNoOfUpdatesOrDeletes = randomNonNegativeLong();
@@ -1130,7 +1131,7 @@ protected ReplicaResult shardOperationOnReplica(Request request, IndexShard repl
11301131
return new ReplicaResult();
11311132
}
11321133
};
1133-
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
1134+
final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
11341135
final Request request = new Request(shardId);
11351136
final long checkpoint = randomNonNegativeLong();
11361137
final long maxSeqNoOfUpdates = randomNonNegativeLong();
@@ -1371,29 +1372,8 @@ private IndexShard mockIndexShard(ShardId shardId, ClusterService clusterService
13711372
/**
13721373
* Transport channel that is needed for replica operation testing.
13731374
*/
1374-
public TransportChannel createTransportChannel(final PlainActionFuture<TestResponse> listener) {
1375-
return new TransportChannel() {
1376-
1377-
@Override
1378-
public String getProfileName() {
1379-
return "";
1380-
}
1381-
1382-
@Override
1383-
public void sendResponse(TransportResponse response) {
1384-
listener.onResponse(((TestResponse) response));
1385-
}
1386-
1387-
@Override
1388-
public void sendResponse(Exception exception) {
1389-
listener.onFailure(exception);
1390-
}
1391-
1392-
@Override
1393-
public String getChannelType() {
1394-
return "replica_test";
1395-
}
1396-
};
1375+
public TransportChannel createTransportChannel(final PlainActionFuture<TransportResponse> listener) {
1376+
return new TestTransportChannel(listener);
13971377
}
13981378

13991379
}

server/src/test/java/org/elasticsearch/cluster/coordination/NodeJoinTests.java

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.apache.logging.log4j.message.ParameterizedMessage;
2222
import org.elasticsearch.Version;
23+
import org.elasticsearch.action.ActionListener;
2324
import org.elasticsearch.cluster.ClusterName;
2425
import org.elasticsearch.cluster.ClusterState;
2526
import org.elasticsearch.cluster.ESAllocationTestCase;
@@ -44,8 +45,8 @@
4445
import org.elasticsearch.threadpool.TestThreadPool;
4546
import org.elasticsearch.threadpool.ThreadPool;
4647
import org.elasticsearch.transport.RequestHandlerRegistry;
48+
import org.elasticsearch.transport.TestTransportChannel;
4749
import org.elasticsearch.transport.Transport;
48-
import org.elasticsearch.transport.TransportChannel;
4950
import org.elasticsearch.transport.TransportRequest;
5051
import org.elasticsearch.transport.TransportResponse;
5152
import org.elasticsearch.transport.TransportService;
@@ -229,29 +230,22 @@ private SimpleFuture joinNodeAsync(final JoinRequest joinRequest) {
229230
try {
230231
final RequestHandlerRegistry<JoinRequest> joinHandler = (RequestHandlerRegistry<JoinRequest>)
231232
transport.getRequestHandler(JoinHelper.JOIN_ACTION_NAME);
232-
joinHandler.processMessageReceived(joinRequest, new TransportChannel() {
233-
@Override
234-
public String getProfileName() {
235-
return "dummy";
236-
}
237-
238-
@Override
239-
public String getChannelType() {
240-
return "dummy";
241-
}
233+
final ActionListener<TransportResponse> listener = new ActionListener<>() {
242234

243235
@Override
244-
public void sendResponse(TransportResponse response) {
236+
public void onResponse(TransportResponse transportResponse) {
245237
logger.debug("{} completed", future);
246238
future.markAsDone();
247239
}
248240

249241
@Override
250-
public void sendResponse(Exception e) {
242+
public void onFailure(Exception e) {
251243
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
252244
future.markAsFailed(e);
253245
}
254-
});
246+
};
247+
248+
joinHandler.processMessageReceived(joinRequest, new TestTransportChannel(listener));
255249
} catch (Exception e) {
256250
logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
257251
future.markAsFailed(e);
@@ -402,27 +396,17 @@ public void testJoinFollowerWithHigherTerm() throws Exception {
402396
private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception {
403397
final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = (RequestHandlerRegistry<StartJoinRequest>)
404398
transport.getRequestHandler(JoinHelper.START_JOIN_ACTION_NAME);
405-
startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TransportChannel() {
399+
startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TestTransportChannel(new ActionListener<>() {
406400
@Override
407-
public String getProfileName() {
408-
return "dummy";
409-
}
401+
public void onResponse(TransportResponse transportResponse) {
410402

411-
@Override
412-
public String getChannelType() {
413-
return "dummy";
414403
}
415404

416405
@Override
417-
public void sendResponse(TransportResponse response) {
418-
419-
}
420-
421-
@Override
422-
public void sendResponse(Exception exception) {
406+
public void onFailure(Exception e) {
423407
fail();
424408
}
425-
});
409+
}));
426410
deterministicTaskQueue.runAllRunnableTasks();
427411
assertFalse(isLocalNodeElectedMaster());
428412
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.CANDIDATE));
@@ -432,27 +416,19 @@ private void handleFollowerCheckFrom(DiscoveryNode node, long term) throws Excep
432416
final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler =
433417
(RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest>)
434418
transport.getRequestHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
435-
followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), new TransportChannel() {
436-
@Override
437-
public String getProfileName() {
438-
return "dummy";
439-
}
440-
441-
@Override
442-
public String getChannelType() {
443-
return "dummy";
444-
}
445-
419+
final TestTransportChannel channel = new TestTransportChannel(new ActionListener<>() {
446420
@Override
447-
public void sendResponse(TransportResponse response) {
421+
public void onResponse(TransportResponse transportResponse) {
448422

449423
}
450424

451425
@Override
452-
public void sendResponse(Exception exception) {
426+
public void onFailure(Exception e) {
453427
fail();
454428
}
455429
});
430+
followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), channel);
431+
// Will throw exception if failed
456432
deterministicTaskQueue.runAllRunnableTasks();
457433
assertFalse(isLocalNodeElectedMaster());
458434
assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.FOLLOWER));

0 commit comments

Comments
 (0)