Skip to content

[1.3] connect restrictions mqtt3 #512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,36 @@ dependencies {
testRuntimeOnly("org.slf4j:slf4j-simple:${property("slf4j.version")}")
}

/* ******************** integration Tests ******************** */

sourceSets.create("integrationTest") {
compileClasspath += sourceSets.main.get().output
runtimeClasspath += sourceSets.main.get().output
}

val integrationTestImplementation: Configuration by configurations.getting {
extendsFrom(configurations.testImplementation.get())
}
val integrationTestRuntimeOnly: Configuration by configurations.getting {
extendsFrom(configurations.testRuntimeOnly.get())
}

dependencies {
integrationTestImplementation("com.hivemq:hivemq-testcontainer-junit5:${property("hivemq-testcontainer.version")}")
integrationTestImplementation("com.hivemq:hivemq-extension-sdk:${property("hivemq-extension-sdk.version")}")
}

val integrationTest by tasks.registering(Test::class) {
group = "verification"
description = "Runs integration tests."
useJUnitPlatform()
testClassesDirs = sourceSets["integrationTest"].output.classesDirs
classpath = sourceSets["integrationTest"].runtimeClasspath
shouldRunAfter(tasks.test)
}

tasks.check { dependsOn(integrationTest) }


/* ******************** jars ******************** */

Expand Down
2 changes: 2 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ mockito.version=2.18.3
guava.version=24.1-jre
bouncycastle.version=1.59
paho.version=1.2.0
hivemq-testcontainer.version=2.0.0
hivemq-extension-sdk.version=4.5.0
#
# plugins
#
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package com.hivemq.client.restrictions;

import com.hivemq.client.mqtt.MqttGlobalPublishFilter;
import com.hivemq.client.mqtt.datatypes.MqttQos;
import com.hivemq.client.mqtt.mqtt3.Mqtt3Client;
import com.hivemq.client.mqtt.mqtt5.Mqtt5BlockingClient;
import com.hivemq.client.mqtt.mqtt5.Mqtt5Client;
import com.hivemq.client.mqtt.mqtt5.message.publish.Mqtt5Publish;
import com.hivemq.extension.sdk.api.ExtensionMain;
import com.hivemq.extension.sdk.api.client.ClientContext;
import com.hivemq.extension.sdk.api.client.parameter.InitializerInput;
import com.hivemq.extension.sdk.api.interceptor.puback.PubackOutboundInterceptor;
import com.hivemq.extension.sdk.api.interceptor.puback.parameter.PubackOutboundInput;
import com.hivemq.extension.sdk.api.interceptor.puback.parameter.PubackOutboundOutput;
import com.hivemq.extension.sdk.api.parameter.ExtensionStartInput;
import com.hivemq.extension.sdk.api.parameter.ExtensionStartOutput;
import com.hivemq.extension.sdk.api.parameter.ExtensionStopInput;
import com.hivemq.extension.sdk.api.parameter.ExtensionStopOutput;
import com.hivemq.extension.sdk.api.services.Services;
import com.hivemq.extension.sdk.api.services.intializer.ClientInitializer;
import com.hivemq.testcontainer.core.HiveMQExtension;
import com.hivemq.testcontainer.junit5.HiveMQTestContainerExtension;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.testcontainers.utility.MountableFile;

import java.time.Duration;
import java.util.concurrent.ConcurrentLinkedQueue;

import static org.junit.jupiter.api.Assertions.assertEquals;

