Skip to content

Add sni name to SSLEngine in netty transport (#33144) #33513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ public class Netty4Transport extends TcpTransport {
intSetting("transport.netty.boss_count", 1, 1, Property.NodeScope);


protected final RecvByteBufAllocator recvByteBufAllocator;
protected final int workerCount;
protected final ByteSizeValue receivePredictorMin;
protected final ByteSizeValue receivePredictorMax;
protected volatile Bootstrap bootstrap;
protected final Map<String, ServerBootstrap> serverBootstraps = newConcurrentMap();
private final RecvByteBufAllocator recvByteBufAllocator;
private final int workerCount;
private final ByteSizeValue receivePredictorMin;
private final ByteSizeValue receivePredictorMax;
private volatile Bootstrap clientBootstrap;
private final Map<String, ServerBootstrap> serverBootstraps = newConcurrentMap();

public Netty4Transport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) {
Expand All @@ -125,7 +125,7 @@ public Netty4Transport(Settings settings, ThreadPool threadPool, NetworkService
protected void doStart() {
boolean success = false;
try {
bootstrap = createBootstrap();
clientBootstrap = createClientBootstrap();
if (NetworkService.NETWORK_SERVER.get(settings)) {
for (ProfileSettings profileSettings : profileSettings) {
createServerBootstrap(profileSettings);
Expand All @@ -141,13 +141,11 @@ protected void doStart() {
}
}

private Bootstrap createBootstrap() {
private Bootstrap createClientBootstrap() {
final Bootstrap bootstrap = new Bootstrap();
bootstrap.group(new NioEventLoopGroup(workerCount, daemonThreadFactory(settings, TRANSPORT_CLIENT_BOSS_THREAD_NAME_PREFIX)));
bootstrap.channel(NioSocketChannel.class);

bootstrap.handler(getClientChannelInitializer());

bootstrap.option(ChannelOption.TCP_NODELAY, TCP_NO_DELAY.get(settings));
bootstrap.option(ChannelOption.SO_KEEPALIVE, TCP_KEEP_ALIVE.get(settings));

Expand All @@ -166,8 +164,6 @@ private Bootstrap createBootstrap() {
final boolean reuseAddress = TCP_REUSE_ADDRESS.get(settings);
bootstrap.option(ChannelOption.SO_REUSEADDR, reuseAddress);

bootstrap.validate();

return bootstrap;
}

Expand Down Expand Up @@ -215,7 +211,7 @@ protected ChannelHandler getServerChannelInitializer(String name) {
return new ServerChannelInitializer(name);
}

protected ChannelHandler getClientChannelInitializer() {
protected ChannelHandler getClientChannelInitializer(DiscoveryNode node) {
return new ClientChannelInitializer();
}

Expand All @@ -231,7 +227,11 @@ protected final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
@Override
protected NettyTcpChannel initiateChannel(DiscoveryNode node, ActionListener<Void> listener) throws IOException {
InetSocketAddress address = node.getAddress().address();
ChannelFuture channelFuture = bootstrap.connect(address);
Bootstrap bootstrapWithHandler = clientBootstrap.clone();
bootstrapWithHandler.handler(getClientChannelInitializer(node));
bootstrapWithHandler.remoteAddress(address);
ChannelFuture channelFuture = bootstrapWithHandler.connect();

Channel channel = channelFuture.channel();
if (channel == null) {
ExceptionsHelper.maybeDieOnAnotherThread(channelFuture.cause());
Expand Down Expand Up @@ -294,9 +294,9 @@ protected void stopInternal() {
}
serverBootstraps.clear();

if (bootstrap != null) {
bootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS).awaitUninterruptibly();
bootstrap = null;
if (clientBootstrap != null) {
clientBootstrap.config().group().shutdownGracefully(0, 5, TimeUnit.SECONDS).awaitUninterruptibly();
clientBootstrap = null;
}
});
}
Expand Down
8 changes: 8 additions & 0 deletions server/src/main/java/org/elasticsearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
import org.elasticsearch.usage.UsageService;
import org.elasticsearch.watcher.ResourceWatcherService;

import javax.net.ssl.SNIHostName;
import java.io.BufferedWriter;
import java.io.Closeable;
import java.io.IOException;
Expand Down Expand Up @@ -212,6 +213,13 @@ public abstract class Node implements Closeable {
throw new IllegalArgumentException(key + " cannot have leading or trailing whitespace " +
"[" + value + "]");
}
if (value.length() > 0 && "node.attr.server_name".equals(key)) {
try {
new SNIHostName(value);
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException("invalid node.attr.server_name [" + value + "]", e );
}
}
return value;
}, Property.NodeScope));
public static final Setting<String> BREAKER_TYPE_KEY = new Setting<>("indices.breaker.type", "hierarchy", (s) -> {
Expand Down
21 changes: 21 additions & 0 deletions server/src/test/java/org/elasticsearch/node/NodeTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,27 @@ public void testNodeAttributes() throws IOException {
assertSettingDeprecationsAndWarnings(new Setting<?>[] { NetworkModule.HTTP_ENABLED });
}

public void testServerNameNodeAttribute() throws IOException {
String attr = "valid-hostname";
Settings.Builder settings = baseSettings().put(Node.NODE_ATTRIBUTES.getKey() + "server_name", attr);
int i = 0;
try (Node node = new MockNode(settings.build(), Collections.singleton(getTestTransportPlugin()))) {
final Settings nodeSettings = randomBoolean() ? node.settings() : node.getEnvironment().settings();
assertEquals(attr, Node.NODE_ATTRIBUTES.getAsMap(nodeSettings).get("server_name"));
}

// non-LDH hostname not allowed
attr = "invalid_hostname";
settings = baseSettings().put(Node.NODE_ATTRIBUTES.getKey() + "server_name", attr);
try (Node node = new MockNode(settings.build(), Collections.singleton(getTestTransportPlugin()))) {
fail("should not allow a server_name attribute with an underscore");
} catch (IllegalArgumentException e) {
assertEquals("invalid node.attr.server_name [invalid_hostname]", e.getMessage());
}

assertSettingDeprecationsAndWarnings(new Setting<?>[] { NetworkModule.HTTP_ENABLED });
}

private static Settings.Builder baseSettings() {
final Path tempDir = createTempDir();
return Settings.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1937,16 +1937,12 @@ public void testTimeoutPerConnection() throws IOException {

public void testHandshakeWithIncompatVersion() {
assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport);
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
Version version = Version.fromString("2.0.0");
try (MockTcpTransport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE,
new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version);
MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, null,
Collections.emptySet())) {
try (MockTransportService service = build(Settings.EMPTY, version, null, true)) {
service.start();
service.acceptIncomingRequests();
DiscoveryNode node =
new DiscoveryNode("TS_TPC", "TS_TPC", transport.boundAddress().publishAddress(), emptyMap(), emptySet(), version0);
TransportAddress address = service.boundAddress().publishAddress();
DiscoveryNode node = new DiscoveryNode("TS_TPC", "TS_TPC", address, emptyMap(), emptySet(), version0);
ConnectionProfile.Builder builder = new ConnectionProfile.Builder();
builder.addConnections(1,
TransportRequestOptions.Type.BULK,
Expand All @@ -1962,15 +1958,11 @@ public void testHandshakeUpdatesVersion() throws IOException {
assumeTrue("only tcp transport has a handshake method", serviceA.getOriginalTransport() instanceof TcpTransport);
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
Version version = VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(), Version.CURRENT);
try (MockTcpTransport transport = new MockTcpTransport(Settings.EMPTY, threadPool, BigArrays.NON_RECYCLING_INSTANCE,
new NoneCircuitBreakerService(), namedWriteableRegistry, new NetworkService(Collections.emptyList()), version);
MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, transport, version, threadPool, null,
Collections.emptySet())) {
try (MockTransportService service = build(Settings.EMPTY, version, null, true)) {
service.start();
service.acceptIncomingRequests();
DiscoveryNode node =
new DiscoveryNode("TS_TPC", "TS_TPC", transport.boundAddress().publishAddress(), emptyMap(), emptySet(),
Version.fromString("2.0.0"));
TransportAddress address = service.boundAddress().publishAddress();
DiscoveryNode node = new DiscoveryNode("TS_TPC", "TS_TPC", address, emptyMap(), emptySet(), Version.fromString("2.0.0"));
ConnectionProfile.Builder builder = new ConnectionProfile.Builder();
builder.addConnections(1,
TransportRequestOptions.Type.BULK,
Expand Down Expand Up @@ -2689,7 +2681,7 @@ private void closeConnectionChannel(Transport.Connection connection) {
}

@SuppressForbidden(reason = "need local ephemeral port")
private InetSocketAddress getLocalEphemeral() throws UnknownHostException {
protected InetSocketAddress getLocalEphemeral() throws UnknownHostException {
return new InetSocketAddress(InetAddress.getLocalHost(), 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import io.netty.channel.ChannelPromise;
import io.netty.handler.ssl.SslHandler;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.netty4.Netty4Transport;
Expand All @@ -26,7 +28,10 @@
import org.elasticsearch.xpack.core.ssl.SSLConfiguration;
import org.elasticsearch.xpack.core.ssl.SSLService;

import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collections;
Expand Down Expand Up @@ -105,8 +110,8 @@ protected ChannelHandler getNoSslChannelInitializer(final String name) {
}

@Override
protected ChannelHandler getClientChannelInitializer() {
return new SecurityClientChannelInitializer();
protected ChannelHandler getClientChannelInitializer(DiscoveryNode node) {
return new SecurityClientChannelInitializer(node);
}

@Override
Expand Down Expand Up @@ -166,16 +171,28 @@ protected ServerChannelInitializer getSslChannelInitializer(final String name, f
private class SecurityClientChannelInitializer extends ClientChannelInitializer {

private final boolean hostnameVerificationEnabled;
private final SNIHostName serverName;

SecurityClientChannelInitializer() {
SecurityClientChannelInitializer(DiscoveryNode node) {
this.hostnameVerificationEnabled = sslEnabled && sslConfiguration.verificationMode().isHostnameVerificationEnabled();
String configuredServerName = node.getAttributes().get("server_name");
if (configuredServerName != null) {
try {
serverName = new SNIHostName(configuredServerName);
} catch (IllegalArgumentException e) {
throw new ConnectTransportException(node, "invalid DiscoveryNode server_name [" + configuredServerName + "]", e);
}
} else {
serverName = null;
}
}

@Override
protected void initChannel(Channel ch) throws Exception {
super.initChannel(ch);
if (sslEnabled) {
ch.pipeline().addFirst(new ClientSslHandlerInitializer(sslConfiguration, sslService, hostnameVerificationEnabled));
ch.pipeline().addFirst(new ClientSslHandlerInitializer(sslConfiguration, sslService, hostnameVerificationEnabled,
serverName));
}
}
}
Expand All @@ -185,11 +202,14 @@ private static class ClientSslHandlerInitializer extends ChannelOutboundHandlerA
private final boolean hostnameVerificationEnabled;
private final SSLConfiguration sslConfiguration;
private final SSLService sslService;
private final SNIServerName serverName;

private ClientSslHandlerInitializer(SSLConfiguration sslConfiguration, SSLService sslService, boolean hostnameVerificationEnabled) {
private ClientSslHandlerInitializer(SSLConfiguration sslConfiguration, SSLService sslService, boolean hostnameVerificationEnabled,
SNIServerName serverName) {
this.sslConfiguration = sslConfiguration;
this.hostnameVerificationEnabled = hostnameVerificationEnabled;
this.sslService = sslService;
this.serverName = serverName;
}

@Override
Expand All @@ -206,6 +226,11 @@ public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress,
}

sslEngine.setUseClientMode(true);
if (serverName != null) {
SSLParameters sslParameters = sslEngine.getSSLParameters();
sslParameters.setServerNames(Collections.singletonList(serverName));
sslEngine.setSSLParameters(sslParameters);
}
ctx.pipeline().replace(this, "ssl", new SslHandler(sslEngine));
super.connect(ctx, remoteAddress, localAddress, promise);
}
Expand Down
Loading