Skip to content

Commit d7ae1a7

Browse files
committed
Fixed MqttSubscribedPublishFlowTree.getSubscriptions with root multi level wildcards
1 parent 66345e0 commit d7ae1a7

File tree

3 files changed

+107
-27
lines changed

3 files changed

+107
-27
lines changed

Diff for: src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevel.java

+34-18
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,42 @@ public boolean isSingleLevelWildcard() {
6060
return this;
6161
}
6262

63-
public @Nullable MqttTopicFilterImpl toFilter(final byte @Nullable [] prefix, final boolean multiLevelWildcard) {
64-
final byte[] bytes;
63+
public static @Nullable MqttTopicFilterImpl toFilter(
64+
final byte @Nullable [] prefix,
65+
final @Nullable MqttTopicLevel topicLevel,
66+
final boolean multiLevelWildcard) {
67+
68+
int length = 0;
69+
if (prefix != null) {
70+
length += prefix.length + 1;
71+
}
72+
if (topicLevel != null) {
73+
length += topicLevel.array.length;
74+
}
75+
if (multiLevelWildcard) {
76+
if (topicLevel != null) {
77+
length++;
78+
}
79+
length++;
80+
}
81+
final byte[] bytes = new byte[length];
82+
int cursor = 0;
6583
if (prefix != null) {
66-
if (multiLevelWildcard) {
67-
bytes = new byte[prefix.length + 1 + array.length + 2];
68-
bytes[bytes.length - 2] = MqttTopicImpl.TOPIC_LEVEL_SEPARATOR;
69-
bytes[bytes.length - 1] = MqttTopicFilterImpl.MULTI_LEVEL_WILDCARD;
70-
} else {
71-
bytes = new byte[prefix.length + 1 + array.length];
84+
System.arraycopy(prefix, 0, bytes, cursor, prefix.length);
85+
cursor += prefix.length;
86+
bytes[cursor] = MqttTopicImpl.TOPIC_LEVEL_SEPARATOR;
87+
cursor++;
88+
}
89+
if (topicLevel != null) {
90+
System.arraycopy(topicLevel.array, 0, bytes, cursor, topicLevel.array.length);
91+
cursor += topicLevel.array.length;
92+
}
93+
if (multiLevelWildcard) {
94+
if (topicLevel != null) {
95+
bytes[cursor] = MqttTopicImpl.TOPIC_LEVEL_SEPARATOR;
96+
cursor++;
7297
}
73-
System.arraycopy(prefix, 0, bytes, 0, prefix.length);
74-
bytes[prefix.length] = MqttTopicImpl.TOPIC_LEVEL_SEPARATOR;
75-
System.arraycopy(array, 0, bytes, prefix.length + 1, array.length);
76-
} else if (multiLevelWildcard) {
77-
bytes = new byte[array.length + 2];
78-
System.arraycopy(array, 0, bytes, 0, array.length);
79-
bytes[bytes.length - 2] = MqttTopicImpl.TOPIC_LEVEL_SEPARATOR;
80-
bytes[bytes.length - 1] = MqttTopicFilterImpl.MULTI_LEVEL_WILDCARD;
81-
} else {
82-
bytes = array;
98+
bytes[cursor] = MqttTopicFilterImpl.MULTI_LEVEL_WILDCARD;
8399
}
84100
return MqttTopicFilterImpl.of(bytes);
85101
}

Diff for: src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowTree.java

+11-9
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ public void clear(final @NotNull Throwable cause) {
112112

113113
@Override
114114
public @NotNull Map<@NotNull Integer, @NotNull List<@NotNull MqttSubscription>> getSubscriptions() {
115+
// we sort in reverse order of subscription identifiers so that newer subscriptions are first
115116
final Map<Integer, List<MqttSubscription>> map = new TreeMap<>(Comparator.reverseOrder());
116117
if (rootNode != null) {
117118
final Queue<IteratorNode> nodes = new LinkedList<>();
@@ -533,13 +534,11 @@ void getSubscriptions(
533534

534535
final MqttTopicLevel topicLevels = ((parentTopicLevels == null) || (topicLevel == null)) ? topicLevel :
535536
MqttTopicLevels.concat(parentTopicLevels, topicLevel);
536-
if (topicLevels != null) {
537-
if (entries != null) {
538-
getSubscriptions(entries, topicLevels, false, map);
539-
}
540-
if (multiLevelEntries != null) {
541-
getSubscriptions(multiLevelEntries, topicLevels, true, map);
542-
}
537+
if (entries != null) {
538+
getSubscriptions(entries, topicLevels, false, map);
539+
}
540+
if (multiLevelEntries != null) {
541+
getSubscriptions(multiLevelEntries, topicLevels, true, map);
543542
}
544543
if (next != null) {
545544
next.forEach(node -> nodes.add(new IteratorNode(node, topicLevels)));
@@ -551,21 +550,24 @@ void getSubscriptions(
551550

552551
private static void getSubscriptions(
553552
final @NotNull NodeList<TopicTreeEntry> entries,
554-
final @NotNull MqttTopicLevel topicLevels,
553+
final @Nullable MqttTopicLevel topicLevels,
555554
final boolean multiLevelWildcard,
556555
final @NotNull Map<@NotNull Integer, @NotNull List<@NotNull MqttSubscription>> map) {
557556

557+
// exact subscription = subscription without prefix, so no shared subscription
558558
boolean exactFound = false;
559+
// iterate in reverse order to only include the newest exact subscription
559560
for (TopicTreeEntry entry = entries.getLast(); entry != null; entry = entry.getPrev()) {
560561
if (entry.acknowledged) {
561562
if (entry.topicFilterPrefix == null) {
562563
if (exactFound) {
564+
// ignore older exact subscriptions as they are overwritten by the newest
563565
continue;
564566
}
565567
exactFound = true;
566568
}
567569
final MqttTopicFilterImpl topicFilter =
568-
topicLevels.toFilter(entry.topicFilterPrefix, multiLevelWildcard);
570+
MqttTopicLevel.toFilter(entry.topicFilterPrefix, topicLevels, multiLevelWildcard);
569571
assert topicFilter != null : "reconstructed topic filter must be valid";
570572
final MqttQos qos = MqttSubscription.decodeQos(entry.subscriptionOptions);
571573
assert qos != null : "reconstructed qos must be valid";

Diff for: src/test/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscribedPublishFlowsTest.java

+62
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscription;
2424
import com.hivemq.client.internal.mqtt.message.subscribe.MqttSubscriptionBuilder;
2525
import com.hivemq.client.internal.util.collections.HandleList;
26+
import com.hivemq.client.internal.util.collections.ImmutableList;
2627
import org.jetbrains.annotations.NotNull;
2728
import org.junit.jupiter.api.BeforeEach;
2829
import org.junit.jupiter.api.Test;
@@ -32,6 +33,9 @@
3233
import org.junit.jupiter.params.converter.SimpleArgumentConverter;
3334
import org.junit.jupiter.params.provider.CsvSource;
3435

36+
import java.util.List;
37+
import java.util.Map;
38+
import java.util.concurrent.atomic.AtomicInteger;
3539
import java.util.function.Supplier;
3640

3741
import static org.junit.jupiter.api.Assertions.*;
@@ -744,6 +748,64 @@ void clear_notAcknowledged_noFlow() {
744748
MqttStatefulPublish.DEFAULT_NO_SUBSCRIPTION_IDENTIFIERS));
745749
}
746750

751+
@Test
752+
void getSubscriptions() {
753+
final ImmutableList<MqttSubscription> subscriptions = ImmutableList.of(
754+
new MqttSubscriptionBuilder.Default().topicFilter("abc").build(),
755+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/abc").build(),
756+
new MqttSubscriptionBuilder.Default().topicFilter("#").build(),
757+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/#").build(),
758+
new MqttSubscriptionBuilder.Default().topicFilter("test/#").build(),
759+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/test/#").build(),
760+
new MqttSubscriptionBuilder.Default().topicFilter("+/#").build(),
761+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/+/#").build(),
762+
new MqttSubscriptionBuilder.Default().topicFilter("+/abc").build(),
763+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/+/abc").build());
764+
for (int i = 0; i < subscriptions.size(); i++) {
765+
flows.subscribe(subscriptions.get(i), i, null);
766+
flows.suback(subscriptions.get(i).getTopicFilter(), i, false);
767+
}
768+
final Map<Integer, List<MqttSubscription>> allSubscriptions = flows.getSubscriptions();
769+
for (int i = 0; i < subscriptions.size(); i++) {
770+
assertEquals(ImmutableList.of(subscriptions.get(i)), allSubscriptions.get(i));
771+
}
772+
// check if sorted in reverse order
773+
final AtomicInteger atomicInteger = new AtomicInteger(subscriptions.size());
774+
allSubscriptions.forEach(
775+
(subscriptionId, subscriptionsForId) -> assertEquals(atomicInteger.decrementAndGet(), subscriptionId));
776+
}
777+
778+
@Test
779+
void getSubscriptions_sameSubscriptionIdentifiers() {
780+
final ImmutableList<MqttSubscription> subscriptions = ImmutableList.of(
781+
new MqttSubscriptionBuilder.Default().topicFilter("abc").build(),
782+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/abc").build(),
783+
new MqttSubscriptionBuilder.Default().topicFilter("#").build(),
784+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/#").build(),
785+
new MqttSubscriptionBuilder.Default().topicFilter("test/#").build(),
786+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/test/#").build(),
787+
new MqttSubscriptionBuilder.Default().topicFilter("+/#").build(),
788+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/+/#").build(),
789+
new MqttSubscriptionBuilder.Default().topicFilter("+/abc").build(),
790+
new MqttSubscriptionBuilder.Default().topicFilter("$share/group/+/abc").build());
791+
for (int i = 0; i < subscriptions.size(); i += 2) {
792+
for (int j = 0; j < 2; j++) {
793+
flows.subscribe(subscriptions.get(i + j), i, null);
794+
flows.suback(subscriptions.get(i + j).getTopicFilter(), i, false);
795+
}
796+
}
797+
final Map<Integer, List<MqttSubscription>> allSubscriptions = flows.getSubscriptions();
798+
for (int i = 0; i < subscriptions.size(); i += 2) {
799+
assertEquals(
800+
ImmutableSet.of(subscriptions.get(i), subscriptions.get(i + 1)),
801+
ImmutableSet.copyOf(allSubscriptions.get(i)));
802+
}
803+
// check if sorted in reverse order
804+
final AtomicInteger atomicInteger = new AtomicInteger(subscriptions.size());
805+
allSubscriptions.forEach(
806+
(subscriptionId, subscriptionsForId) -> assertEquals(atomicInteger.addAndGet(-2), subscriptionId));
807+
}
808+
747809
private static @NotNull MqttSubscribedPublishFlow mockSubscriptionFlow(final @NotNull String name) {
748810
final MqttSubscribedPublishFlow flow = mock(MqttSubscribedPublishFlow.class);
749811
final HandleList<MqttTopicFilterImpl> topicFilters = new HandleList<>();

0 commit comments

Comments
 (0)