/**
* @author Yannick Weber
*/
public class Mqtt3SendMaximumIT {

public static final int RECEIVE_MAXIMUM = 10;
public static final @NotNull HiveMQExtension NO_PUBACK_EXTENSION = HiveMQExtension.builder()
.version("1.0.0")
.priority(100)
.name("No PUBACK Extension")
.id("no-puback-extension")
.mainClass(NoPubackExtension.class)
.build();

@RegisterExtension
public final @NotNull HiveMQTestContainerExtension hivemq =
new HiveMQTestContainerExtension().withExtension(NO_PUBACK_EXTENSION)
.withHiveMQConfig(MountableFile.forClasspathResource("/config.xml"));

@Test
void mqtt3_sendMaximum_applied() throws InterruptedException {

final Mqtt3Client publisher = Mqtt3Client.builder().serverPort(hivemq.getMqttPort()).build();
publisher.toBlocking().connectWith().restrictions().sendMaximum(RECEIVE_MAXIMUM).applyRestrictions().send();

final ConcurrentLinkedQueue<Mqtt5Publish> publishes = new ConcurrentLinkedQueue<>();
final Mqtt5BlockingClient subscriber = Mqtt5Client.builder().serverPort(hivemq.getMqttPort()).buildBlocking();
subscriber.connectWith().send();
subscriber.toAsync().publishes(MqttGlobalPublishFilter.ALL, publishes::add);
subscriber.subscribeWith().topicFilter("#").send();

for (int i = 0; i < 12; i++) {
publisher.toAsync().publishWith().topic("test").qos(MqttQos.AT_LEAST_ONCE).send();
}

Thread.sleep(10000);

assertEquals(RECEIVE_MAXIMUM, publishes.size());
}

public static class NoPubackExtension implements ExtensionMain {

@Override
public void extensionStart(
final @NotNull ExtensionStartInput extensionStartInput,
final @NotNull ExtensionStartOutput extensionStartOutput) {
Services.initializerRegistry().setClientInitializer(new MyClientInitializer());
}

@Override
public void extensionStop(
final @NotNull ExtensionStopInput extensionStopInput,
final @NotNull ExtensionStopOutput extensionStopOutput) {

}
}

public static class MyClientInitializer implements ClientInitializer {

@Override
public void initialize(
final @NotNull InitializerInput initializerInput, final @NotNull ClientContext clientContext) {
clientContext.addPubackOutboundInterceptor(new NoPubackInterceptorHandler());
}
}

public static class NoPubackInterceptorHandler implements PubackOutboundInterceptor {

@Override
public void onOutboundPuback(
final @NotNull PubackOutboundInput pubackOutboundInput,
final @NotNull PubackOutboundOutput pubackOutboundOutput) {
pubackOutboundOutput.async(Duration.ofHours(1));
}
}

}
6 changes: 6 additions & 0 deletions src/integrationTest/resources/config.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<?xml version="1.0"?>
<hivemq>
<persistence>
<mode>in-memory</mode>
</persistence>
</hivemq>
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.hivemq.client.internal.mqtt.util.MqttChecks;
import com.hivemq.client.internal.util.Checks;
import com.hivemq.client.mqtt.mqtt3.message.connect.Mqtt3ConnectRestrictionsBuilder;
import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5ConnectRestrictionsBuilder;
import org.jetbrains.annotations.NotNull;

Expand Down Expand Up @@ -98,7 +99,7 @@ public abstract class MqttConnectRestrictionsBuilder<B extends MqttConnectRestri
}

