Skip to content

Test that malformed HTTP request is not validated #95886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,6 @@ protected void initChannel(Channel ch) throws Exception {
protected HttpMessage createMessage(String[] initialLine) throws Exception {
return HttpHeadersAuthenticatorUtils.wrapAsMessageWithAuthenticationContext(super.createMessage(initialLine));
}

@Override
protected HttpMessage createInvalidMessage() {
return HttpHeadersAuthenticatorUtils.wrapAsMessageWithAuthenticationContext(super.createInvalidMessage());
}
};
} else {
decoder = new HttpRequestDecoder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@

package org.elasticsearch.http.netty4;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.AsciiString;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchSecurityException;
Expand All @@ -26,6 +31,7 @@
import org.elasticsearch.http.netty4.internal.HttpValidator;
import org.elasticsearch.test.ESTestCase;

import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

Expand All @@ -35,6 +41,8 @@
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.WAITING_TO_START;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;

Expand Down Expand Up @@ -583,4 +591,101 @@ public void testValidationFailureForLargeMessage() {
assertThat(lastHttpContent.refCnt(), equalTo(0));
assertThat(channel.readInbound(), nullValue());
}

public void testFullRequestValidationFailure() {
assertTrue(channel.config().isAutoRead());
assertThat(netty4HttpHeaderValidator.getState(), equalTo(WAITING_TO_START));

ByteBuf buf = channel.alloc().buffer();
ByteBufUtil.copy(AsciiString.of("test full http request"), buf);
final DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/uri", buf);
channel.writeInbound(request);

// request got through to validation
assertThat(header.get(), sameInstance(request));
// channel is paused
assertThat(channel.readInbound(), nullValue());
assertFalse(channel.config().isAutoRead());

// validation fails
Exception exception = new ElasticsearchException("Boom");
listener.get().onFailure(exception);
channel.runPendingTasks();

// channel is resumed and waiting for next request
assertTrue(channel.config().isAutoRead());
assertThat(netty4HttpHeaderValidator.getState(), equalTo(WAITING_TO_START));

DefaultFullHttpRequest throughRequest = channel.readInbound();
// "through request" contains a decoder exception
assertThat(throughRequest, not(sameInstance(request)));
assertTrue(throughRequest.decoderResult().isFailure());
// the content is cleared when validation fails
assertThat(new String(ByteBufUtil.getBytes(throughRequest.content()), StandardCharsets.UTF_8), is(""));
assertThat(buf.refCnt(), is(0));
Exception cause = (Exception) throughRequest.decoderResult().cause();
assertThat(cause, equalTo(exception));
}

public void testFullRequestValidationSuccess() {
assertTrue(channel.config().isAutoRead());
assertThat(netty4HttpHeaderValidator.getState(), equalTo(WAITING_TO_START));

ByteBuf buf = channel.alloc().buffer();
ByteBufUtil.copy(AsciiString.of("test full http request"), buf);
final DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/uri", buf);
channel.writeInbound(request);

// request got through to validation
assertThat(header.get(), sameInstance(request));
// channel is paused
assertThat(channel.readInbound(), nullValue());
assertFalse(channel.config().isAutoRead());

// validation succeeds
listener.get().onResponse(null);
channel.runPendingTasks();

// channel is resumed and waiting for next request
assertTrue(channel.config().isAutoRead());
assertThat(netty4HttpHeaderValidator.getState(), equalTo(WAITING_TO_START));

DefaultFullHttpRequest throughRequest = channel.readInbound();
// request goes through unaltered
assertThat(throughRequest, sameInstance(request));
assertFalse(throughRequest.decoderResult().isFailure());
// the content is unaltered
assertThat(new String(ByteBufUtil.getBytes(throughRequest.content()), StandardCharsets.UTF_8), is("test full http request"));
assertThat(buf.refCnt(), is(1));
assertThat(throughRequest.decoderResult().cause(), nullValue());
}

public void testFullRequestWithDecoderException() {
assertTrue(channel.config().isAutoRead());
assertThat(netty4HttpHeaderValidator.getState(), equalTo(WAITING_TO_START));

ByteBuf buf = channel.alloc().buffer();
ByteBufUtil.copy(AsciiString.of("test full http request"), buf);
// a request with a decoder error prior to validation
final DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/uri", buf);
Exception cause = new ElasticsearchException("Boom");
request.setDecoderResult(DecoderResult.failure(cause));
channel.writeInbound(request);

// request goes through without invoking the validator
assertThat(header.get(), nullValue());
assertThat(listener.get(), nullValue());
// channel is NOT paused
assertTrue(channel.config().isAutoRead());
assertThat(netty4HttpHeaderValidator.getState(), equalTo(WAITING_TO_START));

DefaultFullHttpRequest throughRequest = channel.readInbound();
// request goes through unaltered
assertThat(throughRequest, sameInstance(request));
assertTrue(throughRequest.decoderResult().isFailure());
assertThat(throughRequest.decoderResult().cause(), equalTo(cause));
// the content is unaltered
assertThat(new String(ByteBufUtil.getBytes(throughRequest.content()), StandardCharsets.UTF_8), is("test full http request"));
assertThat(buf.refCnt(), is(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.QueryStringDecoder;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
Expand Down Expand Up @@ -53,6 +56,7 @@
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.sameInstance;
import static org.hamcrest.core.Is.is;
import static org.mockito.Mockito.mock;

Expand Down Expand Up @@ -210,6 +214,17 @@ public void testPipeliningRequestsAreReleased() throws InterruptedException {
}
}

public void testDecoderErrorSurfacedAsNettyInboundError() {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(getTestHttpHandler());
// a request with a decoder error
final DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/uri");
Exception cause = new ElasticsearchException("Boom");
request.setDecoderResult(DecoderResult.failure(cause));
embeddedChannel.writeInbound(request);
final Netty4HttpRequest nettyRequest = embeddedChannel.readInbound();
assertThat(nettyRequest.getInboundException(), sameInstance(cause));
}

public void testResumesChunkedMessage() {
final List<Object> messagesSeen = new ArrayList<>();
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(capturingHandler(messagesSeen), getTestHttpHandler());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
*/
package org.elasticsearch.xpack.security.transport.netty4;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.HttpConstants;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AsciiString;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.network.NetworkService;
Expand All @@ -28,6 +31,7 @@
import org.elasticsearch.http.AbstractHttpServerTransportTestCase;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.http.netty4.Netty4HttpResponse;
import org.elasticsearch.http.netty4.Netty4HttpServerTransport;
Expand All @@ -50,6 +54,7 @@
import java.nio.file.Path;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import javax.net.ssl.SSLEngine;
Expand Down Expand Up @@ -321,6 +326,7 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th
ch.pipeline().remove(pipelineHandlerName);
}
}
// STEP 0: send a "wrapped" request
var writeFuture = testThreadPool.generic().submit(() -> {
ch.writeInbound(
HttpHeadersAuthenticatorUtils.wrapAsMessageWithAuthenticationContext(
Expand Down Expand Up @@ -433,4 +439,157 @@ public void testHttpHeaderAuthnFaultyHeaderValidator() throws Exception {
testThreadPool.shutdownNow();
}
}

public void testMalformedRequestDispatchedNoAuthn() throws Exception {
final AtomicReference<Throwable> dispatchThrowableReference = new AtomicReference<>();
final AtomicInteger authnInvocationCount = new AtomicInteger();
final AtomicInteger badDispatchInvocationCount = new AtomicInteger();
final Settings settings = Settings.builder()
.put(env.settings())
.put(HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE.getKey(), "32b")
.put(HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.getKey(), "32b")
.build();
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
@Override
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
logger.error("--> Unexpected dispatched request [" + FakeRestRequest.requestToString(channel.request()) + "]");
throw new AssertionError("Unexpected dispatched request");
}

@Override
public void dispatchBadRequest(final RestChannel channel, final ThreadContext threadContext, final Throwable cause) {
assertThat(cause, notNullValue());
dispatchThrowableReference.set(cause);
badDispatchInvocationCount.incrementAndGet();
}
};
final ThreadPool testThreadPool = new TestThreadPool(TEST_MOCK_TRANSPORT_THREAD_PREFIX);
try (
Netty4HttpServerTransport transport = Security.getHttpServerTransportWithHeadersValidator(
settings,
new NetworkService(List.of()),
testThreadPool,
xContentRegistry(),
dispatcher,
randomClusterSettings(),
new SharedGroupFactory(settings),
Tracer.NOOP,
TLSConfig.noTLS(),
null,
(httpPreRequest, channel, listener) -> {
authnInvocationCount.incrementAndGet();
throw new AssertionError("Malformed requests shouldn't be authenticated");
}
)
) {
final ChannelHandler handler = transport.configureServerChannelHandler();
assertThat(authnInvocationCount.get(), is(0));
assertThat(badDispatchInvocationCount.get(), is(0));
// case 1: invalid initial line
{
EmbeddedChannel ch = new EmbeddedChannel(handler);
ByteBuf buf = ch.alloc().buffer();
ByteBufUtil.copy(AsciiString.of("This is not a valid HTTP line"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
var writeFuture = testThreadPool.generic().submit(() -> {
ch.writeInbound(buf);
ch.flushInbound();
});
writeFuture.get();
assertThat(dispatchThrowableReference.get().toString(), containsString("NOT A VALID HTTP LINE"));
assertThat(badDispatchInvocationCount.get(), is(1));
assertThat(authnInvocationCount.get(), is(0));
}
// case 2: too long initial line
{
EmbeddedChannel ch = new EmbeddedChannel(handler);
ByteBuf buf = ch.alloc().buffer();
ByteBufUtil.copy(AsciiString.of("GET /this/is/a/valid/but/too/long/initial/line HTTP/1.1"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
var writeFuture = testThreadPool.generic().submit(() -> {
ch.writeInbound(buf);
ch.flushInbound();
});
writeFuture.get();
assertThat(dispatchThrowableReference.get().toString(), containsString("HTTP line is larger than"));
assertThat(badDispatchInvocationCount.get(), is(2));
assertThat(authnInvocationCount.get(), is(0));
}
// case 3: invalid header with no colon
{
EmbeddedChannel ch = new EmbeddedChannel(handler);
ByteBuf buf = ch.alloc().buffer();
ByteBufUtil.copy(AsciiString.of("GET /url HTTP/1.1"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
ByteBufUtil.copy(AsciiString.of("Host"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
var writeFuture = testThreadPool.generic().submit(() -> {
ch.writeInbound(buf);
ch.flushInbound();
});
writeFuture.get();
assertThat(dispatchThrowableReference.get().toString(), containsString("No colon found"));
assertThat(badDispatchInvocationCount.get(), is(3));
assertThat(authnInvocationCount.get(), is(0));
}
// case 4: invalid header longer than max allowed
{
EmbeddedChannel ch = new EmbeddedChannel(handler);
ByteBuf buf = ch.alloc().buffer();
ByteBufUtil.copy(AsciiString.of("GET /url HTTP/1.1"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
ByteBufUtil.copy(AsciiString.of("Host: this.looks.like.a.good.url.but.is.longer.than.permitted"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
var writeFuture = testThreadPool.generic().submit(() -> {
ch.writeInbound(buf);
ch.flushInbound();
});
writeFuture.get();
assertThat(dispatchThrowableReference.get().toString(), containsString("HTTP header is larger than"));
assertThat(badDispatchInvocationCount.get(), is(4));
assertThat(authnInvocationCount.get(), is(0));
}
// case 5: invalid header format
{
EmbeddedChannel ch = new EmbeddedChannel(handler);
ByteBuf buf = ch.alloc().buffer();
ByteBufUtil.copy(AsciiString.of("GET /url HTTP/1.1"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
ByteBufUtil.copy(AsciiString.of("Host: invalid host value"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
var writeFuture = testThreadPool.generic().submit(() -> {
ch.writeInbound(buf);
ch.flushInbound();
});
writeFuture.get();
assertThat(dispatchThrowableReference.get().toString(), containsString("Validation failed for header 'Host'"));
assertThat(badDispatchInvocationCount.get(), is(5));
assertThat(authnInvocationCount.get(), is(0));
}
// case 6: connection closed before all headers are sent
{
EmbeddedChannel ch = new EmbeddedChannel(handler);
ByteBuf buf = ch.alloc().buffer();
ByteBufUtil.copy(AsciiString.of("GET /url HTTP/1.1"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
ByteBufUtil.copy(AsciiString.of("Host: localhost"), buf);
ByteBufUtil.writeShortBE(buf, HttpConstants.LF);
testThreadPool.generic().submit(() -> {
ch.writeInbound(buf);
ch.flushInbound();
}).get();
testThreadPool.generic().submit(() -> ch.close().get()).get();
assertThat(dispatchThrowableReference.get().toString(), containsString("Connection closed before received headers"));
assertThat(badDispatchInvocationCount.get(), is(6));
assertThat(authnInvocationCount.get(), is(0));
}
} finally {
testThreadPool.shutdownNow();
}
}
}