Skip to content

Commit 441075b

Browse files
committed
Use TransportChannel in TransportHandshaker (elastic#54684)
Currently the TransportHandshaker has a specialized codepath for sending a response. In other work, we are going to start having handshakes contribute to circuit breaking (while not being breakable). This commit moves in that direction by allowing the handshaker to responding using a standard TcpTransportChannel similar to other requests.
1 parent 9cf2406 commit 441075b

File tree

10 files changed

+109
-150
lines changed

10 files changed

+109
-150
lines changed

server/src/main/java/org/elasticsearch/transport/InboundHandler.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
157157
try {
158158
messageListener.onRequestReceived(requestId, action);
159159
if (header.isHandshake()) {
160-
handshaker.handleHandshake(version, features, channel, requestId, stream);
160+
// Handshakes are not currently circuit broken
161+
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
162+
circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
163+
handshaker.handleHandshake(transportChannel, requestId, stream);
161164
} else {
162165
final RequestHandlerRegistry<T> reg = getRequestHandler(action);
163166
if (reg == null) {
@@ -170,7 +173,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
170173
breaker.addWithoutBreaking(messageLengthBytes);
171174
}
172175
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
173-
circuitBreakerService, messageLengthBytes, header.isCompressed());
176+
circuitBreakerService, messageLengthBytes, header.isCompressed(), header.isHandshake());
174177
final T request = reg.newRequest(stream);
175178
request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
176179
// in case we throw an exception, i.e. when the limit is hit, we don't want to verify
@@ -186,7 +189,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
186189
// the circuit breaker tripped
187190
if (transportChannel == null) {
188191
transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
189-
circuitBreakerService, 0, header.isCompressed());
192+
circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
190193
}
191194
try {
192195
transportChannel.sendResponse(e);

server/src/main/java/org/elasticsearch/transport/TcpTransport.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,7 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P
159159
this.handshaker = new TransportHandshaker(version, threadPool,
160160
(node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
161161
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
162-
TransportRequestOptions.EMPTY, v, false, true),
163-
(v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, requestId,
164-
TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true));
162+
TransportRequestOptions.EMPTY, v, false, true));
165163
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
166164
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker,
167165
keepAlive);

server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ public final class TcpTransportChannel implements TransportChannel {
3939
private final CircuitBreakerService breakerService;
4040
private final long reservedBytes;
4141
private final boolean compressResponse;
42+
private final boolean isHandshake;
4243

4344
TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
44-
Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) {
45+
Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse,
46+
boolean isHandshake) {
4547
this.version = version;
4648
this.features = features;
4749
this.channel = channel;
@@ -51,6 +53,7 @@ public final class TcpTransportChannel implements TransportChannel {
5153
this.breakerService = breakerService;
5254
this.reservedBytes = reservedBytes;
5355
this.compressResponse = compressResponse;
56+
this.isHandshake = isHandshake;
5457
}
5558

5659
@Override
@@ -61,7 +64,7 @@ public String getProfileName() {
6164
@Override
6265
public void sendResponse(TransportResponse response) throws IOException {
6366
try {
64-
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, false);
67+
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, isHandshake);
6568
} finally {
6669
release(false);
6770
}
@@ -102,6 +105,5 @@ public Version getVersion() {
102105
public TcpChannel getChannel() {
103106
return channel;
104107
}
105-
106108
}
107109

server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
import java.io.EOFException;
3333
import java.io.IOException;
34-
import java.util.Set;
3534
import java.util.concurrent.ConcurrentHashMap;
3635
import java.util.concurrent.ConcurrentMap;
3736
import java.util.concurrent.atomic.AtomicBoolean;
@@ -49,14 +48,11 @@ final class TransportHandshaker {
4948
private final Version version;
5049
private final ThreadPool threadPool;
5150
private final HandshakeRequestSender handshakeRequestSender;
52-
private final HandshakeResponseSender handshakeResponseSender;
5351

54-
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
55-
HandshakeResponseSender handshakeResponseSender) {
52+
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender) {
5653
this.version = version;
5754
this.threadPool = threadPool;
5855
this.handshakeRequestSender = handshakeRequestSender;
59-
this.handshakeResponseSender = handshakeResponseSender;
6056
}
6157

6258
void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener<Version> listener) {
@@ -88,16 +84,15 @@ void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeV
8884
}
8985
}
9086