public static class Default extends MqttConnectRestrictionsBuilder<Default>
implements Mqtt5ConnectRestrictionsBuilder {
implements Mqtt5ConnectRestrictionsBuilder, Mqtt3ConnectRestrictionsBuilder {

public Default() {}

Expand All @@ -113,9 +114,9 @@ public Default() {}
}

public static class Nested<P> extends MqttConnectRestrictionsBuilder<Nested<P>>
implements Mqtt5ConnectRestrictionsBuilder.Nested<P> {
implements Mqtt5ConnectRestrictionsBuilder.Nested<P>, Mqtt3ConnectRestrictionsBuilder.Nested<P> {

Nested(
public Nested(
final @NotNull MqttConnectRestrictions restrictions,
final @NotNull Function<? super MqttConnectRestrictions, P> parentConsumer) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
import com.hivemq.client.internal.mqtt.message.auth.mqtt3.Mqtt3SimpleAuthView;
import com.hivemq.client.internal.mqtt.message.connect.MqttConnect;
import com.hivemq.client.internal.mqtt.message.connect.MqttConnectRestrictions;
import com.hivemq.client.internal.mqtt.message.connect.MqttConnectRestrictionsBuilder;
import com.hivemq.client.internal.mqtt.message.publish.MqttWillPublish;
import com.hivemq.client.internal.mqtt.message.publish.mqtt3.Mqtt3PublishView;
import com.hivemq.client.internal.util.Checks;
import com.hivemq.client.mqtt.mqtt3.message.auth.Mqtt3SimpleAuth;
import com.hivemq.client.mqtt.mqtt3.message.connect.Mqtt3Connect;
import com.hivemq.client.mqtt.mqtt3.message.connect.Mqtt3ConnectRestrictions;
import com.hivemq.client.mqtt.mqtt3.message.publish.Mqtt3Publish;
import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5ConnectRestrictions;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand All @@ -38,26 +42,28 @@
@Immutable
public class Mqtt3ConnectView implements Mqtt3Connect {

public static final @NotNull Mqtt3ConnectView DEFAULT = of(DEFAULT_KEEP_ALIVE, DEFAULT_CLEAN_SESSION, null, null);
public static final @NotNull Mqtt3ConnectView DEFAULT = of(DEFAULT_KEEP_ALIVE, DEFAULT_CLEAN_SESSION, null, null, MqttConnectRestrictions.DEFAULT);

private static @NotNull MqttConnect delegate(
final int keepAlive,
final boolean cleanSession,
final @Nullable MqttSimpleAuth simpleAuth,
final @Nullable MqttWillPublish willPublish) {
final @Nullable MqttWillPublish willPublish,
final @NotNull MqttConnectRestrictions restrictions) {

return new MqttConnect(keepAlive, cleanSession, cleanSession ? 0 : MqttConnect.NO_SESSION_EXPIRY,
MqttConnectRestrictions.DEFAULT, simpleAuth, null, willPublish,
restrictions, simpleAuth, null, willPublish,
MqttUserPropertiesImpl.NO_USER_PROPERTIES);
}

static @NotNull Mqtt3ConnectView of(
final int keepAlive,
final boolean cleanSession,
final @Nullable MqttSimpleAuth simpleAuth,
final @Nullable MqttWillPublish willPublish) {
final @Nullable MqttWillPublish willPublish,
final @NotNull MqttConnectRestrictions mqttConnectRestrictions) {

return new Mqtt3ConnectView(delegate(keepAlive, cleanSession, simpleAuth, willPublish));
return new Mqtt3ConnectView(delegate(keepAlive, cleanSession, simpleAuth, willPublish, mqttConnectRestrictions));
}

public static @NotNull Mqtt3ConnectView of(final @NotNull MqttConnect delegate) {
Expand All @@ -80,6 +86,11 @@ public boolean isCleanSession() {
return delegate.isCleanStart();
}

@Override
public @NotNull Mqtt3ConnectRestrictions getRestrictions() {
return delegate.getRestrictions();
}

@Override
public @NotNull Optional<Mqtt3SimpleAuth> getSimpleAuth() {
return Optional.ofNullable(getRawSimpleAuth());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@
import com.hivemq.client.internal.mqtt.message.auth.mqtt3.Mqtt3SimpleAuthView;
import com.hivemq.client.internal.mqtt.message.auth.mqtt3.Mqtt3SimpleAuthViewBuilder;
import com.hivemq.client.internal.mqtt.message.connect.MqttConnect;
import com.hivemq.client.internal.mqtt.message.connect.MqttConnectRestrictions;
import com.hivemq.client.internal.mqtt.message.connect.MqttConnectRestrictionsBuilder;
import com.hivemq.client.internal.mqtt.message.publish.MqttWillPublish;
import com.hivemq.client.internal.mqtt.message.publish.mqtt3.Mqtt3PublishView;
import com.hivemq.client.internal.mqtt.message.publish.mqtt3.Mqtt3PublishViewBuilder;
import com.hivemq.client.internal.util.Checks;
import com.hivemq.client.mqtt.mqtt3.message.auth.Mqtt3SimpleAuth;
import com.hivemq.client.mqtt.mqtt3.message.connect.Mqtt3ConnectBuilder;
import com.hivemq.client.mqtt.mqtt3.message.connect.Mqtt3ConnectRestrictions;
import com.hivemq.client.mqtt.mqtt3.message.connect.Mqtt3ConnectRestrictionsBuilder;
import com.hivemq.client.mqtt.mqtt3.message.publish.Mqtt3Publish;
import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5ConnectRestrictions;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand All @@ -41,6 +46,7 @@ public abstract class Mqtt3ConnectViewBuilder<B extends Mqtt3ConnectViewBuilder<
private boolean cleanSession = Mqtt3ConnectView.DEFAULT_CLEAN_SESSION;
private @Nullable MqttSimpleAuth simpleAuth;
private @Nullable MqttWillPublish willPublish;
private @NotNull MqttConnectRestrictions restrictions = MqttConnectRestrictions.DEFAULT;

Mqtt3ConnectViewBuilder() {}

Expand All @@ -50,6 +56,7 @@ public abstract class Mqtt3ConnectViewBuilder<B extends Mqtt3ConnectViewBuilder<
cleanSession = delegate.isCleanStart();
simpleAuth = delegate.getRawSimpleAuth();
willPublish = delegate.getRawWillPublish();
restrictions = delegate.getRestrictions();
}

abstract @NotNull B self();
Expand All @@ -69,6 +76,15 @@ public abstract class Mqtt3ConnectViewBuilder<B extends Mqtt3ConnectViewBuilder<
return self();
}

public @NotNull B restrictions(final @Nullable Mqtt3ConnectRestrictions restrictions) {
this.restrictions = Checks.notImplemented(restrictions, MqttConnectRestrictions.class, "Connect restrictions");
return self();
}

public MqttConnectRestrictionsBuilder.@NotNull Nested<B> restrictions() {
return new MqttConnectRestrictionsBuilder.Nested<>(restrictions, this::restrictions);
}

public @NotNull B simpleAuth(final @Nullable Mqtt3SimpleAuth simpleAuth) {
this.simpleAuth = (simpleAuth == null) ? null :
Checks.notImplemented(simpleAuth, Mqtt3SimpleAuthView.class, "Simple auth").getDelegate();
Expand All @@ -90,7 +106,7 @@ public abstract class Mqtt3ConnectViewBuilder<B extends Mqtt3ConnectViewBuilder<
}

public @NotNull Mqtt3ConnectView build() {
return Mqtt3ConnectView.of(keepAliveSeconds, cleanSession, simpleAuth, willPublish);
return Mqtt3ConnectView.of(keepAliveSeconds, cleanSession, simpleAuth, willPublish, restrictions);
}

public static class Default extends Mqtt3ConnectViewBuilder<Default> implements Mqtt3ConnectBuilder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.hivemq.client.mqtt.mqtt3.message.Mqtt3MessageType;
import com.hivemq.client.mqtt.mqtt3.message.auth.Mqtt3SimpleAuth;
import com.hivemq.client.mqtt.mqtt3.message.publish.Mqtt3Publish;
import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5ConnectRestrictions;
import org.jetbrains.annotations.NotNull;

import java.util.Optional;
Expand Down Expand Up @@ -67,6 +68,11 @@ public interface Mqtt3Connect extends Mqtt3Message {
*/
boolean isCleanSession();

/**
* @return the restrictions set from the client.
*/
@NotNull Mqtt3ConnectRestrictions getRestrictions();

/**
* @return the optional simple authentication and/or authorization related data of this Connect message.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import com.hivemq.client.mqtt.mqtt3.message.auth.Mqtt3SimpleAuthBuilder;
import com.hivemq.client.mqtt.mqtt3.message.publish.Mqtt3Publish;
import com.hivemq.client.mqtt.mqtt3.message.publish.Mqtt3WillPublishBuilder;
import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5Connect;
import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5ConnectRestrictions;
import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5ConnectRestrictionsBuilder;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand Down Expand Up @@ -63,6 +66,27 @@ public interface Mqtt3ConnectBuilderBase<B extends Mqtt3ConnectBuilderBase<B>> {
@CheckReturnValue
@NotNull B cleanSession(boolean cleanSession);

/**
* Sets the {@link Mqtt3Connect#getRestrictions() restrictions} from the client.
*
* @param restrictions the restrictions from the client.
* @return the builder.
*/
@CheckReturnValue
@NotNull B restrictions(@NotNull Mqtt3ConnectRestrictions restrictions);

/**
* Fluent counterpart of {@link #restrictions(Mqtt3ConnectRestrictions)}.
* <p>
* Calling {@link Mqtt3ConnectRestrictionsBuilder.Nested#applyRestrictions()} on the returned builder has the effect
* of {@link Mqtt3ConnectRestrictions#extend() extending} the current restrictions.
*
* @return the fluent builder for the restrictions.
* @see #restrictions(Mqtt3ConnectRestrictions)
*/
@CheckReturnValue
Mqtt3ConnectRestrictionsBuilder.@NotNull Nested<? extends B> restrictions();

/**
* Sets the optional {@link Mqtt3Connect#getSimpleAuth() simple authentication and/or authorization related data}.
*
Expand Down
Loading