diff --git a/extensions/vertx-http/deployment/pom.xml b/extensions/vertx-http/deployment/pom.xml index a4ad672f739f7..8da2520f5adb5 100644 --- a/extensions/vertx-http/deployment/pom.xml +++ b/extensions/vertx-http/deployment/pom.xml @@ -12,6 +12,10 @@ quarkus-vertx-http-deployment Quarkus - Vert.x - HTTP - Deployment + + 2.0.0-M23 + + io.quarkus @@ -139,6 +143,12 @@ smallrye-certificate-generator-junit5 test + + org.apache.directory.server + apacheds-protocol-dns + ${apacheds-protocol-dns.version} + test + diff --git a/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/TrustedForwarderDnsResolveTest.java b/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/TrustedForwarderDnsResolveTest.java new file mode 100644 index 0000000000000..f223bc2a34dd5 --- /dev/null +++ b/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/TrustedForwarderDnsResolveTest.java @@ -0,0 +1,55 @@ +package io.quarkus.vertx.http.proxy; + +import org.hamcrest.Matchers; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.vertx.http.ForwardedHandlerInitializer; +import io.quarkus.vertx.http.proxy.fakedns.DnameRecordEncoder; +import io.quarkus.vertx.http.proxy.fakedns.DnsMessageEncoder; +import io.quarkus.vertx.http.proxy.fakedns.FakeDNSServer; +import io.restassured.RestAssured; + +public class TrustedForwarderDnsResolveTest { + + private FakeDNSServer dnsServer; + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(ForwardedHandlerInitializer.class, DnameRecordEncoder.class, DnsMessageEncoder.class, + FakeDNSServer.class) + .addAsResource(new StringAsset("quarkus.http.proxy.proxy-address-forwarding=true\n" + + "quarkus.http.proxy.allow-forwarded=true\n" + + "quarkus.http.proxy.enable-forwarded-host=true\n" + + "quarkus.http.proxy.enable-forwarded-prefix=true\n" + + "quarkus.vertx.resolver.servers=127.0.0.1:53530\n" + + "quarkus.http.proxy.trusted-proxies=trusted.example.com"), + "application.properties")); + + @BeforeEach + public void setUp() throws Exception { + dnsServer = new FakeDNSServer(); + dnsServer.start(); + } + + @AfterEach + public void tearDown() { + dnsServer.stop(); + } + + @Test + public void testTrustedProxyResolved() { + dnsServer.addRecordsToStore("trusted.example.com", "127.0.0.3", "127.0.0.2", "127.0.0.1"); + RestAssured.given() + .header("Forwarded", "proto=http;for=backend2:5555;host=somehost2") + .get("/path") + .then() + .body(Matchers.equalTo("http|somehost2|backend2:5555|/path|http://somehost2/path")); + } + +} diff --git a/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/fakedns/DnameRecordEncoder.java b/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/fakedns/DnameRecordEncoder.java new file mode 100644 index 0000000000000..a72c57e83da70 --- /dev/null +++ b/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/fakedns/DnameRecordEncoder.java @@ -0,0 +1,14 @@ +package io.quarkus.vertx.http.proxy.fakedns; + +import org.apache.directory.server.dns.io.encoder.ResourceRecordEncoder; +import org.apache.directory.server.dns.messages.ResourceRecord; +import org.apache.directory.server.dns.store.DnsAttribute; +import org.apache.mina.core.buffer.IoBuffer; + +public class DnameRecordEncoder extends ResourceRecordEncoder { + + protected void putResourceRecordData(IoBuffer byteBuffer, ResourceRecord record) { + String domainName = record.get(DnsAttribute.DOMAIN_NAME); + putDomainName(byteBuffer, domainName); + } +} diff --git a/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/fakedns/DnsMessageEncoder.java b/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/fakedns/DnsMessageEncoder.java new file mode 100644 index 0000000000000..371f5a7dc3bd0 --- /dev/null +++ b/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/fakedns/DnsMessageEncoder.java @@ -0,0 +1,198 @@ +package io.quarkus.vertx.http.proxy.fakedns; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +import org.apache.directory.server.dns.io.encoder.*; +import org.apache.directory.server.dns.messages.DnsMessage; +import org.apache.directory.server.dns.messages.MessageType; +import org.apache.directory.server.dns.messages.OpCode; +import org.apache.directory.server.dns.messages.QuestionRecord; +import org.apache.directory.server.dns.messages.RecordType; +import org.apache.directory.server.dns.messages.ResourceRecord; +import org.apache.directory.server.dns.messages.ResponseCode; +import org.apache.directory.server.i18n.I18n; +import org.apache.mina.core.buffer.IoBuffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An encoder for DNS messages. The primary usage of the DnsMessageEncoder is + * to call the encode(ByteBuffer, DnsMessage) method which will + * write the message to the outgoing ByteBuffer according to the DnsMessage + * encoding in RFC-1035. + * + * @author Apache Directory Project + * @version $Rev$, $Date$ + */ +public class DnsMessageEncoder { + /** the log for this class */ + private static final Logger log = LoggerFactory.getLogger(DnsMessageEncoder.class); + + /** + * A Hashed Adapter mapping record types to their encoders. + */ + private static final Map DEFAULT_ENCODERS; + + static { + Map map = new HashMap(); + + map.put(RecordType.SOA, new StartOfAuthorityRecordEncoder()); + map.put(RecordType.A, new AddressRecordEncoder()); + map.put(RecordType.NS, new NameServerRecordEncoder()); + map.put(RecordType.CNAME, new CanonicalNameRecordEncoder()); + map.put(RecordType.PTR, new PointerRecordEncoder()); + map.put(RecordType.MX, new MailExchangeRecordEncoder()); + map.put(RecordType.SRV, new ServerSelectionRecordEncoder()); + map.put(RecordType.TXT, new TextRecordEncoder()); + map.put(RecordType.DNAME, new DnameRecordEncoder()); + + DEFAULT_ENCODERS = Collections.unmodifiableMap(map); + } + + /** + * Encodes the {@link DnsMessage} into the {@link IoBuffer}. + * + * @param byteBuffer + * @param message + */ + public void encode(IoBuffer byteBuffer, DnsMessage message) { + byteBuffer.putShort((short) message.getTransactionId()); + + byte header = (byte) 0x00; + header |= encodeMessageType(message.getMessageType()); + header |= encodeOpCode(message.getOpCode()); + header |= encodeAuthoritativeAnswer(message.isAuthoritativeAnswer()); + header |= encodeTruncated(message.isTruncated()); + header |= encodeRecursionDesired(message.isRecursionDesired()); + byteBuffer.put(header); + + header = (byte) 0x00; + header |= encodeRecursionAvailable(message.isRecursionAvailable()); + header |= encodeResponseCode(message.getResponseCode()); + byteBuffer.put(header); + + byteBuffer + .putShort((short) (message.getQuestionRecords() != null ? message.getQuestionRecords().size() : 0)); + byteBuffer.putShort((short) (message.getAnswerRecords() != null ? message.getAnswerRecords().size() : 0)); + byteBuffer.putShort((short) (message.getAuthorityRecords() != null ? message.getAuthorityRecords().size() + : 0)); + byteBuffer.putShort((short) (message.getAdditionalRecords() != null ? message.getAdditionalRecords().size() + : 0)); + + putQuestionRecords(byteBuffer, message.getQuestionRecords()); + putResourceRecords(byteBuffer, message.getAnswerRecords()); + putResourceRecords(byteBuffer, message.getAuthorityRecords()); + putResourceRecords(byteBuffer, message.getAdditionalRecords()); + } + + private void putQuestionRecords(IoBuffer byteBuffer, List questions) { + if (questions == null) { + return; + } + + QuestionRecordEncoder encoder = new QuestionRecordEncoder(); + + Iterator it = questions.iterator(); + + while (it.hasNext()) { + QuestionRecord question = it.next(); + encoder.put(byteBuffer, question); + } + } + + private void putResourceRecords(IoBuffer byteBuffer, List records) { + if (records == null) { + return; + } + + Iterator it = records.iterator(); + + while (it.hasNext()) { + ResourceRecord record = it.next(); + + try { + put(byteBuffer, record); + } catch (IOException ioe) { + log.error(ioe.getLocalizedMessage(), ioe); + } + } + } + + private void put(IoBuffer byteBuffer, ResourceRecord record) throws IOException { + RecordType type = record.getRecordType(); + + RecordEncoder encoder = DEFAULT_ENCODERS.get(type); + + if (encoder == null) { + throw new IOException(I18n.err(I18n.ERR_597, type)); + } + + encoder.put(byteBuffer, record); + } + + private byte encodeMessageType(MessageType messageType) { + byte oneBit = (byte) (messageType.convert() & 0x01); + return (byte) (oneBit << 7); + } + + private byte encodeOpCode(OpCode opCode) { + byte fourBits = (byte) (opCode.convert() & 0x0F); + return (byte) (fourBits << 3); + } + + private byte encodeAuthoritativeAnswer(boolean authoritative) { + if (authoritative) { + return (byte) ((byte) 0x01 << 2); + } + return (byte) 0; + } + + private byte encodeTruncated(boolean truncated) { + if (truncated) { + return (byte) ((byte) 0x01 << 1); + } + return 0; + } + + private byte encodeRecursionDesired(boolean recursionDesired) { + if (recursionDesired) { + return (byte) 0x01; + } + return 0; + } + + private byte encodeRecursionAvailable(boolean recursionAvailable) { + if (recursionAvailable) { + return (byte) ((byte) 0x01 << 7); + } + return 0; + } + + private byte encodeResponseCode(ResponseCode responseCode) { + return (byte) (responseCode.convert() & 0x0F); + } +} diff --git a/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/fakedns/FakeDNSServer.java b/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/fakedns/FakeDNSServer.java new file mode 100644 index 0000000000000..8d321c055eca9 --- /dev/null +++ b/extensions/vertx-http/deployment/src/test/java/io/quarkus/vertx/http/proxy/fakedns/FakeDNSServer.java @@ -0,0 +1,532 @@ +/* + * Copyright (c) 2011-2019 Contributors to the Eclipse Foundation + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 + * which is available at https://www.apache.org/licenses/LICENSE-2.0. + * + * SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 + */ + +package io.quarkus.vertx.http.proxy.fakedns; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.directory.server.dns.DnsServer; +import org.apache.directory.server.dns.io.encoder.ResourceRecordEncoder; +import org.apache.directory.server.dns.messages.*; +import org.apache.directory.server.dns.protocol.DnsProtocolHandler; +import org.apache.directory.server.dns.protocol.DnsTcpDecoder; +import org.apache.directory.server.dns.protocol.DnsUdpDecoder; +import org.apache.directory.server.dns.protocol.DnsUdpEncoder; +import org.apache.directory.server.dns.store.DnsAttribute; +import org.apache.directory.server.dns.store.RecordStore; +import org.apache.directory.server.protocol.shared.transport.TcpTransport; +import org.apache.directory.server.protocol.shared.transport.Transport; +import org.apache.directory.server.protocol.shared.transport.UdpTransport; +import org.apache.mina.core.buffer.IoBuffer; +import org.apache.mina.core.service.IoAcceptor; +import org.apache.mina.core.session.IoSession; +import org.apache.mina.filter.codec.ProtocolCodecFactory; +import org.apache.mina.filter.codec.ProtocolCodecFilter; +import org.apache.mina.filter.codec.ProtocolDecoder; +import org.apache.mina.filter.codec.ProtocolEncoder; +import org.apache.mina.filter.codec.ProtocolEncoderOutput; +import org.apache.mina.transport.socket.DatagramSessionConfig; + +/** + * @author Norman Maurer + */ +public final class FakeDNSServer extends DnsServer { + + public static RecordStore A_store(Map entries) { + return questionRecord -> entries.entrySet().stream().map(entry -> { + return a(entry.getKey(), 100).ipAddress(entry.getValue()); + }).collect(Collectors.toSet()); + } + + public static RecordStore A_store(Function entries) { + return questionRecord -> { + String res = entries.apply(questionRecord.getDomainName()); + if (res != null) { + return Collections.singleton(a(questionRecord.getDomainName(), 100).ipAddress(res)); + } + return Collections.emptySet(); + }; + } + + public static final int PORT = 53530; + public static final String IP_ADDRESS = "127.0.0.1"; + + private String ipAddress = IP_ADDRESS; + private int port = PORT; + private volatile RecordStore store; + private List acceptors; + private final Deque currentMessage = new ArrayDeque<>(); + + public FakeDNSServer() { + } + + public RecordStore store() { + return store; + } + + public FakeDNSServer store(RecordStore store) { + this.store = store; + return this; + } + + public synchronized DnsMessage pollMessage() { + return currentMessage.poll(); + } + + public InetSocketAddress localAddress() { + return (InetSocketAddress) getTransports()[0].getAcceptor().getLocalAddress(); + } + + public FakeDNSServer ipAddress(String ipAddress) { + this.ipAddress = ipAddress; + return this; + } + + public FakeDNSServer port(int p) { + port = p; + return this; + } + + public FakeDNSServer testResolveA(final String ipAddress) { + return testResolveA(Collections.singletonMap("vertx.io", ipAddress)); + } + + public FakeDNSServer testResolveA(Map entries) { + return store(A_store(entries)); + } + + public FakeDNSServer testResolveA(Function entries) { + return store(A_store(entries)); + } + + public FakeDNSServer testResolveAAAA(final String ipAddress) { + return store(questionRecord -> Collections.singleton(aaaa("vertx.io", 100).ipAddress(ipAddress))); + } + + public FakeDNSServer testResolveMX(final int prio, final String mxRecord) { + return store(questionRecord -> Collections.singleton(mx("vertx.io", 100) + .set(DnsAttribute.MX_PREFERENCE, String.valueOf(prio)) + .set(DnsAttribute.DOMAIN_NAME, mxRecord))); + } + + public FakeDNSServer testResolveTXT(final String txt) { + return store(questionRecord -> Collections.singleton(txt("vertx.io", 100) + .set(DnsAttribute.CHARACTER_STRING, txt))); + } + + public FakeDNSServer testResolveNS(final String ns) { + return store(questionRecord -> Collections.singleton(ns("vertx.io", 100) + .set(DnsAttribute.DOMAIN_NAME, ns))); + } + + public FakeDNSServer testResolveCNAME(final String cname) { + return store(questionRecord -> Collections.singleton(cname("vertx.io", 100) + .set(DnsAttribute.DOMAIN_NAME, cname))); + } + + public FakeDNSServer testResolvePTR(final String ptr) { + return store(questionRecord -> Collections.singleton(ptr("vertx.io", 100) + .set(DnsAttribute.DOMAIN_NAME, ptr))); + } + + public FakeDNSServer testResolveSRV(String name, int priority, int weight, int port, String target) { + return store(questionRecord -> Collections.singleton(srv(name, 100) + .set(DnsAttribute.SERVICE_PRIORITY, priority) + .set(DnsAttribute.SERVICE_WEIGHT, weight) + .set(DnsAttribute.SERVICE_PORT, port) + .set(DnsAttribute.DOMAIN_NAME, target))); + } + + public FakeDNSServer testResolveDNAME(final String dname) { + return store(questionRecord -> Collections.singleton(dname("vertx.io", 100) + .set(DnsAttribute.DOMAIN_NAME, dname))); + } + + public FakeDNSServer testResolveSRV2(final int priority, final int weight, final int basePort, final String target) { + return store(questionRecord -> IntStream + .range(0, 2) + .mapToObj(i -> srv(target, 100) + .set(DnsAttribute.SERVICE_PRIORITY, priority) + .set(DnsAttribute.SERVICE_WEIGHT, weight) + .set(DnsAttribute.SERVICE_PORT, basePort + i) + .set(DnsAttribute.DOMAIN_NAME, "svc" + i + ".vertx.io.")) + .collect(Collectors.toSet())); + } + + public static Record a(String domainName, int ttl) { + return new Record(domainName, RecordType.A, RecordClass.IN, ttl); + } + + public static Record aaaa(String domainName, int ttl) { + return new Record(domainName, RecordType.AAAA, RecordClass.IN, ttl); + } + + public static Record mx(String domainName, int ttl) { + return new Record(domainName, RecordType.MX, RecordClass.IN, ttl); + } + + public static Record txt(String domainName, int ttl) { + return new Record(domainName, RecordType.TXT, RecordClass.IN, ttl); + } + + public static Record ns(String domainName, int ttl) { + return new Record(domainName, RecordType.NS, RecordClass.IN, ttl); + } + + public static Record cname(String domainName, int ttl) { + return new Record(domainName, RecordType.CNAME, RecordClass.IN, ttl); + } + + public static Record ptr(String domainName, int ttl) { + return new Record(domainName, RecordType.PTR, RecordClass.IN, ttl); + } + + public static Record srv(String domainName, int ttl) { + return new Record(domainName, RecordType.SRV, RecordClass.IN, ttl); + } + + public static Record dname(String domainName, int ttl) { + return new Record(domainName, RecordType.DNAME, RecordClass.IN, ttl); + } + + public static Record record(String domainName, RecordType recordType, RecordClass recordClass, int ttl) { + return new Record(domainName, recordType, recordClass, ttl); + } + + public static class Record extends HashMap implements ResourceRecord { + + private final String domainName; + private final RecordType recordType; + private final RecordClass recordClass; + private final int ttl; + + public Record(String domainName, RecordType recordType, RecordClass recordClass, int ttl) { + this.domainName = domainName; + this.recordType = recordType; + this.recordClass = recordClass; + this.ttl = ttl; + } + + public Record ipAddress(String ipAddress) { + return set(DnsAttribute.IP_ADDRESS, ipAddress); + } + + public Record set(String name, Object value) { + put(name, "" + value); + return this; + } + + @Override + public String getDomainName() { + return domainName; + } + + @Override + public RecordType getRecordType() { + return recordType; + } + + @Override + public RecordClass getRecordClass() { + return recordClass; + } + + @Override + public int getTimeToLive() { + return ttl; + } + + @Override + public String get(String id) { + return get((Object) id); + } + } + + public FakeDNSServer testLookup4(String ip) { + return store(questionRecord -> { + Set set = new HashSet<>(); + if (questionRecord.getRecordType() == RecordType.A) { + set.add(a("vertx.io", 100).ipAddress(ip)); + } + return set; + }); + } + + public FakeDNSServer testLookup6(String ip) { + return store(questionRecord -> { + Set set = new HashSet<>(); + if (questionRecord.getRecordType() == RecordType.AAAA) { + set.add(aaaa("vertx.io", 100).ipAddress(ip)); + } + return set; + }); + } + + public FakeDNSServer testLookupNonExisting() { + return store(questionRecord -> null); + } + + public FakeDNSServer testReverseLookup(final String ptr) { + return store(questionRecord -> Collections.singleton(ptr(ptr, 100) + .set(DnsAttribute.DOMAIN_NAME, "vertx.io"))); + } + + public FakeDNSServer testResolveASameServer(final String ipAddress) { + return store(A_store(Collections.singletonMap("vertx.io", ipAddress))); + } + + public FakeDNSServer testLookup4CNAME(final String cname, final String ip) { + return store(questionRecord -> { + // use LinkedHashSet since the order of the result records has to be preserved to make sure the unit test fails + Set set = new LinkedHashSet<>(); + ResourceRecordModifier rm = new ResourceRecordModifier(); + set.add(cname("vertx.io", 100).set(DnsAttribute.DOMAIN_NAME, cname)); + set.add(a(cname, 100).ipAddress(ip)); + return set; + }); + } + + @Override + public void start() throws IOException { + + DnsProtocolHandler handler = new DnsProtocolHandler(this, question -> { + RecordStore actual = store; + if (actual == null) { + return Collections.emptySet(); + } else { + return actual.getRecords(question); + } + }) { + @Override + public void sessionCreated(IoSession session) { + // Use our own codec to support AAAA testing + if (session.getTransportMetadata().isConnectionless()) { + session.getFilterChain().addFirst("codec", new ProtocolCodecFilter(new TestDnsProtocolUdpCodecFactory())); + } else { + session.getFilterChain().addFirst("codec", new ProtocolCodecFilter(new TestDnsProtocolTcpCodecFactory())); + } + } + + @Override + public void messageReceived(IoSession session, Object message) { + if (message instanceof DnsMessage) { + synchronized (FakeDNSServer.this) { + currentMessage.add((DnsMessage) message); + } + } + super.messageReceived(session, message); + } + }; + + UdpTransport udpTransport = new UdpTransport(ipAddress, port); + ((DatagramSessionConfig) udpTransport.getAcceptor().getSessionConfig()).setReuseAddress(true); + TcpTransport tcpTransport = new TcpTransport(ipAddress, port); + tcpTransport.getAcceptor().getSessionConfig().setReuseAddress(true); + + setTransports(udpTransport, tcpTransport); + + for (Transport transport : getTransports()) { + IoAcceptor acceptor = transport.getAcceptor(); + + acceptor.setHandler(handler); + + // Start the listener + acceptor.bind(); + } + } + + @Override + public void stop() { + for (Transport transport : getTransports()) { + transport.getAcceptor().dispose(); + } + } + + public static class VertxResourceRecord implements ResourceRecord { + + private final String ipAddress; + private final String domainName; + private boolean isTruncated; + + public VertxResourceRecord(String domainName, String ipAddress) { + this.domainName = domainName; + this.ipAddress = ipAddress; + } + + public boolean isTruncated() { + return isTruncated; + } + + public VertxResourceRecord setTruncated(boolean truncated) { + isTruncated = truncated; + return this; + } + + @Override + public String getDomainName() { + return domainName; + } + + @Override + public RecordType getRecordType() { + return RecordType.A; + } + + @Override + public RecordClass getRecordClass() { + return RecordClass.IN; + } + + @Override + public int getTimeToLive() { + return 100; + } + + @Override + public String get(String id) { + return DnsAttribute.IP_ADDRESS.equals(id) ? ipAddress : null; + } + } + + private static final ResourceRecordEncoder TestAAAARecordEncoder = new ResourceRecordEncoder() { + @Override + protected void putResourceRecordData(IoBuffer ioBuffer, ResourceRecord resourceRecord) { + if (!resourceRecord.get(DnsAttribute.IP_ADDRESS).equals("::1")) { + throw new IllegalStateException("Only supposed to be used with IPV6 address of ::1"); + } + // encode the ::1 + ioBuffer.put(new byte[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }); + } + }; + + private final DnsMessageEncoder encoder = new DnsMessageEncoder(); + + private void encode(DnsMessage dnsMessage, IoBuffer buf) { + + // Hack + if (dnsMessage.getAnswerRecords().size() == 1 && dnsMessage.getAnswerRecords().get(0) instanceof VertxResourceRecord) { + VertxResourceRecord vrr = (VertxResourceRecord) dnsMessage.getAnswerRecords().get(0); + + DnsMessageModifier modifier = new DnsMessageModifier(); + modifier.setTransactionId(dnsMessage.getTransactionId()); + modifier.setMessageType(dnsMessage.getMessageType()); + modifier.setOpCode(dnsMessage.getOpCode()); + modifier.setAuthoritativeAnswer(dnsMessage.isAuthoritativeAnswer()); + modifier.setTruncated(dnsMessage.isTruncated()); + modifier.setRecursionDesired(dnsMessage.isRecursionDesired()); + modifier.setRecursionAvailable(dnsMessage.isRecursionAvailable()); + modifier.setReserved(dnsMessage.isReserved()); + modifier.setAcceptNonAuthenticatedData(dnsMessage.isAcceptNonAuthenticatedData()); + modifier.setResponseCode(dnsMessage.getResponseCode()); + modifier.setQuestionRecords(dnsMessage.getQuestionRecords()); + modifier.setAnswerRecords(dnsMessage.getAnswerRecords()); + modifier.setAuthorityRecords(dnsMessage.getAuthorityRecords()); + modifier.setAdditionalRecords(dnsMessage.getAdditionalRecords()); + + modifier.setTruncated(vrr.isTruncated); + + dnsMessage = modifier.getDnsMessage(); + } + + encoder.encode(buf, dnsMessage); + + for (ResourceRecord record : dnsMessage.getAnswerRecords()) { + // This is a hack to allow to also test for AAAA resolution as DnsMessageEncoder does not support it and it + // is hard to extend, because the interesting methods are private... + // In case of RecordType.AAAA we need to encode the RecordType by ourself + if (record.getRecordType() == RecordType.AAAA) { + try { + TestAAAARecordEncoder.put(buf, record); + } catch (IOException e) { + // Should never happen + throw new IllegalStateException(e); + } + } + } + } + + /** + * ProtocolCodecFactory which allows to test AAAA resolution + */ + private final class TestDnsProtocolUdpCodecFactory implements ProtocolCodecFactory { + @Override + public ProtocolEncoder getEncoder(IoSession session) throws Exception { + return new DnsUdpEncoder() { + + @Override + public void encode(IoSession session, Object message, ProtocolEncoderOutput out) { + IoBuffer buf = IoBuffer.allocate(1024); + FakeDNSServer.this.encode((DnsMessage) message, buf); + buf.flip(); + out.write(buf); + } + }; + } + + @Override + public ProtocolDecoder getDecoder(IoSession session) throws Exception { + return new DnsUdpDecoder(); + } + } + + /** + * ProtocolCodecFactory which allows to test AAAA resolution + */ + private final class TestDnsProtocolTcpCodecFactory implements ProtocolCodecFactory { + @Override + public ProtocolEncoder getEncoder(IoSession session) throws Exception { + return new DnsUdpEncoder() { + + @Override + public void encode(IoSession session, Object message, ProtocolEncoderOutput out) { + IoBuffer buf = IoBuffer.allocate(1024); + buf.putShort((short) 0); + FakeDNSServer.this.encode((DnsMessage) message, buf); + encoder.encode(buf, (DnsMessage) message); + int end = buf.position(); + short recordLength = (short) (end - 2); + buf.rewind(); + buf.putShort(recordLength); + buf.position(end); + buf.flip(); + out.write(buf); + } + }; + } + + @Override + public ProtocolDecoder getDecoder(IoSession session) throws Exception { + return new DnsTcpDecoder(); + } + } + + public void addRecordsToStore(String domainName, String... entries) { + Set records = new LinkedHashSet<>(); + Function createRecord = ipAddress -> new VertxResourceRecord(domainName, ipAddress); + for (String e : entries) { + records.add(createRecord.apply(e)); + } + store(x -> records); + } +} diff --git a/extensions/vertx-http/runtime/src/main/java/io/quarkus/vertx/http/runtime/ForwardedProxyHandler.java b/extensions/vertx-http/runtime/src/main/java/io/quarkus/vertx/http/runtime/ForwardedProxyHandler.java index a34288d0ad72b..dfe1d9f51a833 100644 --- a/extensions/vertx-http/runtime/src/main/java/io/quarkus/vertx/http/runtime/ForwardedProxyHandler.java +++ b/extensions/vertx-http/runtime/src/main/java/io/quarkus/vertx/http/runtime/ForwardedProxyHandler.java @@ -2,20 +2,28 @@ import static io.quarkus.vertx.http.runtime.TrustedProxyCheck.denyAll; +import java.net.Inet4Address; +import java.net.Inet6Address; import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collection; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.jboss.logging.Logger; import io.smallrye.common.net.Inet; -import io.vertx.core.AsyncResult; +import io.vertx.core.Future; import io.vertx.core.Handler; import io.vertx.core.Vertx; import io.vertx.core.dns.DnsClient; import io.vertx.core.http.HttpServerRequest; +import io.vertx.core.net.SocketAddress; import io.vertx.core.net.impl.SocketAddressImpl; /** @@ -75,30 +83,28 @@ private void lookupHostNamesAndHandleRequest(HttpServerRequest event, // we do not cache result as IP address may change, and we advise users to use IP or CIDR final var entry = iterator.next(); final String hostName = entry.getKey(); - dnsClient.lookup(hostName, - new Handler>() { - @Override - public void handle(AsyncResult stringAsyncResult) { - if (stringAsyncResult.succeeded() && stringAsyncResult.result() != null) { - var trustedIP = Inet.parseInetAddress(stringAsyncResult.result()); - if (trustedIP != null) { - // create proxy check for resolved IP and proceed with the lookup - lookupHostNamesAndHandleRequest(event, iterator, - builder.withTrustedIP(trustedIP, entry.getValue()), dnsClient); - } else { - logInvalidIpAddress(hostName); - // ignore this hostname proxy check and proceed with the lookup - lookupHostNamesAndHandleRequest(event, iterator, builder, dnsClient); - } - } else { - // inform we can't cope without IP - logDnsLookupFailure(hostName); - // ignore this hostname proxy check and proceed with the lookup - lookupHostNamesAndHandleRequest(event, iterator, builder, dnsClient); - } - } - }); + resolveHostNameToAllIpAddresses(dnsClient, hostName, event.remoteAddress(), results -> { + if (!results.isEmpty()) { + Set trustedIPs = results.stream().map(Inet::parseInetAddress).filter(Objects::nonNull) + .collect(Collectors.toSet()); + if (!trustedIPs.isEmpty()) { + // create proxy check for resolved IP and proceed with the lookup + lookupHostNamesAndHandleRequest(event, iterator, + builder.withTrustedIP(trustedIPs, entry.getValue()), dnsClient); + } else { + logInvalidIpAddress(hostName); + // ignore this hostname proxy check and proceed with the lookup + lookupHostNamesAndHandleRequest(event, iterator, builder, dnsClient); + } + } else { + // inform we can't cope without IP + logDnsLookupFailure(hostName); + // ignore this hostname proxy check and proceed with the lookup + lookupHostNamesAndHandleRequest(event, iterator, builder, dnsClient); + } + }); + } else { // DNS lookup is done if (builder.hasProxyChecks()) { @@ -110,6 +116,38 @@ public void handle(AsyncResult stringAsyncResult) { } } + private void resolveHostNameToAllIpAddresses(DnsClient dnsClient, String hostName, SocketAddress callersSocketAddress, + Handler> handler) { + ArrayList>> results = new ArrayList<>(); + InetAddress proxyIP = null; + if (callersSocketAddress != null) { + proxyIP = ((SocketAddressImpl) callersSocketAddress).ipAddress(); + } + // Match the lookup with the address type of the caller + if (proxyIP == null || proxyIP instanceof Inet4Address) { + results.add(dnsClient.resolveA(hostName)); + } + if (proxyIP == null || proxyIP instanceof Inet6Address) { + results.add(dnsClient.resolveAAAA(hostName)); + } + processFutures(results, new ArrayList<>(), handler); + } + + private void processFutures(ArrayList>> future, Collection results, + Handler> handler) { + if (!future.isEmpty()) { + Future> poll = future.remove(0); + poll.onComplete(result -> { + if (result.succeeded() && result.result() != null) { + results.addAll(result.result()); + } + processFutures(future, results, handler); + }); + } else { + handler.handle(results); + } + } + private void resolveProxyIpAndHandleRequest(HttpServerRequest event, TrustedProxyCheck.TrustedProxyCheckBuilder builder) { InetAddress proxyIP = ((SocketAddressImpl) event.remoteAddress()).ipAddress(); @@ -121,28 +159,26 @@ private void resolveProxyIpAndHandleRequest(HttpServerRequest event, if (proxyIP == null) { // perform DNS lookup, then create proxy check and handle request final String hostName = Objects.requireNonNull(event.remoteAddress().hostName()); - vertx.get().createDnsClient().lookup(hostName, - new Handler>() { - @Override - public void handle(AsyncResult stringAsyncResult) { - TrustedProxyCheck proxyCheck; - if (stringAsyncResult.succeeded()) { - // use resolved IP to build proxy check - final var proxyIP = Inet.parseInetAddress(stringAsyncResult.result()); - if (proxyIP != null) { - proxyCheck = builder.build(proxyIP, event.remoteAddress().port()); - } else { - logInvalidIpAddress(hostName); - proxyCheck = denyAll(); - } + resolveHostNameToAllIpAddresses(vertx.get().createDnsClient(), hostName, null, + results -> { + TrustedProxyCheck proxyCheck; + if (!results.isEmpty()) { + // use resolved IP to build proxy check + Set proxyIPs = results.stream().map(Inet::parseInetAddress).filter(Objects::nonNull) + .collect(Collectors.toSet()); + if (!proxyIPs.isEmpty()) { + proxyCheck = builder.build(proxyIPs, event.remoteAddress().port()); } else { - // we can't cope without IP => ignore headers - logDnsLookupFailure(hostName); + logInvalidIpAddress(hostName); proxyCheck = denyAll(); } - - handleForwardedServerRequest(event, proxyCheck); + } else { + // we can't cope without IP => ignore headers + logDnsLookupFailure(hostName); + proxyCheck = denyAll(); } + + handleForwardedServerRequest(event, proxyCheck); }); } else { // we have proxy IP => create proxy check and handle request diff --git a/extensions/vertx-http/runtime/src/main/java/io/quarkus/vertx/http/runtime/TrustedProxyCheck.java b/extensions/vertx-http/runtime/src/main/java/io/quarkus/vertx/http/runtime/TrustedProxyCheck.java index 2fa0f75f678df..9ad56ffb05403 100644 --- a/extensions/vertx-http/runtime/src/main/java/io/quarkus/vertx/http/runtime/TrustedProxyCheck.java +++ b/extensions/vertx-http/runtime/src/main/java/io/quarkus/vertx/http/runtime/TrustedProxyCheck.java @@ -2,6 +2,7 @@ import java.net.InetAddress; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -62,7 +63,7 @@ static TrustedProxyCheckBuilder builder(List parts) { return new TrustedProxyCheckBuilder(hostNameToPort, proxyChecks); } - TrustedProxyCheckBuilder withTrustedIP(InetAddress trustedIP, int trustedPort) { + TrustedProxyCheckBuilder withTrustedIP(Collection trustedIP, int trustedPort) { final List> proxyChecks = new ArrayList<>(this.proxyChecks); proxyChecks.add(createNewIpCheck(trustedIP, trustedPort)); return new TrustedProxyCheckBuilder(null, proxyChecks); @@ -87,6 +88,20 @@ public boolean isProxyAllowed() { }; } + TrustedProxyCheck build(Collection proxyIPs, int proxyPort) { + Objects.requireNonNull(proxyIPs); + return () -> { + for (BiPredicate proxyCheck : proxyChecks) { + for (InetAddress proxyIP : proxyIPs) { + if (proxyCheck.test(proxyIP, proxyPort)) { + return true; + } + } + } + return false; + }; + } + boolean hasHostNames() { return hasHostNames(this.hostNameToPort); } @@ -115,6 +130,20 @@ private boolean isPortOk(int port) { }; } + static BiPredicate createNewIpCheck(Collection trustedIP, int trustedPort) { + final boolean doNotCheckPort = trustedPort == 0; + return new BiPredicate<>() { + @Override + public boolean test(InetAddress proxyIP, Integer proxyPort) { + return isPortOk(proxyPort) && trustedIP.contains(proxyIP); + } + + private boolean isPortOk(int port) { + return doNotCheckPort || port == trustedPort; + } + }; + } + final class TrustedProxyCheckPart { final BiPredicate proxyCheck; diff --git a/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/core/runtime/config/AddressResolverConfiguration.java b/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/core/runtime/config/AddressResolverConfiguration.java index c2c8fb8c3afef..342b7ff8825e3 100644 --- a/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/core/runtime/config/AddressResolverConfiguration.java +++ b/extensions/vertx/runtime/src/main/java/io/quarkus/vertx/core/runtime/config/AddressResolverConfiguration.java @@ -61,7 +61,7 @@ public interface AddressResolverConfiguration { /** * Set the list of DNS server addresses, an address is the IP of the dns server, followed by an optional - * colon and a port, e.g {@code 8.8.8.8} or {code 192.168.0.1:40000}. When the list is empty, the resolver + * colon and a port, e.g {@code 8.8.8.8} or {@code 192.168.0.1:40000}. When the list is empty, the resolver * will use the list of the system DNS server addresses from the environment, if that list cannot be retrieved * it will use Google's public DNS servers {@code "8.8.8.8"} and {@code "8.8.4.4"}. **/