Skip to content

Commit 7a407f5

Browse files
authored
Handle WRAP ops during SSL read (#41611)
It is possible that a WRAP operation can occur while decrypting handshake data in TLS 1.3. The SSLDriver does not currently handle this well as it does not have access to the outbound buffer during read call. This commit moves the buffer into the Driver to fix this issue. Data wrapped during a read call will be queued for writing after the read call is complete.
1 parent 40216b4 commit 7a407f5

File tree

4 files changed

+77
-71
lines changed

4 files changed

+77
-71
lines changed

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,16 @@
99
import org.elasticsearch.core.internal.io.IOUtils;
1010
import org.elasticsearch.nio.FlushOperation;
1111
import org.elasticsearch.nio.InboundChannelBuffer;
12+
import org.elasticsearch.nio.NioSelector;
1213
import org.elasticsearch.nio.NioSocketChannel;
13-
import org.elasticsearch.nio.Page;
1414
import org.elasticsearch.nio.ReadWriteHandler;
1515
import org.elasticsearch.nio.SocketChannelContext;
16-
import org.elasticsearch.nio.NioSelector;
1716
import org.elasticsearch.nio.WriteOperation;
1817

1918
import javax.net.ssl.SSLEngine;
2019
import java.io.IOException;
21-
import java.nio.ByteBuffer;
2220
import java.nio.channels.ClosedChannelException;
21+
import java.util.LinkedList;
2322
import java.util.concurrent.TimeUnit;
2423
import java.util.function.BiConsumer;
2524
import java.util.function.Consumer;
@@ -37,8 +36,7 @@ public final class SSLChannelContext extends SocketChannelContext {
3736
private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {};
3837

3938
private final SSLDriver sslDriver;
40-
private final SSLOutboundBuffer outboundBuffer;
41-
private FlushOperation encryptedFlush;
39+
private final LinkedList<FlushOperation> encryptedFlushes = new LinkedList<>();
4240
private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
4341

4442
SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
@@ -51,14 +49,16 @@ public final class SSLChannelContext extends SocketChannelContext {
5149
Predicate<NioSocketChannel> allowChannelPredicate) {
5250
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
5351
this.sslDriver = sslDriver;
54-
// TODO: When the bytes are actually recycled, we need to test that they are released on context close
55-
this.outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
5652
}
5753

5854
@Override
5955
public void register() throws IOException {
6056
super.register();
6157
sslDriver.init();
58+
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
59+
if (outboundBuffer.hasEncryptedBytesToFlush()) {
60+
encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation());
61+
}
6262
}
6363

6464
@Override
@@ -98,11 +98,12 @@ public void flushChannel() throws IOException {
9898
try {
9999
// Attempt to encrypt application write data. The encrypted data ends up in the
100100
// outbound write buffer.
101-
sslDriver.write(unencryptedFlush, outboundBuffer);
101+
sslDriver.write(unencryptedFlush);
102+
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
102103
if (outboundBuffer.hasEncryptedBytesToFlush() == false) {
103104
break;
104105
}
105-
encryptedFlush = outboundBuffer.buildNetworkFlushOperation();
106+
encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation());
106107
// Flush the write buffer to the channel
107108
flushEncryptedOperation();
108109
} catch (IOException e) {
@@ -115,10 +116,11 @@ public void flushChannel() throws IOException {
115116
// We are not ready for application writes, check if the driver has non-application writes. We
116117
// only want to continue producing new writes if the outbound write buffer is fully flushed.
117118
while (pendingChannelFlush() == false && sslDriver.needsNonApplicationWrite()) {
118-
sslDriver.nonApplicationWrite(outboundBuffer);
119+
sslDriver.nonApplicationWrite();
119120
// If non-application writes were produced, flush the outbound write buffer.
121+
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
120122
if (outboundBuffer.hasEncryptedBytesToFlush()) {
121-
encryptedFlush = outboundBuffer.buildNetworkFlushOperation();
123+
encryptedFlushes.addFirst(outboundBuffer.buildNetworkFlushOperation());
122124
flushEncryptedOperation();
123125
}
124126
}
@@ -127,14 +129,14 @@ public void flushChannel() throws IOException {
127129

128130
private void flushEncryptedOperation() throws IOException {
129131
try {
132+
FlushOperation encryptedFlush = encryptedFlushes.getFirst();
130133
flushToChannel(encryptedFlush);
131134
if (encryptedFlush.isFullyFlushed()) {
132135
getSelector().executeListener(encryptedFlush.getListener(), null);
133-
encryptedFlush = null;
136+
encryptedFlushes.removeFirst();
134137
}
135138
} catch (IOException e) {
136-
getSelector().executeFailedListener(encryptedFlush.getListener(), e);
137-
encryptedFlush = null;
139+
getSelector().executeFailedListener(encryptedFlushes.removeFirst().getListener(), e);
138140
throw e;
139141
}
140142
}
@@ -163,6 +165,11 @@ public int read() throws IOException {
163165
sslDriver.read(channelBuffer);
164166

165167
handleReadBytes();
168+
// It is possible that a read call produced non-application bytes to flush
169+
SSLOutboundBuffer outboundBuffer = sslDriver.getOutboundBuffer();
170+
if (outboundBuffer.hasEncryptedBytesToFlush()) {
171+
encryptedFlushes.addLast(outboundBuffer.buildNetworkFlushOperation());
172+
}
166173

167174
return bytesRead;
168175
}
@@ -190,10 +197,11 @@ public void closeFromSelector() throws IOException {
190197
getSelector().assertOnSelectorThread();
191198
if (channel.isOpen()) {
192199
closeTimeoutCanceller.run();
193-
if (encryptedFlush != null) {
200+
for (FlushOperation encryptedFlush : encryptedFlushes) {
194201
getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException());
195202
}
196-
IOUtils.close(super::closeFromSelector, outboundBuffer::close, sslDriver::close);
203+
encryptedFlushes.clear();
204+
IOUtils.close(super::closeFromSelector, sslDriver::close);
197205
}
198206
}
199207

@@ -208,7 +216,7 @@ private void channelCloseTimeout() {
208216
}
209217

210218
private boolean pendingChannelFlush() {
211-
return encryptedFlush != null;
219+
return encryptedFlushes.isEmpty() == false;
212220
}
213221

214222
private static class CloseNotifyOperation implements WriteOperation {

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import org.elasticsearch.nio.FlushOperation;
99
import org.elasticsearch.nio.InboundChannelBuffer;
10+
import org.elasticsearch.nio.Page;
1011
import org.elasticsearch.nio.utils.ExceptionsHelper;
1112

1213
import javax.net.ssl.SSLEngine;
@@ -32,14 +33,14 @@
3233
*
3334
* Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be
3435
* called to determine if this driver needs to produce more data to advance the handshake or close process.
35-
* If that method returns true, {@link #nonApplicationWrite(SSLOutboundBuffer)} should be called (and the
36+
* If that method returns true, {@link #nonApplicationWrite()} should be called (and the
3637
* data produced then flushed to the channel) until no further non-application writes are needed.
3738
*
3839
* If no non-application writes are needed, {@link #readyForApplicationWrites()} can be called to determine
3940
* if the driver is ready to consume application data. (Note: It is possible that
4041
* {@link #readyForApplicationWrites()} and {@link #needsNonApplicationWrite()} can both return false if the
4142
* driver is waiting on non-application data from the peer.) If the driver indicates it is ready for
42-
* application writes, {@link #write(FlushOperation, SSLOutboundBuffer)} can be called. This method will
43+
* application writes, {@link #write(FlushOperation)} can be called. This method will
4344
* encrypt flush operation application data and place it in the outbound buffer for flushing to a channel.
4445
*
4546
* If you are ready to close the channel {@link #initiateClose()} should be called. After that is called, the
@@ -53,6 +54,8 @@ public class SSLDriver implements AutoCloseable {
5354
private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {});
5455

5556
private final SSLEngine engine;
57+
// TODO: When the bytes are actually recycled, we need to test that they are released on driver close
58+
private final SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
5659
private final boolean isClientMode;
5760
// This should only be accessed by the network thread associated with this channel, so nothing needs to
5861
// be volatile.
@@ -107,6 +110,10 @@ public ByteBuffer getNetworkReadBuffer() {
107110
return networkReadBuffer;
108111
}
109112

113+
public SSLOutboundBuffer getOutboundBuffer() {
114+
return outboundBuffer;
115+
}
116+
110117
public void read(InboundChannelBuffer buffer) throws SSLException {
111118
Mode modePriorToRead;
112119
do {
@@ -125,14 +132,14 @@ public boolean needsNonApplicationWrite() {
125132
return currentMode.needsNonApplicationWrite();
126133
}
127134

128-
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
129-
return currentMode.write(applicationBytes, outboundBuffer);
135+
public int write(FlushOperation applicationBytes) throws SSLException {
136+
return currentMode.write(applicationBytes);
130137
}
131138

132-
public void nonApplicationWrite(SSLOutboundBuffer outboundBuffer) throws SSLException {
139+
public void nonApplicationWrite() throws SSLException {
133140
assert currentMode.isApplication() == false : "Should not be called if driver is in application mode";
134141
if (currentMode.isApplication() == false) {
135-
currentMode.write(EMPTY_FLUSH_OPERATION, outboundBuffer);
142+
currentMode.write(EMPTY_FLUSH_OPERATION);
136143
} else {
137144
throw new AssertionError("Attempted to non-application write from invalid mode: " + currentMode.modeName());
138145
}
@@ -148,6 +155,7 @@ public boolean isClosed() {
148155

149156
@Override
150157
public void close() throws SSLException {
158+
outboundBuffer.close();
151159
ArrayList<SSLException> closingExceptions = new ArrayList<>(2);
152160
closingInternal();
153161
CloseMode closeMode = (CloseMode) this.currentMode;
@@ -276,7 +284,7 @@ private interface Mode {
276284

277285
void read(InboundChannelBuffer buffer) throws SSLException;
278286

279-
int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException;
287+
int write(FlushOperation applicationBytes) throws SSLException;
280288

281289
boolean needsNonApplicationWrite();
282290

@@ -296,18 +304,17 @@ private class HandshakeMode implements Mode {
296304

297305
private void startHandshake() throws SSLException {
298306
handshakeStatus = engine.getHandshakeStatus();
299-
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
300-
handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_WRAP) {
307+
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
301308
try {
302-
handshake(null);
309+
handshake();
303310
} catch (SSLException e) {
304311
closingInternal();
305312
throw e;
306313
}
307314
}
308315
}
309316

310-
private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException {
317+
private void handshake() throws SSLException {
311318
boolean continueHandshaking = true;
312319
while (continueHandshaking) {
313320
switch (handshakeStatus) {
@@ -316,15 +323,7 @@ private void handshake(SSLOutboundBuffer outboundBuffer) throws SSLException {
316323
continueHandshaking = false;
317324
break;
318325
case NEED_WRAP:
319-
if (outboundBuffer != null) {
320-
handshakeStatus = wrap(outboundBuffer).getHandshakeStatus();
321-
// If we need NEED_TASK we should run the tasks immediately
322-
if (handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_TASK) {
323-
continueHandshaking = false;
324-
}
325-
} else {
326-
continueHandshaking = false;
327-
}
326+
handshakeStatus = wrap(outboundBuffer).getHandshakeStatus();
328327
break;
329328
case NEED_TASK:
330329
runTasks();
@@ -351,7 +350,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException {
351350
try {
352351
SSLEngineResult result = unwrap(buffer);
353352
handshakeStatus = result.getHandshakeStatus();
354-
handshake(null);
353+
handshake();
355354
// If we are done handshaking we should exit the handshake read
356355
continueUnwrap = result.bytesConsumed() > 0 && currentMode.isHandshake();
357356
} catch (SSLException e) {
@@ -362,9 +361,9 @@ public void read(InboundChannelBuffer buffer) throws SSLException {
362361
}
363362

364363
@Override
365-
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
364+
public int write(FlushOperation applicationBytes) throws SSLException {
366365
try {
367-
handshake(outboundBuffer);
366+
handshake();
368367
} catch (SSLException e) {
369368
closingInternal();
370369
throw e;
@@ -444,7 +443,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException {
444443
}
445444

446445
@Override
447-
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
446+
public int write(FlushOperation applicationBytes) throws SSLException {
448447
boolean continueWrap = true;
449448
int totalBytesProduced = 0;
450449
while (continueWrap && applicationBytes.isFullyFlushed() == false) {
@@ -538,7 +537,7 @@ public void read(InboundChannelBuffer buffer) throws SSLException {
538537
}
539538

540539
@Override
541-
public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuffer) throws SSLException {
540+
public int write(FlushOperation applicationBytes) throws SSLException {
542541
int bytesProduced = 0;
543542
if (engine.isOutboundDone() == false) {
544543
bytesProduced += wrap(outboundBuffer).bytesProduced();
@@ -549,6 +548,8 @@ public int write(FlushOperation applicationBytes, SSLOutboundBuffer outboundBuff
549548
closeInboundAndSwallowPeerDidNotCloseException();
550549
}
551550
}
551+
} else {
552+
needToSendClose = false;
552553
}
553554
return bytesProduced;
554555
}

0 commit comments

Comments
 (0)