91-
void handleHandshake(Version version, Set<String> features, TcpChannel channel, long requestId, StreamInput stream) throws IOException {
87+
void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
9288
// Must read the handshake request to exhaust the stream
9389
HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
9490
final int nextByte = stream.read();
9591
if (nextByte != -1) {
9692
throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
9793
+ TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
9894
}
99-
HandshakeResponse response = new HandshakeResponse(this.version);
100-
handshakeResponseSender.sendResponse(version, features, channel, response, requestId);
95+
channel.sendResponse(new HandshakeResponse(this.version));
10196
}
10297

10398
TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
@@ -228,11 +223,4 @@ interface HandshakeRequestSender {
228223

229224
void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException;
230225
}
231-
232-
@FunctionalInterface
233-
interface HandshakeResponseSender {
234-
235-
void sendResponse(Version version, Set<String> features, TcpChannel channel, TransportResponse response, long requestId)
236-
throws IOException;
237-
}
238226
}

server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
import org.elasticsearch.test.transport.CapturingTransport;
5858
import org.elasticsearch.threadpool.TestThreadPool;
5959
import org.elasticsearch.threadpool.ThreadPool;
60-
import org.elasticsearch.transport.TransportChannel;
60+
import org.elasticsearch.transport.TestTransportChannel;
6161
import org.elasticsearch.transport.TransportResponse;
6262
import org.elasticsearch.transport.TransportService;
6363
import org.junit.After;
@@ -366,14 +366,15 @@ public void testOperationExecution() throws Exception {
366366
final TransportBroadcastByNodeAction.BroadcastByNodeTransportRequestHandler handler =
367367
action.new BroadcastByNodeTransportRequestHandler();
368368

369-
TestTransportChannel channel = new TestTransportChannel();
369+
final PlainActionFuture<TransportResponse> future = PlainActionFuture.newFuture();
370+
TestTransportChannel channel = new TestTransportChannel(future);
370371

371372
handler.messageReceived(action.new NodeRequest(nodeId, new Request(), new ArrayList<>(shards)), channel, null);
372373

373374
// check the operation was executed only on the expected shards
374375
assertEquals(shards, action.getResults().keySet());
375376

376-
TransportResponse response = channel.getCapturedResponse();
377+
TransportResponse response = future.actionGet();
377378
assertTrue(response instanceof TransportBroadcastByNodeAction.NodeResponse);
378379
TransportBroadcastByNodeAction.NodeResponse nodeResponse = (TransportBroadcastByNodeAction.NodeResponse) response;
379380

@@ -469,32 +470,4 @@ public void testResultAggregation() throws ExecutionException, InterruptedExcept
469470
assertEquals("failed shards", totalFailedShards, response.getFailedShards());
470471
assertEquals("accumulated exceptions", totalFailedShards, response.getShardFailures().length);
471472
}
472-
473-
public class TestTransportChannel implements TransportChannel {
474-
private TransportResponse capturedResponse;
475-
476-
public TransportResponse getCapturedResponse() {
477-
return capturedResponse;
478-
}
479-
480-
@Override
481-
public String getProfileName() {
482-
return "";
483-
}
484-
485-
@Override
486-
public void sendResponse(TransportResponse response) throws IOException {
487-
capturedResponse = response;
488-
}
489-
490-
@Override
491-
public void sendResponse(Exception exception) throws IOException {
492-
}
493-
494-
@Override
495-
public String getChannelType() {
496-
return "test";
497-
}
498-
499-
}
500473
}

