diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index 4b4cbd3414554..c8b96b8bf6da2 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -97,12 +97,12 @@ public class Netty4Transport extends TcpTransport { intSetting("transport.netty.boss_count", 1, 1, Property.NodeScope); - protected final RecvByteBufAllocator recvByteBufAllocator; - protected final int workerCount; - protected final ByteSizeValue receivePredictorMin; - protected final ByteSizeValue receivePredictorMax; - protected volatile Bootstrap bootstrap; - protected final Map serverBootstraps = newConcurrentMap(); + private final RecvByteBufAllocator recvByteBufAllocator; + private final int workerCount; + private final ByteSizeValue receivePredictorMin; + private final ByteSizeValue receivePredictorMax; + private volatile Bootstrap clientBootstrap; + private final Map serverBootstraps = newConcurrentMap(); public Netty4Transport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { @@ -125,7 +125,7 @@ public Netty4Transport(Settings settings, ThreadPool threadPool, NetworkService protected void doStart() { boolean success = false; try { - bootstrap = createBootstrap(); + clientBootstrap = createClientBootstrap(); if (NetworkService.NETWORK_SERVER.get(settings)) { for (ProfileSettings profileSettings : profileSettings) { createServerBootstrap(profileSettings); @@ -141,13 +141,11 @@ protected void doStart() { } } - private Bootstrap createBootstrap() { + private Bootstrap createClientBootstrap() { final Bootstrap bootstrap = new Bootstrap(); bootstrap.group(new NioEventLoopGroup(workerCount, daemonThreadFactory(settings, TRANSPORT_CLIENT_BOSS_THREAD_NAME_PREFIX))); bootstrap.channel(NioSocketChannel.class); - bootstrap.handler(getClientChannelInitializer()); - bootstrap.option(ChannelOption.TCP_NODELAY, TCP_NO_DELAY.get(settings)); bootstrap.option(ChannelOption.SO_KEEPALIVE, TCP_KEEP_ALIVE.get(settings)); @@ -166,8 +164,6 @@ private Bootstrap createBootstrap() { final boolean reuseAddress = TCP_REUSE_ADDRESS.get(settings); bootstrap.option(ChannelOption.SO_REUSEADDR, reuseAddress); - bootstrap.validate(); - return bootstrap; } @@ -215,7 +211,7 @@ protected ChannelHandler getServerChannelInitializer(String name) { return new ServerChannelInitializer(name); } - protected ChannelHandler getClientChannelInitializer() { + protected ChannelHandler getClientChannelInitializer(DiscoveryNode node) { return new ClientChannelInitializer(); } @@ -231,7 +227,11 @@ protected final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) @Override protected NettyTcpChannel initiateChannel(DiscoveryNode node, ActionListener listener) throws IOException { InetSocketAddress address = node.getAddress().address(); - ChannelFuture channelFuture = bootstrap.connect(address); + Bootstrap bootstrapWithHandler = clientBootstrap.clone(); + bootstrapWithHandler.handler(getClientChannelInitializer(node)); + bootstrapWithHandler.remoteAddress(address); + ChannelFuture channelFuture = bootstrapWithHandler.connect(); + Channel channel = channelFuture.channel(); if (channel == null) { ExceptionsHelper.maybeDieOnAnotherThread(channelFuture.cause()); @@ -294,9 +294,9 @@ protected void stopInternal() { } serverBootstraps.clear(); - if (bootstrap != null) { - bootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS).awaitUninterruptibly(); - bootstrap = null; + if (clientBootstrap != null) { + clientBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS).awaitUninterruptibly(); + clientBootstrap = null; } }); } diff --git a/server/src/main/java/org/elasticsearch/node/Node.java b/server/src/main/java/org/elasticsearch/node/Node.java index 7dcd6b30ca8fd..ae783d1c8fba0 100644 --- a/server/src/main/java/org/elasticsearch/node/Node.java +++ b/server/src/main/java/org/elasticsearch/node/Node.java @@ -153,6 +153,7 @@ import org.elasticsearch.usage.UsageService; import org.elasticsearch.watcher.ResourceWatcherService; +import javax.net.ssl.SNIHostName; import java.io.BufferedWriter; import java.io.Closeable; import java.io.IOException; @@ -212,6 +213,13 @@ public abstract class Node implements Closeable { throw new IllegalArgumentException(key + " cannot have leading or trailing whitespace " + "[" + value + "]"); } + if (value.length() > 0 && "node.attr.server_name".equals(key)) { + try { + new SNIHostName(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("invalid node.attr.server_name [" + value + "]", e ); + } + } return value; }, Property.NodeScope)); public static final Setting BREAKER_TYPE_KEY = new Setting<>("indices.breaker.type", "hierarchy", (s) -> { diff --git a/server/src/test/java/org/elasticsearch/node/NodeTests.java b/server/src/test/java/org/elasticsearch/node/NodeTests.java index 254823791d5cd..52cd91ac5f6cc 100644 --- a/server/src/test/java/org/elasticsearch/node/NodeTests.java +++ b/server/src/test/java/org/elasticsearch/node/NodeTests.java @@ -150,6 +150,27 @@ public void testNodeAttributes() throws IOException { assertSettingDeprecationsAndWarnings(new Setting[] { NetworkModule.HTTP_ENABLED }); } + public void testServerNameNodeAttribute() throws IOException { + String attr = "valid-hostname"; + Settings.Builder settings = baseSettings().put(Node.NODE_ATTRIBUTES.getKey() + "server_name", attr); + int i = 0; + try (Node node = new MockNode(settings.build(), Collections.singleton(getTestTransportPlugin()))) { + final Settings nodeSettings = randomBoolean() ? node.settings() : node.getEnvironment().settings(); + assertEquals(attr, Node.NODE_ATTRIBUTES.getAsMap(nodeSettings).get("server_name")); + } + + // non-LDH hostname not allowed + attr = "invalid_hostname"; + settings = baseSettings().put(Node.NODE_ATTRIBUTES.getKey() + "server_name", attr); + try (Node node = new MockNode(settings.build(), Collections.singleton(getTestTransportPlugin()))) { + fail("should not allow a server_name attribute with an underscore"); + } catch (IllegalArgumentException e) { + assertEquals("invalid node.attr.server_name [invalid_hostname]", e.getMessage()); + } + + assertSettingDeprecationsAndWarnings(new Setting[] { NetworkModule.HTTP_ENABLED }); + } + private static Settings.Builder baseSettings() { final Path tempDir = createTempDir(); return Settings.builder() diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 8c654ab8883f6..725aad17d7053 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -1937,16 +1937,12 @@ public void testTimeoutPerConnection() throws IOException { public void testHandshakeWithIncompatVersion() { assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport); - NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); Version version = Version.fromString("2.0.0"); - try (MockTcpTransport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, - new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version); - MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, null, - Collections.emptySet())) { + try (MockTransportService service = build(Settings.EMPTY, version, null, true)) { service.start(); service.acceptIncomingRequests(); - DiscoveryNode node = - new DiscoveryNode("TS_TPC", "TS_TPC", transport.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); + TransportAddress address = service.boundAddress().publishAddress(); + DiscoveryNode node = new DiscoveryNode("TS_TPC", "TS_TPC", address, emptyMap(), emptySet(), version0); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.addConnections(1, TransportRequestOptions.Type.BULK, @@ -1962,15 +1958,11 @@ public void testHandshakeUpdatesVersion() throws IOException { assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport); NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); Version version = VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(), Version.CURRENT); - try (MockTcpTransport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE, - new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version); - MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, null, - Collections.emptySet())) { + try (MockTransportService service = build(Settings.EMPTY, version, null, true)) { service.start(); service.acceptIncomingRequests(); - DiscoveryNode node = - new DiscoveryNode("TS_TPC", "TS_TPC", transport.boundAddress().publishAddress(), emptyMap(), emptySet(), - Version.fromString("2.0.0")); + TransportAddress address = service.boundAddress().publishAddress(); + DiscoveryNode node = new DiscoveryNode("TS_TPC", "TS_TPC", address, emptyMap(), emptySet(), Version.fromString("2.0.0")); ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); builder.addConnections(1, TransportRequestOptions.Type.BULK, @@ -2689,7 +2681,7 @@ private void closeConnectionChannel(Transport.Connection connection) { } @SuppressForbidden(reason = "need local ephemeral port") - private InetSocketAddress getLocalEphemeral() throws UnknownHostException { + protected InetSocketAddress getLocalEphemeral() throws UnknownHostException { return new InetSocketAddress(InetAddress.getLocalHost(), 0); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java index 1773b9664bbd4..9253f29741c8e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/transport/netty4/SecurityNetty4Transport.java @@ -12,12 +12,14 @@ import io.netty.channel.ChannelPromise; import io.netty.handler.ssl.SslHandler; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.netty4.Netty4Transport; @@ -26,7 +28,10 @@ import org.elasticsearch.xpack.core.ssl.SSLConfiguration; import org.elasticsearch.xpack.core.ssl.SSLService; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.Collections; @@ -105,8 +110,8 @@ protected ChannelHandler getNoSslChannelInitializer(final String name) { } @Override - protected ChannelHandler getClientChannelInitializer() { - return new SecurityClientChannelInitializer(); + protected ChannelHandler getClientChannelInitializer(DiscoveryNode node) { + return new SecurityClientChannelInitializer(node); } @Override @@ -166,16 +171,28 @@ protected ServerChannelInitializer getSslChannelInitializer(final String name, f private class SecurityClientChannelInitializer extends ClientChannelInitializer { private final boolean hostnameVerificationEnabled; + private final SNIHostName serverName; - SecurityClientChannelInitializer() { + SecurityClientChannelInitializer(DiscoveryNode node) { this.hostnameVerificationEnabled = sslEnabled && sslConfiguration.verificationMode().isHostnameVerificationEnabled(); + String configuredServerName = node.getAttributes().get("server_name"); + if (configuredServerName != null) { + try { + serverName = new SNIHostName(configuredServerName); + } catch (IllegalArgumentException e) { + throw new ConnectTransportException(node, "invalid DiscoveryNode server_name [" + configuredServerName + "]", e); + } + } else { + serverName = null; + } } @Override protected void initChannel(Channel ch) throws Exception { super.initChannel(ch); if (sslEnabled) { - ch.pipeline().addFirst(new ClientSslHandlerInitializer(sslConfiguration, sslService, hostnameVerificationEnabled)); + ch.pipeline().addFirst(new ClientSslHandlerInitializer(sslConfiguration, sslService, hostnameVerificationEnabled, + serverName)); } } } @@ -185,11 +202,14 @@ private static class ClientSslHandlerInitializer extends ChannelOutboundHandlerA private final boolean hostnameVerificationEnabled; private final SSLConfiguration sslConfiguration; private final SSLService sslService; + private final SNIServerName serverName; - private ClientSslHandlerInitializer(SSLConfiguration sslConfiguration, SSLService sslService, boolean hostnameVerificationEnabled) { + private ClientSslHandlerInitializer(SSLConfiguration sslConfiguration, SSLService sslService, boolean hostnameVerificationEnabled, + SNIServerName serverName) { this.sslConfiguration = sslConfiguration; this.hostnameVerificationEnabled = hostnameVerificationEnabled; this.sslService = sslService; + this.serverName = serverName; } @Override @@ -206,6 +226,11 @@ public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, } sslEngine.setUseClientMode(true); + if (serverName != null) { + SSLParameters sslParameters = sslEngine.getSSLParameters(); + sslParameters.setServerNames(Collections.singletonList(serverName)); + sslEngine.setSSLParameters(sslParameters); + } ctx.pipeline().replace(this, "ssl", new SslHandler(sslEngine)); super.connect(ctx, remoteAddress, localAddress, promise); } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4TransportTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4TransportTests.java new file mode 100644 index 0000000000000..5181f3a747ead --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SimpleSecurityNetty4TransportTests.java @@ -0,0 +1,383 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.security.transport.netty4; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.ssl.SslHandler; +import org.elasticsearch.Version; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.SuppressForbidden; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; +import org.elasticsearch.common.network.NetworkService; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.node.Node; +import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.AbstractSimpleTransportTestCase; +import org.elasticsearch.transport.BindTransportException; +import org.elasticsearch.transport.ConnectTransportException; +import org.elasticsearch.transport.ConnectionProfile; +import org.elasticsearch.transport.TcpChannel; +import org.elasticsearch.transport.TcpTransport; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.common.socket.SocketAccess; +import org.elasticsearch.xpack.core.security.transport.netty4.SecurityNetty4Transport; +import org.elasticsearch.xpack.core.ssl.SSLConfiguration; +import org.elasticsearch.xpack.core.ssl.SSLService; + +import javax.net.SocketFactory; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIMatcher; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocket; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketTimeoutException; +import java.net.UnknownHostException; +import java.nio.file.Path; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.elasticsearch.xpack.core.security.SecurityField.setting; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; + +public class SimpleSecurityNetty4TransportTests extends AbstractSimpleTransportTestCase { + + private static final ConnectionProfile SINGLE_CHANNEL_PROFILE; + + static { + ConnectionProfile.Builder builder = new ConnectionProfile.Builder(); + builder.addConnections(1, + TransportRequestOptions.Type.BULK, + TransportRequestOptions.Type.PING, + TransportRequestOptions.Type.RECOVERY, + TransportRequestOptions.Type.REG, + TransportRequestOptions.Type.STATE); + SINGLE_CHANNEL_PROFILE = builder.build(); + } + + private SSLService createSSLService() { + Path testnodeCert = getDataPath("/org/elasticsearch/xpack/security/transport/ssl/certs/simple/testnode.crt"); + Path testnodeKey = getDataPath("/org/elasticsearch/xpack/security/transport/ssl/certs/simple/testnode.pem"); + MockSecureSettings secureSettings = new MockSecureSettings(); + secureSettings.setString("xpack.ssl.secure_key_passphrase", "testnode"); + Settings settings = Settings.builder() + .put("xpack.security.transport.ssl.enabled", true) + .put("xpack.ssl.key", testnodeKey) + .put("xpack.ssl.certificate", testnodeCert) + .put("path.home", createTempDir()) + .setSecureSettings(secureSettings) + .build(); + try { + return new SSLService(settings, TestEnvironment.newEnvironment(settings)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version, + ClusterSettings clusterSettings, boolean doHandshake) { + NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); + NetworkService networkService = new NetworkService(Collections.emptyList()); + Settings settings1 = Settings.builder() + .put(settings) + .put("xpack.security.transport.ssl.enabled", true).build(); + Transport transport = new SecurityNetty4Transport(settings1, threadPool, + networkService, BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, + new NoneCircuitBreakerService(), createSSLService()) { + + @Override + protected Version executeHandshake(DiscoveryNode node, TcpChannel channel, TimeValue timeout) throws IOException, + InterruptedException { + if (doHandshake) { + return super.executeHandshake(node, channel, timeout); + } else { + return version.minimumCompatibilityVersion(); + } + } + + @Override + protected Version getCurrentVersion() { + return version; + } + + }; + MockTransportService mockTransportService = + MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, clusterSettings, + Collections.emptySet()); + mockTransportService.start(); + return mockTransportService; + } + + @Override + protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) { + settings = Settings.builder().put(settings) + .put(TcpTransport.PORT.getKey(), "0") + .build(); + MockTransportService transportService = nettyFromThreadPool(settings, threadPool, version, clusterSettings, doHandshake); + transportService.start(); + return transportService; + } + + public void testConnectException() throws UnknownHostException { + try { + serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876), + emptyMap(), emptySet(), Version.CURRENT)); + fail("Expected ConnectTransportException"); + } catch (ConnectTransportException e) { + assertThat(e.getMessage(), containsString("connect_exception")); + assertThat(e.getMessage(), containsString("[127.0.0.1:9876]")); + Throwable cause = e.getCause(); + assertThat(cause, instanceOf(IOException.class)); + } + } + + public void testBindUnavailableAddress() { + // this is on a lower level since it needs access to the TransportService before it's started + int port = serviceA.boundAddress().publishAddress().getPort(); + Settings settings = Settings.builder() + .put(Node.NODE_NAME_SETTING.getKey(), "foobar") + .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") + .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") + .put("transport.tcp.port", port) + .build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> { + MockTransportService transportService = nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true); + try { + transportService.start(); + } finally { + transportService.stop(); + transportService.close(); + } + }); + assertEquals("Failed to bind to [" + port + "]", bindTransportException.getMessage()); + } + + @SuppressForbidden(reason = "Need to open socket connection") + public void testRenegotiation() throws Exception { + SSLService sslService = createSSLService(); + final SSLConfiguration sslConfiguration = sslService.getSSLConfiguration("xpack.ssl"); + SocketFactory factory = sslService.sslSocketFactory(sslConfiguration); + try (SSLSocket socket = (SSLSocket) factory.createSocket()) { + SocketAccess.doPrivileged(() -> socket.connect(serviceA.boundAddress().publishAddress().address())); + + CountDownLatch handshakeLatch = new CountDownLatch(1); + HandshakeCompletedListener firstListener = event -> handshakeLatch.countDown(); + socket.addHandshakeCompletedListener(firstListener); + socket.startHandshake(); + handshakeLatch.await(); + socket.removeHandshakeCompletedListener(firstListener); + + OutputStreamStreamOutput stream = new OutputStreamStreamOutput(socket.getOutputStream()); + stream.writeByte((byte) 'E'); + stream.writeByte((byte) 'S'); + stream.writeInt(-1); + stream.flush(); + + socket.startHandshake(); + CountDownLatch renegotiationLatch = new CountDownLatch(1); + HandshakeCompletedListener secondListener = event -> renegotiationLatch.countDown(); + socket.addHandshakeCompletedListener(secondListener); + + AtomicReference error = new AtomicReference<>(); + CountDownLatch catchReadErrorsLatch = new CountDownLatch(1); + Thread renegotiationThread = new Thread(() -> { + try { + socket.setSoTimeout(50); + socket.getInputStream().read(); + } catch (SocketTimeoutException e) { + // Ignore. We expect a timeout. + } catch (IOException e) { + error.set(e); + } finally { + catchReadErrorsLatch.countDown(); + } + }); + renegotiationThread.start(); + renegotiationLatch.await(); + socket.removeHandshakeCompletedListener(secondListener); + catchReadErrorsLatch.await(); + + assertNull(error.get()); + + stream.writeByte((byte) 'E'); + stream.writeByte((byte) 'S'); + stream.writeInt(-1); + stream.flush(); + } + } + + // TODO: These tests currently rely on plaintext transports + + @Override + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/33285") + public void testTcpHandshake() { + } + + // TODO: These tests as configured do not currently work with the security transport + + @Override + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/33285") + public void testTransportProfilesWithPortAndHost() { + } + + public void testSNIServerNameIsPropagated() throws Exception { + SSLService sslService = createSSLService(); + final ServerBootstrap serverBootstrap = new ServerBootstrap(); + boolean success = false; + try { + serverBootstrap.group(new NioEventLoopGroup(1)); + serverBootstrap.channel(NioServerSocketChannel.class); + + final String sniIp = "sni-hostname"; + final SNIHostName sniHostName = new SNIHostName(sniIp); + final CountDownLatch latch = new CountDownLatch(2); + serverBootstrap.childHandler(new ChannelInitializer() { + + @Override + protected void initChannel(Channel ch) { + SSLEngine serverEngine = sslService.createSSLEngine(sslService.getSSLConfiguration(setting("transport.ssl.")), + null, -1); + serverEngine.setUseClientMode(false); + SSLParameters sslParameters = serverEngine.getSSLParameters(); + sslParameters.setSNIMatchers(Collections.singletonList(new SNIMatcher(0) { + @Override + public boolean matches(SNIServerName sniServerName) { + if (sniHostName.equals(sniServerName)) { + latch.countDown(); + return true; + } else { + return false; + } + } + })); + serverEngine.setSSLParameters(sslParameters); + final SslHandler sslHandler = new SslHandler(serverEngine); + sslHandler.handshakeFuture().addListener(future -> latch.countDown()); + ch.pipeline().addFirst("sslhandler", sslHandler); + } + }); + serverBootstrap.validate(); + ChannelFuture serverFuture = serverBootstrap.bind(getLocalEphemeral()); + serverFuture.await(); + InetSocketAddress serverAddress = (InetSocketAddress) serverFuture.channel().localAddress(); + + try (MockTransportService serviceC = build( + Settings.builder() + .put("name", "TS_TEST") + .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") + .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") + .build(), + version0, + null, true)) { + serviceC.acceptIncomingRequests(); + + HashMap attributes = new HashMap<>(); + attributes.put("server_name", sniIp); + DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes, + EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT); + + new Thread(() -> { + try { + serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE); + } catch (ConnectTransportException ex) { + // Ignore. The other side is not setup to do the ES handshake. So this will fail. + } + }).start(); + + latch.await(); + serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS); + success = true; + } + } finally { + if (success == false) { + serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS); + } + } + } + + public void testInvalidSNIServerName() throws Exception { + SSLService sslService = createSSLService(); + final ServerBootstrap serverBootstrap = new ServerBootstrap(); + boolean success = false; + try { + serverBootstrap.group(new NioEventLoopGroup(1)); + serverBootstrap.channel(NioServerSocketChannel.class); + + final String sniIp = "invalid_hostname"; + serverBootstrap.childHandler(new ChannelInitializer() { + + @Override + protected void initChannel(Channel ch) { + SSLEngine serverEngine = sslService.createSSLEngine(sslService.getSSLConfiguration(setting("transport.ssl.")), + null, -1); + serverEngine.setUseClientMode(false); + final SslHandler sslHandler = new SslHandler(serverEngine); + ch.pipeline().addFirst("sslhandler", sslHandler); + } + }); + serverBootstrap.validate(); + ChannelFuture serverFuture = serverBootstrap.bind(getLocalEphemeral()); + serverFuture.await(); + InetSocketAddress serverAddress = (InetSocketAddress) serverFuture.channel().localAddress(); + + try (MockTransportService serviceC = build( + Settings.builder() + .put("name", "TS_TEST") + .put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "") + .put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING") + .build(), + version0, + null, true)) { + serviceC.acceptIncomingRequests(); + + HashMap attributes = new HashMap<>(); + attributes.put("server_name", sniIp); + DiscoveryNode node = new DiscoveryNode("server_node_id", new TransportAddress(serverAddress), attributes, + EnumSet.allOf(DiscoveryNode.Role.class), Version.CURRENT); + + ConnectTransportException connectException = expectThrows(ConnectTransportException.class, + () -> serviceC.connectToNode(node, SINGLE_CHANNEL_PROFILE)); + + assertThat(connectException.getMessage(), containsString("invalid DiscoveryNode server_name [invalid_hostname]")); + + serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS); + success = true; + } + } finally { + if (success == false) { + serverBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS); + } + } + } +}