diff --git a/src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImpl.java b/src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImpl.java index b2a08a16f..0207f371b 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImpl.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImpl.java @@ -20,6 +20,9 @@ import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +import java.util.Map; +import java.util.Objects; + /** * @author David Katz * @author Christian Hoff @@ -28,23 +31,26 @@ public class MqttWebSocketConfigImpl implements MqttWebSocketConfig { static final @NotNull MqttWebSocketConfigImpl DEFAULT = new MqttWebSocketConfigImpl(DEFAULT_SERVER_PATH, DEFAULT_QUERY_STRING, DEFAULT_MQTT_SUBPROTOCOL, - DEFAULT_HANDSHAKE_TIMEOUT_MS); + DEFAULT_HANDSHAKE_TIMEOUT_MS, DEFAULT_HTTP_HEADERS); private final @NotNull String serverPath; private final @NotNull String queryString; private final @NotNull String subprotocol; private final int handshakeTimeoutMs; + private final Map httpHeaders; MqttWebSocketConfigImpl( final @NotNull String serverPath, final @NotNull String queryString, final @NotNull String subprotocol, - final int handshakeTimeoutMs) { + final int handshakeTimeoutMs, + final @NotNull Map httpHeaders) { this.serverPath = serverPath; this.queryString = queryString; this.subprotocol = subprotocol; this.handshakeTimeoutMs = handshakeTimeoutMs; + this.httpHeaders = httpHeaders; } @Override @@ -67,6 +73,11 @@ public int getHandshakeTimeoutMs() { return handshakeTimeoutMs; } + @Override + public @NotNull Map getHttpHeaders() { + return httpHeaders; + } + @Override public MqttWebSocketConfigImplBuilder.@NotNull Default extend() { return new MqttWebSocketConfigImplBuilder.Default(this); @@ -83,7 +94,8 @@ public boolean equals(final @Nullable Object o) { final MqttWebSocketConfigImpl that = (MqttWebSocketConfigImpl) o; return serverPath.equals(that.serverPath) && queryString.equals(that.queryString) && - subprotocol.equals(that.subprotocol) && (handshakeTimeoutMs == that.handshakeTimeoutMs); + subprotocol.equals(that.subprotocol) && (handshakeTimeoutMs == that.handshakeTimeoutMs) && + Objects.equals(httpHeaders, that.httpHeaders); } @Override @@ -92,6 +104,7 @@ public int hashCode() { result = 31 * result + queryString.hashCode(); result = 31 * result + subprotocol.hashCode(); result = 31 * result + Integer.hashCode(handshakeTimeoutMs); + result = 31 * result + httpHeaders.hashCode(); return result; } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImplBuilder.java b/src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImplBuilder.java index 808207498..215349d93 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImplBuilder.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImplBuilder.java @@ -21,6 +21,7 @@ import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.Function; @@ -33,6 +34,7 @@ public abstract class MqttWebSocketConfigImplBuilder httpHeaders = MqttWebSocketConfigImpl.DEFAULT_HTTP_HEADERS; MqttWebSocketConfigImplBuilder() {} @@ -70,8 +72,14 @@ public abstract class MqttWebSocketConfigImplBuilder httpHeaders) { + Checks.notNull(httpHeaders, "Http headers"); + this.httpHeaders = httpHeaders; + return self(); + } + public @NotNull MqttWebSocketConfigImpl build() { - return new MqttWebSocketConfigImpl(serverPath, queryString, subprotocol, handshakeTimeoutMs); + return new MqttWebSocketConfigImpl(serverPath, queryString, subprotocol, handshakeTimeoutMs, httpHeaders); } public static class Default extends MqttWebSocketConfigImplBuilder implements MqttWebSocketConfigBuilder { diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/websocket/MqttWebSocketHttpHeaders.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/websocket/MqttWebSocketHttpHeaders.java new file mode 100644 index 000000000..fce57d9f4 --- /dev/null +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/websocket/MqttWebSocketHttpHeaders.java @@ -0,0 +1,44 @@ +/* + * Copyright 2018-present HiveMQ and the HiveMQ Community + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hivemq.client.internal.mqtt.handler.websocket; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.HttpRequest; +import org.jetbrains.annotations.NotNull; + +import java.util.Map; + +public class MqttWebSocketHttpHeaders extends ChannelDuplexHandler { + + public static final String HTTP_HEADERS = "http.headers"; + private final @NotNull Map httpHeaders; + + public MqttWebSocketHttpHeaders(@NotNull final Map httpHeaders) { + this.httpHeaders = httpHeaders; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof HttpRequest) { + final HttpRequest request = (HttpRequest) msg; + this.httpHeaders.forEach((key, value) -> request.headers().set(key, value)); + } + super.write(ctx, msg, promise); + } +} diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/websocket/MqttWebSocketInitializer.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/websocket/MqttWebSocketInitializer.java index ed1c4c3f4..0e29b8b1c 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/websocket/MqttWebSocketInitializer.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/websocket/MqttWebSocketInitializer.java @@ -79,6 +79,7 @@ public void initChannel( channel.pipeline() .addLast(HTTP_CODEC_NAME, new HttpClientCodec()) .addLast(HTTP_AGGREGATOR_NAME, new HttpObjectAggregator(65_535)) + .addLast(MqttWebSocketHttpHeaders.HTTP_HEADERS, new MqttWebSocketHttpHeaders(webSocketConfig.getHttpHeaders())) .addLast(MqttWebsocketHandshakeHandler.NAME, new MqttWebsocketHandshakeHandler(handshaker, webSocketConfig.getHandshakeTimeoutMs(), onSuccess, onError)) diff --git a/src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfig.java b/src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfig.java index 41b17ff85..a7cad5f9b 100644 --- a/src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfig.java +++ b/src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfig.java @@ -20,6 +20,9 @@ import com.hivemq.client.internal.mqtt.MqttWebSocketConfigImplBuilder; import org.jetbrains.annotations.NotNull; +import java.util.LinkedHashMap; +import java.util.Map; + /** * Configuration for a WebSocket transport to use by {@link MqttClient MQTT clients}. * @@ -51,6 +54,12 @@ public interface MqttWebSocketConfig { * @since 1.2 */ int DEFAULT_HANDSHAKE_TIMEOUT_MS = 10_000; + /** + * The default map of headers. + * @since 1.2.3 + */ + @NotNull Map DEFAULT_HTTP_HEADERS = new LinkedHashMap<>(); + /** * Creates a builder for a WebSocket configuration. @@ -82,6 +91,12 @@ public interface MqttWebSocketConfig { */ int getHandshakeTimeoutMs(); + /** + * @return map of already set headers. + * @since 1.2.3 + */ + @NotNull Map getHttpHeaders(); + /** * Creates a builder for extending this WebSocket configuration. * @@ -89,4 +104,5 @@ public interface MqttWebSocketConfig { * @since 1.1 */ @NotNull MqttWebSocketConfigBuilder extend(); + } diff --git a/src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfigBuilderBase.java b/src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfigBuilderBase.java index 79a0f1f82..c414ef601 100644 --- a/src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfigBuilderBase.java +++ b/src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfigBuilderBase.java @@ -20,6 +20,7 @@ import com.hivemq.client.annotations.DoNotImplement; import org.jetbrains.annotations.NotNull; +import java.util.Map; import java.util.concurrent.TimeUnit; /** @@ -71,4 +72,13 @@ public interface MqttWebSocketConfigBuilderBase httpHeaders); }