Skip to content

Added custom http headers to MqttWebSocket #500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.Map;
import java.util.Objects;

/**
* @author David Katz
* @author Christian Hoff
Expand All @@ -28,23 +31,26 @@ public class MqttWebSocketConfigImpl implements MqttWebSocketConfig {

static final @NotNull MqttWebSocketConfigImpl DEFAULT =
new MqttWebSocketConfigImpl(DEFAULT_SERVER_PATH, DEFAULT_QUERY_STRING, DEFAULT_MQTT_SUBPROTOCOL,
DEFAULT_HANDSHAKE_TIMEOUT_MS);
DEFAULT_HANDSHAKE_TIMEOUT_MS, DEFAULT_HTTP_HEADERS);

private final @NotNull String serverPath;
private final @NotNull String queryString;
private final @NotNull String subprotocol;
private final int handshakeTimeoutMs;
private final Map<String, String> httpHeaders;

MqttWebSocketConfigImpl(
final @NotNull String serverPath,
final @NotNull String queryString,
final @NotNull String subprotocol,
final int handshakeTimeoutMs) {
final int handshakeTimeoutMs,
final @NotNull Map<String, String> httpHeaders) {

this.serverPath = serverPath;
this.queryString = queryString;
this.subprotocol = subprotocol;
this.handshakeTimeoutMs = handshakeTimeoutMs;
this.httpHeaders = httpHeaders;
}

@Override
Expand All @@ -67,6 +73,11 @@ public int getHandshakeTimeoutMs() {
return handshakeTimeoutMs;
}

@Override
public @NotNull Map<String, String> getHttpHeaders() {
return httpHeaders;
}

@Override
public MqttWebSocketConfigImplBuilder.@NotNull Default extend() {
return new MqttWebSocketConfigImplBuilder.Default(this);
Expand All @@ -83,7 +94,8 @@ public boolean equals(final @Nullable Object o) {
final MqttWebSocketConfigImpl that = (MqttWebSocketConfigImpl) o;

return serverPath.equals(that.serverPath) && queryString.equals(that.queryString) &&
subprotocol.equals(that.subprotocol) && (handshakeTimeoutMs == that.handshakeTimeoutMs);
subprotocol.equals(that.subprotocol) && (handshakeTimeoutMs == that.handshakeTimeoutMs) &&
Objects.equals(httpHeaders, that.httpHeaders);
}

@Override
Expand All @@ -92,6 +104,7 @@ public int hashCode() {
result = 31 * result + queryString.hashCode();
result = 31 * result + subprotocol.hashCode();
result = 31 * result + Integer.hashCode(handshakeTimeoutMs);
result = 31 * result + httpHeaders.hashCode();
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

Expand All @@ -33,6 +34,7 @@ public abstract class MqttWebSocketConfigImplBuilder<B extends MqttWebSocketConf
private @NotNull String queryString = MqttWebSocketConfigImpl.DEFAULT_QUERY_STRING;
private @NotNull String subprotocol = MqttWebSocketConfigImpl.DEFAULT_MQTT_SUBPROTOCOL;
private int handshakeTimeoutMs = MqttWebSocketConfigImpl.DEFAULT_HANDSHAKE_TIMEOUT_MS;
private @NotNull Map<String, String> httpHeaders = MqttWebSocketConfigImpl.DEFAULT_HTTP_HEADERS;

MqttWebSocketConfigImplBuilder() {}

Expand Down Expand Up @@ -70,8 +72,14 @@ public abstract class MqttWebSocketConfigImplBuilder<B extends MqttWebSocketConf
return self();
}

public @NotNull B httpHeaders(final @Nullable Map<String, String> httpHeaders) {
Checks.notNull(httpHeaders, "Http headers");
this.httpHeaders = httpHeaders;
return self();
}

public @NotNull MqttWebSocketConfigImpl build() {
return new MqttWebSocketConfigImpl(serverPath, queryString, subprotocol, handshakeTimeoutMs);
return new MqttWebSocketConfigImpl(serverPath, queryString, subprotocol, handshakeTimeoutMs, httpHeaders);
}

public static class Default extends MqttWebSocketConfigImplBuilder<Default> implements MqttWebSocketConfigBuilder {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2018-present HiveMQ and the HiveMQ Community
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hivemq.client.internal.mqtt.handler.websocket;

import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpRequest;
import org.jetbrains.annotations.NotNull;

import java.util.Map;

public class MqttWebSocketHttpHeaders extends ChannelDuplexHandler {

public static final String HTTP_HEADERS = "http.headers";
private final @NotNull Map<String, String> httpHeaders;

public MqttWebSocketHttpHeaders(@NotNull final Map<String, String> httpHeaders) {
this.httpHeaders = httpHeaders;
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof HttpRequest) {
final HttpRequest request = (HttpRequest) msg;
this.httpHeaders.forEach((key, value) -> request.headers().set(key, value));
}
super.write(ctx, msg, promise);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public void initChannel(
channel.pipeline()
.addLast(HTTP_CODEC_NAME, new HttpClientCodec())
.addLast(HTTP_AGGREGATOR_NAME, new HttpObjectAggregator(65_535))
.addLast(MqttWebSocketHttpHeaders.HTTP_HEADERS, new MqttWebSocketHttpHeaders(webSocketConfig.getHttpHeaders()))
.addLast(MqttWebsocketHandshakeHandler.NAME,
new MqttWebsocketHandshakeHandler(handshaker, webSocketConfig.getHandshakeTimeoutMs(),
onSuccess, onError))
Expand Down
16 changes: 16 additions & 0 deletions src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import com.hivemq.client.internal.mqtt.MqttWebSocketConfigImplBuilder;
import org.jetbrains.annotations.NotNull;

import java.util.LinkedHashMap;
import java.util.Map;

/**
* Configuration for a WebSocket transport to use by {@link MqttClient MQTT clients}.
*
Expand Down Expand Up @@ -51,6 +54,12 @@ public interface MqttWebSocketConfig {
* @since 1.2
*/
int DEFAULT_HANDSHAKE_TIMEOUT_MS = 10_000;
/**
* The default map of headers.
* @since 1.2.3
*/
@NotNull Map<String, String> DEFAULT_HTTP_HEADERS = new LinkedHashMap<>();


/**
* Creates a builder for a WebSocket configuration.
Expand Down Expand Up @@ -82,11 +91,18 @@ public interface MqttWebSocketConfig {
*/
int getHandshakeTimeoutMs();

/**
* @return map of already set headers.
* @since 1.2.3
*/
@NotNull Map<String, String> getHttpHeaders();

/**
* Creates a builder for extending this WebSocket configuration.
*
* @return the created builder.
* @since 1.1
*/
@NotNull MqttWebSocketConfigBuilder extend();

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.hivemq.client.annotations.DoNotImplement;
import org.jetbrains.annotations.NotNull;

import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
Expand Down Expand Up @@ -71,4 +72,13 @@ public interface MqttWebSocketConfigBuilderBase<B extends MqttWebSocketConfigBui
*/
@CheckReturnValue
@NotNull B handshakeTimeout(long timeout, @NotNull TimeUnit timeUnit);

/**
* Sets the {@link MqttWebSocketConfig#getHttpHeaders() headers}.
*
* @param httpHeaders http headers.
* @return the builder.
*/
@CheckReturnValue
@NotNull B httpHeaders(@NotNull Map<String, String> httpHeaders);
}