Skip to content

Commit 16f2fec

Browse files
committed
Added custom http headers to MqttWebSocket
1 parent 9d84ea9 commit 16f2fec

File tree

6 files changed

+81
-5
lines changed

6 files changed

+81
-5
lines changed

Diff for: src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImpl.java

+16-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import org.jetbrains.annotations.NotNull;
2121
import org.jetbrains.annotations.Nullable;
2222

23+
import java.util.Map;
24+
import java.util.Objects;
25+
2326
/**
2427
* @author David Katz
2528
* @author Christian Hoff
@@ -28,23 +31,26 @@ public class MqttWebSocketConfigImpl implements MqttWebSocketConfig {
2831

2932
static final @NotNull MqttWebSocketConfigImpl DEFAULT =
3033
new MqttWebSocketConfigImpl(DEFAULT_SERVER_PATH, DEFAULT_QUERY_STRING, DEFAULT_MQTT_SUBPROTOCOL,
31-
DEFAULT_HANDSHAKE_TIMEOUT_MS);
34+
DEFAULT_HANDSHAKE_TIMEOUT_MS, DEFAULT_HTTP_HEADERS);
3235

3336
private final @NotNull String serverPath;
3437
private final @NotNull String queryString;
3538
private final @NotNull String subprotocol;
3639
private final int handshakeTimeoutMs;
40+
private final Map<String, String> httpHeaders;
3741

3842
MqttWebSocketConfigImpl(
3943
final @NotNull String serverPath,
4044
final @NotNull String queryString,
4145
final @NotNull String subprotocol,
42-
final int handshakeTimeoutMs) {
46+
final int handshakeTimeoutMs,
47+
final @NotNull Map<String, String> httpHeaders) {
4348

4449
this.serverPath = serverPath;
4550
this.queryString = queryString;
4651
this.subprotocol = subprotocol;
4752
this.handshakeTimeoutMs = handshakeTimeoutMs;
53+
this.httpHeaders = httpHeaders;
4854
}
4955

5056
@Override
@@ -67,6 +73,11 @@ public int getHandshakeTimeoutMs() {
6773
return handshakeTimeoutMs;
6874
}
6975

76+
@Override
77+
public @NotNull Map<String, String> getHttpHeaders() {
78+
return httpHeaders;
79+
}
80+
7081
@Override
7182
public MqttWebSocketConfigImplBuilder.@NotNull Default extend() {
7283
return new MqttWebSocketConfigImplBuilder.Default(this);
@@ -83,7 +94,8 @@ public boolean equals(final @Nullable Object o) {
8394
final MqttWebSocketConfigImpl that = (MqttWebSocketConfigImpl) o;
8495

8596
return serverPath.equals(that.serverPath) && queryString.equals(that.queryString) &&
86-
subprotocol.equals(that.subprotocol) && (handshakeTimeoutMs == that.handshakeTimeoutMs);
97+
subprotocol.equals(that.subprotocol) && (handshakeTimeoutMs == that.handshakeTimeoutMs) &&
98+
Objects.equals(httpHeaders, that.httpHeaders);
8799
}
88100

89101
@Override
@@ -92,6 +104,7 @@ public int hashCode() {
92104
result = 31 * result + queryString.hashCode();
93105
result = 31 * result + subprotocol.hashCode();
94106
result = 31 * result + Integer.hashCode(handshakeTimeoutMs);
107+
result = 31 * result + httpHeaders.hashCode();
95108
return result;
96109
}
97110
}

Diff for: src/main/java/com/hivemq/client/internal/mqtt/MqttWebSocketConfigImplBuilder.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.jetbrains.annotations.NotNull;
2222
import org.jetbrains.annotations.Nullable;
2323

24+
import java.util.Map;
2425
import java.util.concurrent.TimeUnit;
2526
import java.util.function.Function;
2627

@@ -33,6 +34,7 @@ public abstract class MqttWebSocketConfigImplBuilder<B extends MqttWebSocketConf
3334
private @NotNull String queryString = MqttWebSocketConfigImpl.DEFAULT_QUERY_STRING;
3435
private @NotNull String subprotocol = MqttWebSocketConfigImpl.DEFAULT_MQTT_SUBPROTOCOL;
3536
private int handshakeTimeoutMs = MqttWebSocketConfigImpl.DEFAULT_HANDSHAKE_TIMEOUT_MS;
37+
private @NotNull Map<String, String> httpHeaders = MqttWebSocketConfigImpl.DEFAULT_HTTP_HEADERS;
3638

3739
MqttWebSocketConfigImplBuilder() {}
3840

@@ -70,8 +72,14 @@ public abstract class MqttWebSocketConfigImplBuilder<B extends MqttWebSocketConf
7072
return self();
7173
}
7274

75+
public @NotNull B httpHeaders(final @Nullable Map<String, String> httpHeaders) {
76+
Checks.notNull(httpHeaders, "Http headers");
77+
this.httpHeaders = httpHeaders;
78+
return self();
79+
}
80+
7381
public @NotNull MqttWebSocketConfigImpl build() {
74-
return new MqttWebSocketConfigImpl(serverPath, queryString, subprotocol, handshakeTimeoutMs);
82+
return new MqttWebSocketConfigImpl(serverPath, queryString, subprotocol, handshakeTimeoutMs, httpHeaders);
7583
}
7684

7785
public static class Default extends MqttWebSocketConfigImplBuilder<Default> implements MqttWebSocketConfigBuilder {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.hivemq.client.internal.mqtt.handler.websocket;
2+
3+
import io.netty.channel.ChannelDuplexHandler;
4+
import io.netty.channel.ChannelHandlerContext;
5+
import io.netty.channel.ChannelPromise;
6+
import io.netty.handler.codec.http.HttpRequest;
7+
import org.jetbrains.annotations.NotNull;
8+
9+
import java.util.Map;
10+
11+
public class MqttWebSocketHttpHeaders extends ChannelDuplexHandler {
12+
13+
public static final String HTTP_HEADERS = "http.headers";
14+
private final @NotNull Map<String, String> httpHeaders;
15+
16+
public MqttWebSocketHttpHeaders(@NotNull final Map<String, String> httpHeaders) {
17+
this.httpHeaders = httpHeaders;
18+
}
19+
20+
@Override
21+
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
22+
if (msg instanceof HttpRequest) {
23+
final HttpRequest request = (HttpRequest) msg;
24+
this.httpHeaders.forEach((key, value) -> request.headers().set(key, value));
25+
}
26+
super.write(ctx, msg, promise);
27+
}
28+
}

Diff for: src/main/java/com/hivemq/client/internal/mqtt/handler/websocket/MqttWebSocketInitializer.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ public void initChannel(
8282
.addLast(MqttWebsocketHandshakeHandler.NAME,
8383
new MqttWebsocketHandshakeHandler(handshaker, webSocketConfig.getHandshakeTimeoutMs(),
8484
onSuccess, onError))
85-
.addLast(MqttWebSocketCodec.NAME, mqttWebSocketCodec);
85+
.addLast(MqttWebSocketCodec.NAME, mqttWebSocketCodec)
86+
.addLast(MqttWebSocketHttpHeaders.HTTP_HEADERS, new MqttWebSocketHttpHeaders(webSocketConfig.getHttpHeaders()));
8687
}
8788
}

Diff for: src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfig.java

+16
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import com.hivemq.client.internal.mqtt.MqttWebSocketConfigImplBuilder;
2121
import org.jetbrains.annotations.NotNull;
2222

23+
import java.util.LinkedHashMap;
24+
import java.util.Map;
25+
2326
/**
2427
* Configuration for a WebSocket transport to use by {@link MqttClient MQTT clients}.
2528
*
@@ -51,6 +54,12 @@ public interface MqttWebSocketConfig {
5154
* @since 1.2
5255
*/
5356
int DEFAULT_HANDSHAKE_TIMEOUT_MS = 10_000;
57+
/**
58+
* The default map of headers.
59+
* @since 1.2.3
60+
*/
61+
@NotNull Map<String, String> DEFAULT_HTTP_HEADERS = new LinkedHashMap<>();
62+
5463

