diff --git a/reactor/src/main/java/com/hivemq/client/internal/rx/reactor/operators/FluxWithSingleCombine.java b/reactor/src/main/java/com/hivemq/client/internal/rx/reactor/operators/FluxWithSingleCombine.java index a8aee9b16..0ec15cc81 100644 --- a/reactor/src/main/java/com/hivemq/client/internal/rx/reactor/operators/FluxWithSingleCombine.java +++ b/reactor/src/main/java/com/hivemq/client/internal/rx/reactor/operators/FluxWithSingleCombine.java @@ -30,7 +30,6 @@ import reactor.util.context.Context; import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; /** * @author Silvio Giebl @@ -50,17 +49,17 @@ public void subscribe(final @NotNull CoreSubscriber subscriber) private static class CombineSubscriber implements CoreWithSingleSubscriber, Subscription { + private static final @NotNull Object COMPLETE = new Object(); @SuppressWarnings("rawtypes") - static final @NotNull AtomicLongFieldUpdater REQUESTED = + private static final @NotNull AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(CombineSubscriber.class, "requested"); - @SuppressWarnings("rawtypes") - static final @NotNull AtomicReferenceFieldUpdater QUEUED = - AtomicReferenceFieldUpdater.newUpdater(CombineSubscriber.class, Object.class, "queued"); private final @NotNull CoreSubscriber subscriber; private @Nullable Subscription subscription; - volatile long requested; - volatile @Nullable Object queued; + private volatile long requested; + + private @Nullable Object queued; + private @Nullable Object done; CombineSubscriber(final @NotNull CoreSubscriber subscriber) { this.subscriber = subscriber; @@ -84,30 +83,36 @@ public void onNext(final @NotNull F f) { private void next(final @NotNull Object next) { if (REQUESTED.get(this) == 0) { - QUEUED.set(this, next); - if ((REQUESTED.get(this) != 0) && (QUEUED.getAndSet(this, null)) != null) { - Operators.produced(REQUESTED, this, 1); - subscriber.onNext(next); + synchronized (this) { + if (REQUESTED.get(this) == 0) { + queued = next; + return; + } } - } else { - Operators.produced(REQUESTED, this, 1); - subscriber.onNext(next); } + Operators.produced(REQUESTED, this, 1); + subscriber.onNext(next); } @Override - public void onError(final @NotNull Throwable throwable) { - final Object next = QUEUED.get(this); - if ((next == null) || !QUEUED.compareAndSet(this, next, new TerminalElement(next, throwable))) { - subscriber.onError(throwable); + public void onComplete() { + synchronized (this) { + if (queued != null) { + done = COMPLETE; + } else { + subscriber.onComplete(); + } } } @Override - public void onComplete() { - final Object next = QUEUED.get(this); - if ((next == null) || !QUEUED.compareAndSet(this, next, new TerminalElement(next, null))) { - subscriber.onComplete(); + public void onError(final @NotNull Throwable error) { + synchronized (this) { + if (queued != null) { + done = error; + } else { + subscriber.onError(error); + } } } @@ -115,26 +120,30 @@ public void onComplete() { public void request(long n) { assert subscription != null; if (n > 0) { - if (REQUESTED.get(this) == 0) { - final Object next = QUEUED.getAndSet(this, null); - if (next != null) { - if (next instanceof TerminalElement) { - final TerminalElement terminalElement = (TerminalElement) next; - subscriber.onNext(terminalElement.element); - if (terminalElement.error == null) { - subscriber.onComplete(); - } else { - subscriber.onError(terminalElement.error); - } - return; - } else { - subscriber.onNext(next); + if (Operators.addCap(REQUESTED, this, n) == 0) { + synchronized (this) { + final Object queued = this.queued; + if (queued != null) { + this.queued = null; + Operators.produced(REQUESTED, this, 1); + subscriber.onNext(queued); n--; + final Object done = this.done; + if (done != null) { + this.done = null; + if (done instanceof Throwable) { + subscriber.onError((Throwable) done); + } else { + subscriber.onComplete(); + } + return; + } + } + if (n > 0) { + subscription.request(n); } } - } - if (n > 0) { - Operators.addCap(REQUESTED, this, n); + } else { subscription.request(n); } } @@ -265,15 +274,4 @@ private static class SingleElement { this.element = element; } } - - private static class TerminalElement { - - final @NotNull Object element; - final @Nullable Throwable error; - - TerminalElement(final @NotNull Object element, final @Nullable Throwable error) { - this.element = element; - this.error = error; - } - } } diff --git a/reactor/src/test/java/com/hivemq/client/rx/reactor/FluxWithSingleTest.java b/reactor/src/test/java/com/hivemq/client/rx/reactor/FluxWithSingleTest.java index 860e3242a..cf1e00119 100644 --- a/reactor/src/test/java/com/hivemq/client/rx/reactor/FluxWithSingleTest.java +++ b/reactor/src/test/java/com/hivemq/client/rx/reactor/FluxWithSingleTest.java @@ -253,16 +253,15 @@ void mapBoth_multiple(final @NotNull FluxWithSingle fluxW final AtomicInteger nextCounter = new AtomicInteger(); final AtomicInteger singleCounter = new AtomicInteger(); - fluxWithSingle // - .mapBoth(s -> { - nextCounter.incrementAndGet(); - assertNotEquals("test_thread", Thread.currentThread().getName()); - return s + "-1"; - }, stringBuilder -> { - assertEquals(1, singleCounter.incrementAndGet()); - assertNotEquals("test_thread", Thread.currentThread().getName()); - return stringBuilder.append("-1"); - }).mapBoth(s -> { + fluxWithSingle.mapBoth(s -> { + nextCounter.incrementAndGet(); + assertNotEquals("test_thread", Thread.currentThread().getName()); + return s + "-1"; + }, stringBuilder -> { + assertEquals(1, singleCounter.incrementAndGet()); + assertNotEquals("test_thread", Thread.currentThread().getName()); + return stringBuilder.append("-1"); + }).mapBoth(s -> { nextCounter.incrementAndGet(); assertNotEquals("test_thread", Thread.currentThread().getName()); return s + "-2"; diff --git a/src/main/java/com/hivemq/client/internal/mqtt/MqttBlockingClient.java b/src/main/java/com/hivemq/client/internal/mqtt/MqttBlockingClient.java index d573e10b0..49eee61c3 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/MqttBlockingClient.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/MqttBlockingClient.java @@ -17,6 +17,7 @@ package com.hivemq.client.internal.mqtt; +import com.hivemq.client.internal.mqtt.exceptions.MqttClientStateExceptions; import com.hivemq.client.internal.mqtt.message.connect.MqttConnect; import com.hivemq.client.internal.mqtt.message.disconnect.MqttDisconnect; import com.hivemq.client.internal.mqtt.message.publish.MqttPublish; @@ -96,6 +97,9 @@ public class MqttBlockingClient implements Mqtt5BlockingClient { public @NotNull Mqtt5SubAck subscribe(final @Nullable Mqtt5Subscribe subscribe) { final MqttSubscribe mqttSubscribe = MqttChecks.subscribe(subscribe); try { + if (!getState().isConnectedOrReconnect()) { + throw MqttClientStateExceptions.notConnected(); + } return handleSubAck(delegate.subscribeUnsafe(mqttSubscribe).blockingGet()); } catch (final RuntimeException e) { throw AsyncRuntimeException.fillInStackTrace(e); @@ -113,6 +117,9 @@ public class MqttBlockingClient implements Mqtt5BlockingClient { public @NotNull Mqtt5UnsubAck unsubscribe(final @Nullable Mqtt5Unsubscribe unsubscribe) { final MqttUnsubscribe mqttUnsubscribe = MqttChecks.unsubscribe(unsubscribe); try { + if (!getState().isConnectedOrReconnect()) { + throw MqttClientStateExceptions.notConnected(); + } return handleUnsubAck(delegate.unsubscribeUnsafe(mqttUnsubscribe).blockingGet()); } catch (final RuntimeException e) { throw AsyncRuntimeException.fillInStackTrace(e); diff --git a/src/main/java/com/hivemq/client/internal/mqtt/MqttClientConfig.java b/src/main/java/com/hivemq/client/internal/mqtt/MqttClientConfig.java index ca5c49d03..3ddc760cd 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/MqttClientConfig.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/MqttClientConfig.java @@ -69,6 +69,8 @@ public class MqttClientConfig implements Mqtt5ClientConfig { private volatile @Nullable MqttClientConnectionConfig connectionConfig; private @NotNull MqttClientTransportConfigImpl currentTransportConfig; private @Nullable SslContext currentSslContext; + private boolean resubscribeIfSessionExpired; + private boolean republishIfSessionExpired; public MqttClientConfig( final @NotNull MqttVersion mqttVersion, final @NotNull MqttClientIdentifierImpl clientIdentifier, @@ -189,7 +191,7 @@ public void releaseEventLoop() { if (--eventLoopAcquires == 0) { final EventLoop eventLoop = this.eventLoop; final long eventLoopAcquireCount = this.eventLoopAcquireCount; - assert eventLoop != null; + assert eventLoop != null : "eventLoopAcquires was > 0 -> eventLoop != null"; eventLoop.execute(() -> { // release eventLoop after all tasks are finished synchronized (state) { if (eventLoopAcquireCount == this.eventLoopAcquireCount) { // eventLoop has not been reacquired @@ -251,6 +253,22 @@ public void setCurrentSslContext(final @Nullable SslContext currentSslContext) { this.currentSslContext = currentSslContext; } + public boolean isResubscribeIfSessionExpired() { + return resubscribeIfSessionExpired; + } + + public void setResubscribeIfSessionExpired(final boolean resubscribeIfSessionExpired) { + this.resubscribeIfSessionExpired = resubscribeIfSessionExpired; + } + + public boolean isRepublishIfSessionExpired() { + return republishIfSessionExpired; + } + + public void setRepublishIfSessionExpired(final boolean republishIfSessionExpired) { + this.republishIfSessionExpired = republishIfSessionExpired; + } + public static class ConnectDefaults { private static final @NotNull ConnectDefaults EMPTY = new ConnectDefaults(null, null, null); diff --git a/src/main/java/com/hivemq/client/internal/mqtt/codec/decoder/mqtt5/Mqtt5ConnAckDecoder.java b/src/main/java/com/hivemq/client/internal/mqtt/codec/decoder/mqtt5/Mqtt5ConnAckDecoder.java index 6ec217497..b1271bc77 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/codec/decoder/mqtt5/Mqtt5ConnAckDecoder.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/codec/decoder/mqtt5/Mqtt5ConnAckDecoder.java @@ -167,7 +167,7 @@ public class Mqtt5ConnAckDecoder implements MqttMessageDecoder { throw new MqttDecoderException(Mqtt5DisconnectReasonCode.PROTOCOL_ERROR, "wrong maximum Qos"); } maximumQos = MqttQos.fromCode(maximumQosCode); - assert maximumQos != null; + assert maximumQos != null : "maximumQosCode = 0 or = 1"; maximumQosPresent = true; restrictionsPresent |= maximumQos != DEFAULT_MAXIMUM_QOS; break; diff --git a/src/main/java/com/hivemq/client/internal/mqtt/codec/encoder/mqtt5/Mqtt5SubscribeEncoder.java b/src/main/java/com/hivemq/client/internal/mqtt/codec/encoder/mqtt5/Mqtt5SubscribeEncoder.java index 81454c96d..e6390ce21 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/codec/encoder/mqtt5/Mqtt5SubscribeEncoder.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/codec/encoder/mqtt5/Mqtt5SubscribeEncoder.java @@ -109,18 +109,7 @@ private void encodePayload(final @NotNull MqttStatefulSubscribe message, final @ final MqttSubscription subscription = subscriptions.get(i); subscription.getTopicFilter().encode(out); - - int subscriptionOptions = 0; - subscriptionOptions |= subscription.getRetainHandling().getCode() << 4; - if (subscription.isRetainAsPublished()) { - subscriptionOptions |= 0b0000_1000; - } - if (subscription.isNoLocal()) { - subscriptionOptions |= 0b0000_0100; - } - subscriptionOptions |= subscription.getQos().getCode(); - - out.writeByte(subscriptionOptions); + out.writeByte(subscription.encodeSubscriptionOptions()); } } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicFilterImpl.java b/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicFilterImpl.java index cb1682b11..1a26d1e26 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicFilterImpl.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicFilterImpl.java @@ -28,6 +28,8 @@ import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +import java.util.Arrays; + /** * @author Silvio Giebl * @see MqttTopicFilter @@ -286,6 +288,11 @@ int getFilterByteStart() { return toString(); } + public @Nullable byte[] getPrefix() { + final int filterByteStart = getFilterByteStart(); + return (filterByteStart == 0) ? null : Arrays.copyOfRange(toBinary(), 0, filterByteStart - 1); + } + @Override public boolean matches(final @Nullable MqttTopic topic) { return matches(MqttChecks.topic(topic)); diff --git a/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicIterator.java b/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicIterator.java index 83864eff4..1be2048c8 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicIterator.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicIterator.java @@ -18,8 +18,6 @@ package com.hivemq.client.internal.mqtt.datatypes; import com.hivemq.client.internal.util.ByteArrayUtil; -import com.hivemq.client.mqtt.datatypes.MqttTopic; -import com.hivemq.client.mqtt.datatypes.MqttTopicFilter; import org.jetbrains.annotations.NotNull; import java.util.Arrays; @@ -84,7 +82,7 @@ public boolean hasMultiLevelWildcard() { throw new NoSuchElementException(); } start = end + 1; - end = ByteArrayUtil.indexOf(array, start, (byte) MqttTopic.TOPIC_LEVEL_SEPARATOR); + end = ByteArrayUtil.indexOf(array, start, (byte) MqttTopicImpl.TOPIC_LEVEL_SEPARATOR); return this; } @@ -103,7 +101,7 @@ public boolean forwardIfEqual(final @NotNull MqttTopicLevels levels) { final byte[] levelsArray = levels.getArray(); final int levelsEnd = levels.getEnd(); final int to = end + levelsArray.length - levelsEnd; - if ((to <= allEnd) && ((to == allEnd) || (array[to] == MqttTopic.TOPIC_LEVEL_SEPARATOR)) && + if ((to <= allEnd) && ((to == allEnd) || (array[to] == MqttTopicImpl.TOPIC_LEVEL_SEPARATOR)) && ByteArrayUtil.equals(array, end + 1, to, levelsArray, levelsEnd + 1, levelsArray.length)) { start = end = to; return true; @@ -171,7 +169,7 @@ public boolean forwardIfMatch(final @NotNull MqttTopicLevels levels) { if (array[index] == lb) { index++; levelsIndex++; - } else if (lb == MqttTopicFilter.SINGLE_LEVEL_WILDCARD) { + } else if (lb == MqttTopicFilterImpl.SINGLE_LEVEL_WILDCARD) { while ((index < allEnd) && (array[index] != MqttTopicImpl.TOPIC_LEVEL_SEPARATOR)) { index++; } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevel.java b/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevel.java index def52c448..59f4902c6 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevel.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevel.java @@ -18,8 +18,8 @@ package com.hivemq.client.internal.mqtt.datatypes; import com.hivemq.client.internal.util.ByteArray; -import com.hivemq.client.mqtt.datatypes.MqttTopicFilter; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; import java.util.Arrays; @@ -32,7 +32,7 @@ public class MqttTopicLevel extends ByteArray { private static final @NotNull MqttTopicLevel SINGLE_LEVEL_WILDCARD = - new MqttTopicLevel(new byte[]{MqttTopicFilter.SINGLE_LEVEL_WILDCARD}); + new MqttTopicLevel(new byte[]{MqttTopicFilterImpl.SINGLE_LEVEL_WILDCARD}); static @NotNull MqttTopicLevel of(final @NotNull byte[] array, final int start, final int end) { if (isSingleLevelWildcard(array, start, end)) { @@ -42,7 +42,7 @@ public class MqttTopicLevel extends ByteArray { } private static boolean isSingleLevelWildcard(final @NotNull byte[] array, final int start, final int end) { - return ((end - start) == 1) && (array[start] == MqttTopicFilter.SINGLE_LEVEL_WILDCARD); + return ((end - start) == 1) && (array[start] == MqttTopicFilterImpl.SINGLE_LEVEL_WILDCARD); } MqttTopicLevel(final @NotNull byte[] array) { @@ -60,4 +60,28 @@ public boolean isSingleLevelWildcard() { public @NotNull MqttTopicLevel trim() { return this; } + + public @Nullable MqttTopicFilterImpl toFilter(final @Nullable byte[] prefix, final boolean multiLevelWildcard) { + final byte[] bytes; + if (prefix != null) { + if (multiLevelWildcard) { + bytes = new byte[prefix.length + 1 + array.length + 2]; + bytes[bytes.length - 2] = MqttTopicImpl.TOPIC_LEVEL_SEPARATOR; + bytes[bytes.length - 1] = MqttTopicFilterImpl.MULTI_LEVEL_WILDCARD; + } else { + bytes = new byte[prefix.length + 1 + array.length]; + } + System.arraycopy(prefix, 0, bytes, 0, prefix.length); + bytes[prefix.length] = MqttTopicImpl.TOPIC_LEVEL_SEPARATOR; + System.arraycopy(array, 0, bytes, prefix.length + 1, array.length); + } else if (multiLevelWildcard) { + bytes = new byte[array.length + 2]; + System.arraycopy(array, 0, bytes, 0, array.length); + bytes[bytes.length - 2] = MqttTopicImpl.TOPIC_LEVEL_SEPARATOR; + bytes[bytes.length - 1] = MqttTopicFilterImpl.MULTI_LEVEL_WILDCARD; + } else { + bytes = array; + } + return MqttTopicFilterImpl.of(bytes); + } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevels.java b/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevels.java index 586f15a16..8180e2809 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevels.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevels.java @@ -18,7 +18,6 @@ package com.hivemq.client.internal.mqtt.datatypes; import com.hivemq.client.internal.util.ByteArrayUtil; -import com.hivemq.client.mqtt.datatypes.MqttTopic; import org.jetbrains.annotations.NotNull; import java.util.Arrays; @@ -40,7 +39,7 @@ public class MqttTopicLevels extends MqttTopicLevel { final byte[] array2 = level2.trim().getArray(); final byte[] array = new byte[array1.length + 1 + array2.length]; System.arraycopy(array1, 0, array, 0, array1.length); - array[array1.length] = MqttTopic.TOPIC_LEVEL_SEPARATOR; + array[array1.length] = MqttTopicImpl.TOPIC_LEVEL_SEPARATOR; System.arraycopy(array2, 0, array, array1.length + 1, array2.length); return new MqttTopicLevels(array, level1.length()); } @@ -61,7 +60,7 @@ protected int getEnd() { if (index == array.length) { return this; } - assert array[index] == MqttTopic.TOPIC_LEVEL_SEPARATOR; + assert array[index] == MqttTopicImpl.TOPIC_LEVEL_SEPARATOR : "topic levels must only be split on /"; if (index == firstEnd) { return MqttTopicLevel.of(array, 0, firstEnd); } @@ -69,9 +68,9 @@ protected int getEnd() { } public @NotNull MqttTopicLevel after(final int index) { - assert array[index] == MqttTopic.TOPIC_LEVEL_SEPARATOR; + assert array[index] == MqttTopicImpl.TOPIC_LEVEL_SEPARATOR : "topic levels must only be split on /"; final int start = index + 1; - final int end = ByteArrayUtil.indexOf(array, start, (byte) MqttTopic.TOPIC_LEVEL_SEPARATOR); + final int end = ByteArrayUtil.indexOf(array, start, (byte) MqttTopicImpl.TOPIC_LEVEL_SEPARATOR); if (end == array.length) { return MqttTopicLevel.of(array, start, array.length); } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/connect/MqttConnAckFlow.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/connect/MqttConnAckFlow.java index 3b02595e3..ad190cd4c 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/connect/MqttConnAckFlow.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/connect/MqttConnAckFlow.java @@ -60,15 +60,15 @@ boolean setDone() { return true; } - void onSuccess(final @NotNull Mqtt5ConnAck t) { + void onSuccess(final @NotNull Mqtt5ConnAck connAck) { if (observer != null) { - observer.onSuccess(t); + observer.onSuccess(connAck); } } - void onError(final @NotNull Throwable t) { + void onError(final @NotNull Throwable error) { if (observer != null) { - observer.onError(t); + observer.onError(error); } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/connect/MqttConnAckSingle.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/connect/MqttConnAckSingle.java index e5e403190..e256c00da 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/connect/MqttConnAckSingle.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/connect/MqttConnAckSingle.java @@ -162,6 +162,9 @@ private static void reconnect( } }); }, reconnector.getDelay(TimeUnit.NANOSECONDS), TimeUnit.NANOSECONDS); + clientConfig.setResubscribeIfSessionExpired(reconnector.isResubscribeIfSessionExpired()); + clientConfig.setRepublishIfSessionExpired(reconnector.isRepublishIfSessionExpired()); + reconnector.afterOnDisconnected(); } else { clientConfig.getRawState().set(DISCONNECTED); clientConfig.releaseEventLoop(); diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttGlobalIncomingPublishFlow.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttGlobalIncomingPublishFlow.java index f47dca586..2a322773a 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttGlobalIncomingPublishFlow.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttGlobalIncomingPublishFlow.java @@ -28,7 +28,7 @@ /** * @author Silvio Giebl */ -class MqttGlobalIncomingPublishFlow extends MqttIncomingPublishFlow { +public class MqttGlobalIncomingPublishFlow extends MqttIncomingPublishFlow { private final @NotNull MqttGlobalPublishFilter filter; private @Nullable Handle handle; @@ -43,7 +43,7 @@ class MqttGlobalIncomingPublishFlow extends MqttIncomingPublishFlow { @Override void runCancel() { - incomingQosHandler.getIncomingPublishFlows().cancelGlobal(this); + incomingPublishService.incomingPublishFlows.cancelGlobal(this); super.runCancel(); } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttGlobalIncomingPublishFlowable.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttGlobalIncomingPublishFlowable.java index bc813de58..9d2b03569 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttGlobalIncomingPublishFlowable.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttGlobalIncomingPublishFlowable.java @@ -18,6 +18,7 @@ package com.hivemq.client.internal.mqtt.handler.publish.incoming; import com.hivemq.client.internal.mqtt.MqttClientConfig; +import com.hivemq.client.internal.mqtt.handler.subscribe.MqttSubscriptionHandler; import com.hivemq.client.internal.mqtt.ioc.ClientComponent; import com.hivemq.client.mqtt.MqttGlobalPublishFilter; import com.hivemq.client.mqtt.mqtt5.message.publish.Mqtt5Publish; @@ -44,15 +45,11 @@ public MqttGlobalIncomingPublishFlowable( protected void subscribeActual(final @NotNull Subscriber subscriber) { final ClientComponent clientComponent = clientConfig.getClientComponent(); final MqttIncomingQosHandler incomingQosHandler = clientComponent.incomingQosHandler(); - final MqttIncomingPublishFlows incomingPublishFlows = incomingQosHandler.getIncomingPublishFlows(); + final MqttSubscriptionHandler subscriptionHandler = clientComponent.subscriptionHandler(); final MqttGlobalIncomingPublishFlow flow = new MqttGlobalIncomingPublishFlow(subscriber, clientConfig, incomingQosHandler, filter); subscriber.onSubscribe(flow); - flow.getEventLoop().execute(() -> { - if (flow.init()) { - incomingPublishFlows.subscribeGlobal(flow); - } - }); + subscriptionHandler.subscribeGlobal(flow); } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlow.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlow.java index 924d7bfbd..4974b73bf 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlow.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlow.java @@ -35,7 +35,7 @@ /** * @author Silvio Giebl */ -public abstract class MqttIncomingPublishFlow extends FlowWithEventLoop +abstract class MqttIncomingPublishFlow extends FlowWithEventLoop implements Emitter, Subscription, Runnable { private static final int STATE_NO_NEW_REQUESTS = 0; @@ -43,7 +43,7 @@ public abstract class MqttIncomingPublishFlow extends FlowWithEventLoop private static final int STATE_BLOCKED = 2; final @NotNull Subscriber subscriber; - final @NotNull MqttIncomingQosHandler incomingQosHandler; + final @NotNull MqttIncomingPublishService incomingPublishService; private long requested; private final @NotNull AtomicLong newRequested = new AtomicLong(); @@ -62,7 +62,7 @@ public abstract class MqttIncomingPublishFlow extends FlowWithEventLoop super(clientConfig); this.subscriber = subscriber; - this.incomingQosHandler = incomingQosHandler; + incomingPublishService = incomingQosHandler.incomingPublishService; } @CallByThread("Netty EventLoop") @@ -84,25 +84,26 @@ public void onComplete() { if ((referenced == 0) && setDone()) { subscriber.onComplete(); } else { - incomingQosHandler.getIncomingPublishService().drain(); + incomingPublishService.drain(); } } @CallByThread("Netty EventLoop") @Override - public void onError(final @NotNull Throwable t) { + public void onError(final @NotNull Throwable error) { if (done) { - if (t != error) { - RxJavaPlugins.onError(t); + // multiple calls with the same error are expected if flow was subscribed with multiple topic filters + if (error != this.error) { + RxJavaPlugins.onError(error); } return; } - error = t; + this.error = error; done = true; if ((referenced == 0) && setDone()) { - subscriber.onError(t); + subscriber.onError(error); } else { - incomingQosHandler.getIncomingPublishService().drain(); + incomingPublishService.drain(); } } @@ -134,7 +135,7 @@ public void request(final long n) { @Override public void run() { // only executed if was blocking if (referenced > 0) { // is blocking - incomingQosHandler.getIncomingPublishService().drain(); + incomingPublishService.drain(); } } @@ -176,7 +177,7 @@ protected void onCancel() { @CallByThread("Netty EventLoop") void runCancel() { // always executed if cancelled if (referenced > 0) { // is blocking - incomingQosHandler.getIncomingPublishService().drain(); + incomingPublishService.drain(); } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlows.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlows.java index 8dcdfc43c..f907a1655 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlows.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlows.java @@ -21,11 +21,9 @@ import com.hivemq.client.internal.mqtt.datatypes.MqttTopicFilterImpl; import com.hivemq.client.internal.mqtt.ioc.ClientScope; import com.hivemq.client.internal.mqtt.message.publish.MqttStatefulPublish; -import com.hivemq.client.internal.mqtt.message.subscribe.MqttStatefulSubscribe; +import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscribe; import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscription; -import com.hivemq.client.internal.mqtt.message.subscribe.suback.MqttSubAck; -import com.hivemq.client.internal.mqtt.message.unsubscribe.MqttStatefulUnsubscribe; -import com.hivemq.client.internal.mqtt.message.unsubscribe.unsuback.MqttUnsubAck; +import com.hivemq.client.internal.mqtt.message.unsubscribe.MqttUnsubscribe; import com.hivemq.client.internal.mqtt.message.unsubscribe.unsuback.mqtt3.Mqtt3UnsubAckView; import com.hivemq.client.internal.util.collections.HandleList; import com.hivemq.client.internal.util.collections.HandleList.Handle; @@ -37,6 +35,8 @@ import org.jetbrains.annotations.Nullable; import javax.inject.Inject; +import java.util.List; +import java.util.Map; /** * @author Silvio Giebl @@ -45,65 +45,76 @@ @NotThreadSafe public class MqttIncomingPublishFlows { - private final @NotNull MqttSubscriptionFlows subscriptionFlows; + private final @NotNull MqttSubscribedPublishFlows subscribedFlows; private final @Nullable HandleList @NotNull [] globalFlows; @Inject - MqttIncomingPublishFlows(final @NotNull MqttSubscriptionFlows subscriptionFlows) { - this.subscriptionFlows = subscriptionFlows; + MqttIncomingPublishFlows() { + subscribedFlows = new MqttSubscribedPublishFlowTree(); //noinspection unchecked globalFlows = new HandleList[MqttGlobalPublishFilter.values().length]; } public void subscribe( - final @NotNull MqttStatefulSubscribe subscribe, final @Nullable MqttSubscribedPublishFlow flow) { + final @NotNull MqttSubscribe subscribe, final int subscriptionIdentifier, + final @Nullable MqttSubscribedPublishFlow flow) { - final ImmutableList subscriptions = subscribe.stateless().getSubscriptions(); + final ImmutableList subscriptions = subscribe.getSubscriptions(); //noinspection ForLoopReplaceableByForEach for (int i = 0; i < subscriptions.size(); i++) { - subscribe(subscriptions.get(i).getTopicFilter(), flow); + subscribedFlows.subscribe(subscriptions.get(i), subscriptionIdentifier, flow); } } - void subscribe(final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) { - subscriptionFlows.subscribe(topicFilter, flow); - } - public void subAck( - final @NotNull MqttStatefulSubscribe subscribe, final @NotNull MqttSubAck subAck, - final @Nullable MqttSubscribedPublishFlow flow) { + final @NotNull MqttSubscribe subscribe, final int subscriptionIdentifier, + final @NotNull ImmutableList reasonCodes) { - final ImmutableList subscriptions = subscribe.stateless().getSubscriptions(); - final ImmutableList reasonCodes = subAck.getReasonCodes(); + final ImmutableList subscriptions = subscribe.getSubscriptions(); final boolean countNotMatching = subscriptions.size() > reasonCodes.size(); for (int i = 0; i < subscriptions.size(); i++) { - if (countNotMatching || reasonCodes.get(i).isError()) { - remove(subscriptions.get(i).getTopicFilter(), flow); - } + subscribedFlows.suback(subscriptions.get(i).getTopicFilter(), subscriptionIdentifier, + countNotMatching || reasonCodes.get(i).isError()); } } - void remove(final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) { - subscriptionFlows.remove(topicFilter, flow); - } + public void unsubscribe( + final @NotNull MqttUnsubscribe unsubscribe, + final @NotNull ImmutableList reasonCodes) { - public void unsubscribe(final @NotNull MqttStatefulUnsubscribe unsubscribe, final @NotNull MqttUnsubAck unsubAck) { - final ImmutableList topicFilters = unsubscribe.stateless().getTopicFilters(); - final ImmutableList reasonCodes = unsubAck.getReasonCodes(); + final ImmutableList topicFilters = unsubscribe.getTopicFilters(); final boolean allSuccess = reasonCodes == Mqtt3UnsubAckView.REASON_CODES_ALL_SUCCESS; for (int i = 0; i < topicFilters.size(); i++) { if (allSuccess || !reasonCodes.get(i).isError()) { - unsubscribe(topicFilters.get(i)); + subscribedFlows.unsubscribe(topicFilters.get(i)); } } } - void unsubscribe(final @NotNull MqttTopicFilterImpl topicFilter) { - subscriptionFlows.unsubscribe(topicFilter, null); + void cancel(final @NotNull MqttSubscribedPublishFlow flow) { + subscribedFlows.cancel(flow); } - void cancel(final @NotNull MqttSubscribedPublishFlow flow) { - subscriptionFlows.cancel(flow); + public void subscribeGlobal(final @NotNull MqttGlobalIncomingPublishFlow flow) { + final int filter = flow.getFilter().ordinal(); + HandleList globalFlowsForFilter = globalFlows[filter]; + if (globalFlowsForFilter == null) { + globalFlowsForFilter = new HandleList<>(); + globalFlows[filter] = globalFlowsForFilter; + } + flow.setHandle(globalFlowsForFilter.add(flow)); + } + + void cancelGlobal(final @NotNull MqttGlobalIncomingPublishFlow flow) { + final int filter = flow.getFilter().ordinal(); + final HandleList globalFlowsForFilter = globalFlows[filter]; + final Handle handle = flow.getHandle(); + if ((globalFlowsForFilter != null) && (handle != null)) { + globalFlowsForFilter.remove(handle); + if (globalFlowsForFilter.isEmpty()) { + globalFlows[filter] = null; + } + } } @NotNull HandleList findMatching(final @NotNull MqttStatefulPublish publish) { @@ -115,7 +126,7 @@ void cancel(final @NotNull MqttSubscribedPublishFlow flow) { void findMatching( final @NotNull MqttStatefulPublish publish, final @NotNull MqttMatchingPublishFlows matchingFlows) { - subscriptionFlows.findMatching(publish.stateless().getTopic(), matchingFlows); + subscribedFlows.findMatching(publish.stateless().getTopic(), matchingFlows); if (matchingFlows.subscriptionFound) { add(matchingFlows, globalFlows[MqttGlobalPublishFilter.SUBSCRIBED.ordinal()]); } else { @@ -127,30 +138,19 @@ void findMatching( } } - void subscribeGlobal(final @NotNull MqttGlobalIncomingPublishFlow flow) { - final int filter = flow.getFilter().ordinal(); - HandleList globalFlow = globalFlows[filter]; - if (globalFlow == null) { - globalFlow = new HandleList<>(); - globalFlows[filter] = globalFlow; - } - flow.setHandle(globalFlow.add(flow)); - } + private static void add( + final @NotNull MqttMatchingPublishFlows matchingPublishFlows, + final @Nullable HandleList globalFlows) { - void cancelGlobal(final @NotNull MqttGlobalIncomingPublishFlow flow) { - final int filter = flow.getFilter().ordinal(); - final HandleList globalFlow = globalFlows[filter]; - assert globalFlow != null; - final Handle handle = flow.getHandle(); - assert handle != null; - globalFlow.remove(handle); - if (globalFlow.isEmpty()) { - globalFlows[filter] = null; + if (globalFlows != null) { + for (Handle h = globalFlows.getFirst(); h != null; h = h.getNext()) { + matchingPublishFlows.add(h.getElement()); + } } } - void clear(final @NotNull Throwable cause) { - subscriptionFlows.clear(cause); + public void clear(final @NotNull Throwable cause) { + subscribedFlows.clear(cause); for (int i = 0; i < globalFlows.length; i++) { final HandleList globalFlow = globalFlows[i]; if (globalFlow != null) { @@ -162,14 +162,7 @@ void clear(final @NotNull Throwable cause) { } } - private static void add( - final @NotNull HandleList target, - final @Nullable HandleList source) { - - if (source != null) { - for (Handle h = source.getFirst(); h != null; h = h.getNext()) { - target.add(h.getElement()); - } - } + public @NotNull Map<@NotNull Integer, @NotNull List<@NotNull MqttSubscription>> getSubscriptions() { + return subscribedFlows.getSubscriptions(); } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlowsWithId.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlowsWithId.java deleted file mode 100644 index 003aba51c..000000000 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishFlowsWithId.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright 2018 dc-square and the HiveMQ MQTT Client Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package com.hivemq.client.internal.mqtt.handler.publish.incoming; - -import com.hivemq.client.internal.annotations.NotThreadSafe; -import com.hivemq.client.internal.mqtt.datatypes.MqttTopicFilterImpl; -import com.hivemq.client.internal.mqtt.ioc.ClientScope; -import com.hivemq.client.internal.mqtt.message.publish.MqttStatefulPublish; -import com.hivemq.client.internal.mqtt.message.subscribe.MqttStatefulSubscribe; -import com.hivemq.client.internal.mqtt.message.subscribe.suback.MqttSubAck; -import com.hivemq.client.internal.util.collections.ImmutableIntList; -import com.hivemq.client.internal.util.collections.IntIndex; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; - -import javax.inject.Inject; - -import static com.hivemq.client.internal.mqtt.message.subscribe.MqttStatefulSubscribe.DEFAULT_NO_SUBSCRIPTION_IDENTIFIER; - -/** - * @author Silvio Giebl - */ -@ClientScope -@NotThreadSafe -public class MqttIncomingPublishFlowsWithId extends MqttIncomingPublishFlows { - - private static final @NotNull IntIndex.Spec INDEX_SPEC = - new IntIndex.Spec<>(MqttSubscribedPublishFlow::getSubscriptionIdentifier); - - private final @NotNull IntIndex flowsWithIdsIndex = new IntIndex<>(INDEX_SPEC); - private final @NotNull MqttSubscriptionFlows flowsWithIds; - - @Inject - MqttIncomingPublishFlowsWithId( - final @NotNull MqttSubscriptionFlows flowsWithoutIds, final @NotNull MqttSubscriptionFlows flowsWithIds) { - - super(flowsWithoutIds); - this.flowsWithIds = flowsWithIds; - } - - @Override - public void subscribe( - final @NotNull MqttStatefulSubscribe subscribe, final @Nullable MqttSubscribedPublishFlow flow) { - - if (flow != null) { - final int subscriptionIdentifier = subscribe.getSubscriptionIdentifier(); - if (subscriptionIdentifier != DEFAULT_NO_SUBSCRIPTION_IDENTIFIER) { - flow.setSubscriptionIdentifier(subscriptionIdentifier); - flowsWithIdsIndex.put(flow); - } - } - super.subscribe(subscribe, flow); - } - - @Override - void subscribe(final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) { - if ((flow != null) && (flow.getSubscriptionIdentifier() != DEFAULT_NO_SUBSCRIPTION_IDENTIFIER)) { - flowsWithIds.subscribe(topicFilter, flow); - } else { - super.subscribe(topicFilter, flow); - } - } - - @Override - public void subAck( - final @NotNull MqttStatefulSubscribe subscribe, final @NotNull MqttSubAck subAck, - final @Nullable MqttSubscribedPublishFlow flow) { - - super.subAck(subscribe, subAck, flow); - if (flow != null) { - final int subscriptionIdentifier = subscribe.getSubscriptionIdentifier(); - if ((subscriptionIdentifier != DEFAULT_NO_SUBSCRIPTION_IDENTIFIER) && flow.getTopicFilters().isEmpty()) { - flowsWithIdsIndex.remove(subscriptionIdentifier); - } - } - } - - @Override - void remove(final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) { - if ((flow != null) && (flow.getSubscriptionIdentifier() != DEFAULT_NO_SUBSCRIPTION_IDENTIFIER)) { - flowsWithIds.remove(topicFilter, flow); - } else { - super.remove(topicFilter, flow); - } - } - - @Override - void unsubscribe(final @NotNull MqttTopicFilterImpl topicFilter) { - flowsWithIds.unsubscribe(topicFilter, this::unsubscribed); - super.unsubscribe(topicFilter); - } - - private void unsubscribed(final @NotNull MqttSubscribedPublishFlow flow) { - flowsWithIdsIndex.remove(flow.getSubscriptionIdentifier()); - } - - @Override - void cancel(final @NotNull MqttSubscribedPublishFlow flow) { - final int subscriptionIdentifier = flow.getSubscriptionIdentifier(); - if (subscriptionIdentifier != DEFAULT_NO_SUBSCRIPTION_IDENTIFIER) { - flowsWithIdsIndex.remove(subscriptionIdentifier); - flowsWithIds.cancel(flow); - } else { - super.cancel(flow); - } - } - - @Override - void findMatching( - final @NotNull MqttStatefulPublish publish, final @NotNull MqttMatchingPublishFlows matchingFlows) { - - final ImmutableIntList subscriptionIdentifiers = publish.getSubscriptionIdentifiers(); - if (!subscriptionIdentifiers.isEmpty()) { - for (int i = 0; i < subscriptionIdentifiers.size(); i++) { - final MqttSubscribedPublishFlow flow = flowsWithIdsIndex.get(subscriptionIdentifiers.get(i)); - if (flow != null) { - matchingFlows.add(flow); - } - } - if (matchingFlows.isEmpty()) { - flowsWithIds.findMatching(publish.stateless().getTopic(), matchingFlows); - } else { - matchingFlows.subscriptionFound = true; - } - } - super.findMatching(publish, matchingFlows); - } - - @Override - public void clear(final @NotNull Throwable cause) { - flowsWithIdsIndex.clear(); - flowsWithIds.clear(cause); - super.clear(cause); - } -} diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishService.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishService.java index 85aaba86d..c0f00f03a 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishService.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingPublishService.java @@ -39,6 +39,7 @@ class MqttIncomingPublishService { private static final boolean QOS_0_DROP_OLDEST = true; // TODO configurable private final @NotNull MqttIncomingQosHandler incomingQosHandler; + final @NotNull MqttIncomingPublishFlows incomingPublishFlows; private final @NotNull ChunkedArrayQueue qos0Queue = new ChunkedArrayQueue<>(32); private final @NotNull ChunkedArrayQueue.Iterator qos0It = qos0Queue.iterator(); @@ -49,8 +50,12 @@ class MqttIncomingPublishService { private int runIndex; private int blockingFlowCount; - MqttIncomingPublishService(final @NotNull MqttIncomingQosHandler incomingQosHandler) { + MqttIncomingPublishService( + final @NotNull MqttIncomingQosHandler incomingQosHandler, + final @NotNull MqttIncomingPublishFlows incomingPublishFlows) { + this.incomingQosHandler = incomingQosHandler; + this.incomingPublishFlows = incomingPublishFlows; } @CallByThread("Netty EventLoop") @@ -96,8 +101,7 @@ boolean onPublishQos1Or2(final @NotNull MqttStatefulPublish publish, final int r @CallByThread("Netty EventLoop") private @NotNull HandleList onPublish(final @NotNull MqttStatefulPublish publish) { - final HandleList flows = - incomingQosHandler.getIncomingPublishFlows().findMatching(publish); + final HandleList flows = incomingPublishFlows.findMatching(publish); if (flows.isEmpty()) { LOGGER.warn("No publish flow registered for {}.", publish); } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingQosHandler.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingQosHandler.java index bcd4ba618..0cde1dee7 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingQosHandler.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttIncomingQosHandler.java @@ -59,12 +59,13 @@ public class MqttIncomingQosHandler extends MqttSessionAwareHandler new IntIndex.Spec<>(MqttMessage.WithId::getPacketIdentifier); private final @NotNull MqttClientConfig clientConfig; - private final @NotNull MqttIncomingPublishFlows incomingPublishFlows; - private final @NotNull MqttIncomingPublishService incomingPublishService; + final @NotNull MqttIncomingPublishService incomingPublishService; + // valid for session private final @NotNull IntIndex messages = new IntIndex<>(INDEX_SPEC); // contains StatefulPublish with AT_LEAST_ONCE/EXACTLY_ONCE, MqttPubAck or MqttPubRec + // valid for connection private int receiveMaximum; @Inject @@ -73,16 +74,15 @@ public class MqttIncomingQosHandler extends MqttSessionAwareHandler final @NotNull MqttIncomingPublishFlows incomingPublishFlows) { this.clientConfig = clientConfig; - this.incomingPublishFlows = incomingPublishFlows; - incomingPublishService = new MqttIncomingPublishService(this); + incomingPublishService = new MqttIncomingPublishService(this, incomingPublishFlows); } @Override public void onSessionStartOrResume( final @NotNull MqttClientConnectionConfig connectionConfig, final @NotNull EventLoop eventLoop) { - super.onSessionStartOrResume(connectionConfig, eventLoop); receiveMaximum = connectionConfig.getReceiveMaximum(); + super.onSessionStartOrResume(connectionConfig, eventLoop); } @Override @@ -237,7 +237,6 @@ private void writePubComp(final @NotNull ChannelHandlerContext ctx, final @NotNu @Override public void onSessionEnd(final @NotNull Throwable cause) { super.onSessionEnd(cause); - incomingPublishFlows.clear(cause); messages.clear(); } @@ -273,12 +272,4 @@ public void onSessionEnd(final @NotNull Throwable cause) { } return pubCompBuilder.build(); } - - @NotNull MqttIncomingPublishFlows getIncomingPublishFlows() { - return incomingPublishFlows; - } - - @NotNull MqttIncomingPublishService getIncomingPublishService() { - return incomingPublishService; - } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlow.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlow.java index 458eaa50b..ccea014a8 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlow.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlow.java @@ -20,7 +20,6 @@ import com.hivemq.client.internal.mqtt.MqttClientConfig; import com.hivemq.client.internal.mqtt.datatypes.MqttTopicFilterImpl; import com.hivemq.client.internal.mqtt.handler.subscribe.MqttSubscriptionFlow; -import com.hivemq.client.internal.mqtt.message.subscribe.MqttStatefulSubscribe; import com.hivemq.client.internal.mqtt.message.subscribe.suback.MqttSubAck; import com.hivemq.client.internal.util.collections.HandleList; import com.hivemq.client.mqtt.mqtt5.message.publish.Mqtt5Publish; @@ -35,7 +34,6 @@ public class MqttSubscribedPublishFlow extends MqttIncomingPublishFlow implements MqttSubscriptionFlow { private final @NotNull HandleList topicFilters; - private int subscriptionIdentifier = MqttStatefulSubscribe.DEFAULT_NO_SUBSCRIPTION_IDENTIFIER; MqttSubscribedPublishFlow( final @NotNull Subscriber subscriber, final @NotNull MqttClientConfig clientConfig, @@ -55,19 +53,11 @@ public void onSuccess(final @NotNull MqttSubAck subAck) { @Override void runCancel() { - incomingQosHandler.getIncomingPublishFlows().cancel(this); + incomingPublishService.incomingPublishFlows.cancel(this); super.runCancel(); } @NotNull HandleList getTopicFilters() { return topicFilters; } - - int getSubscriptionIdentifier() { - return subscriptionIdentifier; - } - - void setSubscriptionIdentifier(final int subscriptionIdentifier) { - this.subscriptionIdentifier = subscriptionIdentifier; - } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowTree.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowTree.java similarity index 61% rename from src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowTree.java rename to src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowTree.java index 6f7c50d84..bfec73271 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowTree.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowTree.java @@ -19,33 +19,34 @@ import com.hivemq.client.internal.annotations.NotThreadSafe; import com.hivemq.client.internal.mqtt.datatypes.*; -import com.hivemq.client.internal.util.collections.HandleList; +import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscription; import com.hivemq.client.internal.util.collections.HandleList.Handle; import com.hivemq.client.internal.util.collections.Index; import com.hivemq.client.internal.util.collections.NodeList; +import com.hivemq.client.mqtt.datatypes.MqttQos; +import com.hivemq.client.mqtt.mqtt5.message.subscribe.Mqtt5RetainHandling; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; -import javax.inject.Inject; -import java.util.function.Consumer; +import java.util.*; /** * @author Silvio Giebl */ @NotThreadSafe -public class MqttSubscriptionFlowTree implements MqttSubscriptionFlows { +public class MqttSubscribedPublishFlowTree implements MqttSubscribedPublishFlows { private @Nullable TopicTreeNode rootNode; - @Inject - MqttSubscriptionFlowTree() {} + MqttSubscribedPublishFlowTree() {} @Override public void subscribe( - final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) { + final @NotNull MqttSubscription subscription, final int subscriptionIdentifier, + final @Nullable MqttSubscribedPublishFlow flow) { - final TopicTreeEntry entry = (flow == null) ? null : new TopicTreeEntry(flow, topicFilter); - final MqttTopicIterator topicIterator = MqttTopicIterator.of(topicFilter); + final TopicTreeEntry entry = new TopicTreeEntry(subscription, subscriptionIdentifier, flow); + final MqttTopicIterator topicIterator = MqttTopicIterator.of(subscription.getTopicFilter()); TopicTreeNode node = rootNode; if (node == null) { rootNode = node = new TopicTreeNode(null, null); @@ -56,24 +57,23 @@ public void subscribe( } @Override - public void remove(final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) { + public void suback( + final @NotNull MqttTopicFilterImpl topicFilter, final int subscriptionIdentifier, final boolean error) { + final MqttTopicIterator topicIterator = MqttTopicIterator.of(topicFilter); TopicTreeNode node = rootNode; while (node != null) { - node = node.remove(topicIterator, flow); + node = node.suback(topicIterator, subscriptionIdentifier, error); } compact(); } @Override - public void unsubscribe( - final @NotNull MqttTopicFilterImpl topicFilter, - final @Nullable Consumer unsubscribedCallback) { - + public void unsubscribe(final @NotNull MqttTopicFilterImpl topicFilter) { final MqttTopicIterator topicIterator = MqttTopicIterator.of(topicFilter); TopicTreeNode node = rootNode; while (node != null) { - node = node.unsubscribe(topicIterator, unsubscribedCallback); + node = node.unsubscribe(topicIterator); } compact(); } @@ -109,6 +109,20 @@ public void clear(final @NotNull Throwable cause) { rootNode = null; } + @Override + public @NotNull Map<@NotNull Integer, @NotNull List<@NotNull MqttSubscription>> getSubscriptions() { + final Map> map = new TreeMap<>(Comparator.reverseOrder()); + if (rootNode != null) { + final Queue nodes = new LinkedList<>(); + nodes.add(new IteratorNode(rootNode, null)); + while (!nodes.isEmpty()) { + final IteratorNode node = nodes.poll(); + node.node.getSubscriptions(node.parentTopicLevels, map, nodes); + } + } + return map; + } + private void compact() { if ((rootNode != null) && rootNode.isEmpty()) { rootNode = null; @@ -117,12 +131,23 @@ private void compact() { private static class TopicTreeEntry extends NodeList.Node { - final @NotNull MqttSubscribedPublishFlow flow; - final @NotNull Handle handle; - - TopicTreeEntry(final @NotNull MqttSubscribedPublishFlow flow, final @NotNull MqttTopicFilterImpl topicFilter) { + final int subscriptionIdentifier; + final byte subscriptionOptions; + final @Nullable byte[] topicFilterPrefix; + @Nullable MqttSubscribedPublishFlow flow; + @Nullable Handle handle; + boolean acknowledged; + + TopicTreeEntry( + final @NotNull MqttSubscription subscription, final int subscriptionIdentifier, + final @Nullable MqttSubscribedPublishFlow flow) { + + this.subscriptionIdentifier = subscriptionIdentifier; + subscriptionOptions = subscription.encodeSubscriptionOptions(); + final MqttTopicFilterImpl topicFilter = subscription.getTopicFilter(); + this.topicFilterPrefix = topicFilter.getPrefix(); this.flow = flow; - this.handle = flow.getTopicFilters().add(topicFilter); + handle = (flow == null) ? null : flow.getTopicFilters().add(topicFilter); } } @@ -137,8 +162,6 @@ private static class TopicTreeNode { private @Nullable TopicTreeNode singleLevel; private @Nullable NodeList entries; private @Nullable NodeList multiLevelEntries; - private int subscriptions; - private int multiLevelSubscriptions; TopicTreeNode(final @Nullable TopicTreeNode parent, final @Nullable MqttTopicLevel topicLevel) { this.parent = parent; @@ -146,7 +169,7 @@ private static class TopicTreeNode { } @Nullable TopicTreeNode subscribe( - final @NotNull MqttTopicIterator topicIterator, final @Nullable TopicTreeEntry entry) { + final @NotNull MqttTopicIterator topicIterator, final @NotNull TopicTreeEntry entry) { if (topicIterator.hasNext()) { final MqttTopicLevel nextLevel = topicIterator.next(); @@ -171,98 +194,95 @@ private static class TopicTreeNode { return getNext(node, topicIterator); } if (topicIterator.hasMultiLevelWildcard()) { - if (entry != null) { - if (multiLevelEntries == null) { - multiLevelEntries = new NodeList<>(); - } - multiLevelEntries.add(entry); + if (multiLevelEntries == null) { + multiLevelEntries = new NodeList<>(); } - multiLevelSubscriptions++; + multiLevelEntries.add(entry); } else { - if (entry != null) { - if (entries == null) { - entries = new NodeList<>(); - } - entries.add(entry); + if (entries == null) { + entries = new NodeList<>(); } - subscriptions++; + entries.add(entry); } return null; } - @Nullable TopicTreeNode remove( - final @NotNull MqttTopicIterator topicIterator, final @Nullable MqttSubscribedPublishFlow flow) { + @Nullable TopicTreeNode suback( + final @NotNull MqttTopicIterator topicIterator, final int subscriptionIdentifier, final boolean error) { if (topicIterator.hasNext()) { return traverseNext(topicIterator); } if (topicIterator.hasMultiLevelWildcard()) { - if (remove(multiLevelEntries, flow)) { + if (suback(multiLevelEntries, subscriptionIdentifier, error)) { multiLevelEntries = null; } - multiLevelSubscriptions--; } else { - if (remove(entries, flow)) { + if (suback(entries, subscriptionIdentifier, error)) { entries = null; } - subscriptions--; } compact(); return null; } - private static boolean remove( - final @Nullable NodeList entries, final @Nullable MqttSubscribedPublishFlow flow) { + private static boolean suback( + final @Nullable NodeList entries, final int subscriptionIdentifier, + final boolean error) { - if ((entries != null) && (flow != null)) { + if (entries != null) { for (TopicTreeEntry entry = entries.getFirst(); entry != null; entry = entry.getNext()) { - if (entry.flow == flow) { - entry.flow.getTopicFilters().remove(entry.handle); + if (entry.subscriptionIdentifier == subscriptionIdentifier) { + if (!error) { + entry.acknowledged = true; + return false; + } + if (entry.flow != null) { + assert entry.handle != null : "entry.flow != null -> entry.handle != null"; + entry.flow.getTopicFilters().remove(entry.handle); + } entries.remove(entry); - break; + return entries.isEmpty(); } } - return entries.isEmpty(); } return false; } - @Nullable TopicTreeNode unsubscribe( - final @NotNull MqttTopicIterator topicIterator, - final @Nullable Consumer unsubscribedCallback) { - + @Nullable TopicTreeNode unsubscribe(final @NotNull MqttTopicIterator topicIterator) { if (topicIterator.hasNext()) { return traverseNext(topicIterator); } if (topicIterator.hasMultiLevelWildcard()) { - unsubscribe(multiLevelEntries, unsubscribedCallback); - multiLevelEntries = null; - multiLevelSubscriptions = 0; + if (unsubscribe(multiLevelEntries)) { + multiLevelEntries = null; + } } else { - unsubscribe(entries, unsubscribedCallback); - entries = null; - subscriptions = 0; + if (unsubscribe(entries)) { + entries = null; + } } compact(); return null; } - private static void unsubscribe( - final @Nullable NodeList entries, - final @Nullable Consumer unsubscribedCallback) { - + private static boolean unsubscribe(final @Nullable NodeList entries) { if (entries != null) { for (TopicTreeEntry entry = entries.getFirst(); entry != null; entry = entry.getNext()) { - final MqttSubscribedPublishFlow flow = entry.flow; - flow.getTopicFilters().remove(entry.handle); - if (flow.getTopicFilters().isEmpty()) { - flow.onComplete(); - if (unsubscribedCallback != null) { - unsubscribedCallback.accept(flow); + if (entry.acknowledged) { + if (entry.flow != null) { + assert entry.handle != null : "entry.flow != null -> entry.handle != null"; + entry.flow.getTopicFilters().remove(entry.handle); + if (entry.flow.getTopicFilters().isEmpty()) { + entry.flow.onComplete(); + } } + entries.remove(entry); } } + return entries.isEmpty(); } + return false; } @Nullable TopicTreeNode cancel( @@ -272,30 +292,25 @@ private static void unsubscribe( return traverseNext(topicIterator); } if (topicIterator.hasMultiLevelWildcard()) { - if (cancel(multiLevelEntries, flow)) { - multiLevelEntries = null; - } + cancel(multiLevelEntries, flow); } else { - if (cancel(entries, flow)) { - entries = null; - } + cancel(entries, flow); } return null; } - private static boolean cancel( + private static void cancel( final @Nullable NodeList entries, final @NotNull MqttSubscribedPublishFlow flow) { if (entries != null) { for (TopicTreeEntry entry = entries.getFirst(); entry != null; entry = entry.getNext()) { if (entry.flow == flow) { - entries.remove(entry); + entry.flow = null; + entry.handle = null; break; } } - return entries.isEmpty(); } - return false; } @Nullable TopicTreeNode findMatching( @@ -303,9 +318,6 @@ private static boolean cancel( if (topicIterator.hasNext()) { add(matchingFlows, multiLevelEntries); - if (multiLevelSubscriptions != 0) { - matchingFlows.subscriptionFound = true; - } final MqttTopicLevel nextLevel = topicIterator.next(); final TopicTreeNode nextNode = (next == null) ? null : next.get(nextLevel); final TopicTreeNode singleLevel = this.singleLevel; @@ -332,19 +344,19 @@ private static boolean cancel( } add(matchingFlows, entries); add(matchingFlows, multiLevelEntries); - if ((subscriptions != 0) || (multiLevelSubscriptions != 0)) { - matchingFlows.subscriptionFound = true; - } return null; } private static void add( - final @NotNull HandleList target, - final @Nullable NodeList source) { + final @NotNull MqttMatchingPublishFlows matchingFlows, + final @Nullable NodeList entries) { - if (source != null) { - for (TopicTreeEntry entry = source.getFirst(); entry != null; entry = entry.getNext()) { - target.add(entry.flow); + if (entries != null) { + matchingFlows.subscriptionFound = true; + for (TopicTreeEntry entry = entries.getFirst(); entry != null; entry = entry.getNext()) { + if (entry.flow != null) { + matchingFlows.add(entry.flow); + } } } } @@ -372,7 +384,9 @@ private static void add( private static void clear(final @NotNull NodeList entries, final @NotNull Throwable cause) { for (TopicTreeEntry entry = entries.getFirst(); entry != null; entry = entry.getNext()) { - entry.flow.onError(cause); + if ((entry.flow != null) && entry.acknowledged) { + entry.flow.onError(cause); + } } } @@ -390,7 +404,7 @@ private static void clear(final @NotNull NodeList entries, final if (topicLevelBefore.isSingleLevelWildcard()) { singleLevel = nodeBefore; } else { - assert next != null; + assert next != null : "node must be in next -> next != null"; next.put(nodeBefore); } node.parent = nodeBefore; @@ -451,7 +465,7 @@ private static void clear(final @NotNull NodeList entries, final } private void compact() { - if ((parent != null) && ((subscriptions + multiLevelSubscriptions) == 0)) { + if ((parent != null) && (entries == null) && (multiLevelEntries == null)) { final boolean hasSingleLevel = singleLevel != null; final boolean hasNext = next != null; if (!hasSingleLevel && !hasNext) { @@ -466,10 +480,10 @@ private void compact() { } private void fuse(final @NotNull TopicTreeNode child) { - assert parent != null; - assert topicLevel != null; - assert child.parent == this; - assert child.topicLevel != null; + assert parent != null : "parent = null -> this = root node, root node must not be fused"; + assert topicLevel != null : "topicLevel = null -> this = root node, root node must not be fused"; + assert child.parent == this : "this must only be fused with its child"; + assert child.topicLevel != null : "child.topicLevel = null -> child = root node, root node has no parent"; final TopicTreeNode parent = this.parent; final MqttTopicLevels fusedTopicLevel = MqttTopicLevels.concat(topicLevel, child.topicLevel); child.parent = parent; @@ -477,17 +491,17 @@ private void fuse(final @NotNull TopicTreeNode child) { if (fusedTopicLevel.isSingleLevelWildcard()) { parent.singleLevel = child; } else { - assert parent.next != null; + assert parent.next != null : "this must be in parent.next -> parent.next != null"; parent.next.put(child); } } private void removeNext(final @NotNull TopicTreeNode node) { - assert node.topicLevel != null; + assert node.topicLevel != null : "topicLevel = null -> node = root node, root node has no parent"; if (node.topicLevel.isSingleLevelWildcard()) { singleLevel = null; } else { - assert next != null; + assert next != null : "node must be in next -> next != null"; next.remove(node.topicLevel); if (next.size() == 0) { next = null; @@ -496,7 +510,73 @@ private void removeNext(final @NotNull TopicTreeNode node) { } boolean isEmpty() { - return ((subscriptions + multiLevelSubscriptions) == 0) && (singleLevel == null) && (next == null); + return (next == null) && (singleLevel == null) && (entries == null) && (multiLevelEntries == null); + } + + public void getSubscriptions( + final @Nullable MqttTopicLevel parentTopicLevels, + final @NotNull Map<@NotNull Integer, @NotNull List<@NotNull MqttSubscription>> map, + final @NotNull Queue<@NotNull IteratorNode> nodes) { + + final MqttTopicLevel topicLevels = ((parentTopicLevels == null) || (topicLevel == null)) ? topicLevel : + MqttTopicLevels.concat(parentTopicLevels, topicLevel); + if (topicLevels != null) { + if (entries != null) { + getSubscriptions(entries, topicLevels, false, map); + } + if (multiLevelEntries != null) { + getSubscriptions(multiLevelEntries, topicLevels, true, map); + } + } + if (next != null) { + next.forEach(node -> nodes.add(new IteratorNode(node, topicLevels))); + } + if (singleLevel != null) { + nodes.add(new IteratorNode(singleLevel, topicLevels)); + } + } + + private static void getSubscriptions( + final @NotNull NodeList entries, final @NotNull MqttTopicLevel topicLevels, + final boolean multiLevelWildcard, + final @NotNull Map<@NotNull Integer, @NotNull List<@NotNull MqttSubscription>> map) { + + boolean exactFound = false; + for (TopicTreeEntry entry = entries.getLast(); entry != null; entry = entry.getPrev()) { + if (entry.acknowledged) { + if (entry.topicFilterPrefix == null) { + if (exactFound) { + continue; + } + exactFound = true; + } + final MqttTopicFilterImpl topicFilter = + topicLevels.toFilter(entry.topicFilterPrefix, multiLevelWildcard); + assert topicFilter != null : "reconstructed topic filter must be valid"; + final MqttQos qos = MqttSubscription.decodeQos(entry.subscriptionOptions); + assert qos != null : "reconstructed qos must be valid"; + final boolean noLocal = MqttSubscription.decodeNoLocal(entry.subscriptionOptions); + final Mqtt5RetainHandling retainHandling = + MqttSubscription.decodeRetainHandling(entry.subscriptionOptions); + assert retainHandling != null : "reconstructed retain handling must be valid"; + final boolean retainAsPublished = + MqttSubscription.decodeRetainAsPublished(entry.subscriptionOptions); + final MqttSubscription subscription = + new MqttSubscription(topicFilter, qos, noLocal, retainHandling, retainAsPublished); + map.computeIfAbsent(entry.subscriptionIdentifier, k -> new LinkedList<>()).add(subscription); + } + } + } + } + + private static class IteratorNode { + + final @NotNull TopicTreeNode node; + final @Nullable MqttTopicLevel parentTopicLevels; + + IteratorNode(final @NotNull TopicTreeNode node, final @Nullable MqttTopicLevel parentTopicLevels) { + this.node = node; + this.parentTopicLevels = parentTopicLevels; } } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowable.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowable.java index 9e97e32d6..78a3e1f58 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowable.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowable.java @@ -18,7 +18,6 @@ package com.hivemq.client.internal.mqtt.handler.publish.incoming; import com.hivemq.client.internal.mqtt.MqttClientConfig; -import com.hivemq.client.internal.mqtt.exceptions.MqttClientStateExceptions; import com.hivemq.client.internal.mqtt.handler.subscribe.MqttSubscriptionHandler; import com.hivemq.client.internal.mqtt.ioc.ClientComponent; import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscribe; @@ -26,7 +25,6 @@ import com.hivemq.client.mqtt.mqtt5.message.subscribe.suback.Mqtt5SubAck; import com.hivemq.client.rx.FlowableWithSingle; import com.hivemq.client.rx.reactivestreams.WithSingleSubscriber; -import io.reactivex.internal.subscriptions.EmptySubscription; import org.jetbrains.annotations.NotNull; import org.reactivestreams.Subscriber; @@ -47,18 +45,14 @@ public MqttSubscribedPublishFlowable( @Override protected void subscribeActual(final @NotNull Subscriber subscriber) { - if (clientConfig.getState().isConnectedOrReconnect()) { - final ClientComponent clientComponent = clientConfig.getClientComponent(); - final MqttIncomingQosHandler incomingQosHandler = clientComponent.incomingQosHandler(); - final MqttSubscriptionHandler subscriptionHandler = clientComponent.subscriptionHandler(); + final ClientComponent clientComponent = clientConfig.getClientComponent(); + final MqttIncomingQosHandler incomingQosHandler = clientComponent.incomingQosHandler(); + final MqttSubscriptionHandler subscriptionHandler = clientComponent.subscriptionHandler(); - final MqttSubscribedPublishFlow flow = - new MqttSubscribedPublishFlow(subscriber, clientConfig, incomingQosHandler); - subscriber.onSubscribe(flow); - subscriptionHandler.subscribe(subscribe, flow); - } else { - EmptySubscription.error(MqttClientStateExceptions.notConnected(), subscriber); - } + final MqttSubscribedPublishFlow flow = + new MqttSubscribedPublishFlow(subscriber, clientConfig, incomingQosHandler); + subscriber.onSubscribe(flow); + subscriptionHandler.subscribe(subscribe, flow); } @Override diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlows.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlows.java similarity index 67% rename from src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlows.java rename to src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlows.java index a6a17bd1b..c86eb3b7c 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlows.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlows.java @@ -20,28 +20,32 @@ import com.hivemq.client.internal.annotations.NotThreadSafe; import com.hivemq.client.internal.mqtt.datatypes.MqttTopicFilterImpl; import com.hivemq.client.internal.mqtt.datatypes.MqttTopicImpl; +import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscription; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; -import java.util.function.Consumer; +import java.util.List; +import java.util.Map; /** * @author Silvio Giebl */ @NotThreadSafe -public interface MqttSubscriptionFlows { +public interface MqttSubscribedPublishFlows { - void subscribe(@NotNull MqttTopicFilterImpl topicFilter, @Nullable MqttSubscribedPublishFlow flow); + void subscribe( + @NotNull MqttSubscription subscription, int subscriptionIdentifier, + @Nullable MqttSubscribedPublishFlow flow); - void remove(@NotNull MqttTopicFilterImpl topicFilter, @Nullable MqttSubscribedPublishFlow flow); + void suback(@NotNull MqttTopicFilterImpl topicFilter, int subscriptionIdentifier, boolean error); - void unsubscribe( - @NotNull MqttTopicFilterImpl topicFilter, - @Nullable Consumer unsubscribedCallback); + void unsubscribe(@NotNull MqttTopicFilterImpl topicFilter); void cancel(@NotNull MqttSubscribedPublishFlow flow); void findMatching(@NotNull MqttTopicImpl topic, @NotNull MqttMatchingPublishFlows matchingFlows); void clear(@NotNull Throwable cause); + + @NotNull Map<@NotNull Integer, @NotNull List<@NotNull MqttSubscription>> getSubscriptions(); } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowList.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowList.java deleted file mode 100644 index 84bd6b5f7..000000000 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowList.java +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Copyright 2018 dc-square and the HiveMQ MQTT Client Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package com.hivemq.client.internal.mqtt.handler.publish.incoming; - -import com.hivemq.client.internal.annotations.NotThreadSafe; -import com.hivemq.client.internal.mqtt.datatypes.MqttTopicFilterImpl; -import com.hivemq.client.internal.mqtt.datatypes.MqttTopicImpl; -import com.hivemq.client.internal.util.collections.HandleList; -import com.hivemq.client.internal.util.collections.HandleList.Handle; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; - -import javax.inject.Inject; -import java.util.HashMap; -import java.util.function.Consumer; - -/** - * @author Silvio Giebl - */ -@NotThreadSafe -public class MqttSubscriptionFlowList implements MqttSubscriptionFlows { - - private final @NotNull HandleList flows; - private final @NotNull HashMap subscribedTopicFilters; - - @Inject - MqttSubscriptionFlowList() { - flows = new HandleList<>(); - subscribedTopicFilters = new HashMap<>(); - } - - @Override - public void subscribe( - final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) { - - if (flow != null) { - final HandleList topicFilters = flow.getTopicFilters(); - if (topicFilters.isEmpty()) { - flows.add(flow); - } - topicFilters.add(topicFilter); - } - final Integer count = subscribedTopicFilters.put(topicFilter, 1); - if (count != null) { - subscribedTopicFilters.put(topicFilter, count + 1); - } - } - - @Override - public void remove(final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) { - if (flow != null) { - final HandleList topicFilters = flow.getTopicFilters(); - for (Handle h = topicFilters.getFirst(); h != null; h = h.getNext()) { - if (topicFilter.equals(h.getElement())) { - topicFilters.remove(h); - break; - } - } - if (topicFilters.isEmpty()) { - cancel(flow); - } - } - final Integer count = subscribedTopicFilters.remove(topicFilter); - if ((count != null) && (count > 1)) { - subscribedTopicFilters.put(topicFilter, count - 1); - } - } - - @Override - public void unsubscribe( - final @NotNull MqttTopicFilterImpl topicFilter, - final @Nullable Consumer unsubscribedCallback) { - - for (Handle h = flows.getFirst(); h != null; h = h.getNext()) { - final MqttSubscribedPublishFlow flow = h.getElement(); - final HandleList topicFilters = flow.getTopicFilters(); - for (Handle h2 = topicFilters.getFirst(); h2 != null; h2 = h2.getNext()) { - if (topicFilter.equals(h2.getElement())) { - topicFilters.remove(h2); - } - } - if (topicFilters.isEmpty()) { - flows.remove(h); - flow.onComplete(); - if (unsubscribedCallback != null) { - unsubscribedCallback.accept(flow); - } - } - } - subscribedTopicFilters.remove(topicFilter); - } - - @Override - public void cancel(final @NotNull MqttSubscribedPublishFlow flow) { - for (Handle h = flows.getFirst(); h != null; h = h.getNext()) { - if (h.getElement() == flow) { - flows.remove(h); - break; - } - } - } - - @Override - public void findMatching( - final @NotNull MqttTopicImpl topic, final @NotNull MqttMatchingPublishFlows matchingFlows) { - - for (Handle h = flows.getFirst(); h != null; h = h.getNext()) { - final MqttSubscribedPublishFlow flow = h.getElement(); - for (Handle h2 = flow.getTopicFilters().getFirst(); h2 != null; h2 = h2.getNext()) { - if (h2.getElement().matches(topic)) { - matchingFlows.add(flow); - break; - } - } - } - if (!matchingFlows.isEmpty()) { - matchingFlows.subscriptionFound = true; - return; - } - for (final MqttTopicFilterImpl subscribedTopicFilter : subscribedTopicFilters.keySet()) { - if (subscribedTopicFilter.matches(topic)) { - matchingFlows.subscriptionFound = true; - return; - } - } - } - - @Override - public void clear(final @NotNull Throwable cause) { - for (Handle h = flows.getFirst(); h != null; h = h.getNext()) { - h.getElement().onError(cause); - } - flows.clear(); - subscribedTopicFilters.clear(); - } -} diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckFlowableFlow.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckFlowableFlow.java index 88649357c..5ff576ed4 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckFlowableFlow.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckFlowableFlow.java @@ -45,12 +45,11 @@ class MqttAckFlowableFlow extends MqttAckFlow implements Subscription, Runnable private final @NotNull Subscriber subscriber; private final @NotNull MqttOutgoingQosHandler outgoingQosHandler; - private long requestedNettyLocal; + private long requested; private final @NotNull AtomicLong newRequested = new AtomicLong(); private final @NotNull AtomicInteger requestState = new AtomicInteger(STATE_NO_NEW_REQUESTS); private volatile long acknowledged; - private long acknowledgedNettyLocal; private final @NotNull AtomicLong published = new AtomicLong(); private @Nullable Throwable error; // synced over volatile published @@ -106,26 +105,26 @@ public void run() { } requested = addNewRequested(); } - if (requestedNettyLocal != Long.MAX_VALUE) { - requestedNettyLocal -= emitted; + if (this.requested != Long.MAX_VALUE) { + this.requested -= emitted; } acknowledged(acknowledged); } @CallByThread("Netty EventLoop") @Override - void acknowledged(final long acknowledged) { - if (acknowledged > 0) { - final long acknowledgedLocal = this.acknowledgedNettyLocal += acknowledged; - this.acknowledged = acknowledgedLocal; - if ((acknowledgedLocal == published.get()) && setDone()) { + void acknowledged(final long newAcknowledged) { + if (newAcknowledged > 0) { + final long acknowledged = this.acknowledged + newAcknowledged; + this.acknowledged = acknowledged; + if ((acknowledged == published.get()) && setDone()) { if (error != null) { subscriber.onError(error); } else { subscriber.onComplete(); } } - outgoingQosHandler.request(acknowledged); + outgoingQosHandler.request(newAcknowledged); } } @@ -138,14 +137,14 @@ void onComplete(final long published) { } } - void onError(final @NotNull Throwable t, final long published) { - error = t; + void onError(final @NotNull Throwable error, final long published) { + this.error = error; if (!this.published.compareAndSet(0, published)) { - RxJavaPlugins.onError(t); + RxJavaPlugins.onError(error); return; } if ((acknowledged == published) && setDone()) { - subscriber.onError(t); + subscriber.onError(error); } } @@ -164,10 +163,7 @@ public void request(final long n) { @CallByThread("Netty EventLoop") private long requested() { - if (requestedNettyLocal <= 0) { - return addNewRequested(); - } - return requestedNettyLocal; + return (requested > 0) ? requested : addNewRequested(); } @CallByThread("Netty EventLoop") @@ -182,7 +178,7 @@ private long addNewRequested() { // requestState is afterwards set to STATE_NEW_REQUESTS although newRequested is reset to 0. // If request is not called until the next invocation of this method, newRequested may be 0. if (newRequested > 0) { - return requestedNettyLocal = BackpressureHelper.addCap(requestedNettyLocal, newRequested); + return requested = BackpressureHelper.addCap(requested, newRequested); } } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckSingle.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckSingle.java index bddd653a7..43cddbe5c 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckSingle.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckSingle.java @@ -82,14 +82,23 @@ private static class Flow extends MqttAckFlow implements Disposable { @Override void onNext(final @NotNull MqttPublishResult result) { if (result.acknowledged()) { - onNextUnsafe(result); + done(result); } else { this.result = result; } } @CallByThread("Netty EventLoop") - private void onNextUnsafe(final @NotNull MqttPublishResult result) { + @Override + void acknowledged(final long acknowledged) { + final MqttPublishResult result = this.result; + assert (acknowledged == 1) && (result != null) : "a single publish must be acknowledged exactly once"; + this.result = null; + done(result); + } + + @CallByThread("Netty EventLoop") + private void done(final @NotNull MqttPublishResult result) { if (setDone()) { final Throwable error = result.getRawError(); if (error == null) { @@ -100,17 +109,5 @@ private void onNextUnsafe(final @NotNull MqttPublishResult result) { } outgoingQosHandler.request(1); } - - @CallByThread("Netty EventLoop") - @Override - void acknowledged(final long acknowledged) { - final MqttPublishResult result = this.result; - if ((acknowledged != 1) || (result == null)) { - throw new IllegalStateException( - "A single publish must be acknowledged exactly once. This must not happen and is a bug."); - } - this.result = null; - onNextUnsafe(result); - } } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckSingleFlowable.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckSingleFlowable.java index 8e61362eb..c1798d9f9 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckSingleFlowable.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttAckSingleFlowable.java @@ -66,6 +66,7 @@ private static class Flow extends MqttAckFlow implements Subscription, Runnable private static final int STATE_NONE = 0; private static final int STATE_RESULT = 1; private static final int STATE_REQUESTED = 2; + private static final int STATE_CANCELLED = 3; private final @NotNull Subscriber subscriber; private final @NotNull MqttOutgoingQosHandler outgoingQosHandler; @@ -87,25 +88,26 @@ private static class Flow extends MqttAckFlow implements Subscription, Runnable @CallByThread("Netty EventLoop") @Override void onNext(final @NotNull MqttPublishResult result) { - if (state.get() == STATE_REQUESTED) { - onNextUnsafe(result); - } else { - this.result = result; - if (!state.compareAndSet(STATE_NONE, STATE_RESULT)) { - this.result = null; - onNextUnsafe(result); - } else if (isCancelled()) { - this.result = null; - if (result.acknowledged()) { - acknowledged(1); + switch (state.get()) { + case STATE_NONE: + if (state.compareAndSet(STATE_NONE, STATE_RESULT)) { + this.result = result; + } else { + onNext(result); } - } + break; + case STATE_REQUESTED: + subscriber.onNext(result); + done(result); + break; + case STATE_CANCELLED: + done(result); + break; } } @CallByThread("Netty EventLoop") - private void onNextUnsafe(final @NotNull MqttPublishResult result) { - subscriber.onNext(result); + private void done(final @NotNull MqttPublishResult result) { if (result.acknowledged()) { acknowledged(1); } @@ -114,10 +116,7 @@ private void onNextUnsafe(final @NotNull MqttPublishResult result) { @CallByThread("Netty EventLoop") @Override void acknowledged(final long acknowledged) { - if (acknowledged != 1) { - throw new IllegalStateException( - "A single publish must be acknowledged exactly once. This must not happen and is a bug."); - } + assert acknowledged == 1 : "a single publish must be acknowledged exactly once"; if (setDone()) { subscriber.onComplete(); } @@ -126,14 +125,14 @@ void acknowledged(final long acknowledged) { @Override public void request(final long n) { - if ((n > 0) && !isCancelled() && (state.getAndSet(STATE_REQUESTED) == STATE_RESULT)) { + if ((n > 0) && (state.getAndSet(STATE_REQUESTED) == STATE_RESULT)) { eventLoop.execute(this); } } @Override protected void onCancel() { - if (state.get() == STATE_RESULT) { + if (state.getAndSet(STATE_CANCELLED) == STATE_RESULT) { eventLoop.execute(this); } } @@ -144,13 +143,10 @@ public void run() { final MqttPublishResult result = this.result; if (result != null) { this.result = null; - if (isCancelled()) { - if (result.acknowledged()) { - acknowledged(1); - } - } else { - onNextUnsafe(result); + if (!isCancelled()) { + subscriber.onNext(result); } + done(result); } } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttOutgoingQosHandler.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttOutgoingQosHandler.java index 9ab9db3e2..2dd728d75 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttOutgoingQosHandler.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttOutgoingQosHandler.java @@ -48,6 +48,7 @@ import com.hivemq.client.internal.util.UnsignedDataTypes; import com.hivemq.client.internal.util.collections.IntIndex; import com.hivemq.client.internal.util.collections.NodeList; +import com.hivemq.client.mqtt.MqttClientState; import com.hivemq.client.mqtt.datatypes.MqttQos; import com.hivemq.client.mqtt.exceptions.ConnectionClosedException; import com.hivemq.client.mqtt.mqtt5.advanced.interceptor.qos1.Mqtt5OutgoingQos1Interceptor; @@ -88,33 +89,32 @@ public class MqttOutgoingQosHandler extends MqttSessionAwareHandler private final @NotNull MqttClientConfig clientConfig; private final @NotNull MqttPublishFlowables publishFlowables; + // valid for session private final @NotNull SpscUnboundedArrayQueue queue = new SpscUnboundedArrayQueue<>(32); private final @NotNull AtomicInteger queuedCounter = new AtomicInteger(); - private final @NotNull IntIndex pendingIndex = new IntIndex<>(INDEX_SPEC); private final @NotNull NodeList pending = new NodeList<>(); private final @NotNull Ranges packetIdentifiers = new Ranges(1, 0); - private int sendMaximum; + // valid for connection + private final @NotNull IntIndex pendingIndex = new IntIndex<>(INDEX_SPEC); private @Nullable MqttPubOrRelWithFlow resendPending; private @Nullable MqttPublishWithFlow currentPending; + private int sendMaximum; private @Nullable MqttTopicAliasMapping topicAliasMapping; + private @Nullable Subscription subscription; private int shrinkRequests; @Inject - MqttOutgoingQosHandler( - final @NotNull MqttClientConfig clientConfig, final @NotNull MqttPublishFlowables publishFlowables) { - + MqttOutgoingQosHandler(final @NotNull MqttClientConfig clientConfig) { this.clientConfig = clientConfig; - this.publishFlowables = publishFlowables; + publishFlowables = new MqttPublishFlowables(); } @Override public void onSessionStartOrResume( final @NotNull MqttClientConnectionConfig connectionConfig, final @NotNull EventLoop eventLoop) { - super.onSessionStartOrResume(connectionConfig, eventLoop); - final int oldSendMaximum = sendMaximum; final int newSendMaximum = Math.min(connectionConfig.getSendMaximum(), UnsignedDataTypes.UNSIGNED_SHORT_MAX_VALUE - MqttSubscriptionHandler.MAX_SUB_PENDING); @@ -139,10 +139,12 @@ public void onSessionStartOrResume( topicAliasMapping = connectionConfig.getSendTopicAliasMapping(); pendingIndex.clear(); - if ((pending.getFirst() != null) || (queuedCounter.get() > 0)) { - resendPending = pending.getFirst(); + resendPending = pending.getFirst(); + if ((resendPending != null) || (queuedCounter.get() > 0)) { eventLoop.execute(this); } + + super.onSessionStartOrResume(connectionConfig, eventLoop); } @Override @@ -187,7 +189,9 @@ void request(final long n) { @Override public void run() { if (!hasSession) { - clearQueued(MqttClientStateExceptions.notConnected()); + if (!isRepublishIfSessionExpired()) { + clearQueued(MqttClientStateExceptions.notConnected()); + } return; } final ChannelHandlerContext ctx = this.ctx; @@ -473,6 +477,13 @@ public void exceptionCaught(final @NotNull ChannelHandlerContext ctx, final @Not public void onSessionEnd(final @NotNull Throwable cause) { super.onSessionEnd(cause); + pendingIndex.clear(); + resendPending = null; + + if (isRepublishIfSessionExpired()) { + return; + } + for (MqttPubOrRelWithFlow current = pending.getFirst(); current != null; current = current.getNext()) { packetIdentifiers.returnId(current.packetIdentifier); if (current instanceof MqttPublishWithFlow) { @@ -489,13 +500,14 @@ public void onSessionEnd(final @NotNull Throwable cause) { } } } - pendingIndex.clear(); pending.clear(); - resendPending = null; - clearQueued(cause); } + private boolean isRepublishIfSessionExpired() { + return clientConfig.isRepublishIfSessionExpired() && (clientConfig.getState() != MqttClientState.DISCONNECTED); + } + private void clearQueued(final @NotNull Throwable cause) { int polled = 0; while (true) { diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttPublishFlowableAckLink.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttPublishFlowableAckLink.java index 4c756338e..0d0d96d25 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttPublishFlowableAckLink.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttPublishFlowableAckLink.java @@ -62,15 +62,15 @@ private static class AckLinkSubscriber implements FlowableSubscriber subscriber; private final @NotNull MqttAckFlowableFlow ackFlow; private @Nullable Subscription subscription; - private final @NotNull AtomicInteger state = new AtomicInteger(); + private final @NotNull AtomicInteger state = new AtomicInteger(STATE_NONE); + private final @NotNull AtomicInteger requestState = new AtomicInteger(STATE_NONE); private long published; AckLinkSubscriber( @@ -90,10 +90,10 @@ public void onSubscribe(final @NotNull Subscription subscription) { @Override public void onNext(final @NotNull MqttPublish publish) { - if (state.compareAndSet(STATE_NONE, STATE_EMITTING)) { + if (state.compareAndSet(STATE_NONE, STATE_IN_PROGRESS)) { subscriber.onNext(new MqttPublishWithFlow(publish, ackFlow)); published++; - if (!state.compareAndSet(STATE_EMITTING, STATE_NONE)) { + if (!state.compareAndSet(STATE_IN_PROGRESS, STATE_NONE)) { cancelActual(); } } @@ -120,29 +120,32 @@ public void onError(final @NotNull Throwable error) { @Override public void request(final long n) { assert subscription != null; - subscription.request(n); + if (requestState.compareAndSet(STATE_NONE, STATE_IN_PROGRESS)) { + subscription.request(n); + if (!requestState.compareAndSet(STATE_IN_PROGRESS, STATE_NONE)) { + subscription.cancel(); + } + } } @Override public void cancel() { - assert subscription != null; LOGGER.error("MqttPublishFlowables is global and must never cancel. This must not happen and is a bug."); - subscription.cancel(); } @Override public void cancelLink() { - if (state.getAndSet(STATE_CANCEL) == STATE_NONE) { + if (state.getAndSet(STATE_CANCELLED) == STATE_NONE) { cancelActual(); } } private void cancelActual() { - if (state.compareAndSet(STATE_CANCEL, STATE_CANCELLED)) { - assert subscription != null; + assert subscription != null; + if (requestState.getAndSet(STATE_CANCELLED) == STATE_NONE) { subscription.cancel(); - subscriber.onComplete(); } + subscriber.onComplete(); } } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttPublishFlowables.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttPublishFlowables.java index 23f6ea022..e6f08bb48 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttPublishFlowables.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/publish/outgoing/MqttPublishFlowables.java @@ -27,8 +27,6 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; -import javax.inject.Inject; - /** * @author Silvio Giebl */ @@ -40,7 +38,6 @@ public class MqttPublishFlowables extends Flowable private @Nullable Subscriber> subscriber; private long requested; - @Inject MqttPublishFlowables() {} @Override diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubAckSingle.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubAckSingle.java index 0d7b890b3..3ad068b91 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubAckSingle.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubAckSingle.java @@ -18,14 +18,12 @@ package com.hivemq.client.internal.mqtt.handler.subscribe; import com.hivemq.client.internal.mqtt.MqttClientConfig; -import com.hivemq.client.internal.mqtt.exceptions.MqttClientStateExceptions; import com.hivemq.client.internal.mqtt.ioc.ClientComponent; import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscribe; import com.hivemq.client.internal.mqtt.message.subscribe.suback.MqttSubAck; import com.hivemq.client.mqtt.mqtt5.message.subscribe.suback.Mqtt5SubAck; import io.reactivex.Single; import io.reactivex.SingleObserver; -import io.reactivex.internal.disposables.EmptyDisposable; import org.jetbrains.annotations.NotNull; /** @@ -43,15 +41,11 @@ public MqttSubAckSingle(final @NotNull MqttSubscribe subscribe, final @NotNull M @Override protected void subscribeActual(final @NotNull SingleObserver observer) { - if (clientConfig.getState().isConnectedOrReconnect()) { - final ClientComponent clientComponent = clientConfig.getClientComponent(); - final MqttSubscriptionHandler subscriptionHandler = clientComponent.subscriptionHandler(); + final ClientComponent clientComponent = clientConfig.getClientComponent(); + final MqttSubscriptionHandler subscriptionHandler = clientComponent.subscriptionHandler(); - final MqttSubOrUnsubAckFlow flow = new MqttSubOrUnsubAckFlow<>(observer, clientConfig); - observer.onSubscribe(flow); - subscriptionHandler.subscribe(subscribe, flow); - } else { - EmptyDisposable.error(MqttClientStateExceptions.notConnected(), observer); - } + final MqttSubOrUnsubAckFlow flow = new MqttSubOrUnsubAckFlow<>(observer, clientConfig); + observer.onSubscribe(flow); + subscriptionHandler.subscribe(subscribe, flow); } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubOrUnsubAckFlow.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubOrUnsubAckFlow.java index 0ee2fc969..0ea3f3139 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubOrUnsubAckFlow.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubOrUnsubAckFlow.java @@ -45,9 +45,9 @@ public void onSuccess(final @NotNull T t) { } @Override - public void onError(final @NotNull Throwable t) { + public void onError(final @NotNull Throwable error) { if (setDone()) { - observer.onError(t); + observer.onError(error); } } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubOrUnsubWithFlow.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubOrUnsubWithFlow.java index 510b19d44..a05e484f6 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubOrUnsubWithFlow.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubOrUnsubWithFlow.java @@ -17,21 +17,15 @@ package com.hivemq.client.internal.mqtt.handler.subscribe; -import com.hivemq.client.internal.mqtt.message.MqttStatefulMessage; import com.hivemq.client.internal.util.collections.NodeList; -import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; /** * @author Silvio Giebl */ -abstract class MqttSubOrUnsubWithFlow { +abstract class MqttSubOrUnsubWithFlow extends NodeList.Node { - abstract @NotNull MqttSubscriptionFlow getFlow(); + int packetIdentifier; - static abstract class Stateful extends NodeList.Node { - - abstract @NotNull MqttStatefulMessage.WithId getMessage(); - - abstract @NotNull MqttSubscriptionFlow getFlow(); - } + abstract @Nullable MqttSubscriptionFlow getFlow(); } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubscribeWithFlow.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubscribeWithFlow.java index e6d465133..5347b0ebd 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubscribeWithFlow.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubscribeWithFlow.java @@ -17,8 +17,6 @@ package com.hivemq.client.internal.mqtt.handler.subscribe; -import com.hivemq.client.internal.mqtt.handler.publish.incoming.MqttSubscribedPublishFlow; -import com.hivemq.client.internal.mqtt.message.subscribe.MqttStatefulSubscribe; import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscribe; import com.hivemq.client.internal.mqtt.message.subscribe.suback.MqttSubAck; import org.jetbrains.annotations.NotNull; @@ -29,47 +27,21 @@ */ class MqttSubscribeWithFlow extends MqttSubOrUnsubWithFlow { - private final @NotNull MqttSubscribe subscribe; - private final @NotNull MqttSubscriptionFlow flow; + final @NotNull MqttSubscribe subscribe; + final int subscriptionIdentifier; + private final @Nullable MqttSubscriptionFlow flow; MqttSubscribeWithFlow( - final @NotNull MqttSubscribe subscribe, final @NotNull MqttSubscriptionFlow flow) { + final @NotNull MqttSubscribe subscribe, final int subscriptionIdentifier, + final @Nullable MqttSubscriptionFlow flow) { this.subscribe = subscribe; + this.subscriptionIdentifier = subscriptionIdentifier; this.flow = flow; } - @NotNull MqttSubscribe getMessage() { - return subscribe; - } - @Override - @NotNull MqttSubscriptionFlow getFlow() { + @Nullable MqttSubscriptionFlow getFlow() { return flow; } - - static class Stateful extends MqttSubOrUnsubWithFlow.Stateful { - - private final @NotNull MqttStatefulSubscribe subscribe; - private final @NotNull MqttSubscriptionFlow flow; - - Stateful(final @NotNull MqttStatefulSubscribe subscribe, final @NotNull MqttSubscriptionFlow flow) { - this.subscribe = subscribe; - this.flow = flow; - } - - @Override - @NotNull MqttStatefulSubscribe getMessage() { - return subscribe; - } - - @Override - @NotNull MqttSubscriptionFlow getFlow() { - return flow; - } - - @Nullable MqttSubscribedPublishFlow getPublishFlow() { - return (flow instanceof MqttSubscribedPublishFlow) ? (MqttSubscribedPublishFlow) flow : null; - } - } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubscriptionHandler.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubscriptionHandler.java index c0ea871fb..af25704c9 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubscriptionHandler.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttSubscriptionHandler.java @@ -20,11 +20,12 @@ import com.hivemq.client.internal.annotations.CallByThread; import com.hivemq.client.internal.logging.InternalLogger; import com.hivemq.client.internal.logging.InternalLoggerFactory; +import com.hivemq.client.internal.mqtt.MqttClientConfig; import com.hivemq.client.internal.mqtt.MqttClientConnectionConfig; -import com.hivemq.client.internal.mqtt.datatypes.MqttVariableByteInteger; -import com.hivemq.client.internal.mqtt.exceptions.MqttClientStateExceptions; +import com.hivemq.client.internal.mqtt.datatypes.MqttUserPropertiesImpl; import com.hivemq.client.internal.mqtt.handler.MqttSessionAwareHandler; import com.hivemq.client.internal.mqtt.handler.disconnect.MqttDisconnectUtil; +import com.hivemq.client.internal.mqtt.handler.publish.incoming.MqttGlobalIncomingPublishFlow; import com.hivemq.client.internal.mqtt.handler.publish.incoming.MqttIncomingPublishFlows; import com.hivemq.client.internal.mqtt.handler.publish.incoming.MqttSubscribedPublishFlow; import com.hivemq.client.internal.mqtt.ioc.ClientScope; @@ -41,6 +42,7 @@ import com.hivemq.client.internal.util.collections.ImmutableList; import com.hivemq.client.internal.util.collections.IntIndex; import com.hivemq.client.internal.util.collections.NodeList; +import com.hivemq.client.mqtt.MqttClientState; import com.hivemq.client.mqtt.mqtt5.exceptions.Mqtt5SubAckException; import com.hivemq.client.mqtt.mqtt5.exceptions.Mqtt5UnsubAckException; import com.hivemq.client.mqtt.mqtt5.message.disconnect.Mqtt5DisconnectReasonCode; @@ -53,8 +55,6 @@ import javax.inject.Inject; import java.io.IOException; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicInteger; /** * @author Silvio Giebl @@ -65,23 +65,29 @@ public class MqttSubscriptionHandler extends MqttSessionAwareHandler implements public static final @NotNull String NAME = "subscription"; private static final @NotNull InternalLogger LOGGER = InternalLoggerFactory.getLogger(MqttSubscriptionHandler.class); - private static final @NotNull IntIndex.Spec INDEX_SPEC = - new IntIndex.Spec<>(x -> x.getMessage().getPacketIdentifier(), 4); + private static final @NotNull IntIndex.Spec INDEX_SPEC = + new IntIndex.Spec<>(x -> x.packetIdentifier, 4); public static final int MAX_SUB_PENDING = 10; // TODO configurable + private final @NotNull MqttClientConfig clientConfig; private final @NotNull MqttIncomingPublishFlows incomingPublishFlows; - private final @NotNull ConcurrentLinkedQueue queued = new ConcurrentLinkedQueue<>(); - private final @NotNull AtomicInteger queuedCounter = new AtomicInteger(); - private final @NotNull IntIndex pendingIndex = new IntIndex<>(INDEX_SPEC); - private final @NotNull NodeList pending = new NodeList<>(); + // valid for session + private final @NotNull NodeList pending = new NodeList<>(); private final @NotNull Ranges packetIdentifiers; + private int nextSubscriptionIdentifier = 1; - private @Nullable MqttSubOrUnsubWithFlow.Stateful resendPending, currentPending; - private @Nullable Ranges subscriptionIdentifiers; + // valid for connection + private final @NotNull IntIndex pendingIndex = new IntIndex<>(INDEX_SPEC); + private @Nullable MqttSubOrUnsubWithFlow sendPending, currentPending; + private boolean subscriptionIdentifiersAvailable; @Inject - MqttSubscriptionHandler(final @NotNull MqttIncomingPublishFlows incomingPublishFlows) { + MqttSubscriptionHandler( + final @NotNull MqttClientConfig clientConfig, + final @NotNull MqttIncomingPublishFlows incomingPublishFlows) { + + this.clientConfig = clientConfig; this.incomingPublishFlows = incomingPublishFlows; final int maxPacketIdentifier = UnsignedDataTypes.UNSIGNED_SHORT_MAX_VALUE; @@ -93,142 +99,119 @@ public class MqttSubscriptionHandler extends MqttSessionAwareHandler implements public void onSessionStartOrResume( final @NotNull MqttClientConnectionConfig connectionConfig, final @NotNull EventLoop eventLoop) { - super.onSessionStartOrResume(connectionConfig, eventLoop); - if (connectionConfig.areSubscriptionIdentifiersAvailable() && (subscriptionIdentifiers == null)) { - subscriptionIdentifiers = new Ranges(1, MqttVariableByteInteger.FOUR_BYTES_MAX_VALUE); + subscriptionIdentifiersAvailable = connectionConfig.areSubscriptionIdentifiersAvailable(); + + if (!hasSession) { + incomingPublishFlows.getSubscriptions().forEach((subscriptionIdentifier, subscriptions) -> { + final MqttSubscribe subscribe = new MqttSubscribe(ImmutableList.copyOf(subscriptions), + MqttUserPropertiesImpl.NO_USER_PROPERTIES); + pending.addFirst(new MqttSubscribeWithFlow(subscribe, subscriptionIdentifier, null)); + }); } - if ((pending.getFirst() != null) || (queuedCounter.get() > 0)) { - resendPending = pending.getFirst(); + + pendingIndex.clear(); + sendPending = pending.getFirst(); + if (sendPending != null) { eventLoop.execute(this); } + + super.onSessionStartOrResume(connectionConfig, eventLoop); } public void subscribe( final @NotNull MqttSubscribe subscribe, final @NotNull MqttSubscriptionFlow flow) { - queue(new MqttSubscribeWithFlow(subscribe, flow)); + flow.getEventLoop().execute(() -> { + if (flow.init()) { + final int subscriptionIdentifier = nextSubscriptionIdentifier++; + incomingPublishFlows.subscribe(subscribe, subscriptionIdentifier, + (flow instanceof MqttSubscribedPublishFlow) ? (MqttSubscribedPublishFlow) flow : null); + queue(new MqttSubscribeWithFlow(subscribe, subscriptionIdentifier, flow)); + } + }); } public void unsubscribe( final @NotNull MqttUnsubscribe unsubscribe, final @NotNull MqttSubOrUnsubAckFlow flow) { - queue(new MqttUnsubscribeWithFlow(unsubscribe, flow)); + flow.getEventLoop().execute(() -> { + if (flow.init()) { + queue(new MqttUnsubscribeWithFlow(unsubscribe, flow)); + } + }); + } + + public void subscribeGlobal(final @NotNull MqttGlobalIncomingPublishFlow flow) { + flow.getEventLoop().execute(() -> { + if (flow.init()) { + incomingPublishFlows.subscribeGlobal(flow); + } + }); } private void queue(final @NotNull MqttSubOrUnsubWithFlow subOrUnsubWithFlow) { - queued.offer(subOrUnsubWithFlow); - if (queuedCounter.getAndIncrement() == 0) { - subOrUnsubWithFlow.getFlow().getEventLoop().execute(this); + pending.add(subOrUnsubWithFlow); + if (sendPending == null) { + sendPending = subOrUnsubWithFlow; + run(); } } @CallByThread("Netty EventLoop") @Override public void run() { - if (!hasSession) { - clearQueued(MqttClientStateExceptions.notConnected()); - return; - } final ChannelHandlerContext ctx = this.ctx; if (ctx == null) { return; } - for (; resendPending != null; resendPending = resendPending.getNext()) { - if (resendPending instanceof MqttSubscribeWithFlow.Stateful) { - writeSubscribe(ctx, (MqttSubscribeWithFlow.Stateful) resendPending); - } else { - writeUnsubscribe(ctx, (MqttUnsubscribeWithFlow.Stateful) resendPending); - } - } - int removedFromQueue = 0; - while (true) { - if (pendingIndex.size() == MAX_SUB_PENDING) { - queuedCounter.getAndAdd(-removedFromQueue); - return; - } - final MqttSubOrUnsubWithFlow subOrUnsubWithFlow = queued.poll(); - if (subOrUnsubWithFlow == null) { - if (queuedCounter.addAndGet(-removedFromQueue) == 0) { + int written = 0; + for (MqttSubOrUnsubWithFlow subOrUnsubWithFlow = sendPending; + (subOrUnsubWithFlow != null) && (pendingIndex.size() < MAX_SUB_PENDING); + sendPending = subOrUnsubWithFlow = subOrUnsubWithFlow.getNext()) { + + if (subOrUnsubWithFlow.packetIdentifier == 0) { + final int packetIdentifier = packetIdentifiers.getId(); + if (packetIdentifier == -1) { + LOGGER.error( + "No Packet Identifier available for (UN)SUBSCRIBE. This must not happen and is a bug."); return; - } else { - removedFromQueue = 0; - continue; } + subOrUnsubWithFlow.packetIdentifier = packetIdentifier; } - final int packetIdentifier = packetIdentifiers.getId(); - if (packetIdentifier == -1) { - LOGGER.error("No Packet Identifier available for (UN)SUBSCRIBE. This must not happen and is a bug."); - return; + pendingIndex.put(subOrUnsubWithFlow); + if (sendPending instanceof MqttSubscribeWithFlow) { + writeSubscribe(ctx, (MqttSubscribeWithFlow) subOrUnsubWithFlow); + } else { + writeUnsubscribe(ctx, (MqttUnsubscribeWithFlow) subOrUnsubWithFlow); } - writeSubscribeOrUnsubscribe(ctx, subOrUnsubWithFlow, packetIdentifier); - removedFromQueue++; + written++; } - } - - private void writeSubscribeOrUnsubscribe( - final @NotNull ChannelHandlerContext ctx, final @NotNull MqttSubOrUnsubWithFlow subOrUnsubWithFlow, - final int packetIdentifier) { - - if (!subOrUnsubWithFlow.getFlow().init()) { - return; - } - - if (subOrUnsubWithFlow instanceof MqttSubscribeWithFlow) { - final MqttSubscribeWithFlow subscribeWithFlow = (MqttSubscribeWithFlow) subOrUnsubWithFlow; - - final int subscriptionIdentifier = (subscriptionIdentifiers != null) ? subscriptionIdentifiers.getId() : - MqttStatefulSubscribe.DEFAULT_NO_SUBSCRIPTION_IDENTIFIER; - final MqttStatefulSubscribe statefulSubscribe = - subscribeWithFlow.getMessage().createStateful(packetIdentifier, subscriptionIdentifier); - - final MqttSubscribeWithFlow.Stateful statefulSubscribeWithFlow = - new MqttSubscribeWithFlow.Stateful(statefulSubscribe, subscribeWithFlow.getFlow()); - - addPending(statefulSubscribeWithFlow); - - if (writeSubscribe(ctx, statefulSubscribeWithFlow)) { - incomingPublishFlows.subscribe(statefulSubscribe, statefulSubscribeWithFlow.getPublishFlow()); - } - } else { - final MqttUnsubscribeWithFlow unsubscribeWithFlow = (MqttUnsubscribeWithFlow) subOrUnsubWithFlow; - - final MqttStatefulUnsubscribe statefulUnsubscribe = - unsubscribeWithFlow.getMessage().createStateful(packetIdentifier); - - final MqttUnsubscribeWithFlow.Stateful statefulUnsubscribeWithFlow = - new MqttUnsubscribeWithFlow.Stateful(statefulUnsubscribe, unsubscribeWithFlow.getFlow()); - - addPending(statefulUnsubscribeWithFlow); - - writeUnsubscribe(ctx, statefulUnsubscribeWithFlow); + if (written > 0) { + ctx.flush(); } } - private void addPending(final @NotNull MqttSubOrUnsubWithFlow.Stateful newPending) { - pendingIndex.put(newPending); - pending.add(newPending); - } + private void writeSubscribe( + final @NotNull ChannelHandlerContext ctx, final @NotNull MqttSubscribeWithFlow subscribeWithFlow) { - private boolean writeSubscribe( - final @NotNull ChannelHandlerContext ctx, - final @NotNull MqttSubscribeWithFlow.Stateful statefulSubscribeWithFlow) { + final int subscriptionIdentifier = subscriptionIdentifiersAvailable ? subscribeWithFlow.subscriptionIdentifier : + MqttStatefulSubscribe.DEFAULT_NO_SUBSCRIPTION_IDENTIFIER; + final MqttStatefulSubscribe statefulSubscribe = + subscribeWithFlow.subscribe.createStateful(subscribeWithFlow.packetIdentifier, subscriptionIdentifier); - final MqttStatefulSubscribe statefulSubscribe = statefulSubscribeWithFlow.getMessage(); - currentPending = statefulSubscribeWithFlow; - ctx.writeAndFlush(statefulSubscribe, ctx.voidPromise()); - if (currentPending == null) { // exception was handled - return false; - } + currentPending = subscribeWithFlow; + ctx.write(statefulSubscribe, ctx.voidPromise()); currentPending = null; - return true; } private void writeUnsubscribe( - final @NotNull ChannelHandlerContext ctx, - final @NotNull MqttUnsubscribeWithFlow.Stateful statefulUnsubscribeWithFlow) { + final @NotNull ChannelHandlerContext ctx, final @NotNull MqttUnsubscribeWithFlow unsubscribeWithFlow) { + + final MqttStatefulUnsubscribe statefulUnsubscribe = + unsubscribeWithFlow.unsubscribe.createStateful(unsubscribeWithFlow.packetIdentifier); - currentPending = statefulUnsubscribeWithFlow; - ctx.writeAndFlush(statefulUnsubscribeWithFlow.getMessage(), ctx.voidPromise()); + currentPending = unsubscribeWithFlow; + ctx.write(statefulUnsubscribe, ctx.voidPromise()); currentPending = null; } @@ -244,79 +227,75 @@ public void channelRead(final @NotNull ChannelHandlerContext ctx, final @NotNull } private void readSubAck(final @NotNull ChannelHandlerContext ctx, final @NotNull MqttSubAck subAck) { - final int packetIdentifier = subAck.getPacketIdentifier(); - final MqttSubOrUnsubWithFlow.Stateful statefulSubOrUnsubWithFlow = pendingIndex.remove(packetIdentifier); + final MqttSubOrUnsubWithFlow subOrUnsubWithFlow = pendingIndex.remove(subAck.getPacketIdentifier()); - if (statefulSubOrUnsubWithFlow == null) { + if (subOrUnsubWithFlow == null) { MqttDisconnectUtil.disconnect(ctx.channel(), Mqtt5DisconnectReasonCode.PROTOCOL_ERROR, "Unknown packet identifier for SUBACK"); return; } - if (!(statefulSubOrUnsubWithFlow instanceof MqttSubscribeWithFlow.Stateful)) { + if (!(subOrUnsubWithFlow instanceof MqttSubscribeWithFlow)) { MqttDisconnectUtil.disconnect(ctx.channel(), Mqtt5DisconnectReasonCode.PROTOCOL_ERROR, "SUBACK received for an UNSUBSCRIBE"); return; } - final MqttSubscribeWithFlow.Stateful statefulSubscribeWithFlow = - (MqttSubscribeWithFlow.Stateful) statefulSubOrUnsubWithFlow; - final MqttStatefulSubscribe subscribe = statefulSubscribeWithFlow.getMessage(); - final MqttSubscriptionFlow flow = statefulSubscribeWithFlow.getFlow(); + final MqttSubscribeWithFlow subscribeWithFlow = (MqttSubscribeWithFlow) subOrUnsubWithFlow; + final MqttSubscriptionFlow flow = subscribeWithFlow.getFlow(); final ImmutableList reasonCodes = subAck.getReasonCodes(); - final boolean countNotMatching = subscribe.stateless().getSubscriptions().size() != reasonCodes.size(); + final boolean countNotMatching = subscribeWithFlow.subscribe.getSubscriptions().size() != reasonCodes.size(); final boolean allErrors = MqttCommonReasonCode.allErrors(subAck.getReasonCodes()); - incomingPublishFlows.subAck(subscribe, subAck, statefulSubscribeWithFlow.getPublishFlow()); + incomingPublishFlows.subAck(subscribeWithFlow.subscribe, subscribeWithFlow.subscriptionIdentifier, reasonCodes); - if (!(countNotMatching || allErrors)) { - if (!flow.isCancelled()) { - flow.onSuccess(subAck); - } else { - LOGGER.warn("Subscribe was successful but the SubAck flow has been cancelled"); - } - } else { - final String errorMessage; - if (countNotMatching) { - errorMessage = "Count of Reason Codes in SUBACK does not match count of subscriptions in SUBSCRIBE"; - } else { // allErrors - errorMessage = "SUBACK contains only Error Codes"; - } - if (!flow.isCancelled()) { - flow.onError(new Mqtt5SubAckException(subAck, errorMessage)); + if (flow != null) { + if (!(countNotMatching || allErrors)) { + if (!flow.isCancelled()) { + flow.onSuccess(subAck); + } else { + LOGGER.warn("Subscribe was successful but the SubAck flow has been cancelled"); + } } else { - LOGGER.warn(errorMessage + " but the SubAck flow has been cancelled"); + final String errorMessage; + if (countNotMatching) { + errorMessage = "Count of Reason Codes in SUBACK does not match count of subscriptions in SUBSCRIBE"; + } else { // allErrors + errorMessage = "SUBACK contains only Error Codes"; + } + if (!flow.isCancelled()) { + flow.onError(new Mqtt5SubAckException(subAck, errorMessage)); + } else { + LOGGER.warn(errorMessage + " but the SubAck flow has been cancelled"); + } } } - completePending(ctx, statefulSubscribeWithFlow); + completePending(subscribeWithFlow); } private void readUnsubAck(final @NotNull ChannelHandlerContext ctx, final @NotNull MqttUnsubAck unsubAck) { - final int packetIdentifier = unsubAck.getPacketIdentifier(); - final MqttSubOrUnsubWithFlow.Stateful statefulSubOrUnsubWithFlow = pendingIndex.remove(packetIdentifier); + final MqttSubOrUnsubWithFlow subOrUnsubWithFlow = pendingIndex.remove(unsubAck.getPacketIdentifier()); - if (statefulSubOrUnsubWithFlow == null) { + if (subOrUnsubWithFlow == null) { MqttDisconnectUtil.disconnect(ctx.channel(), Mqtt5DisconnectReasonCode.PROTOCOL_ERROR, "Unknown packet identifier for UNSUBACK"); return; } - if (!(statefulSubOrUnsubWithFlow instanceof MqttUnsubscribeWithFlow.Stateful)) { + if (!(subOrUnsubWithFlow instanceof MqttUnsubscribeWithFlow)) { MqttDisconnectUtil.disconnect(ctx.channel(), Mqtt5DisconnectReasonCode.PROTOCOL_ERROR, "UNSUBACK received for a SUBSCRIBE"); return; } - final MqttUnsubscribeWithFlow.Stateful statefulUnsubscribeWithFlow = - (MqttUnsubscribeWithFlow.Stateful) statefulSubOrUnsubWithFlow; - final MqttStatefulUnsubscribe unsubscribe = statefulUnsubscribeWithFlow.getMessage(); - final MqttSubOrUnsubAckFlow flow = statefulUnsubscribeWithFlow.getFlow(); + final MqttUnsubscribeWithFlow unsubscribeWithFlow = (MqttUnsubscribeWithFlow) subOrUnsubWithFlow; + final MqttSubOrUnsubAckFlow flow = unsubscribeWithFlow.getFlow(); final ImmutableList reasonCodes = unsubAck.getReasonCodes(); - final boolean countNotMatching = unsubscribe.stateless().getTopicFilters().size() != reasonCodes.size(); + final boolean countNotMatching = unsubscribeWithFlow.unsubscribe.getTopicFilters().size() != reasonCodes.size(); final boolean allErrors = MqttCommonReasonCode.allErrors(unsubAck.getReasonCodes()); if ((reasonCodes == Mqtt3UnsubAckView.REASON_CODES_ALL_SUCCESS) || !(countNotMatching || allErrors)) { - incomingPublishFlows.unsubscribe(unsubscribe, unsubAck); + incomingPublishFlows.unsubscribe(unsubscribeWithFlow.unsubscribe, reasonCodes); if (!flow.isCancelled()) { flow.onSuccess(unsubAck); @@ -337,30 +316,33 @@ private void readUnsubAck(final @NotNull ChannelHandlerContext ctx, final @NotNu } } - completePending(ctx, statefulUnsubscribeWithFlow); + completePending(unsubscribeWithFlow); } - private void completePending( - final @NotNull ChannelHandlerContext ctx, final @NotNull MqttSubOrUnsubWithFlow.Stateful oldPending) { - + private void completePending(final @NotNull MqttSubOrUnsubWithFlow oldPending) { pending.remove(oldPending); - - final int packetIdentifier = oldPending.getMessage().getPacketIdentifier(); - final MqttSubOrUnsubWithFlow subOrUnsubWithFlow = queued.poll(); - if (subOrUnsubWithFlow == null) { - packetIdentifiers.returnId(packetIdentifier); - } else { - queuedCounter.getAndDecrement(); - writeSubscribeOrUnsubscribe(ctx, subOrUnsubWithFlow, packetIdentifier); - } + packetIdentifiers.returnId(oldPending.packetIdentifier); + run(); } @Override public void exceptionCaught(final @NotNull ChannelHandlerContext ctx, final @NotNull Throwable cause) { if (!(cause instanceof IOException) && (currentPending != null)) { - pendingIndex.remove(currentPending.getMessage().getPacketIdentifier()); - currentPending.getFlow().onError(cause); - completePending(ctx, currentPending); + pending.remove(currentPending); + packetIdentifiers.returnId(currentPending.packetIdentifier); + pendingIndex.remove(currentPending.packetIdentifier); + + final MqttSubscriptionFlow flow = currentPending.getFlow(); + if (flow != null) { + flow.onError(cause); + } + + if (currentPending instanceof MqttSubscribeWithFlow) { + final MqttSubscribeWithFlow subscribeWithFlow = (MqttSubscribeWithFlow) currentPending; + incomingPublishFlows.subAck(subscribeWithFlow.subscribe, subscribeWithFlow.subscriptionIdentifier, + ImmutableList.of(Mqtt5SubAckReasonCode.UNSPECIFIED_ERROR)); + } + currentPending = null; } else { ctx.fireExceptionCaught(cause); @@ -371,37 +353,28 @@ public void exceptionCaught(final @NotNull ChannelHandlerContext ctx, final @Not public void onSessionEnd(final @NotNull Throwable cause) { super.onSessionEnd(cause); - for (MqttSubOrUnsubWithFlow.Stateful current = pending.getFirst(); current != null; - current = current.getNext()) { - packetIdentifiers.returnId(current.getMessage().getPacketIdentifier()); - if (!(current.getFlow() instanceof MqttSubscribedPublishFlow)) { - current.getFlow().onError(cause); - } // else flow.onError is already called via incomingPublishFlows.clear() in IncomingQosHandler - } pendingIndex.clear(); - pending.clear(); - resendPending = null; - subscriptionIdentifiers = null; + sendPending = null; + for (MqttSubOrUnsubWithFlow current = pending.getFirst(); current != null; current = current.getNext()) { + if (current.packetIdentifier == 0) { + break; + } + packetIdentifiers.returnId(current.packetIdentifier); + current.packetIdentifier = 0; + } - clearQueued(cause); - } + if (clientConfig.isResubscribeIfSessionExpired() && (clientConfig.getState() != MqttClientState.DISCONNECTED)) { + return; + } - private void clearQueued(final @NotNull Throwable cause) { - int polled = 0; - while (true) { - final MqttSubOrUnsubWithFlow subOrUnsubWithFlow = queued.poll(); - if (subOrUnsubWithFlow == null) { - if (queuedCounter.addAndGet(-polled) == 0) { - break; - } else { - polled = 0; - continue; - } + incomingPublishFlows.clear(cause); + for (MqttSubOrUnsubWithFlow current = pending.getFirst(); current != null; current = current.getNext()) { + final MqttSubscriptionFlow flow = current.getFlow(); + if (flow != null) { + flow.onError(cause); } - if (subOrUnsubWithFlow.getFlow().init()) { - subOrUnsubWithFlow.getFlow().onError(cause); - } - polled++; } + pending.clear(); + nextSubscriptionIdentifier = 1; } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttUnsubAckSingle.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttUnsubAckSingle.java index 55a01eb95..a7802d5c6 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttUnsubAckSingle.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttUnsubAckSingle.java @@ -18,14 +18,12 @@ package com.hivemq.client.internal.mqtt.handler.subscribe; import com.hivemq.client.internal.mqtt.MqttClientConfig; -import com.hivemq.client.internal.mqtt.exceptions.MqttClientStateExceptions; import com.hivemq.client.internal.mqtt.ioc.ClientComponent; import com.hivemq.client.internal.mqtt.message.unsubscribe.MqttUnsubscribe; import com.hivemq.client.internal.mqtt.message.unsubscribe.unsuback.MqttUnsubAck; import com.hivemq.client.mqtt.mqtt5.message.unsubscribe.unsuback.Mqtt5UnsubAck; import io.reactivex.Single; import io.reactivex.SingleObserver; -import io.reactivex.internal.disposables.EmptyDisposable; import org.jetbrains.annotations.NotNull; /** @@ -45,15 +43,11 @@ public MqttUnsubAckSingle( @Override protected void subscribeActual(final @NotNull SingleObserver observer) { - if (clientConfig.getState().isConnectedOrReconnect()) { - final ClientComponent clientComponent = clientConfig.getClientComponent(); - final MqttSubscriptionHandler subscriptionHandler = clientComponent.subscriptionHandler(); - - final MqttSubOrUnsubAckFlow flow = new MqttSubOrUnsubAckFlow<>(observer, clientConfig); - observer.onSubscribe(flow); - subscriptionHandler.unsubscribe(unsubscribe, flow); - } else { - EmptyDisposable.error(MqttClientStateExceptions.notConnected(), observer); - } + final ClientComponent clientComponent = clientConfig.getClientComponent(); + final MqttSubscriptionHandler subscriptionHandler = clientComponent.subscriptionHandler(); + + final MqttSubOrUnsubAckFlow flow = new MqttSubOrUnsubAckFlow<>(observer, clientConfig); + observer.onSubscribe(flow); + subscriptionHandler.unsubscribe(unsubscribe, flow); } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttUnsubscribeWithFlow.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttUnsubscribeWithFlow.java index bfd343284..15cf8566a 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttUnsubscribeWithFlow.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/subscribe/MqttUnsubscribeWithFlow.java @@ -17,7 +17,6 @@ package com.hivemq.client.internal.mqtt.handler.subscribe; -import com.hivemq.client.internal.mqtt.message.unsubscribe.MqttStatefulUnsubscribe; import com.hivemq.client.internal.mqtt.message.unsubscribe.MqttUnsubscribe; import com.hivemq.client.internal.mqtt.message.unsubscribe.unsuback.MqttUnsubAck; import org.jetbrains.annotations.NotNull; @@ -27,7 +26,7 @@ */ class MqttUnsubscribeWithFlow extends MqttSubOrUnsubWithFlow { - private final @NotNull MqttUnsubscribe unsubscribe; + final @NotNull MqttUnsubscribe unsubscribe; private final @NotNull MqttSubOrUnsubAckFlow unsubAckFlow; MqttUnsubscribeWithFlow( @@ -38,36 +37,8 @@ class MqttUnsubscribeWithFlow extends MqttSubOrUnsubWithFlow { this.unsubAckFlow = unsubAckFlow; } - @NotNull MqttUnsubscribe getMessage() { - return unsubscribe; - } - @Override @NotNull MqttSubOrUnsubAckFlow getFlow() { return unsubAckFlow; } - - static class Stateful extends MqttSubOrUnsubWithFlow.Stateful { - - private final @NotNull MqttStatefulUnsubscribe unsubscribe; - private final @NotNull MqttSubOrUnsubAckFlow unsubAckFlow; - - Stateful( - final @NotNull MqttStatefulUnsubscribe unsubscribe, - final @NotNull MqttSubOrUnsubAckFlow unsubAckFlow) { - - this.unsubscribe = unsubscribe; - this.unsubAckFlow = unsubAckFlow; - } - - @Override - @NotNull MqttStatefulUnsubscribe getMessage() { - return unsubscribe; - } - - @Override - @NotNull MqttSubOrUnsubAckFlow getFlow() { - return unsubAckFlow; - } - } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/handler/util/FlowWithEventLoop.java b/src/main/java/com/hivemq/client/internal/mqtt/handler/util/FlowWithEventLoop.java index 084b531dd..62662abbf 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/handler/util/FlowWithEventLoop.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/handler/util/FlowWithEventLoop.java @@ -43,11 +43,11 @@ public FlowWithEventLoop(final @NotNull MqttClientConfig clientConfig) { } public boolean init() { - if (doneState.getAndSet(STATE_NOT_DONE) == STATE_CANCELLED) { - clientConfig.releaseEventLoop(); - return false; + if (doneState.compareAndSet(STATE_INIT, STATE_NOT_DONE)) { + return true; } - return true; + clientConfig.releaseEventLoop(); + return false; } protected boolean setDone() { diff --git a/src/main/java/com/hivemq/client/internal/mqtt/ioc/ClientModule.java b/src/main/java/com/hivemq/client/internal/mqtt/ioc/ClientModule.java index 513a994f4..f554e1790 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/ioc/ClientModule.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/ioc/ClientModule.java @@ -17,18 +17,10 @@ package com.hivemq.client.internal.mqtt.ioc; -import com.hivemq.client.internal.mqtt.handler.publish.incoming.MqttSubscriptionFlowTree; -import com.hivemq.client.internal.mqtt.handler.publish.incoming.MqttSubscriptionFlows; -import dagger.Binds; import dagger.Module; -import org.jetbrains.annotations.NotNull; /** * @author Silvio Giebl */ @Module -abstract class ClientModule { - - @Binds - abstract @NotNull MqttSubscriptionFlows provideSubscriptionFlows(final @NotNull MqttSubscriptionFlowTree tree); -} +abstract class ClientModule {} diff --git a/src/main/java/com/hivemq/client/internal/mqtt/lifecycle/MqttClientReconnector.java b/src/main/java/com/hivemq/client/internal/mqtt/lifecycle/MqttClientReconnector.java index 384a7e375..e2709beec 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/lifecycle/MqttClientReconnector.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/lifecycle/MqttClientReconnector.java @@ -41,12 +41,16 @@ public class MqttClientReconnector implements Mqtt5ClientReconnector { private final @NotNull EventLoop eventLoop; private final int attempts; - private boolean reconnect; + private boolean reconnect = DEFAULT_RECONNECT; private @Nullable CompletableFuture future; - private long delayNanos; + private boolean resubscribeIfSessionExpired = DEFAULT_RESUBSCRIBE_IF_SESSION_EXPIRED; + private boolean republishIfSessionExpired = DEFAULT_REPUBLISH_IF_SESSION_EXPIRED; + private long delayNanos = TimeUnit.MILLISECONDS.toNanos(DEFAULT_DELAY_MS); private @NotNull MqttClientTransportConfigImpl transportConfig; private @NotNull MqttConnect connect; + private boolean afterOnDisconnected; + public MqttClientReconnector( final @NotNull EventLoop eventLoop, final int attempts, final @NotNull MqttConnect connect, final @NotNull MqttClientTransportConfigImpl transportConfig) { @@ -59,49 +63,71 @@ public MqttClientReconnector( @Override public int getAttempts() { - checkThread(); + checkInEventLoop(); return attempts; } @Override public @NotNull MqttClientReconnector reconnect(final boolean reconnect) { - checkThread(); + checkInEventLoop(); this.reconnect = reconnect; return this; } @Override - public @NotNull Mqtt5ClientReconnector reconnectWhen( + public @NotNull MqttClientReconnector reconnectWhen( @Nullable CompletableFuture future, final @Nullable BiConsumer callback) { - checkThread(); + checkInOnDisconnected("reconnectWhen"); Checks.notNull(future, "Future"); this.reconnect = true; if (callback != null) { future = future.whenCompleteAsync(callback, eventLoop); } - if (this.future == null) { - this.future = future; - } else { - this.future = CompletableFuture.allOf(this.future, future); - } + this.future = (this.future == null) ? future : CompletableFuture.allOf(this.future, future); return this; } @Override public boolean isReconnect() { - checkThread(); + checkInEventLoop(); return reconnect; } public @NotNull CompletableFuture getFuture() { - checkThread(); + checkInEventLoop(); return (future == null) ? CompletableFuture.completedFuture(null) : future; } + @Override + public @NotNull MqttClientReconnector resubscribeIfSessionExpired(final boolean resubscribe) { + checkInOnDisconnected("resubscribeIfSessionExpired"); + resubscribeIfSessionExpired = resubscribe; + return this; + } + + @Override + public boolean isResubscribeIfSessionExpired() { + checkInEventLoop(); + return resubscribeIfSessionExpired; + } + + @Override + public @NotNull MqttClientReconnector republishIfSessionExpired(final boolean republish) { + checkInOnDisconnected("republishIfSessionExpired"); + republishIfSessionExpired = republish; + return this; + } + + @Override + public boolean isRepublishIfSessionExpired() { + checkInEventLoop(); + return republishIfSessionExpired; + } + @Override public @NotNull MqttClientReconnector delay(final long delay, final @Nullable TimeUnit timeUnit) { - checkThread(); + checkInOnDisconnected("delay"); Checks.notNull(timeUnit, "Time unit"); this.delayNanos = timeUnit.toNanos(delay); return this; @@ -109,14 +135,14 @@ public boolean isReconnect() { @Override public long getDelay(final @NotNull TimeUnit timeUnit) { - checkThread(); + checkInEventLoop(); Checks.notNull(timeUnit, "Time unit"); return timeUnit.convert(delayNanos, TimeUnit.NANOSECONDS); } @Override public @NotNull MqttClientReconnector transportConfig(final @Nullable MqttClientTransportConfig transportConfig) { - checkThread(); + checkInEventLoop(); this.transportConfig = Checks.notImplemented(transportConfig, MqttClientTransportConfigImpl.class, "Transport config"); return this; @@ -124,38 +150,49 @@ public long getDelay(final @NotNull TimeUnit timeUnit) { @Override public @NotNull MqttClientTransportConfigImplBuilder.Nested transportConfig() { - checkThread(); + checkInEventLoop(); return new MqttClientTransportConfigImplBuilder.Nested<>(transportConfig, this::transportConfig); } @Override public @NotNull MqttClientTransportConfigImpl getTransportConfig() { - checkThread(); + checkInEventLoop(); return transportConfig; } @Override public @NotNull MqttClientReconnector connect(final @Nullable Mqtt5Connect connect) { - checkThread(); + checkInEventLoop(); this.connect = MqttChecks.connect(connect); return this; } @Override public @NotNull MqttConnectBuilder.Nested connectWith() { - checkThread(); + checkInEventLoop(); return new MqttConnectBuilder.Nested<>(connect, this::connect); } @Override public @NotNull MqttConnect getConnect() { - checkThread(); + checkInEventLoop(); return connect; } - private void checkThread() { + public void afterOnDisconnected() { + afterOnDisconnected = true; + } + + private void checkInEventLoop() { if (!eventLoop.inEventLoop()) { throw new IllegalStateException("MqttClientReconnector must be called from the eventLoop."); } } + + private void checkInOnDisconnected(final @NotNull String method) { + checkInEventLoop(); + if (afterOnDisconnected) { + throw new UnsupportedOperationException(method + " must only be called in onDisconnected."); + } + } } diff --git a/src/main/java/com/hivemq/client/internal/mqtt/lifecycle/mqtt3/Mqtt3ClientReconnectorView.java b/src/main/java/com/hivemq/client/internal/mqtt/lifecycle/mqtt3/Mqtt3ClientReconnectorView.java index d3d669d8d..5e59b0c89 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/lifecycle/mqtt3/Mqtt3ClientReconnectorView.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/lifecycle/mqtt3/Mqtt3ClientReconnectorView.java @@ -51,7 +51,7 @@ public class Mqtt3ClientReconnectorView implements Mqtt3ClientReconnector { } @Override - public @NotNull Mqtt3ClientReconnector reconnectWhen( + public @NotNull Mqtt3ClientReconnectorView reconnectWhen( final @Nullable CompletableFuture future, final @Nullable BiConsumer callback) { @@ -64,6 +64,28 @@ public boolean isReconnect() { return delegate.isReconnect(); } + @Override + public @NotNull Mqtt3ClientReconnectorView resubscribeIfSessionExpired(final boolean resubscribe) { + delegate.resubscribeIfSessionExpired(resubscribe); + return this; + } + + @Override + public boolean isResubscribeIfSessionExpired() { + return delegate.isResubscribeIfSessionExpired(); + } + + @Override + public @NotNull Mqtt3ClientReconnectorView republishIfSessionExpired(final boolean republish) { + delegate.republishIfSessionExpired(republish); + return this; + } + + @Override + public boolean isRepublishIfSessionExpired() { + return delegate.isRepublishIfSessionExpired(); + } + @Override public int getAttempts() { return delegate.getAttempts(); diff --git a/src/main/java/com/hivemq/client/internal/mqtt/message/subscribe/MqttSubscription.java b/src/main/java/com/hivemq/client/internal/mqtt/message/subscribe/MqttSubscription.java index e2fc3a1cc..ac890eaf9 100644 --- a/src/main/java/com/hivemq/client/internal/mqtt/message/subscribe/MqttSubscription.java +++ b/src/main/java/com/hivemq/client/internal/mqtt/message/subscribe/MqttSubscription.java @@ -111,4 +111,33 @@ public int hashCode() { result = 31 * result + Boolean.hashCode(retainAsPublished); return result; } + + public byte encodeSubscriptionOptions() { + byte subscriptionOptions = 0; + subscriptionOptions |= retainHandling.getCode() << 4; + if (retainAsPublished) { + subscriptionOptions |= 0b0000_1000; + } + if (noLocal) { + subscriptionOptions |= 0b0000_0100; + } + subscriptionOptions |= qos.getCode(); + return subscriptionOptions; + } + + public static @Nullable MqttQos decodeQos(final byte subscriptionOptions) { + return MqttQos.fromCode(subscriptionOptions & 0b0000_0011); + } + + public static boolean decodeNoLocal(final byte subscriptionOptions) { + return (subscriptionOptions & 0b0000_0100) != 0; + } + + public static @Nullable Mqtt5RetainHandling decodeRetainHandling(final byte subscriptionOptions) { + return Mqtt5RetainHandling.fromCode((subscriptionOptions & 0b0011_0000) >> 4); + } + + public static boolean decodeRetainAsPublished(final byte subscriptionOptions) { + return (subscriptionOptions & 0b0000_1000) != 0; + } } diff --git a/src/main/java/com/hivemq/client/internal/rx/CompletableFlow.java b/src/main/java/com/hivemq/client/internal/rx/CompletableFlow.java index 7d3c5ad86..c87ac0127 100644 --- a/src/main/java/com/hivemq/client/internal/rx/CompletableFlow.java +++ b/src/main/java/com/hivemq/client/internal/rx/CompletableFlow.java @@ -37,8 +37,8 @@ public void onComplete() { observer.onComplete(); } - public void onError(final @NotNull Throwable t) { - observer.onError(t); + public void onError(final @NotNull Throwable error) { + observer.onError(error); } @Override diff --git a/src/main/java/com/hivemq/client/internal/rx/operators/FlowableWithSingleCombine.java b/src/main/java/com/hivemq/client/internal/rx/operators/FlowableWithSingleCombine.java index ff6c1fd88..a8293bb13 100644 --- a/src/main/java/com/hivemq/client/internal/rx/operators/FlowableWithSingleCombine.java +++ b/src/main/java/com/hivemq/client/internal/rx/operators/FlowableWithSingleCombine.java @@ -30,7 +30,6 @@ import org.reactivestreams.Subscription; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; /** * @author Silvio Giebl @@ -50,10 +49,14 @@ protected void subscribeActual(final @NotNull Subscriber subscri private static class CombineSubscriber implements FlowableWithSingleSubscriber, Subscription { + private static final @NotNull Object COMPLETE = new Object(); + private final @NotNull Subscriber subscriber; private @Nullable Subscription subscription; private final @NotNull AtomicLong requested = new AtomicLong(); - private final @NotNull AtomicReference<@Nullable Object> queued = new AtomicReference<>(); + + private @Nullable Object queued; + private @Nullable Object done; CombineSubscriber(final @NotNull Subscriber subscriber) { this.subscriber = subscriber; @@ -77,30 +80,36 @@ public void onNext(final @NotNull F f) { private void next(final @NotNull Object next) { if (requested.get() == 0) { - queued.set(next); - if ((requested.get() != 0) && (queued.getAndSet(null)) != null) { - BackpressureHelper.produced(requested, 1); - subscriber.onNext(next); + synchronized (this) { + if (requested.get() == 0) { + queued = next; + return; + } } - } else { - BackpressureHelper.produced(requested, 1); - subscriber.onNext(next); } + BackpressureHelper.produced(requested, 1); + subscriber.onNext(next); } @Override - public void onError(final @NotNull Throwable throwable) { - final Object next = queued.get(); - if ((next == null) || !queued.compareAndSet(next, new TerminalElement(next, throwable))) { - subscriber.onError(throwable); + public void onComplete() { + synchronized (this) { + if (queued != null) { + done = COMPLETE; + } else { + subscriber.onComplete(); + } } } @Override - public void onComplete() { - final Object next = queued.get(); - if ((next == null) || !queued.compareAndSet(next, new TerminalElement(next, null))) { - subscriber.onComplete(); + public void onError(final @NotNull Throwable error) { + synchronized (this) { + if (queued != null) { + done = error; + } else { + subscriber.onError(error); + } } } @@ -108,26 +117,30 @@ public void onComplete() { public void request(long n) { assert subscription != null; if (n > 0) { - if (requested.get() == 0) { - final Object next = queued.getAndSet(null); - if (next != null) { - if (next instanceof TerminalElement) { - final TerminalElement terminalElement = (TerminalElement) next; - subscriber.onNext(terminalElement.element); - if (terminalElement.error == null) { - subscriber.onComplete(); - } else { - subscriber.onError(terminalElement.error); - } - return; - } else { - subscriber.onNext(next); + if (BackpressureHelper.add(requested, n) == 0) { + synchronized (this) { + final Object queued = this.queued; + if (queued != null) { + this.queued = null; + BackpressureHelper.produced(requested, 1); + subscriber.onNext(queued); n--; + final Object done = this.done; + if (done != null) { + this.done = null; + if (done instanceof Throwable) { + subscriber.onError((Throwable) done); + } else { + subscriber.onComplete(); + } + return; + } + } + if (n > 0) { + subscription.request(n); } } - } - if (n > 0) { - BackpressureHelper.add(requested, n); + } else { subscription.request(n); } } @@ -247,15 +260,4 @@ private static class SingleElement { this.element = element; } } - - private static class TerminalElement { - - final @NotNull Object element; - final @Nullable Throwable error; - - TerminalElement(final @NotNull Object element, final @Nullable Throwable error) { - this.element = element; - this.error = error; - } - } } diff --git a/src/main/java/com/hivemq/client/internal/util/collections/NodeList.java b/src/main/java/com/hivemq/client/internal/util/collections/NodeList.java index 7a24dab00..85652b0ea 100644 --- a/src/main/java/com/hivemq/client/internal/util/collections/NodeList.java +++ b/src/main/java/com/hivemq/client/internal/util/collections/NodeList.java @@ -58,6 +58,21 @@ public void add(final @NotNull N node) { size++; } + public void addFirst(final @NotNull N node) { + assert node.prev == null; + assert node.next == null; + + final N first = this.first; + if (first == null) { + this.first = last = node; + } else { + first.prev = node; + node.next = first; + this.first = node; + } + size++; + } + public void remove(final @NotNull N node) { assert (node.prev != null) || (node == first); assert (node.next != null) || (node == last); diff --git a/src/main/java/com/hivemq/client/mqtt/lifecycle/MqttClientReconnector.java b/src/main/java/com/hivemq/client/mqtt/lifecycle/MqttClientReconnector.java index 852aaa15b..5c9ec36a5 100644 --- a/src/main/java/com/hivemq/client/mqtt/lifecycle/MqttClientReconnector.java +++ b/src/main/java/com/hivemq/client/mqtt/lifecycle/MqttClientReconnector.java @@ -32,10 +32,10 @@ * A reconnector is supplied by a {@link MqttClientDisconnectedContext} and can be used for reconnecting. *

* The client will reconnect only if at least one of the methods {@link #reconnect(boolean)} or {@link - * #reconnectWhen(CompletableFuture, BiConsumer)} are called. + * #reconnectWhen(CompletableFuture, BiConsumer)} is called. *

- * All methods must only be called in {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)} - * or in the callback of the {@link #reconnectWhen(CompletableFuture, BiConsumer)} method. + * All methods must only be called in {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)}. + * Some methods can also be called in the callback supplied to {@link #reconnectWhen(CompletableFuture, BiConsumer)}. * * @author Silvio Giebl * @since 1.1 @@ -43,6 +43,31 @@ @DoNotImplement public interface MqttClientReconnector { + /** + * If reconnect is enabled by default. + * + * @since 1.2 + */ + boolean DEFAULT_RECONNECT = false; + /** + * If resubscribe when the session expired before the client reconnected successfully is enabled by defaultt. + * + * @since 1.2 + */ + boolean DEFAULT_RESUBSCRIBE_IF_SESSION_EXPIRED = true; + /** + * If republish when the session expired before the client reconnected successfully is enabled by default. + * + * @since 1.2 + */ + boolean DEFAULT_REPUBLISH_IF_SESSION_EXPIRED = false; + /** + * Default delay in milliseconds the client will wait for before trying to reconnect. + * + * @since 1.2 + */ + long DEFAULT_DELAY_MS = 0; + /** * @return the number of failed connection attempts. */ @@ -51,7 +76,7 @@ public interface MqttClientReconnector { /** * Instructs the client to reconnect or not. * - * @param reconnect whether to reconnect or not. + * @param reconnect whether to reconnect. * @return this reconnector. */ @NotNull MqttClientReconnector reconnect(boolean reconnect); @@ -59,42 +84,103 @@ public interface MqttClientReconnector { /** * Instructs the client to reconnect after a future completes. *

- * If also a {@link #delay(long, TimeUnit) delay} is supplied, the client will reconnect after both are complete. + * If additionally a {@link #delay(long, TimeUnit) delay} is supplied, the client will reconnect after both are + * complete. + *

+ * This method must only be called in {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)} + * and not in the supplied callback. * * @param future the client will reconnect only after the future completes. * @param callback the callback that will be called after the future completes and before the client will reconnect. * It can be used to set new connect properties (e.g. credentials). * @param the result type of the future. * @return this reconnector. + * @throws UnsupportedOperationException if called outside of {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)}. */ @NotNull MqttClientReconnector reconnectWhen( @NotNull CompletableFuture future, @Nullable BiConsumer callback); /** - * @return whether the client will reconnect or not. + * @return whether the client will reconnect. */ boolean isReconnect(); /** - * Sets a delay which the client will wait before trying to reconnect. + * Instructs the client to automatically restore its subscriptions when the session expired before it reconnected + * successfully. + *

+ * When the client reconnected successfully and its session is still present, the server still knows its + * subscriptions and they do not need to be restored. + *

+ * This setting only has effect if the client will reconnect (at least one of the methods {@link + * #reconnect(boolean)} or {@link #reconnectWhen(CompletableFuture, BiConsumer)} is called). + *

+ * This method must only be called in {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)} + * and not in the callback supplied to {@link #reconnectWhen(CompletableFuture, BiConsumer)}. + * + * @param resubscribe whether to resubscribe when the session expired before the client reconnected successfully. + * @return this reconnector. + * @throws UnsupportedOperationException if called outside of {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)}. + * @since 1.2 + */ + @NotNull MqttClientReconnector resubscribeIfSessionExpired(boolean resubscribe); + + /** + * @return whether the client will resubscribe when the session expired before it reconnected successfully. + * @since 1.2 + */ + boolean isResubscribeIfSessionExpired(); + + /** + * Instructs the client to queue pending Publish messages and automatically publish them even if the session expired + * before reconnected. + *

+ * When the client reconnected successfully and its session is still present, the client will always queue pending + * Publish messages and automatically publish them to ensure the QoS guarantees. + *

+ * This setting only has effect if the client will reconnect (at least one of the methods {@link + * #reconnect(boolean)} or {@link #reconnectWhen(CompletableFuture, BiConsumer)} is called). + *

+ * This method must only be called in {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)} + * and not in the callback supplied to {@link #reconnectWhen(CompletableFuture, BiConsumer)}. + * + * @param republish whether to republish when the session expired before the client reconnected successfully. + * @return this reconnector. + * @throws UnsupportedOperationException if called outside of {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)}. + * @since 1.2 + */ + @NotNull MqttClientReconnector republishIfSessionExpired(boolean republish); + + /** + * @return whether the client will republish when the session expired before it reconnected successfully. + * @since 1.2 + */ + boolean isRepublishIfSessionExpired(); + + /** + * Sets a delay the client will wait for before trying to reconnect. + *

+ * This setting only has effect if the client will reconnect (at least one of the methods {@link + * #reconnect(boolean)} or {@link #reconnectWhen(CompletableFuture, BiConsumer)} is called). *

- * The client will reconnect after the delay only if at least one of the methods {@link #reconnect(boolean)} or - * {@link #reconnectWhen(CompletableFuture, BiConsumer)} are called. + * If additionally a {@link #reconnectWhen(CompletableFuture, BiConsumer) future} is supplied, the client will + * reconnect after both are complete. *

- * If also a {@link #reconnectWhen(CompletableFuture, BiConsumer) future} is supplied, the client will reconnect - * after both are complete. + * This method must only be called in {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)} + * and not in the callback supplied to {@link #reconnectWhen(CompletableFuture, BiConsumer)}. * * @param delay delay which the client will wait before trying to reconnect. * @param timeUnit the time unit of the delay. * @return this reconnector. + * @throws UnsupportedOperationException if called outside of {@link MqttClientDisconnectedListener#onDisconnected(MqttClientDisconnectedContext)}. */ @NotNull MqttClientReconnector delay(long delay, @NotNull TimeUnit timeUnit); /** - * Returns the currently set delay which the client will wait before trying to reconnect. + * Returns the currently set delay the client will wait for before trying to reconnect. *

* If the {@link #delay(long, TimeUnit)} method has not been called before (including previous {@link - * MqttClientDisconnectedListener MqttClientDisconnectedListeners}) it will be 0. + * MqttClientDisconnectedListener MqttClientDisconnectedListeners}) it will be {@link #DEFAULT_DELAY_MS}. * * @param timeUnit the time unit of the returned delay. * @return the delay in the given time unit. diff --git a/src/main/java/com/hivemq/client/mqtt/mqtt3/lifecycle/Mqtt3ClientReconnector.java b/src/main/java/com/hivemq/client/mqtt/mqtt3/lifecycle/Mqtt3ClientReconnector.java index 97e44b4af..5a3a72f6e 100644 --- a/src/main/java/com/hivemq/client/mqtt/mqtt3/lifecycle/Mqtt3ClientReconnector.java +++ b/src/main/java/com/hivemq/client/mqtt/mqtt3/lifecycle/Mqtt3ClientReconnector.java @@ -48,6 +48,12 @@ public interface Mqtt3ClientReconnector extends MqttClientReconnector { @NotNull Mqtt3ClientReconnector reconnectWhen( @NotNull CompletableFuture future, @Nullable BiConsumer callback); + @Override + @NotNull Mqtt3ClientReconnector resubscribeIfSessionExpired(boolean resubscribe); + + @Override + @NotNull Mqtt3ClientReconnector republishIfSessionExpired(boolean republish); + @Override @NotNull Mqtt3ClientReconnector delay(long delay, @NotNull TimeUnit timeUnit); diff --git a/src/main/java/com/hivemq/client/mqtt/mqtt5/lifecycle/Mqtt5ClientReconnector.java b/src/main/java/com/hivemq/client/mqtt/mqtt5/lifecycle/Mqtt5ClientReconnector.java index 032aa4eb6..116c28a88 100644 --- a/src/main/java/com/hivemq/client/mqtt/mqtt5/lifecycle/Mqtt5ClientReconnector.java +++ b/src/main/java/com/hivemq/client/mqtt/mqtt5/lifecycle/Mqtt5ClientReconnector.java @@ -48,6 +48,12 @@ public interface Mqtt5ClientReconnector extends MqttClientReconnector { @NotNull Mqtt5ClientReconnector reconnectWhen( @NotNull CompletableFuture future, @Nullable BiConsumer callback); + @Override + @NotNull Mqtt5ClientReconnector resubscribeIfSessionExpired(boolean resubscribe); + + @Override + @NotNull Mqtt5ClientReconnector republishIfSessionExpired(boolean republish); + @Override @NotNull Mqtt5ClientReconnector delay(long delay, @NotNull TimeUnit timeUnit); diff --git a/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowTreeTest.java b/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowTreeTest.java similarity index 85% rename from src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowTreeTest.java rename to src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowTreeTest.java index a5fb107a8..1360b383b 100644 --- a/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowTreeTest.java +++ b/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowTreeTest.java @@ -17,8 +17,12 @@ package com.hivemq.client.internal.mqtt.handler.publish.incoming; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.hivemq.client.internal.mqtt.datatypes.MqttTopicFilterImpl; import com.hivemq.client.internal.mqtt.datatypes.MqttTopicImpl; +import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscription; +import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscriptionBuilder; import org.jetbrains.annotations.NotNull; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -28,10 +32,10 @@ /** * @author Silvio Giebl */ -class MqttSubscriptionFlowTreeTest extends MqttSubscriptionFlowsTest { +class MqttSubscribedPublishFlowTreeTest extends MqttSubscribedPublishFlowsTest { - MqttSubscriptionFlowTreeTest() { - super(MqttSubscriptionFlowTree::new); + MqttSubscribedPublishFlowTreeTest() { + super(MqttSubscribedPublishFlowTree::new); } @ParameterizedTest @@ -96,9 +100,15 @@ void branching_compaction( final @NotNull String filter3, final @NotNull String topic1, final @NotNull String topic2, final @NotNull String topic3) { - flows.subscribe(MqttTopicFilterImpl.of(filter1), null); - flows.subscribe(MqttTopicFilterImpl.of(filter2), null); - flows.subscribe(MqttTopicFilterImpl.of(filter3), null); + final MqttSubscription subscription1 = new MqttSubscriptionBuilder.Default().topicFilter(filter1).build(); + final MqttSubscription subscription2 = new MqttSubscriptionBuilder.Default().topicFilter(filter2).build(); + final MqttSubscription subscription3 = new MqttSubscriptionBuilder.Default().topicFilter(filter3).build(); + flows.subscribe(subscription1, 1, null); + flows.subscribe(subscription2, 2, null); + flows.subscribe(subscription3, 3, null); + flows.suback(subscription1.getTopicFilter(), 1, false); + flows.suback(subscription2.getTopicFilter(), 2, false); + flows.suback(subscription3.getTopicFilter(), 3, false); final MqttMatchingPublishFlows matching1 = new MqttMatchingPublishFlows(); flows.findMatching(MqttTopicImpl.of(topic1), matching1); @@ -110,16 +120,19 @@ void branching_compaction( flows.findMatching(MqttTopicImpl.of(topic3), matching3); assertTrue(matching3.subscriptionFound); + assertEquals(ImmutableMap.of(1, ImmutableList.of(subscription1), 2, ImmutableList.of(subscription2), 3, + ImmutableList.of(subscription3)), flows.getSubscriptions()); + switch (compactOperation) { case "unsubscribe": - flows.unsubscribe(MqttTopicFilterImpl.of(filter1), null); - flows.unsubscribe(MqttTopicFilterImpl.of(filter2), null); - flows.unsubscribe(MqttTopicFilterImpl.of(filter3), null); + flows.unsubscribe(MqttTopicFilterImpl.of(filter1)); + flows.unsubscribe(MqttTopicFilterImpl.of(filter2)); + flows.unsubscribe(MqttTopicFilterImpl.of(filter3)); break; case "remove": - flows.remove(MqttTopicFilterImpl.of(filter1), null); - flows.remove(MqttTopicFilterImpl.of(filter2), null); - flows.remove(MqttTopicFilterImpl.of(filter3), null); + flows.suback(MqttTopicFilterImpl.of(filter1), 1, true); + flows.suback(MqttTopicFilterImpl.of(filter2), 2, true); + flows.suback(MqttTopicFilterImpl.of(filter3), 3, true); break; default: fail(); @@ -134,5 +147,7 @@ void branching_compaction( final MqttMatchingPublishFlows matching6 = new MqttMatchingPublishFlows(); flows.findMatching(MqttTopicImpl.of(topic3), matching6); assertFalse(matching6.subscriptionFound); + + assertTrue(flows.getSubscriptions().isEmpty()); } } diff --git a/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowsTest.java b/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowsTest.java new file mode 100644 index 000000000..a37897f61 --- /dev/null +++ b/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowsTest.java @@ -0,0 +1,696 @@ +/* + * Copyright 2018 dc-square and the HiveMQ MQTT Client Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.hivemq.client.internal.mqtt.handler.publish.incoming; + +import com.google.common.collect.ImmutableSet; +import com.hivemq.client.internal.mqtt.datatypes.MqttTopicFilterImpl; +import com.hivemq.client.internal.mqtt.datatypes.MqttTopicImpl; +import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscription; +import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscriptionBuilder; +import com.hivemq.client.internal.util.collections.HandleList; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.converter.ArgumentConversionException; +import org.junit.jupiter.params.converter.ConvertWith; +import org.junit.jupiter.params.converter.SimpleArgumentConverter; +import org.junit.jupiter.params.provider.CsvSource; + +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +/** + * @author Silvio Giebl + */ +abstract class MqttSubscribedPublishFlowsTest { + + public static class CsvToArray extends SimpleArgumentConverter { + + @Override + protected @NotNull Object convert(final @NotNull Object source, final @NotNull Class targetType) + throws ArgumentConversionException { + final String s = (String) source; + return s.split("\\s*;\\s*"); + } + } + + private final @NotNull Supplier flowsSupplier; + @NotNull MqttSubscribedPublishFlows flows; + + MqttSubscribedPublishFlowsTest(final @NotNull Supplier flowsSupplier) { + this.flowsSupplier = flowsSupplier; + } + + @BeforeEach + void setUp() { + flows = flowsSupplier.get(); + } + + @ParameterizedTest + @CsvSource({ + "a, a; +; a/#; +/#; #, true", + "a, a; +; a/#; +/#; #, false", + "a/b, a/b; a/+; +/b; +/+; a/b/#; a/+/#; +/b/#; +/+/#; a/#; #, true", + "a/b, a/b; a/+; +/b; +/+; a/b/#; a/+/#; +/b/#; +/+/#; a/#; #, false", + "/, /; +/+; +/; /+; +/#; /#; #, true", + "/, /; +/+; +/; /+; +/#; /#; #, false" + }) + void subscribe_matchingTopicFilters_doMatch( + final @NotNull String topic, @ConvertWith(CsvToArray.class) final @NotNull String[] matchingTopicFilters, + final boolean acknowledge) { + + final MqttSubscribedPublishFlow[] matchingFlows = new MqttSubscribedPublishFlow[matchingTopicFilters.length]; + for (int i = 0; i < matchingTopicFilters.length; i++) { + final MqttSubscription subscription = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilters[i]).build(); + final MqttSubscribedPublishFlow flow = mockSubscriptionFlow(matchingTopicFilters[i]); + flows.subscribe(subscription, i, flow); + assertEquals(ImmutableSet.of(subscription.getTopicFilter()), toSet(flow.getTopicFilters())); + matchingFlows[i] = flow; + + if (acknowledge) { + flows.suback(subscription.getTopicFilter(), i, false); + } + } + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertFalse(matching.isEmpty()); + assertEquals(ImmutableSet.copyOf(matchingFlows), toSet(matching)); + } + + @ParameterizedTest + @CsvSource({ + "a, a; +; a/#; +/#; #, true", + "a, a; +; a/#; +/#; #, false", + "a/b, a/b; a/+; +/b; +/+; a/b/#; a/+/#; +/b/#; +/+/#; a/#; #, true", + "a/b, a/b; a/+; +/b; +/+; a/b/#; a/+/#; +/b/#; +/+/#; a/#; #, false", + "/, /; +/+; +/; /+; +/#; /#; #, true", + "/, /; +/+; +/; /+; +/#; /#; #, false" + }) + void subscribe_matchingTopicFilters_doMatch_noFlow( + final @NotNull String topic, @ConvertWith(CsvToArray.class) final @NotNull String[] matchingTopicFilters, + final boolean acknowledge) { + + for (int i = 0; i < matchingTopicFilters.length; i++) { + final MqttSubscription subscription = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilters[i]).build(); + flows.subscribe(subscription, i, null); + + if (acknowledge) { + flows.suback(subscription.getTopicFilter(), i, false); + } + } + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @ParameterizedTest + @CsvSource({ + "a, /a; b; a/b; a/+; +/a; +/+; a/b/#; /#; / ", + "a/b, /a/b; a/c; c/b; a/b/c; +/a/b; a/+/b; a/b/+; a/b/c/#; +", + "/, //; a/b; a/; /a; + " + }) + void subscribe_nonMatchingTopicFilters_doNotMatch( + final @NotNull String topic, + @ConvertWith(CsvToArray.class) final @NotNull String[] notMatchingTopicFilters) { + + for (int i = 0; i < notMatchingTopicFilters.length; i++) { + final MqttSubscription subscription = + new MqttSubscriptionBuilder.Default().topicFilter(notMatchingTopicFilters[i]).build(); + final MqttSubscribedPublishFlow flow = mockSubscriptionFlow(notMatchingTopicFilters[i]); + flows.subscribe(subscription, i, flow); + assertEquals(ImmutableSet.of(subscription.getTopicFilter()), toSet(flow.getTopicFilters())); + } + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertFalse(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @ParameterizedTest + @CsvSource({ + "a, /a; b; a/b; a/+; +/a; +/+; a/b/#; /#; / ", + "a/b, /a/b; a/c; c/b; a/b/c; +/a/b; a/+/b; a/b/+; a/b/c/#; +", + "/, //; a/b; a/; /a; + " + }) + void subscribe_nonMatchingTopicFilters_doNotMatch_noFlow( + final @NotNull String topic, + @ConvertWith(CsvToArray.class) final @NotNull String[] notMatchingTopicFilters) { + + for (int i = 0; i < notMatchingTopicFilters.length; i++) { + final MqttSubscription subscription = + new MqttSubscriptionBuilder.Default().topicFilter(notMatchingTopicFilters[i]).build(); + flows.subscribe(subscription, i, null); + } + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertFalse(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @ParameterizedTest + @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) + void unsubscribe_matchingTopicFilters_doNoLongerMatch( + final @NotNull String topic, final @NotNull String matchingTopicFilter) { + + final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + flows.subscribe(subscription1, 1, flow1); + flows.subscribe(subscription2, 2, flow2); + flows.suback(subscription1.getTopicFilter(), 1, false); + flows.suback(subscription2.getTopicFilter(), 2, false); + + flows.unsubscribe(MqttTopicFilterImpl.of(matchingTopicFilter)); + assertTrue(flow1.getTopicFilters().isEmpty()); + assertTrue(flow2.getTopicFilters().isEmpty()); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertFalse(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @ParameterizedTest + @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) + void unsubscribe_matchingTopicFilters_doNoLongerMatch_noFlow( + final @NotNull String topic, final @NotNull String matchingTopicFilter) { + + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + flows.subscribe(subscription1, 1, null); + flows.subscribe(subscription2, 2, null); + flows.suback(subscription1.getTopicFilter(), 1, false); + flows.suback(subscription2.getTopicFilter(), 2, false); + + flows.unsubscribe(MqttTopicFilterImpl.of(matchingTopicFilter)); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertFalse(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @ParameterizedTest + @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) + void unsubscribe_matchingTopicFilters_notAcknowledged_doStillMatch( + final @NotNull String topic, final @NotNull String matchingTopicFilter) { + + final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + flows.subscribe(subscription1, 1, flow1); + flows.subscribe(subscription2, 2, flow2); + + flows.unsubscribe(MqttTopicFilterImpl.of(matchingTopicFilter)); + assertEquals(ImmutableSet.of(subscription1.getTopicFilter()), toSet(flow1.getTopicFilters())); + assertEquals(ImmutableSet.of(subscription2.getTopicFilter()), toSet(flow2.getTopicFilters())); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertEquals(ImmutableSet.of(flow1, flow2), toSet(matching)); + } + + @ParameterizedTest + @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) + void unsubscribe_matchingTopicFilters_notAcknowledged_doStillMatch_noFlow( + final @NotNull String topic, final @NotNull String matchingTopicFilter) { + + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + flows.subscribe(subscription1, 1, null); + flows.subscribe(subscription2, 2, null); + + flows.unsubscribe(MqttTopicFilterImpl.of(matchingTopicFilter)); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @ParameterizedTest + @CsvSource({"a, a, b", "a, a, a/b", "a/b, a/b, a/c"}) + void unsubscribe_nonMatchingTopicFilters_othersStillMatch( + final @NotNull String topic, final @NotNull String matchingTopicFilter, + final @NotNull String notMatchingTopicFilter) { + + final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(notMatchingTopicFilter); + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(notMatchingTopicFilter).build(); + flows.subscribe(subscription1, 1, flow1); + flows.subscribe(subscription2, 2, flow2); + flows.suback(subscription1.getTopicFilter(), 1, false); + flows.suback(subscription2.getTopicFilter(), 2, false); + + flows.unsubscribe(MqttTopicFilterImpl.of(notMatchingTopicFilter)); + assertFalse(flow1.getTopicFilters().isEmpty()); + assertTrue(flow2.getTopicFilters().isEmpty()); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertFalse(matching.isEmpty()); + assertEquals(ImmutableSet.of(flow1), toSet(matching)); + } + + @ParameterizedTest + @CsvSource({"a, a, b", "a, a, a/b", "a/b, a/b, a/c"}) + void unsubscribe_nonMatchingTopicFilters_othersStillMatch_noFlow( + final @NotNull String topic, final @NotNull String matchingTopicFilter, + final @NotNull String notMatchingTopicFilter) { + + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(notMatchingTopicFilter).build(); + flows.subscribe(subscription1, 1, null); + flows.subscribe(subscription2, 2, null); + flows.suback(subscription1.getTopicFilter(), 1, false); + flows.suback(subscription2.getTopicFilter(), 2, false); + + flows.unsubscribe(MqttTopicFilterImpl.of(notMatchingTopicFilter)); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @ParameterizedTest + @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) + void cancel_doNoLongerMatch(final @NotNull String topic, final @NotNull String matchingTopicFilter) { + final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + flows.subscribe(subscription1, 1, flow1); + flows.subscribe(subscription2, 2, flow2); + + flows.cancel(flow1); + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertFalse(matching.isEmpty()); + assertEquals(ImmutableSet.of(flow2), toSet(matching)); + + flows.cancel(flow2); + final MqttMatchingPublishFlows matching2 = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching2); + assertTrue(matching2.subscriptionFound); + assertTrue(matching2.isEmpty()); + } + + @ParameterizedTest + @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) + void cancel_notPresentFlows_areIgnored(final @NotNull String topic, final @NotNull String matchingTopicFilter) { + final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscription subscription = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + flows.subscribe(subscription, 1, flow1); + + flows.cancel(flow2); + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertFalse(matching.isEmpty()); + assertEquals(ImmutableSet.of(flow1), toSet(matching)); + } + + @Test + void cancel_partiallyUnsubscribedFlow() { + final MqttSubscribedPublishFlow flow = mockSubscriptionFlow("test/topic(2)"); + final MqttSubscription subscription1 = new MqttSubscriptionBuilder.Default().topicFilter("test/topic").build(); + final MqttSubscription subscription2 = new MqttSubscriptionBuilder.Default().topicFilter("test/topic2").build(); + flows.subscribe(subscription1, 1, flow); + flows.subscribe(subscription2, 1, flow); + flows.suback(subscription1.getTopicFilter(), 1, false); + flows.suback(subscription2.getTopicFilter(), 1, false); + + flows.unsubscribe(MqttTopicFilterImpl.of("test/topic")); + flows.cancel(flow); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of("test/topic"), matching); + assertFalse(matching.subscriptionFound); + flows.findMatching(MqttTopicImpl.of("test/topic2"), matching); + assertTrue(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @ParameterizedTest + @CsvSource({ + "1/a, 1/a, 2/a, 2/a", "1/a, 1/+, 2/a, 2/+", "1/a, 1/#, 2/a, 2/#", "1/a/b, 1/a/b, 2/a/b, 2/a/b", + "1/a/b, 1/a/+, 2/a/b, 2/a/+", "1/a/b, 1/+/b, 2/a/b, 2/+/b", "1/a/b, 1/+/+, 2/a/b, 2/+/+", + "1/a/b, 1/+/#, 2/a/b, 2/+/#", "1/a/b, 1/#, 2/a/b, 2/#" + }) + void suback_error( + final @NotNull String topic, final @NotNull String matchingTopicFilter, final @NotNull String topic2, + final @NotNull String matchingTopicFilter2) { + + final MqttSubscribedPublishFlow flow = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter2).build(); + flows.subscribe(subscription1, 1, flow); + flows.subscribe(subscription2, 2, flow); + assertEquals( + ImmutableSet.of(subscription1.getTopicFilter(), subscription2.getTopicFilter()), + toSet(flow.getTopicFilters())); + + flows.suback(subscription1.getTopicFilter(), 1, true); + assertEquals(ImmutableSet.of(subscription2.getTopicFilter()), toSet(flow.getTopicFilters())); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertFalse(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + + final MqttMatchingPublishFlows matching2 = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic2), matching2); + assertTrue(matching2.subscriptionFound); + assertFalse(matching2.isEmpty()); + assertEquals(ImmutableSet.of(flow), toSet(matching2)); + } + + @ParameterizedTest + @CsvSource({ + "1/a, 1/a, 2/a, 2/a", "1/a, 1/+, 2/a, 2/+", "1/a, 1/#, 2/a, 2/#", "1/a/b, 1/a/b, 2/a/b, 2/a/b", + "1/a/b, 1/a/+, 2/a/b, 2/a/+", "1/a/b, 1/+/b, 2/a/b, 2/+/b", "1/a/b, 1/+/+, 2/a/b, 2/+/+", + "1/a/b, 1/+/#, 2/a/b, 2/+/#", "1/a/b, 1/#, 2/a/b, 2/#" + }) + void suback_error_noFlow( + final @NotNull String topic, final @NotNull String matchingTopicFilter, final @NotNull String topic2, + final @NotNull String matchingTopicFilter2) { + + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter2).build(); + flows.subscribe(subscription1, 1, null); + flows.subscribe(subscription2, 2, null); + + flows.suback(subscription1.getTopicFilter(), 1, true); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertFalse(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + + final MqttMatchingPublishFlows matching2 = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic2), matching2); + assertTrue(matching2.subscriptionFound); + assertTrue(matching2.isEmpty()); + } + + @ParameterizedTest + @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) + void suback_error_doesNotUnsubscribe(final @NotNull String topic, final @NotNull String matchingTopicFilter) { + final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(matchingTopicFilter); + final MqttSubscription subscription = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + flows.subscribe(subscription, 1, flow1); + flows.subscribe(subscription, 2, flow2); + assertEquals(ImmutableSet.of(subscription.getTopicFilter()), toSet(flow1.getTopicFilters())); + assertEquals(ImmutableSet.of(subscription.getTopicFilter()), toSet(flow2.getTopicFilters())); + + flows.suback(subscription.getTopicFilter(), 1, true); + assertTrue(flow1.getTopicFilters().isEmpty()); + assertEquals(ImmutableSet.of(subscription.getTopicFilter()), toSet(flow2.getTopicFilters())); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertFalse(matching.isEmpty()); + assertEquals(ImmutableSet.of(flow2), toSet(matching)); + } + + @ParameterizedTest + @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) + void suback_error_doesNotUnsubscribe_noFlow( + final @NotNull String topic, final @NotNull String matchingTopicFilter) { + + final MqttSubscription subscription = + new MqttSubscriptionBuilder.Default().topicFilter(matchingTopicFilter).build(); + flows.subscribe(subscription, 1, null); + flows.subscribe(subscription, 2, null); + + flows.suback(subscription.getTopicFilter(), 1, true); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertTrue(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @ParameterizedTest + @CsvSource({"a/b, a/b/c, +/b/c, +/+/+", "a/b/c/d, a/b/c/d/e, +/b//d/ec, +/+/+/+/+"}) + void findMatching_matchingMultipleButNotAllLevels( + final @NotNull String topic, final @NotNull String filter1, final @NotNull String filter2, + final @NotNull String filter3) { + + final MqttSubscription subscription1 = new MqttSubscriptionBuilder.Default().topicFilter(filter1).build(); + final MqttSubscription subscription2 = new MqttSubscriptionBuilder.Default().topicFilter(filter2).build(); + final MqttSubscription subscription3 = new MqttSubscriptionBuilder.Default().topicFilter(filter3).build(); + flows.subscribe(subscription1, 1, null); + flows.subscribe(subscription2, 2, null); + flows.subscribe(subscription3, 3, null); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of(topic), matching); + assertFalse(matching.subscriptionFound); + assertTrue(matching.isEmpty()); + } + + @Test + void clear() { + final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow("test/topic/filter"); + final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow("test2/topic/filter"); + final MqttSubscribedPublishFlow flow3 = mockSubscriptionFlow("test/topic2/filter"); + final MqttSubscribedPublishFlow flow4 = mockSubscriptionFlow("test/topic/filter2"); + final MqttSubscribedPublishFlow flow5 = mockSubscriptionFlow("+/topic"); + final MqttSubscribedPublishFlow flow6 = mockSubscriptionFlow("topic/#"); + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(flow1.toString()).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(flow2.toString()).build(); + final MqttSubscription subscription3 = + new MqttSubscriptionBuilder.Default().topicFilter(flow3.toString()).build(); + final MqttSubscription subscription4 = + new MqttSubscriptionBuilder.Default().topicFilter(flow4.toString()).build(); + final MqttSubscription subscription5 = + new MqttSubscriptionBuilder.Default().topicFilter(flow5.toString()).build(); + final MqttSubscription subscription6 = + new MqttSubscriptionBuilder.Default().topicFilter(flow6.toString()).build(); + flows.subscribe(subscription1, 1, flow1); + flows.subscribe(subscription2, 2, flow2); + flows.subscribe(subscription3, 3, flow3); + flows.subscribe(subscription4, 4, flow4); + flows.subscribe(subscription5, 5, flow5); + flows.subscribe(subscription6, 6, flow6); + flows.suback(subscription1.getTopicFilter(), 1, false); + flows.suback(subscription2.getTopicFilter(), 2, false); + flows.suback(subscription3.getTopicFilter(), 3, false); + flows.suback(subscription4.getTopicFilter(), 4, false); + flows.suback(subscription5.getTopicFilter(), 5, false); + flows.suback(subscription6.getTopicFilter(), 6, false); + + final Exception cause = new Exception("test"); + flows.clear(cause); + verify(flow1).onError(cause); + verify(flow2).onError(cause); + verify(flow3).onError(cause); + verify(flow4).onError(cause); + verify(flow5).onError(cause); + verify(flow6).onError(cause); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of("test/topic"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test2/topic/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic2/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic/filter2"), matching); + flows.findMatching(MqttTopicImpl.of("abc/topic"), matching); + flows.findMatching(MqttTopicImpl.of("topic/abc"), matching); + assertFalse(matching.subscriptionFound); + } + + @Test + void clear_noFlow() { + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter("test/topic/filter").build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter("test2/topic/filter").build(); + final MqttSubscription subscription3 = + new MqttSubscriptionBuilder.Default().topicFilter("test/topic2/filter").build(); + final MqttSubscription subscription4 = + new MqttSubscriptionBuilder.Default().topicFilter("test/topic/filter2").build(); + final MqttSubscription subscription5 = new MqttSubscriptionBuilder.Default().topicFilter("+/topic").build(); + final MqttSubscription subscription6 = new MqttSubscriptionBuilder.Default().topicFilter("topic/#").build(); + flows.subscribe(subscription1, 1, null); + flows.subscribe(subscription2, 2, null); + flows.subscribe(subscription3, 3, null); + flows.subscribe(subscription4, 4, null); + flows.subscribe(subscription5, 5, null); + flows.subscribe(subscription6, 6, null); + flows.suback(subscription1.getTopicFilter(), 1, false); + flows.suback(subscription2.getTopicFilter(), 2, false); + flows.suback(subscription3.getTopicFilter(), 3, false); + flows.suback(subscription4.getTopicFilter(), 4, false); + flows.suback(subscription5.getTopicFilter(), 5, false); + flows.suback(subscription6.getTopicFilter(), 6, false); + + final Exception cause = new Exception("test"); + flows.clear(cause); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of("test/topic"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test2/topic/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic2/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic/filter2"), matching); + flows.findMatching(MqttTopicImpl.of("abc/topic"), matching); + flows.findMatching(MqttTopicImpl.of("topic/abc"), matching); + assertFalse(matching.subscriptionFound); + } + + @Test + void clear_notAcknowledged_doesNotErrorFlows() { + final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow("test/topic/filter"); + final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow("test2/topic/filter"); + final MqttSubscribedPublishFlow flow3 = mockSubscriptionFlow("test/topic2/filter"); + final MqttSubscribedPublishFlow flow4 = mockSubscriptionFlow("test/topic/filter2"); + final MqttSubscribedPublishFlow flow5 = mockSubscriptionFlow("+/topic"); + final MqttSubscribedPublishFlow flow6 = mockSubscriptionFlow("topic/#"); + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter(flow1.toString()).build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter(flow2.toString()).build(); + final MqttSubscription subscription3 = + new MqttSubscriptionBuilder.Default().topicFilter(flow3.toString()).build(); + final MqttSubscription subscription4 = + new MqttSubscriptionBuilder.Default().topicFilter(flow4.toString()).build(); + final MqttSubscription subscription5 = + new MqttSubscriptionBuilder.Default().topicFilter(flow5.toString()).build(); + final MqttSubscription subscription6 = + new MqttSubscriptionBuilder.Default().topicFilter(flow6.toString()).build(); + flows.subscribe(subscription1, 1, flow1); + flows.subscribe(subscription2, 2, flow2); + flows.subscribe(subscription3, 3, flow3); + flows.subscribe(subscription4, 4, flow4); + flows.subscribe(subscription5, 5, flow5); + flows.subscribe(subscription6, 6, flow6); + + final Exception cause = new Exception("test"); + flows.clear(cause); + verify(flow1, never()).onError(cause); + verify(flow2, never()).onError(cause); + verify(flow3, never()).onError(cause); + verify(flow4, never()).onError(cause); + verify(flow5, never()).onError(cause); + verify(flow6, never()).onError(cause); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of("test/topic"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test2/topic/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic2/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic/filter2"), matching); + flows.findMatching(MqttTopicImpl.of("abc/topic"), matching); + flows.findMatching(MqttTopicImpl.of("topic/abc"), matching); + assertFalse(matching.subscriptionFound); + } + + @Test + void clear_notAcknowledged_noFlow() { + final MqttSubscription subscription1 = + new MqttSubscriptionBuilder.Default().topicFilter("test/topic/filter").build(); + final MqttSubscription subscription2 = + new MqttSubscriptionBuilder.Default().topicFilter("test2/topic/filter").build(); + final MqttSubscription subscription3 = + new MqttSubscriptionBuilder.Default().topicFilter("test/topic2/filter").build(); + final MqttSubscription subscription4 = + new MqttSubscriptionBuilder.Default().topicFilter("test/topic/filter2").build(); + final MqttSubscription subscription5 = new MqttSubscriptionBuilder.Default().topicFilter("+/topic").build(); + final MqttSubscription subscription6 = new MqttSubscriptionBuilder.Default().topicFilter("topic/#").build(); + flows.subscribe(subscription1, 1, null); + flows.subscribe(subscription2, 2, null); + flows.subscribe(subscription3, 3, null); + flows.subscribe(subscription4, 4, null); + flows.subscribe(subscription5, 5, null); + flows.subscribe(subscription6, 6, null); + + final Exception cause = new Exception("test"); + flows.clear(cause); + + final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); + flows.findMatching(MqttTopicImpl.of("test/topic"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test2/topic/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic2/filter"), matching); + flows.findMatching(MqttTopicImpl.of("test/topic/filter2"), matching); + flows.findMatching(MqttTopicImpl.of("abc/topic"), matching); + flows.findMatching(MqttTopicImpl.of("topic/abc"), matching); + assertFalse(matching.subscriptionFound); + } + + private static @NotNull MqttSubscribedPublishFlow mockSubscriptionFlow(final @NotNull String name) { + final MqttSubscribedPublishFlow flow = mock(MqttSubscribedPublishFlow.class); + final HandleList topicFilters = new HandleList<>(); + when(flow.getTopicFilters()).thenReturn(topicFilters); + when(flow.toString()).thenReturn(name); + return flow; + } + + private @NotNull ImmutableSet toSet(final @NotNull HandleList list) { + final ImmutableSet.Builder builder = ImmutableSet.builder(); + for (HandleList.Handle h = list.getFirst(); h != null; h = h.getNext()) { + builder.add(h.getElement()); + } + return builder.build(); + } +} diff --git a/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowListTest.java b/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowListTest.java deleted file mode 100644 index 38d93265e..000000000 --- a/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowListTest.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2018 dc-square and the HiveMQ MQTT Client Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package com.hivemq.client.internal.mqtt.handler.publish.incoming; - -/** - * @author Silvio Giebl - */ -class MqttSubscriptionFlowListTest extends MqttSubscriptionFlowsTest { - - MqttSubscriptionFlowListTest() { - super(MqttSubscriptionFlowList::new); - } - -} diff --git a/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowsTest.java b/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowsTest.java deleted file mode 100644 index 9b6f92cfc..000000000 --- a/src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowsTest.java +++ /dev/null @@ -1,450 +0,0 @@ -/* - * Copyright 2018 dc-square and the HiveMQ MQTT Client Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package com.hivemq.client.internal.mqtt.handler.publish.incoming; - -import com.google.common.collect.ImmutableSet; -import com.hivemq.client.internal.mqtt.datatypes.MqttTopicFilterImpl; -import com.hivemq.client.internal.mqtt.datatypes.MqttTopicImpl; -import com.hivemq.client.internal.util.collections.HandleList; -import org.jetbrains.annotations.NotNull; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.converter.ArgumentConversionException; -import org.junit.jupiter.params.converter.ConvertWith; -import org.junit.jupiter.params.converter.SimpleArgumentConverter; -import org.junit.jupiter.params.provider.CsvSource; - -import java.util.function.Supplier; - -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.Mockito.*; - -/** - * @author Silvio Giebl - */ -abstract class MqttSubscriptionFlowsTest { - - public static class CsvToArray extends SimpleArgumentConverter { - - @Override - protected @NotNull Object convert(final @NotNull Object source, final @NotNull Class targetType) - throws ArgumentConversionException { - final String s = (String) source; - return s.split("\\s*;\\s*"); - } - } - - private final @NotNull Supplier flowsSupplier; - @SuppressWarnings("NullabilityAnnotations") - MqttSubscriptionFlows flows; - - MqttSubscriptionFlowsTest(final @NotNull Supplier flowsSupplier) { - this.flowsSupplier = flowsSupplier; - } - - @BeforeEach - void setUp() { - flows = flowsSupplier.get(); - } - - @ParameterizedTest - @CsvSource({ - "a, a; +; a/#; +/#; # ", - "a/b, a/b; a/+; +/b; +/+; a/b/#; a/+/#; +/b/#; +/+/#; a/#; #", - "/, /; +/+; +/; /+; +/#; /#; # " - }) - void subscribe_matchingTopicFilters_doMatch( - final @NotNull String topic, @ConvertWith(CsvToArray.class) final @NotNull String[] matchingTopicFilters) { - - final MqttSubscribedPublishFlow[] matchingFlows = new MqttSubscribedPublishFlow[matchingTopicFilters.length]; - for (int i = 0; i < matchingTopicFilters.length; i++) { - final MqttSubscribedPublishFlow flow = mockSubscriptionFlow(matchingTopicFilters[i]); - final MqttTopicFilterImpl topicFilter = MqttTopicFilterImpl.of(matchingTopicFilters[i]); - flows.subscribe(topicFilter, flow); - assertEquals(ImmutableSet.of(topicFilter), toSet(flow.getTopicFilters())); - matchingFlows[i] = flow; - } - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertTrue(matching.subscriptionFound); - assertFalse(matching.isEmpty()); - assertEquals(ImmutableSet.copyOf(matchingFlows), toSet(matching)); - } - - @ParameterizedTest - @CsvSource({ - "a, a; +; a/#; +/#; # ", - "a/b, a/b; a/+; +/b; +/+; a/b/#; a/+/#; +/b/#; +/+/#; a/#; #", - "/, /; +/+; +/; /+; +/#; /#; # " - }) - void subscribe_matchingTopicFilters_doMatch_noFlow( - final @NotNull String topic, @ConvertWith(CsvToArray.class) final @NotNull String[] matchingTopicFilters) { - - for (final String matchingTopicFilter : matchingTopicFilters) { - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), null); - } - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertTrue(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - } - - @ParameterizedTest - @CsvSource({ - "a, /a; b; a/b; a/+; +/a; +/+; a/b/#; /#; / ", - "a/b, /a/b; a/c; c/b; a/b/c; +/a/b; a/+/b; a/b/+; a/b/c/#; +", - "/, //; a/b; a/; /a; + " - }) - void subscribe_nonMatchingTopicFilters_doNotMatch( - final @NotNull String topic, - @ConvertWith(CsvToArray.class) final @NotNull String[] notMatchingTopicFilters) { - - for (final String notMatchingTopicFilter : notMatchingTopicFilters) { - final MqttSubscribedPublishFlow flow = mockSubscriptionFlow(notMatchingTopicFilter); - final MqttTopicFilterImpl topicFilter = MqttTopicFilterImpl.of(notMatchingTopicFilter); - flows.subscribe(topicFilter, flow); - assertEquals(ImmutableSet.of(topicFilter), toSet(flow.getTopicFilters())); - } - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertFalse(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - } - - @ParameterizedTest - @CsvSource({ - "a, /a; b; a/b; a/+; +/a; +/+; a/b/#; /#; / ", - "a/b, /a/b; a/c; c/b; a/b/c; +/a/b; a/+/b; a/b/+; a/b/c/#; +", - "/, //; a/b; a/; /a; + " - }) - void subscribe_nonMatchingTopicFilters_doNotMatch_noFlow( - final @NotNull String topic, - @ConvertWith(CsvToArray.class) final @NotNull String[] notMatchingTopicFilters) { - - for (final String notMatchingTopicFilter : notMatchingTopicFilters) { - flows.subscribe(MqttTopicFilterImpl.of(notMatchingTopicFilter), null); - } - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertFalse(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - } - - @ParameterizedTest - @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) - void unsubscribe_matchingTopicFilters_doNoLongerMatch( - final @NotNull String topic, final @NotNull String matchingTopicFilter) { - - final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); - final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(matchingTopicFilter); - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), flow1); - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), flow2); - - final HandleList unsubscribed = new HandleList<>(); - flows.unsubscribe(MqttTopicFilterImpl.of(matchingTopicFilter), unsubscribed::add); - assertTrue(flow1.getTopicFilters().isEmpty()); - assertTrue(flow2.getTopicFilters().isEmpty()); - assertEquals(ImmutableSet.of(flow1, flow2), toSet(unsubscribed)); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertFalse(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - } - - @ParameterizedTest - @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) - void unsubscribe_matchingTopicFilters_doNoLongerMatch_noFlow( - final @NotNull String topic, final @NotNull String matchingTopicFilter) { - - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), null); - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), null); - - final HandleList unsubscribed = new HandleList<>(); - flows.unsubscribe(MqttTopicFilterImpl.of(matchingTopicFilter), unsubscribed::add); - assertTrue(unsubscribed.isEmpty()); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertFalse(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - } - - @ParameterizedTest - @CsvSource({"a, a, b", "a, a, a/b", "a/b, a/b, a/c"}) - void unsubscribe_nonMatchingTopicFilters_othersStillMatch( - final @NotNull String topic, final @NotNull String matchingTopicFilter, - final @NotNull String notMatchingTopicFilter) { - - final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); - final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(notMatchingTopicFilter); - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), flow1); - flows.subscribe(MqttTopicFilterImpl.of(notMatchingTopicFilter), flow2); - - final HandleList unsubscribed = new HandleList<>(); - flows.unsubscribe(MqttTopicFilterImpl.of(notMatchingTopicFilter), unsubscribed::add); - assertFalse(flow1.getTopicFilters().isEmpty()); - assertTrue(flow2.getTopicFilters().isEmpty()); - assertEquals(ImmutableSet.of(flow2), toSet(unsubscribed)); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertTrue(matching.subscriptionFound); - assertFalse(matching.isEmpty()); - assertEquals(ImmutableSet.of(flow1), toSet(matching)); - } - - @ParameterizedTest - @CsvSource({"a, a, b", "a, a, a/b", "a/b, a/b, a/c"}) - void unsubscribe_nonMatchingTopicFilters_othersStillMatch_noFlow( - final @NotNull String topic, final @NotNull String matchingTopicFilter, - final @NotNull String notMatchingTopicFilter) { - - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), null); - flows.subscribe(MqttTopicFilterImpl.of(notMatchingTopicFilter), null); - - final HandleList unsubscribed = new HandleList<>(); - flows.unsubscribe(MqttTopicFilterImpl.of(notMatchingTopicFilter), unsubscribed::add); - assertTrue(unsubscribed.isEmpty()); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertTrue(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - } - - @ParameterizedTest - @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) - void cancel_doNoLongerMatch(final @NotNull String topic, final @NotNull String matchingTopicFilter) { - final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); - final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(matchingTopicFilter); - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), flow1); - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), flow2); - - flows.cancel(flow1); - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertTrue(matching.subscriptionFound); - assertFalse(matching.isEmpty()); - assertEquals(ImmutableSet.of(flow2), toSet(matching)); - - flows.cancel(flow2); - final MqttMatchingPublishFlows matching2 = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching2); - assertTrue(matching2.subscriptionFound); - assertTrue(matching2.isEmpty()); - } - - @ParameterizedTest - @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) - void cancel_notPresentFlows_areIgnored(final @NotNull String topic, final @NotNull String matchingTopicFilter) { - final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); - final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(matchingTopicFilter); - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), flow1); - - flows.cancel(flow2); - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertTrue(matching.subscriptionFound); - assertFalse(matching.isEmpty()); - assertEquals(ImmutableSet.of(flow1), toSet(matching)); - } - - @Test - void cancel_partiallyUnsubscribedFlow() { - final MqttSubscribedPublishFlow flow = mockSubscriptionFlow("test/topic(2)"); - flows.subscribe(MqttTopicFilterImpl.of("test/topic"), flow); - flows.subscribe(MqttTopicFilterImpl.of("test/topic2"), flow); - - flows.unsubscribe(MqttTopicFilterImpl.of("test/topic"), null); - flows.cancel(flow); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of("test/topic"), matching); - assertFalse(matching.subscriptionFound); - flows.findMatching(MqttTopicImpl.of("test/topic2"), matching); - assertTrue(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - } - - @ParameterizedTest - @CsvSource({ - "1/a, 1/a, 2/a, 2/a", "1/a, 1/+, 2/a, 2/+", "1/a, 1/#, 2/a, 2/#", "1/a/b, 1/a/b, 2/a/b, 2/a/b", - "1/a/b, 1/a/+, 2/a/b, 2/a/+", "1/a/b, 1/+/b, 2/a/b, 2/+/b", "1/a/b, 1/+/+, 2/a/b, 2/+/+", - "1/a/b, 1/+/#, 2/a/b, 2/+/#", "1/a/b, 1/#, 2/a/b, 2/#" - }) - void remove( - final @NotNull String topic, final @NotNull String matchingTopicFilter, final @NotNull String topic2, - final @NotNull String matchingTopicFilter2) { - - final MqttSubscribedPublishFlow flow = mockSubscriptionFlow(matchingTopicFilter); - final MqttTopicFilterImpl topicFilter = MqttTopicFilterImpl.of(matchingTopicFilter); - final MqttTopicFilterImpl topicFilter2 = MqttTopicFilterImpl.of(matchingTopicFilter2); - flows.subscribe(topicFilter, flow); - flows.subscribe(topicFilter2, flow); - assertEquals(ImmutableSet.of(topicFilter, topicFilter2), toSet(flow.getTopicFilters())); - - flows.remove(topicFilter, flow); - assertEquals(ImmutableSet.of(topicFilter2), toSet(flow.getTopicFilters())); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertFalse(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - - final MqttMatchingPublishFlows matching2 = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic2), matching2); - assertTrue(matching2.subscriptionFound); - assertFalse(matching2.isEmpty()); - assertEquals(ImmutableSet.of(flow), toSet(matching2)); - } - - @ParameterizedTest - @CsvSource({ - "1/a, 1/a, 2/a, 2/a", "1/a, 1/+, 2/a, 2/+", "1/a, 1/#, 2/a, 2/#", "1/a/b, 1/a/b, 2/a/b, 2/a/b", - "1/a/b, 1/a/+, 2/a/b, 2/a/+", "1/a/b, 1/+/b, 2/a/b, 2/+/b", "1/a/b, 1/+/+, 2/a/b, 2/+/+", - "1/a/b, 1/+/#, 2/a/b, 2/+/#", "1/a/b, 1/#, 2/a/b, 2/#" - }) - void remove_noFlow( - final @NotNull String topic, final @NotNull String matchingTopicFilter, final @NotNull String topic2, - final @NotNull String matchingTopicFilter2) { - - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), null); - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter2), null); - - flows.remove(MqttTopicFilterImpl.of(matchingTopicFilter), null); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertFalse(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - - final MqttMatchingPublishFlows matching2 = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic2), matching2); - assertTrue(matching2.subscriptionFound); - assertTrue(matching2.isEmpty()); - } - - @ParameterizedTest - @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) - void remove_doesNotUnsubscribe(final @NotNull String topic, final @NotNull String matchingTopicFilter) { - final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow(matchingTopicFilter); - final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow(matchingTopicFilter); - final MqttTopicFilterImpl topicFilter = MqttTopicFilterImpl.of(matchingTopicFilter); - flows.subscribe(topicFilter, flow1); - flows.subscribe(topicFilter, flow2); - assertEquals(ImmutableSet.of(topicFilter), toSet(flow1.getTopicFilters())); - assertEquals(ImmutableSet.of(topicFilter), toSet(flow2.getTopicFilters())); - - flows.remove(topicFilter, flow1); - assertTrue(flow1.getTopicFilters().isEmpty()); - assertEquals(ImmutableSet.of(topicFilter), toSet(flow2.getTopicFilters())); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertTrue(matching.subscriptionFound); - assertFalse(matching.isEmpty()); - assertEquals(ImmutableSet.of(flow2), toSet(matching)); - } - - @ParameterizedTest - @CsvSource({"a, a", "a, +", "a, #", "a/b, a/b", "a/b, a/+", "a/b, +/b", "a/b, +/+", "a/b, +/#", "a/b, #"}) - void remove_doesNotUnsubscribe_noFlow(final @NotNull String topic, final @NotNull String matchingTopicFilter) { - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), null); - flows.subscribe(MqttTopicFilterImpl.of(matchingTopicFilter), null); - - flows.remove(MqttTopicFilterImpl.of(matchingTopicFilter), null); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertTrue(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - } - - @ParameterizedTest - @CsvSource({"a/b, a/b/c, +/b/c, +/+/+", "a/b/c/d, a/b/c/d/e, +/b//d/ec, +/+/+/+/+"}) - void findMatching_matchingMultipleButNotAllLevels( - final @NotNull String topic, final @NotNull String filter1, final @NotNull String filter2, - final @NotNull String filter3) { - - flows.subscribe(MqttTopicFilterImpl.of(filter1), null); - flows.subscribe(MqttTopicFilterImpl.of(filter2), null); - flows.subscribe(MqttTopicFilterImpl.of(filter3), null); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of(topic), matching); - assertFalse(matching.subscriptionFound); - assertTrue(matching.isEmpty()); - } - - @Test - void clear() { - final MqttSubscribedPublishFlow flow1 = mockSubscriptionFlow("test/topic/filter"); - final MqttSubscribedPublishFlow flow2 = mockSubscriptionFlow("test2/topic/filter"); - final MqttSubscribedPublishFlow flow3 = mockSubscriptionFlow("test/topic2/filter"); - final MqttSubscribedPublishFlow flow4 = mockSubscriptionFlow("test/topic/filter2"); - final MqttSubscribedPublishFlow flow5 = mockSubscriptionFlow("+/topic"); - final MqttSubscribedPublishFlow flow6 = mockSubscriptionFlow("topic/#"); - flows.subscribe(MqttTopicFilterImpl.of(flow1.toString()), flow1); - flows.subscribe(MqttTopicFilterImpl.of(flow2.toString()), flow2); - flows.subscribe(MqttTopicFilterImpl.of(flow3.toString()), flow3); - flows.subscribe(MqttTopicFilterImpl.of(flow4.toString()), flow4); - flows.subscribe(MqttTopicFilterImpl.of(flow5.toString()), flow5); - flows.subscribe(MqttTopicFilterImpl.of(flow6.toString()), flow6); - - final Exception cause = new Exception("test"); - flows.clear(cause); - verify(flow1).onError(cause); - verify(flow2).onError(cause); - verify(flow3).onError(cause); - verify(flow4).onError(cause); - verify(flow5).onError(cause); - verify(flow6).onError(cause); - - final MqttMatchingPublishFlows matching = new MqttMatchingPublishFlows(); - flows.findMatching(MqttTopicImpl.of("test/topic"), matching); - flows.findMatching(MqttTopicImpl.of("test/topic/filter"), matching); - flows.findMatching(MqttTopicImpl.of("test2/topic/filter"), matching); - flows.findMatching(MqttTopicImpl.of("test/topic2/filter"), matching); - flows.findMatching(MqttTopicImpl.of("test/topic/filter2"), matching); - flows.findMatching(MqttTopicImpl.of("abc/topic"), matching); - flows.findMatching(MqttTopicImpl.of("topic/abc"), matching); - assertFalse(matching.subscriptionFound); - } - - private static @NotNull MqttSubscribedPublishFlow mockSubscriptionFlow(final @NotNull String name) { - final MqttSubscribedPublishFlow flow = mock(MqttSubscribedPublishFlow.class); - final HandleList topicFilters = new HandleList<>(); - when(flow.getTopicFilters()).thenReturn(topicFilters); - when(flow.toString()).thenReturn(name); - return flow; - } - - private @NotNull ImmutableSet toSet(final @NotNull HandleList list) { - final ImmutableSet.Builder builder = ImmutableSet.builder(); - for (HandleList.Handle h = list.getFirst(); h != null; h = h.getNext()) { - builder.add(h.getElement()); - } - return builder.build(); - } -} diff --git a/src/test/java/com/hivemq/client/rx/FlowableWithSingleTest.java b/src/test/java/com/hivemq/client/rx/FlowableWithSingleTest.java index 424c8ba8a..9b775efec 100644 --- a/src/test/java/com/hivemq/client/rx/FlowableWithSingleTest.java +++ b/src/test/java/com/hivemq/client/rx/FlowableWithSingleTest.java @@ -36,6 +36,7 @@ import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.*; @@ -201,6 +202,46 @@ void observeOnBoth_delayError_bufferSize() throws InterruptedException { executorService.shutdown(); } + @Test + void observeOnBoth_delayError_bufferSize_2() { + final Flowable flowable = + Flowable.just(new StringBuilder("single")).concatWith( + Flowable.range(0, 1024).zipWith(Flowable.just("next").repeat(1024), (i, s) -> s + i)) + .concatWith(Flowable.error(new IllegalArgumentException("test"))) + .hide(); + final FlowableWithSingle flowableWithSingle = + new FlowableWithSingleSplit<>(flowable, String.class, StringBuilder.class); + + final ExecutorService executorService = + Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().setNameFormat("test_thread").build()); + + final AtomicInteger count = new AtomicInteger(); + final AtomicReference single = new AtomicReference<>(); + final AtomicReference error = new AtomicReference<>(); + // @formatter:off + flowableWithSingle.observeOnBoth(Schedulers.from(executorService), true, 1024) + .doOnSingle(stringBuilder -> { + single.set(stringBuilder); + assertEquals("test_thread", Thread.currentThread().getName()); + }) + .doOnNext(string -> { + assertEquals("next" + count.getAndIncrement(), string); + assertEquals("test_thread", Thread.currentThread().getName()); + }) + .doOnError(error::set) + .ignoreElements() + .onErrorComplete() + .blockingAwait(); + // @formatter:on + + assertEquals(1024, count.get()); + assertEquals("single", single.get().toString()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertEquals("test", error.get().getMessage()); + + executorService.shutdown(); + } + @MethodSource("singleNext3") @ParameterizedTest void mapSingle(final @NotNull FlowableWithSingle flowableWithSingle) { @@ -273,16 +314,15 @@ void mapBoth_multiple(final @NotNull FlowableWithSingle f final AtomicInteger nextCounter = new AtomicInteger(); final AtomicInteger singleCounter = new AtomicInteger(); - flowableWithSingle // - .mapBoth(s -> { - nextCounter.incrementAndGet(); - assertNotEquals("test_thread", Thread.currentThread().getName()); - return s + "-1"; - }, stringBuilder -> { - assertEquals(1, singleCounter.incrementAndGet()); - assertNotEquals("test_thread", Thread.currentThread().getName()); - return stringBuilder.append("-1"); - }).mapBoth(s -> { + flowableWithSingle.mapBoth(s -> { + nextCounter.incrementAndGet(); + assertNotEquals("test_thread", Thread.currentThread().getName()); + return s + "-1"; + }, stringBuilder -> { + assertEquals(1, singleCounter.incrementAndGet()); + assertNotEquals("test_thread", Thread.currentThread().getName()); + return stringBuilder.append("-1"); + }).mapBoth(s -> { nextCounter.incrementAndGet(); assertNotEquals("test_thread", Thread.currentThread().getName()); return s + "-2";