server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
import org.elasticsearch.test.transport.MockTransportService;
7979
import org.elasticsearch.threadpool.TestThreadPool;
8080
import org.elasticsearch.threadpool.ThreadPool;
81+
import org.elasticsearch.transport.TestTransportChannel;
8182
import org.elasticsearch.transport.Transport;
8283
import org.elasticsearch.transport.TransportChannel;
8384
import org.elasticsearch.transport.TransportException;
@@ -817,7 +818,7 @@ public void testSeqNoIsSetOnPrimary() {
817818
Request request = new Request(shardId);
818819
TransportReplicationAction.ConcreteShardRequest<Request> concreteShardRequest =
819820
new TransportReplicationAction.ConcreteShardRequest<>(request, routingEntry.allocationId().getId(), primaryTerm);
820-
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
821+
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
821822

822823

823824
final IndexShard shard = mockIndexShard(shardId, clusterService);
@@ -981,7 +982,7 @@ public void testPrimaryActionRejectsWrongAidOrWrongTerm() throws Exception {
981982
setState(clusterService, state(index, true, ShardRoutingState.STARTED));
982983
final ShardRouting primary = clusterService.state().routingTable().shardRoutingTable(shardId).primaryShard();
983984
final long primaryTerm = clusterService.state().metadata().index(shardId.getIndexName()).primaryTerm(shardId.id());
984-
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
985+
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
985986
final boolean wrongAllocationId = randomBoolean();
986987
final long requestTerm = wrongAllocationId && randomBoolean() ? primaryTerm : primaryTerm + randomIntBetween(1, 10);
987988
Request request = new Request(shardId).timeout("1ms");
@@ -1018,7 +1019,7 @@ public void testReplicaActionRejectsWrongAid() throws Exception {
10181019
state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build();
10191020
setState(clusterService, state);
10201021

1021-
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
1022+
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
10221023
Request request = new Request(shardId).timeout("1ms");
10231024
action.handleReplicaRequest(
10241025
new TransportReplicationAction.ConcreteReplicaRequest<>(request, "_not_a_valid_aid_", randomNonNegativeLong(),
@@ -1062,7 +1063,7 @@ protected ReplicaResult shardOperationOnReplica(Request request, IndexShard repl
10621063
return new ReplicaResult();
10631064
}
10641065
};
1065-
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
1066+
final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
10661067
final Request request = new Request(shardId);
10671068
final long checkpoint = randomNonNegativeLong();
10681069
final long maxSeqNoOfUpdatesOrDeletes = randomNonNegativeLong();
@@ -1130,7 +1131,7 @@ protected ReplicaResult shardOperationOnReplica(Request request, IndexShard repl
11301131
return new ReplicaResult();
11311132
}
11321133
};
1133-
final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
1134+
final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
11341135
final Request request = new Request(shardId);
11351136
final long checkpoint = randomNonNegativeLong();
11361137
final long maxSeqNoOfUpdates = randomNonNegativeLong();
@@ -1371,29 +1372,8 @@ private IndexShard mockIndexShard(ShardId shardId, ClusterService clusterService
13711372
/**
13721373
* Transport channel that is needed for replica operation testing.
13731374
*/
1374-
public TransportChannel createTransportChannel(final PlainActionFuture<TestResponse> listener) {
1375-
return new TransportChannel() {
1376-
1377-
@Override
1378-
public String getProfileName() {
1379-
return "";
1380-
}
1381-
1382-
@Override
1383-
public void sendResponse(TransportResponse response) {
1384-
listener.onResponse(((TestResponse) response));
1385-
}
1386-
1387-
@Override
1388-
public void sendResponse(Exception exception) {
1389-
listener.onFailure(exception);
1390-
}
1391-
1392-
@Override
1393-
public String getChannelType() {
1394-
return "replica_test";
1395-
}
1396-
};
1375+
public TransportChannel createTransportChannel(final PlainActionFuture<TransportResponse> listener) {
1376+
return new TestTransportChannel(listener);
13971377
}
13981378

13991379
}

0 commit comments

Comments
 (0)