5564
/**
5665
* Creates a builder for a WebSocket configuration.
@@ -82,11 +91,18 @@ public interface MqttWebSocketConfig {
8291
*/
8392
int getHandshakeTimeoutMs();
8493

94+
/**
95+
* @return map of already set headers.
96+
* @since 1.2.3
97+
*/
98+
@NotNull Map<String, String> getHttpHeaders();
99+
85100
/**
86101
* Creates a builder for extending this WebSocket configuration.
87102
*
88103
* @return the created builder.
89104
* @since 1.1
90105
*/
91106
@NotNull MqttWebSocketConfigBuilder extend();
107+
92108
}

Diff for: src/main/java/com/hivemq/client/mqtt/MqttWebSocketConfigBuilderBase.java

+10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import com.hivemq.client.annotations.DoNotImplement;
2121
import org.jetbrains.annotations.NotNull;
2222

23+
import java.util.Map;
2324
import java.util.concurrent.TimeUnit;
2425

2526
/**
@@ -71,4 +72,13 @@ public interface MqttWebSocketConfigBuilderBase<B extends MqttWebSocketConfigBui
7172
*/
7273
@CheckReturnValue
7374
@NotNull B handshakeTimeout(long timeout, @NotNull TimeUnit timeUnit);
75+
76+
/**
77+
* Sets the {@link MqttWebSocketConfig#getHttpHeaders() headers}.
78+
*
79+
* @param httpHeaders http headers.
80+
* @return the builder.
81+
*/
82+
@CheckReturnValue
83+
@NotNull B httpHeaders(@NotNull Map<String, String> httpHeaders);
7484
}

0 commit comments

Comments
 (0)