From e52e29199f656b56d66a843a4203845d8fa169b9 Mon Sep 17 00:00:00 2001 From: Hendrik Muhs Date: Fri, 29 Jul 2022 14:42:56 +0200 Subject: [PATCH 1/4] use a bitset for deduplication --- .../CountingItemSetTraverser.java | 42 +- .../frequentitemsets/EclatMapReducer.java | 32 +- .../FrequentItemSetCollector.java | 76 ++- .../aggs/frequentitemsets/ItemSetBitSet.java | 250 ++++++++++ .../frequentitemsets/ItemSetTraverser.java | 66 ++- .../frequentitemsets/TransactionStore.java | 13 +- .../FrequentItemSetCollectorTests.java | 187 +++---- .../frequentitemsets/ItemSetBitSetTests.java | 213 ++++++++ .../ItemSetTraverserTests.java | 466 +++++++++++------- 9 files changed, 978 insertions(+), 367 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSetTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/CountingItemSetTraverser.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/CountingItemSetTraverser.java index a0d4b407f5feb..4d9de3a86e23c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/CountingItemSetTraverser.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/CountingItemSetTraverser.java @@ -7,11 +7,13 @@ package org.elasticsearch.xpack.ml.aggs.frequentitemsets; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; -import org.apache.lucene.util.LongsRef; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.ml.aggs.frequentitemsets.TransactionStore.TopItemIds; import java.io.IOException; import java.util.Arrays; @@ -30,6 +32,7 @@ * if [a, b] is not in T, [a, b, c] can not be in T either */ class CountingItemSetTraverser implements Releasable { + private static final Logger logger = LogManager.getLogger(CountingItemSetTraverser.class); // start size and size increment for the occurences stack private static final int OCCURENCES_SIZE_INCREMENT = 10; @@ -48,13 +51,19 @@ class CountingItemSetTraverser implements Releasable { // growable bit set from java util private java.util.BitSet visited; - CountingItemSetTraverser(TransactionStore transactionStore, int cacheTraversalDepth, int cacheNumberOfTransactions, long minCount) { + CountingItemSetTraverser( + TransactionStore transactionStore, + TopItemIds topItemIds, + int cacheTraversalDepth, + int cacheNumberOfTransactions, + long minCount + ) { this.transactionStore = transactionStore; boolean success = false; try { // we allocate 2 big arrays, if the 2nd allocation fails, ensure we clean up - this.topItemSetTraverser = transactionStore.getTopItemIdTraverser(); + this.topItemSetTraverser = new ItemSetTraverser(topItemIds); this.topTransactionIds = transactionStore.getTopTransactionIds(); success = true; } finally { @@ -80,11 +89,15 @@ public boolean next(long earlyStopMinCount) throws IOException { final long totalTransactionCount = transactionStore.getTotalTransactionCount(); int depth = topItemSetTraverser.getNumberOfItems(); + long occurencesOfSingleItem = transactionStore.getItemCount(topItemSetTraverser.getItemId()); + if (depth == 1) { // at the 1st level, we can take the count directly from the transaction store - occurencesStack[0] = transactionStore.getItemCount(topItemSetTraverser.getItemId()); + occurencesStack[0] = occurencesOfSingleItem; + return true; + } else if (occurencesOfSingleItem < earlyStopMinCount) { + rememberCountInStack(depth, occurencesOfSingleItem); return true; - // till a certain depth store results in a cache matrix } else if (depth < cacheTraversalDepth) { // get the cached skip count @@ -187,7 +200,7 @@ public long getCount() { /** * Get the count of the item set without the last item */ - public long getPreviousCount() { + public long getParentCount() { if (topItemSetTraverser.getNumberOfItems() > 1) { return occurencesStack[topItemSetTraverser.getNumberOfItems() - 2]; } @@ -201,7 +214,7 @@ public boolean hasBeenVisited() { return true; } - public boolean hasPredecessorBeenVisited() { + public boolean hasParentBeenVisited() { if (topItemSetTraverser.getNumberOfItems() > 1) { return visited.get(topItemSetTraverser.getNumberOfItems() - 2); } @@ -214,7 +227,7 @@ public void setVisited() { } } - public void setPredecessorVisited() { + public void setParentVisited() { if (topItemSetTraverser.getNumberOfItems() > 1) { visited.set(topItemSetTraverser.getNumberOfItems() - 2); } @@ -228,10 +241,15 @@ public int getNumberOfItems() { } /** - * Get the current item set + * + * Get a bitset representation of the current item set */ - public LongsRef getItemSet() { - return topItemSetTraverser.getItemSet(); + public ItemSetBitSet getItemSetBitSet() { + return topItemSetTraverser.getItemSetBitSet(); + } + + public ItemSetBitSet getParentItemSetBitSet() { + return topItemSetTraverser.getParentItemSetBitSet(); } /** @@ -250,7 +268,7 @@ public boolean atLeaf() { @Override public void close() { - Releasables.close(topItemSetTraverser, topTransactionIds); + Releasables.close(topTransactionIds); } // remember the count in the stack without tracking push and pop diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/EclatMapReducer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/EclatMapReducer.java index 92d31fc7fe118..ef7c168d1fa0d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/EclatMapReducer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/EclatMapReducer.java @@ -9,7 +9,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.util.LongsRef; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -25,6 +24,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.ml.aggs.frequentitemsets.FrequentItemSetCollector.FrequentItemSet; +import org.elasticsearch.xpack.ml.aggs.frequentitemsets.TransactionStore.TopItemIds; import org.elasticsearch.xpack.ml.aggs.frequentitemsets.mr.AbstractItemSetMapReducer; import org.elasticsearch.xpack.ml.aggs.frequentitemsets.mr.ItemSetMapReduceValueSource.Field; @@ -338,8 +338,6 @@ private static EclatResult eclat( final long totalTransactionCount = transactionStore.getTotalTransactionCount(); Map profilingInfo = null; long minCount = (long) Math.ceil(totalTransactionCount * minimumSupport); - FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, size, minCount); - long numberOfSetsChecked = 0; if (profilingInfoReduce != null) { profilingInfo = new LinkedHashMap<>(profilingInfoReduce); @@ -347,8 +345,10 @@ private static EclatResult eclat( } try ( + TopItemIds topItemIds = transactionStore.getTopItemIds(); CountingItemSetTraverser setTraverser = new CountingItemSetTraverser( transactionStore, + topItemIds, BITSET_CACHE_TRAVERSAL_DEPTH, (int) Math.min(MAX_BITSET_CACHE_NUMBER_OF_TRANSACTIONS, totalTransactionCount), minCount @@ -360,7 +360,8 @@ private static EclatResult eclat( minCount, transactionStore.getTotalItemCount() ); - + FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, topItemIds, size, minCount); + long numberOfSetsChecked = 0; long previousMinCount = 0; while (setTraverser.next(minCount)) { @@ -402,8 +403,11 @@ private static EclatResult eclat( if (setTraverser.atLeaf() && setTraverser.hasBeenVisited() == false && setTraverser.getCount() >= minCount - && setTraverser.getItemSet().length >= minimumSetSize) { - minCount = collector.add(setTraverser.getItemSet(), setTraverser.getCount()); + && setTraverser.getItemSetBitSet().cardinality() >= minimumSetSize) { + + logger.trace("add after prune"); + + minCount = collector.add(setTraverser.getItemSetBitSet(), setTraverser.getCount()); // no need to set visited, as we are on a leaf } @@ -418,19 +422,17 @@ private static EclatResult eclat( * * iff the count of the subset is higher, collect */ - if (setTraverser.hasPredecessorBeenVisited() == false - && setTraverser.getItemSet().length > minimumSetSize - && setTraverser.getCount() < setTraverser.getPreviousCount()) { + if (setTraverser.hasParentBeenVisited() == false + && setTraverser.getItemSetBitSet().cardinality() > minimumSetSize + && setTraverser.getCount() < setTraverser.getParentCount()) { // add the set without the last item - LongsRef subItemSet = setTraverser.getItemSet().clone(); - subItemSet.length--; - minCount = collector.add(subItemSet, setTraverser.getPreviousCount()); + minCount = collector.add(setTraverser.getParentItemSetBitSet(), setTraverser.getParentCount()); } // closed set criteria: the predecessor is no longer of interest: either we reported in the previous step or we found a // super set - setTraverser.setPredecessorVisited(); + setTraverser.setParentVisited(); /** * Iff the traverser reached a leaf, the item set can not be further expanded, e.g. we reached [f]: @@ -445,8 +447,8 @@ private static EclatResult eclat( * * Note: this also covers the last item, e.g. [a, x, y] */ - if (setTraverser.atLeaf() && setTraverser.getItemSet().length >= minimumSetSize) { - minCount = collector.add(setTraverser.getItemSet(), setTraverser.getCount()); + if (setTraverser.atLeaf() && setTraverser.getItemSetBitSet().cardinality() >= minimumSetSize) { + minCount = collector.add(setTraverser.getItemSetBitSet(), setTraverser.getCount()); // no need to set visited, as we are on a leaf } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetCollector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetCollector.java index e38f1dde9b2e5..1ceb5935ae2cf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetCollector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetCollector.java @@ -9,7 +9,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.util.LongsRef; import org.apache.lucene.util.PriorityQueue; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -20,11 +19,11 @@ import org.elasticsearch.search.aggregations.Aggregation.CommonFields; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.ml.aggs.frequentitemsets.TransactionStore.TopItemIds; import org.elasticsearch.xpack.ml.aggs.frequentitemsets.mr.ItemSetMapReduceValueSource.Field; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -159,7 +158,8 @@ public String toString() { */ class FrequentItemSetCandidate { - private LongsRef items; + private static final int STARTBITS = 64; + private ItemSetBitSet items; private long docCount; // every set has a unique id, required for the outer logic @@ -167,15 +167,16 @@ class FrequentItemSetCandidate { private FrequentItemSetCandidate() { this.id = -1; - this.items = new LongsRef(10); + this.items = new ItemSetBitSet(STARTBITS); this.docCount = -1; } FrequentItemSet toFrequentItemSet(List fields) throws IOException { Map> frequentItemsKeyValues = new HashMap<>(); - for (int i = 0; i < items.length; ++i) { - Tuple item = transactionStore.getItem(items.longs[i]); + int pos = items.nextSetBit(0); + while (pos > 0) { + Tuple item = transactionStore.getItem(topItemIds.getItemIdAt(pos - 1)); final Field field = fields.get(item.v1()); Object formattedValue = field.formatValue(item.v2()); String fieldName = fields.get(item.v1()).getName(); @@ -187,6 +188,8 @@ FrequentItemSet toFrequentItemSet(List fields) throws IOException { l.add(formattedValue); frequentItemsKeyValues.put(fieldName, l); } + + pos = items.nextSetBit(++pos); } return new FrequentItemSet(frequentItemsKeyValues, docCount, (double) docCount / transactionStore.getTotalTransactionCount()); @@ -196,7 +199,7 @@ long getDocCount() { return docCount; } - LongsRef getItems() { + ItemSetBitSet getItems() { return items; } @@ -205,17 +208,11 @@ int getId() { } int size() { - return items.length; + return items.cardinality(); } - private void reset(int id, LongsRef items, long docCount) { - if (items.length > this.items.length) { - this.items = new LongsRef(items.length); - } - - System.arraycopy(items.longs, 0, this.items.longs, 0, items.length); - - this.items.length = items.length; + private void reset(int id, ItemSetBitSet items, long docCount) { + this.items.reset(items); this.docCount = docCount; this.id = id; } @@ -229,10 +226,7 @@ static class FrequentItemSetPriorityQueue extends PriorityQueue setsThatShareSameDocCount = frequentItemsByCount.get(docCount); - if (setsThatShareSameDocCount != null) { - for (FrequentItemSetCandidate otherSet : setsThatShareSameDocCount) { - if (otherSet.size() < itemSet.length) { - continue; - } - - // quick, intrinsic optimized prefix matching - int commonPrefix = Arrays.mismatch(otherSet.items.longs, 0, otherSet.items.longs.length, itemSet.longs, 0, itemSet.length); + private boolean hasSuperSet(ItemSetBitSet itemSetBitSet, long docCount) { + List setsThatShareSameDocCountBits = frequentItemsByCount.get(docCount); - if (commonPrefix == -1 || commonPrefix == itemSet.length) { + if (setsThatShareSameDocCountBits != null) { + for (FrequentItemSetCandidate otherSet : setsThatShareSameDocCountBits) { + if (itemSetBitSet.isSubset(otherSet.getItems())) { return true; } - - int pos = commonPrefix; - int posOther = commonPrefix; - - while (otherSet.size() - posOther >= itemSet.length - pos) { - if (otherSet.items.longs[posOther++] == itemSet.longs[pos]) { - pos++; - if (pos == itemSet.length) { - return true; - } - } - } } } + return false; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java new file mode 100644 index 0000000000000..f78031306ba74 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java @@ -0,0 +1,250 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.frequentitemsets; + +import java.util.Arrays; + +/** + * Custom implementation of a bitset for fast item set deduplication. + * + * Unfortunately other {@code BitSet} implementation, e.g. java.util, + * lack a subset check. + */ +class ItemSetBitSet implements Cloneable { + + // taken from {@code BitSet} + private static final int ADDRESS_BITS_PER_WORD = 6; + private static final int BITS_PER_WORD = 1 << ADDRESS_BITS_PER_WORD; + + /* Used to shift left or right for a partial word mask */ + private static final long WORD_MASK = 0xffffffffffffffffL; + + private long[] words; + private transient int wordsInUse = 0; + private int cardinality = 0; + + ItemSetBitSet() { + initWords(BITS_PER_WORD); + } + + ItemSetBitSet(int nbits) { + // nbits can't be negative; size 0 is OK + if (nbits < 0) throw new NegativeArraySizeException("nbits < 0: " + nbits); + + initWords(nbits); + } + + /*private ItemSetBitSet(long[] words) { + this.words = words; + this.wordsInUse = words.length; + } + + public static ItemSetBitSet valueOf(long[] longs) { + int n; + for (n = longs.length; n > 0 && longs[n - 1] == 0; n--) + ; + return new ItemSetBitSet(Arrays.copyOf(longs, n)); + }*/ + void reset(ItemSetBitSet bitSet) { + ensureCapacity(bitSet.wordsInUse); + System.arraycopy(bitSet.words, 0, this.words, 0, bitSet.wordsInUse); + this.cardinality = bitSet.cardinality; + this.wordsInUse = bitSet.wordsInUse; + } + + void set(int bitIndex) { + if (bitIndex < 0) throw new IndexOutOfBoundsException("bitIndex < 0: " + bitIndex); + + int wordIndex = wordIndex(bitIndex); + expandTo(wordIndex); + + final long oldWord = words[wordIndex]; + words[wordIndex] |= (1L << bitIndex); // Restores invariants + + if (oldWord != words[wordIndex]) { + cardinality++; + } + } + + boolean get(int bitIndex) { + if (bitIndex < 0) throw new IndexOutOfBoundsException("bitIndex < 0: " + bitIndex); + + int wordIndex = wordIndex(bitIndex); + return (wordIndex < wordsInUse) && ((words[wordIndex] & (1L << bitIndex)) != 0); + } + + void clear(int bitIndex) { + if (bitIndex < 0) throw new IndexOutOfBoundsException("bitIndex < 0: " + bitIndex); + + int wordIndex = wordIndex(bitIndex); + if (wordIndex >= wordsInUse) return; + + final long oldWord = words[wordIndex]; + + words[wordIndex] &= ~(1L << bitIndex); + if (oldWord != words[wordIndex]) { + cardinality--; + } + recalculateWordsInUse(); + } + + /** + * Returns true if the specified {@code ItemBitSet} is a subset of this + * set. + * + * @param set {@code ItemBitSet} to check + * @return true if the given set is a subset of this set + */ + boolean isSubset(ItemSetBitSet set) { + if (wordsInUse > set.wordsInUse) { + return false; + } + + for (int i = wordsInUse - 1; i >= 0; i--) + if ((words[i] & set.words[i]) != words[i]) return false; + + return true; + } + + int nextSetBit(int fromIndex) { + if (fromIndex < 0) throw new IndexOutOfBoundsException("fromIndex < 0: " + fromIndex); + + int u = wordIndex(fromIndex); + if (u >= wordsInUse) return -1; + + long word = words[u] & (WORD_MASK << fromIndex); + + while (true) { + if (word != 0) return (u * BITS_PER_WORD) + Long.numberOfTrailingZeros(word); + if (++u == wordsInUse) return -1; + word = words[u]; + } + } + + int cardinality() { + return cardinality; + } + + public static int compare(ItemSetBitSet a, ItemSetBitSet b) { + if (a.cardinality != b.cardinality) { + return a.cardinality > b.cardinality ? 1 : -1; + } + + if (a.wordsInUse != b.wordsInUse) { + return a.wordsInUse < b.wordsInUse ? 1 : -1; + } + + int i = Arrays.mismatch(a.words, 0, a.wordsInUse, b.words, 0, b.wordsInUse); + + if (i == -1) { + return 0; + } + + return a.words[i] < b.words[i] ? 1 : -1; + } + + @Override + public Object clone() { + trimToSize(); + + try { + ItemSetBitSet result = (ItemSetBitSet) super.clone(); + result.words = words.clone(); + return result; + } catch (CloneNotSupportedException e) { + throw new InternalError(e); + } + } + + @Override + public String toString() { + final int MAX_INITIAL_CAPACITY = Integer.MAX_VALUE - 8; + int numBits = wordsInUse * BITS_PER_WORD; + // Avoid overflow in the case of a humongous numBits + int initialCapacity = (numBits <= (MAX_INITIAL_CAPACITY - 2) / 6) ? 6 * numBits + 2 : MAX_INITIAL_CAPACITY; + StringBuilder b = new StringBuilder(initialCapacity); + + for (int i = 0; i < wordsInUse; ++i) { + b.append(words[i]); + b.append(" "); + } + + return b.toString(); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + if (this == obj) { + return true; + } + + final ItemSetBitSet set = (ItemSetBitSet) obj; + if (wordsInUse != set.wordsInUse) return false; + + // Check words in use by both BitSets + for (int i = 0; i < wordsInUse; i++) + if (words[i] != set.words[i]) return false; + + return true; + } + + @Override + public int hashCode() { + // Arrays.hashCode does not support subarrays + int result = 1; + for (int i = 0; i < wordsInUse; i++) { + int elementHash = (int) (words[i] ^ (words[i] >>> 32)); + result = 31 * result + elementHash; + } + + return result; + } + + private void trimToSize() { + if (wordsInUse != words.length) { + words = Arrays.copyOf(words, wordsInUse); + } + } + + private void initWords(int nbits) { + words = new long[wordIndex(nbits - 1) + 1]; + } + + private void ensureCapacity(int wordsRequired) { + if (words.length < wordsRequired) { + // Allocate larger of doubled size or required size + int request = Math.max(2 * words.length, wordsRequired); + words = Arrays.copyOf(words, request); + } + } + + private void recalculateWordsInUse() { + // Traverse the bitset until a used word is found + int i; + for (i = wordsInUse - 1; i >= 0; i--) + if (words[i] != 0) break; + + wordsInUse = i + 1; // The new logical size + } + + private void expandTo(int wordIndex) { + int wordsRequired = wordIndex + 1; + if (wordsInUse < wordsRequired) { + ensureCapacity(wordsRequired); + wordsInUse = wordsRequired; + } + } + + private static int wordIndex(int bitIndex) { + return bitIndex >> ADDRESS_BITS_PER_WORD; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverser.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverser.java index ee3e6479de404..a69d8c31a0116 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverser.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverser.java @@ -7,9 +7,8 @@ package org.elasticsearch.xpack.ml.aggs.frequentitemsets; +import org.apache.lucene.util.IntsRef; import org.apache.lucene.util.LongsRef; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.Releasables; import java.util.ArrayList; import java.util.List; @@ -27,7 +26,7 @@ * Note: In order to avoid churn, the traverser is reusing objects as much as it can, * see the comments containing the non-optimized code */ -class ItemSetTraverser implements Releasable { +class ItemSetTraverser { // start size and size increment for array holding items private static final int SIZE_INCREMENT = 100; @@ -39,12 +38,21 @@ class ItemSetTraverser implements Releasable { private final List itemIterators = new ArrayList<>(); private LongsRef itemIdStack = new LongsRef(SIZE_INCREMENT); + private final ItemSetBitSet itemPositionsVector; + private final ItemSetBitSet itemPositionsVectorParent; + private IntsRef itemPositionsStack = new IntsRef(SIZE_INCREMENT); + private int stackPosition = 0; ItemSetTraverser(TransactionStore.TopItemIds topItemIds) { this.topItemIds = topItemIds; // push the first iterator itemIterators.add(topItemIds.iterator()); + + // create a bit vector that corresponds to the number of items + itemPositionsVector = new ItemSetBitSet((int) topItemIds.size()); + // create a bit vector that corresponds to the item set + itemPositionsVectorParent = new ItemSetBitSet((int) topItemIds.size()); } /** @@ -81,25 +89,33 @@ public boolean next() { return false; } itemIdStack.length--; + itemPositionsStack.length--; + itemPositionsVectorParent.clear(itemPositionsStack.ints[itemPositionsStack.length]); + itemPositionsVector.clear(itemPositionsStack.ints[itemPositionsStack.length]); } } // push a new iterator on the stack + + int itemPosition = itemIterators.get(stackPosition).getIndex(); // non-optimized: itemIterators.add(topItemIds.iterator(itemIteratorStack.peek().getIndex())); if (itemIterators.size() == stackPosition + 1) { - itemIterators.add(topItemIds.iterator(itemIterators.get(stackPosition).getIndex())); + itemIterators.add(topItemIds.iterator(itemPosition)); } else { - itemIterators.get(stackPosition + 1).reset(itemIterators.get(stackPosition).getIndex()); + itemIterators.get(stackPosition + 1).reset(itemPosition); } - if (itemIdStack.longs.length == itemIdStack.length) { - LongsRef resizedItemIdStack2 = new LongsRef(itemIdStack.length + SIZE_INCREMENT); - System.arraycopy(itemIdStack.longs, 0, resizedItemIdStack2.longs, 0, itemIdStack.length); - resizedItemIdStack2.length = itemIdStack.length; - itemIdStack = resizedItemIdStack2; + growStacksIfNecessary(); + itemIdStack.longs[itemIdStack.length++] = itemId; + + // set the position from the previous step + if (itemPositionsStack.length > 0) { + itemPositionsVectorParent.set(itemPositionsStack.ints[itemPositionsStack.length - 1]); } - itemIdStack.longs[itemIdStack.length++] = itemId; + // set the position from the this step + itemPositionsStack.ints[itemPositionsStack.length++] = itemPosition; + itemPositionsVector.set(itemPosition); ++stackPosition; return true; @@ -113,6 +129,14 @@ public LongsRef getItemSet() { return itemIdStack; } + public ItemSetBitSet getItemSetBitSet() { + return itemPositionsVector; + } + + public ItemSetBitSet getParentItemSetBitSet() { + return itemPositionsVectorParent; + } + public int getNumberOfItems() { return stackPosition; } @@ -132,11 +156,25 @@ public void prune() { return; } itemIdStack.length--; + itemPositionsStack.length--; + itemPositionsVectorParent.clear(itemPositionsStack.ints[itemPositionsStack.length]); + itemPositionsVector.clear(itemPositionsStack.ints[itemPositionsStack.length]); } - @Override - public void close() { - Releasables.close(topItemIds); + private void growStacksIfNecessary() { + if (itemIdStack.longs.length == itemIdStack.length) { + LongsRef resizedItemIdStack = new LongsRef(itemIdStack.length + SIZE_INCREMENT); + System.arraycopy(itemIdStack.longs, 0, resizedItemIdStack.longs, 0, itemIdStack.length); + resizedItemIdStack.length = itemIdStack.length; + itemIdStack = resizedItemIdStack; + } + + if (itemPositionsStack.ints.length == itemPositionsStack.length) { + IntsRef resizeditemPositionsStack = new IntsRef(itemPositionsStack.length + SIZE_INCREMENT); + System.arraycopy(itemPositionsStack.ints, 0, resizeditemPositionsStack.ints, 0, itemPositionsStack.length); + resizeditemPositionsStack.length = itemPositionsStack.length; + itemPositionsStack = resizeditemPositionsStack; + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/TransactionStore.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/TransactionStore.java index 0269abe47b628..5a4b48dc1c53c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/TransactionStore.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/TransactionStore.java @@ -86,6 +86,10 @@ public IdIterator iterator(int startIndex) { return new IdIterator(startIndex); } + public long getItemIdAt(long index) { + return sortedItems.get(index); + } + public long size() { return sortedItems.size(); } @@ -346,15 +350,6 @@ public TopItemIds getTopItemIds() { return getTopItemIds(getItems().size()); } - /** - * Get a traverser object to traverse top items - * - * @return a top item traverser - */ - public ItemSetTraverser getTopItemIdTraverser() { - return new ItemSetTraverser(getTopItemIds()); - } - /** * Check if a transaction specified by id contains the item * diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetCollectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetCollectorTests.java index cd759af3559db..671f7b3a7c07f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetCollectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetCollectorTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.ml.aggs.frequentitemsets; -import org.apache.lucene.util.LongsRef; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.MockBigArrays; @@ -16,6 +15,7 @@ import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.aggs.frequentitemsets.FrequentItemSetCollector.FrequentItemSetPriorityQueue; +import org.elasticsearch.xpack.ml.aggs.frequentitemsets.TransactionStore.TopItemIds; import org.junit.After; import java.io.IOException; @@ -38,141 +38,148 @@ public void closeReleasables() throws IOException { public void testQueue() { transactionStore = new HashBasedTransactionStore(mockBigArrays()); - FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, 5, Long.MAX_VALUE); + try (TopItemIds topItemIds = transactionStore.getTopItemIds();) { + FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, topItemIds, 5, Long.MAX_VALUE); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(1L, 2L, 3L, 4L), 10L)); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(5L, 6L, 7L, 8L), 11L)); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(11L, 12L, 13L, 14L), 9L)); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(21L, 2L, 3L, 4L), 13L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 1L, 2L, 3L, 4L }, 10L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 5L, 6L, 7L, 8L }, 11L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 11L, 12L, 13L, 14L }, 9L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 21L, 2L, 3L, 4L }, 13L)); - // queue should be full, drop weakest element - assertEquals(9L, collector.add(longsRef(31L, 2L, 3L, 4L), 14L)); - assertEquals(10L, collector.add(longsRef(41L, 2L, 3L, 4L), 15L)); - assertEquals(11L, collector.add(longsRef(51L, 2L, 3L, 4L), 16L)); + // queue should be full, drop weakest element + assertEquals(9L, addToCollector(collector, new long[] { 31L, 2L, 3L, 4L }, 14L)); + assertEquals(10L, addToCollector(collector, new long[] { 41L, 2L, 3L, 4L }, 15L)); + assertEquals(11L, addToCollector(collector, new long[] { 51L, 2L, 3L, 4L }, 16L)); - // check that internal data has been removed as well - assertEquals(5, collector.getFrequentItemsByCount().size()); + // check that internal data has been removed as well + assertEquals(5, collector.getFrequentItemsByCount().size()); - // fill slots with same doc count - assertEquals(13L, collector.add(longsRef(61L, 2L, 3L, 4L), 20L)); - assertEquals(14L, collector.add(longsRef(71L, 2L, 3L, 4L), 20L)); - assertEquals(15L, collector.add(longsRef(81L, 2L, 3L, 4L), 20L)); - assertEquals(16L, collector.add(longsRef(91L, 2L, 3L, 4L), 20L)); - assertEquals(20L, collector.add(longsRef(101L, 2L, 3L, 4L), 20L)); + // fill slots with same doc count + assertEquals(13L, addToCollector(collector, new long[] { 61L, 2L, 3L, 4L }, 20L)); + assertEquals(14L, addToCollector(collector, new long[] { 71L, 2L, 3L, 4L }, 20L)); + assertEquals(15L, addToCollector(collector, new long[] { 81L, 2L, 3L, 4L }, 20L)); + assertEquals(16L, addToCollector(collector, new long[] { 91L, 2L, 3L, 4L }, 20L)); + assertEquals(20L, addToCollector(collector, new long[] { 101L, 2L, 3L, 4L }, 20L)); - // check that internal map has only 1 key - assertEquals(1, collector.getFrequentItemsByCount().size()); + // check that internal map has only 1 key + assertEquals(1, collector.getFrequentItemsByCount().size()); - // ignore set below current weakest one - assertEquals(20L, collector.add(longsRef(111L, 2L, 3L, 4L), 1L)); + // ignore set below current weakest one + assertEquals(20L, addToCollector(collector, new long[] { 111L, 2L, 3L, 4L }, 1L)); - FrequentItemSetPriorityQueue queue = collector.getQueue(); + FrequentItemSetPriorityQueue queue = collector.getQueue(); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 61L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 71L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 81L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 91L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 101L, 2L, 3L, 4L })); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 101L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 91L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 81L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 71L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 61L, 2L, 3L, 4L }))); - assertEquals(0, collector.size()); + assertEquals(0, collector.size()); + } } public void testClosedSetSkipping() { transactionStore = new HashBasedTransactionStore(mockBigArrays()); - FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, 5, Long.MAX_VALUE); + try (TopItemIds topItemIds = transactionStore.getTopItemIds();) { + FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, topItemIds, 5, Long.MAX_VALUE); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(1L, 2L, 3L, 4L), 10L)); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(5L, 6L, 7L, 8L), 11L)); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(11L, 12L, 13L, 14L), 12L)); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(21L, 2L, 3L, 4L), 13L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 1L, 2L, 3L, 4L }, 10L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 5L, 6L, 7L, 8L }, 11L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 11L, 12L, 13L, 14L }, 12L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 21L, 2L, 3L, 4L }, 13L)); - // add a subset of the 1st entry, it should be ignored - assertEquals(Long.MAX_VALUE, collector.add(longsRef(1L, 2L, 3L), 10L)); + // add a subset of the 1st entry, it should be ignored + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 1L, 2L, 3L }, 10L)); - // fill slots with same doc count - assertEquals(10L, collector.add(longsRef(61L, 2L, 3L, 4L), 20L)); - assertEquals(11L, collector.add(longsRef(71L, 2L, 3L, 4L), 20L)); - assertEquals(12L, collector.add(longsRef(81L, 2L, 3L, 4L), 20L)); - assertEquals(13L, collector.add(longsRef(91L, 2L, 3L, 4L), 20L)); + // fill slots with same doc count + assertEquals(10L, addToCollector(collector, new long[] { 61L, 2L, 3L, 4L }, 20L)); + assertEquals(11L, addToCollector(collector, new long[] { 71L, 2L, 3L, 4L }, 20L)); + assertEquals(12L, addToCollector(collector, new long[] { 81L, 2L, 3L, 4L }, 20L)); + assertEquals(13L, addToCollector(collector, new long[] { 91L, 2L, 3L, 4L }, 20L)); - // add a subset of an entry, it should be ignored - assertEquals(13L, collector.add(longsRef(81L, 2L, 4L), 20L)); + // add a subset of an entry, it should be ignored + assertEquals(13L, addToCollector(collector, new long[] { 81L, 2L, 4L }, 20L)); - FrequentItemSetPriorityQueue queue = collector.getQueue(); + FrequentItemSetPriorityQueue queue = collector.getQueue(); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 21L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 61L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 71L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 81L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 91L, 2L, 3L, 4L })); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 21L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 91L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 81L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 71L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 61L, 2L, 3L, 4L }))); - assertEquals(0, collector.size()); + assertEquals(0, collector.size()); + } } public void testCopyOnAdd() { transactionStore = new HashBasedTransactionStore(mockBigArrays()); + try (TopItemIds topItemIds = transactionStore.getTopItemIds();) { + FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, topItemIds, 5, Long.MAX_VALUE); + long[] itemSet = new long[] { 1L, 2L, 3L, 4L, 5L }; - FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, 5, Long.MAX_VALUE); - LongsRef itemSet = longsRef(1L, 2L, 3L, 4L, 5L); + assertEquals(Long.MAX_VALUE, addToCollector(collector, itemSet, 10L)); + itemSet[0] = 42L; + itemSet[4] = 42L; - assertEquals(Long.MAX_VALUE, collector.add(itemSet, 10L)); - itemSet.longs[0] = 42L; - itemSet.longs[4] = 42L; + FrequentItemSetPriorityQueue queue = collector.getQueue(); - FrequentItemSetPriorityQueue queue = collector.getQueue(); - - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 1L, 2L, 3L, 4L, 5L })); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 1L, 2L, 3L, 4L, 5L }))); + } } public void testLargerItemSetsPreference() { transactionStore = new HashBasedTransactionStore(mockBigArrays()); + try (TopItemIds topItemIds = transactionStore.getTopItemIds();) { + FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, topItemIds, 5, Long.MAX_VALUE); - FrequentItemSetCollector collector = new FrequentItemSetCollector(transactionStore, 5, Long.MAX_VALUE); - - assertEquals(Long.MAX_VALUE, collector.add(longsRef(1L, 2L, 3L, 4L), 10L)); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(5L, 6L, 7L, 8L), 11L)); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(11L, 12L, 13L, 14L), 9L)); - assertEquals(Long.MAX_VALUE, collector.add(longsRef(21L, 2L, 3L, 4L), 13L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 1L, 2L, 3L, 4L }, 10L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 5L, 6L, 7L, 8L }, 11L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 11L, 12L, 13L, 14L }, 9L)); + assertEquals(Long.MAX_VALUE, addToCollector(collector, new long[] { 21L, 2L, 3L, 4L }, 13L)); - // queue should be full, drop weakest element - assertEquals(9L, collector.add(longsRef(31L, 2L, 3L, 4L), 14L)); + // queue should be full, drop weakest element + assertEquals(9L, addToCollector(collector, new long[] { 31L, 2L, 3L, 4L }, 14L)); - assertEquals(9L, collector.getLastSet().getDocCount()); - assertEquals(4, collector.getLastSet().size()); + assertEquals(9L, collector.getLastSet().getDocCount()); + assertEquals(4, collector.getLastSet().size()); - // ignore set with same doc count but fewer items - assertEquals(9L, collector.add(longsRef(22L, 23L, 24L), 9L)); + // ignore set with same doc count but fewer items + assertEquals(9L, addToCollector(collector, new long[] { 22L, 23L, 24L }, 9L)); - assertEquals(9L, collector.getLastSet().getDocCount()); - assertEquals(4, collector.getLastSet().size()); + assertEquals(9L, collector.getLastSet().getDocCount()); + assertEquals(4, collector.getLastSet().size()); - // take set with same doc count but more items - assertEquals(9L, collector.add(longsRef(25L, 26L, 27L, 28L, 29L), 9L)); + // take set with same doc count but more items + assertEquals(9L, addToCollector(collector, new long[] { 25L, 26L, 27L, 28L, 29L }, 9L)); - assertEquals(9L, collector.getLastSet().getDocCount()); - assertEquals(5, collector.getLastSet().size()); + assertEquals(9L, collector.getLastSet().getDocCount()); + assertEquals(5, collector.getLastSet().size()); - FrequentItemSetPriorityQueue queue = collector.getQueue(); + FrequentItemSetPriorityQueue queue = collector.getQueue(); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 25L, 26L, 27L, 28L, 29L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 1L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 5L, 6L, 7L, 8L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 21L, 2L, 3L, 4L })); - assertThat(queue.pop().getItems().longs, equalTo(new long[] { 31L, 2L, 3L, 4L })); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 25L, 26L, 27L, 28L, 29L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 1L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 5L, 6L, 7L, 8L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 21L, 2L, 3L, 4L }))); + assertThat(queue.pop().getItems(), equalTo(createItemSetBitSet(new long[] { 31L, 2L, 3L, 4L }))); - assertEquals(0, collector.size()); + assertEquals(0, collector.size()); + } } - private static LongsRef longsRef(long l1, long l2, long l3) { - return new LongsRef(new long[] { l1, l2, l3 }, 0, 3); - } + private static ItemSetBitSet createItemSetBitSet(long[] longs) { + ItemSetBitSet itemsAsBitVector = new ItemSetBitSet(); + for (int i = 0; i < longs.length; ++i) { + itemsAsBitVector.set((int) longs[i]); + } - private static LongsRef longsRef(long l1, long l2, long l3, long l4) { - return new LongsRef(new long[] { l1, l2, l3, l4 }, 0, 4); + return itemsAsBitVector; } - private static LongsRef longsRef(long l1, long l2, long l3, long l4, long l5) { - return new LongsRef(new long[] { l1, l2, l3, l4, l5 }, 0, 5); + private static long addToCollector(FrequentItemSetCollector collector, long[] longsRef, long docCount) { + return collector.add(createItemSetBitSet(longsRef), docCount); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSetTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSetTests.java new file mode 100644 index 0000000000000..a0c599d1c6da7 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSetTests.java @@ -0,0 +1,213 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.frequentitemsets; + +import org.elasticsearch.test.ESTestCase; + +public class ItemSetBitSetTests extends ESTestCase { + + public void testBasics() { + ItemSetBitSet set = new ItemSetBitSet(); + set.set(0); + set.set(3); + set.set(200); + set.set(5); + set.set(65); + + assertTrue(set.get(0)); + assertFalse(set.get(1)); + assertFalse(set.get(2)); + assertTrue(set.get(3)); + assertTrue(set.get(5)); + assertFalse(set.get(64)); + assertTrue(set.get(65)); + assertTrue(set.get(200)); + + set.clear(0); + set.clear(65); + set.clear(5); + + assertFalse(set.get(0)); + assertFalse(set.get(1)); + assertFalse(set.get(2)); + assertTrue(set.get(3)); + assertFalse(set.get(5)); + assertFalse(set.get(64)); + assertFalse(set.get(65)); + assertTrue(set.get(200)); + } + + public void testIsSubSet() { + ItemSetBitSet set1 = new ItemSetBitSet(); + set1.set(0); + set1.set(3); + set1.set(200); + set1.set(5); + set1.set(65); + + ItemSetBitSet set2 = new ItemSetBitSet(); + set2.set(3); + set2.set(200); + set2.set(65); + + assertTrue(set2.isSubset(set1)); + assertFalse(set1.isSubset(set2)); + assertTrue(set1.isSubset(set1)); + + set2.set(0); + set2.set(5); + assertTrue(set2.isSubset(set1)); + assertTrue(set1.isSubset(set2)); + + set2.set(99); + assertFalse(set2.isSubset(set1)); + assertTrue(set1.isSubset(set2)); + + set1.set(999); + assertFalse(set1.isSubset(set2)); + set2.set(999); + assertTrue(set1.isSubset(set2)); + set2.set(2222); + assertTrue(set1.isSubset(set2)); + } + + public void testClone() { + ItemSetBitSet set1 = new ItemSetBitSet(); + set1.set(0); + set1.set(3); + set1.set(200); + set1.set(5); + set1.set(65); + + ItemSetBitSet set2 = (ItemSetBitSet) set1.clone(); + + assertTrue(set2.get(0)); + assertFalse(set2.get(1)); + assertFalse(set2.get(2)); + assertTrue(set2.get(3)); + assertTrue(set2.get(5)); + assertFalse(set2.get(64)); + assertTrue(set2.get(65)); + assertTrue(set2.get(200)); + + set1.clear(200); + assertTrue(set2.get(200)); + + set1.set(42); + assertTrue(set1.get(42)); + assertFalse(set2.get(42)); + } + + public void testReset() { + ItemSetBitSet set1 = new ItemSetBitSet(); + set1.set(0); + set1.set(3); + set1.set(200); + set1.set(5); + set1.set(65); + + ItemSetBitSet set2 = new ItemSetBitSet(); + set2.reset(set1); + assertEquals(set1, set2); + assertEquals(set1.cardinality(), set2.cardinality()); + + assertTrue(set2.get(0)); + assertFalse(set2.get(1)); + assertFalse(set2.get(2)); + assertTrue(set2.get(3)); + assertTrue(set2.get(5)); + assertFalse(set2.get(64)); + assertTrue(set2.get(65)); + assertTrue(set2.get(200)); + + set1.clear(200); + assertTrue(set2.get(200)); + + set1.set(42); + assertTrue(set1.get(42)); + assertFalse(set2.get(42)); + + set2.set(99999999); + assertTrue(set2.get(99999999)); + + ItemSetBitSet set3 = new ItemSetBitSet(); + set3.set(2); + set3.set(9); + set2.reset(set3); + + assertEquals(set3, set2); + assertFalse(set2.get(99999999)); + } + + public void testHashCode() { + ItemSetBitSet set1 = new ItemSetBitSet(); + set1.set(0); + set1.set(3); + set1.set(200); + set1.set(5); + set1.set(65); + + ItemSetBitSet set2 = new ItemSetBitSet(); + set2.reset(set1); + + assertEquals(set1.hashCode(), set2.hashCode()); + set2.set(99999999); + assertNotEquals(set1.hashCode(), set2.hashCode()); + set2.clear(99999999); + assertEquals(set1.hashCode(), set2.hashCode()); + } + + public void testCompare() { + ItemSetBitSet set1 = new ItemSetBitSet(); + set1.set(0); + set1.set(3); + ItemSetBitSet set2 = new ItemSetBitSet(); + set2.set(0); + + assertEquals(1, ItemSetBitSet.compare(set1, set2)); + assertEquals(-1, ItemSetBitSet.compare(set2, set1)); + + set2.set(3); + assertEquals(0, ItemSetBitSet.compare(set2, set1)); + set1.set(4); + set2.set(5); + + assertEquals(1, ItemSetBitSet.compare(set1, set2)); + set1.set(6); + set2.set(6); + assertEquals(1, ItemSetBitSet.compare(set1, set2)); + set1.set(7); + set2.set(8); + assertEquals(1, ItemSetBitSet.compare(set1, set2)); + + ItemSetBitSet set3 = new ItemSetBitSet(); + set3.set(2); + set3.set(3); + set3.set(4); + ItemSetBitSet set4 = new ItemSetBitSet(); + set4.set(2); + set4.set(3); + set4.set(4); + + set3.set(71); + set4.set(101); + assertEquals(1, ItemSetBitSet.compare(set3, set4)); + assertEquals(-1, ItemSetBitSet.compare(set4, set3)); + + set3.set(61); + assertEquals(1, ItemSetBitSet.compare(set3, set4)); + assertEquals(-1, ItemSetBitSet.compare(set4, set3)); + + set3.clear(71); + set4.set(101); + + assertEquals(1, ItemSetBitSet.compare(set3, set4)); + assertEquals(-1, ItemSetBitSet.compare(set4, set3)); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverserTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverserTests.java index d92bc65c02df4..64b142ca7e2ec 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverserTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverserTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.aggs.frequentitemsets.TransactionStore.TopItemIds; import org.elasticsearch.xpack.ml.aggs.frequentitemsets.mr.ItemSetMapReduceValueSource.Field; import org.junit.After; @@ -31,11 +32,10 @@ static BigArrays mockBigArrays() { } private HashBasedTransactionStore transactionStore = null; - private ItemSetTraverser it = null; @After public void closeReleasables() throws IOException { - Releasables.close(transactionStore, it); + Releasables.close(transactionStore); } public void testIteration() throws IOException { @@ -60,99 +60,166 @@ public void testIteration() throws IOException { // we don't want to prune transactionStore.prune(0.1); - it = new ItemSetTraverser(transactionStore.getTopItemIds()); - - /** - * items are sorted by frequency: - * d:8, b:7, c:5, a:4, e:3, f:2, g:1 - * this creates the following traversal tree: - * - * 1: d-->b-->c-->a-->e-->f-->g - * 2: | | `->g - * 3: | |`->f-->g - * 4: | `->g - * 5: |`->e-->f-->g - * 6: | `->g - * 7: |`->f-->g - * 8: `->g - * ... - */ - - assertTrue(it.next()); - assertEquals("d", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(1, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("b", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(2, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("c", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(3, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("a", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(4, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("e", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(6, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(7, it.getNumberOfItems()); - - // branch row 2 - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(6, it.getNumberOfItems()); - - // branch row 3 - assertTrue(it.next()); - assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(6, it.getNumberOfItems()); - - // branch row 4 - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - - // branch row 5 - assertTrue(it.next()); - assertEquals("e", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(4, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(6, it.getNumberOfItems()); - - // branch row 6 - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - - // branch row 7 - assertTrue(it.next()); - assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(4, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - - // branch row 8 - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(4, it.getNumberOfItems()); - - int furtherSteps = 0; - while (it.next()) { - ++furtherSteps; - } - assertEquals(109, furtherSteps); + try (TopItemIds topItemIds = transactionStore.getTopItemIds()) { + ItemSetTraverser it = new ItemSetTraverser(topItemIds); + + /** + * items are sorted by frequency: + * d:8, b:7, c:5, a:4, e:3, f:2, g:1 + * this creates the following traversal tree: + * + * 1: d-->b-->c-->a-->e-->f-->g + * 2: | | `->g + * 3: | |`->f-->g + * 4: | `->g + * 5: |`->e-->f-->g + * 6: | `->g + * 7: |`->f-->g + * 8: `->g + * ... + * + * bit representation: + * d:1, b:2, c:3, a:4, e:5, f:6, g:7 + */ + + assertTrue(it.next()); + assertEquals("d", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(1, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(1)); + assertTrue(it.next()); + assertEquals("b", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(2, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(2)); + assertTrue(it.next()); + assertEquals("c", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(3, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(3)); + assertTrue(it.next()); + assertEquals("a", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(4, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(4)); + assertFalse(it.getParentItemSetBitSet().get(4)); + assertTrue(it.next()); + assertEquals("e", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(5)); + assertFalse(it.getParentItemSetBitSet().get(5)); + assertTrue(it.getParentItemSetBitSet().get(4)); + assertTrue(it.next()); + assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(6, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(6)); + assertFalse(it.getParentItemSetBitSet().get(6)); + assertTrue(it.getParentItemSetBitSet().get(5)); + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(7, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(7)); + assertFalse(it.getParentItemSetBitSet().get(7)); + assertTrue(it.getParentItemSetBitSet().get(6)); + + // branch row 2 + it.next(); + // assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(6, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(7)); + assertFalse(it.getItemSetBitSet().get(6)); + assertFalse(it.getParentItemSetBitSet().get(6)); + assertFalse(it.getParentItemSetBitSet().get(7)); + + // branch row 3 + assertTrue(it.next()); + assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(6)); + assertFalse(it.getItemSetBitSet().get(5)); + assertFalse(it.getItemSetBitSet().get(7)); + assertFalse(it.getParentItemSetBitSet().get(5)); + assertFalse(it.getParentItemSetBitSet().get(6)); + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(6, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(7)); + assertTrue(it.getItemSetBitSet().get(6)); + assertFalse(it.getItemSetBitSet().get(5)); + assertTrue(it.getParentItemSetBitSet().get(6)); + assertFalse(it.getParentItemSetBitSet().get(7)); + assertFalse(it.getParentItemSetBitSet().get(5)); + + // branch row 4 + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + + // branch row 5 + assertTrue(it.next()); + assertEquals("e", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(4, it.getNumberOfItems()); + assertTrue(it.next()); + assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(6, it.getNumberOfItems()); + + // branch row 6: "dbceg" + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(1)); + assertTrue(it.getItemSetBitSet().get(2)); + assertTrue(it.getItemSetBitSet().get(3)); + assertFalse(it.getItemSetBitSet().get(4)); + assertTrue(it.getItemSetBitSet().get(5)); + assertFalse(it.getItemSetBitSet().get(6)); + assertTrue(it.getItemSetBitSet().get(7)); + + assertTrue(it.getParentItemSetBitSet().get(1)); + assertTrue(it.getParentItemSetBitSet().get(2)); + assertTrue(it.getParentItemSetBitSet().get(3)); + assertFalse(it.getParentItemSetBitSet().get(4)); + assertTrue(it.getParentItemSetBitSet().get(5)); + assertFalse(it.getParentItemSetBitSet().get(6)); + assertFalse(it.getParentItemSetBitSet().get(7)); + + // branch row 7 + assertTrue(it.next()); + assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(4, it.getNumberOfItems()); + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + + // branch row 8: "dbcg" + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(4, it.getNumberOfItems()); + + assertTrue(it.getItemSetBitSet().get(1)); + assertTrue(it.getItemSetBitSet().get(2)); + assertTrue(it.getItemSetBitSet().get(3)); + assertFalse(it.getItemSetBitSet().get(4)); + assertFalse(it.getItemSetBitSet().get(5)); + assertFalse(it.getItemSetBitSet().get(6)); + assertTrue(it.getItemSetBitSet().get(7)); + + assertTrue(it.getParentItemSetBitSet().get(1)); + assertTrue(it.getParentItemSetBitSet().get(2)); + assertTrue(it.getParentItemSetBitSet().get(3)); + assertFalse(it.getParentItemSetBitSet().get(4)); + assertFalse(it.getParentItemSetBitSet().get(5)); + assertFalse(it.getParentItemSetBitSet().get(6)); + assertFalse(it.getParentItemSetBitSet().get(7)); + + int furtherSteps = 0; + while (it.next()) { + ++furtherSteps; + } + + assertEquals(109, furtherSteps); + } } public void testPruning() throws IOException { @@ -177,92 +244,133 @@ public void testPruning() throws IOException { // we don't want to prune transactionStore.prune(0.1); - it = new ItemSetTraverser(transactionStore.getTopItemIds()); - - /** - * items are sorted by frequency: - * d:8, b:7, c:5, a:4, e:3, f:2, g:1 - * this creates the following traversal tree: - * - * this item we prune the tree in various places marked with "[", "]" - * - * 1: d-->b-->c-->a-->e[-->f-->g ] - * 2: | | [`->g ] - * 3: | |`->f-->g - * 4: | `->g - * 5: |`->e-->f-->g - * 6: | `->g - * 7: |`->f-->g - * 8: `->g - * ... - */ - - assertTrue(it.next()); - assertEquals("d", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(1, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("b", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(2, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("c", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(3, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("a", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(4, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("e", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - - // now prune the tree - it.prune(); - - // branch row 3 - assertTrue(it.next()); - assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(6, it.getNumberOfItems()); - - // prune, which actually is ineffective, as we would go up anyway - it.prune(); - - // branch row 4 - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - - // branch row 5 - assertTrue(it.next()); - assertEquals("e", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(4, it.getNumberOfItems()); - - // prune - it.prune(); - - // branch row 7 - assertTrue(it.next()); - assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(4, it.getNumberOfItems()); - assertTrue(it.next()); - assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); - assertEquals(5, it.getNumberOfItems()); - - // prune aggressively - it.prune(); - it.prune(); - it.prune(); - it.prune(); - it.prune(); - it.prune(); - it.prune(); - - int furtherSteps = 0; - while (it.next()) { - ++furtherSteps; + try (TopItemIds topItemIds = transactionStore.getTopItemIds()) { + ItemSetTraverser it = new ItemSetTraverser(topItemIds); + + /** + * items are sorted by frequency: + * d:8, b:7, c:5, a:4, e:3, f:2, g:1 + * this creates the following traversal tree: + * + * this item we prune the tree in various places marked with "[", "]" + * + * 1: d-->b-->c-->a-->e[-->f-->g ] + * 2: | | [`->g ] + * 3: | |`->f-->g + * 4: | `->g + * 5: |`->e-->f-->g + * 6: | `->g + * 7: |`->f-->g + * 8: `->g + * ... + * + * bit representation: + * d:1, b:2, c:3, a:4, e:5, f:6, g:7 + */ + + assertTrue(it.next()); + assertEquals("d", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(1, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(1)); + assertTrue(it.next()); + assertEquals("b", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(2, it.getNumberOfItems()); + assertTrue(it.next()); + assertEquals("c", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(3, it.getNumberOfItems()); + assertTrue(it.next()); + assertEquals("a", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(4, it.getNumberOfItems()); + assertTrue(it.next()); + assertEquals("e", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(1)); + assertTrue(it.getItemSetBitSet().get(2)); + assertTrue(it.getItemSetBitSet().get(3)); + assertTrue(it.getItemSetBitSet().get(4)); + assertTrue(it.getItemSetBitSet().get(5)); + assertTrue(it.getParentItemSetBitSet().get(1)); + assertTrue(it.getParentItemSetBitSet().get(2)); + assertTrue(it.getParentItemSetBitSet().get(3)); + assertTrue(it.getParentItemSetBitSet().get(4)); + assertFalse(it.getParentItemSetBitSet().get(5)); + + // now prune the tree + it.prune(); + + // branch row 3 + assertTrue(it.next()); + assertTrue(it.getItemSetBitSet().get(1)); + assertTrue(it.getItemSetBitSet().get(2)); + assertTrue(it.getItemSetBitSet().get(3)); + assertTrue(it.getItemSetBitSet().get(4)); + assertFalse(it.getItemSetBitSet().get(5)); + assertTrue(it.getItemSetBitSet().get(6)); + assertTrue(it.getParentItemSetBitSet().get(1)); + assertTrue(it.getParentItemSetBitSet().get(2)); + assertTrue(it.getParentItemSetBitSet().get(3)); + assertTrue(it.getParentItemSetBitSet().get(4)); + assertFalse(it.getParentItemSetBitSet().get(5)); + assertFalse(it.getParentItemSetBitSet().get(6)); + assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(6, it.getNumberOfItems()); + + // prune, which actually is ineffective, as we would go up anyway + it.prune(); + + // branch row 4 + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + + // branch row 5 + assertTrue(it.next()); + assertEquals("e", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(4, it.getNumberOfItems()); + + // prune + it.prune(); + + // branch row 7 + assertTrue(it.next()); + assertEquals("f", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(4, it.getNumberOfItems()); + assertTrue(it.getItemSetBitSet().get(1)); + assertTrue(it.getItemSetBitSet().get(2)); + assertTrue(it.getItemSetBitSet().get(3)); + assertFalse(it.getItemSetBitSet().get(4)); + assertFalse(it.getItemSetBitSet().get(5)); + assertTrue(it.getItemSetBitSet().get(6)); + + assertTrue(it.getParentItemSetBitSet().get(1)); + assertTrue(it.getParentItemSetBitSet().get(2)); + assertTrue(it.getParentItemSetBitSet().get(3)); + assertFalse(it.getParentItemSetBitSet().get(4)); + assertFalse(it.getParentItemSetBitSet().get(5)); + assertFalse(it.getParentItemSetBitSet().get(6)); + assertTrue(it.next()); + assertEquals("g", transactionStore.getItem(it.getItemId()).v2()); + assertEquals(5, it.getNumberOfItems()); + + // prune aggressively + it.prune(); + it.prune(); + it.prune(); + it.prune(); + it.prune(); + it.prune(); + it.prune(); + + int furtherSteps = 0; + while (it.next()) { + ++furtherSteps; + } + + assertEquals(0, furtherSteps); } - - assertEquals(0, furtherSteps); } } From 3d14701b9e3f237190d027b91694e74d0dd133b3 Mon Sep 17 00:00:00 2001 From: Hendrik Muhs Date: Fri, 29 Jul 2022 15:37:30 +0200 Subject: [PATCH 2/4] remove dead code --- .../xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java index f78031306ba74..031deec36b2a8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java @@ -39,17 +39,6 @@ class ItemSetBitSet implements Cloneable { initWords(nbits); } - /*private ItemSetBitSet(long[] words) { - this.words = words; - this.wordsInUse = words.length; - } - - public static ItemSetBitSet valueOf(long[] longs) { - int n; - for (n = longs.length; n > 0 && longs[n - 1] == 0; n--) - ; - return new ItemSetBitSet(Arrays.copyOf(longs, n)); - }*/ void reset(ItemSetBitSet bitSet) { ensureCapacity(bitSet.wordsInUse); System.arraycopy(bitSet.words, 0, this.words, 0, bitSet.wordsInUse); From f5f7a31b20b85fe8ade185bd51747bff01d311a8 Mon Sep 17 00:00:00 2001 From: Hendrik Muhs Date: Fri, 29 Jul 2022 15:40:41 +0200 Subject: [PATCH 3/4] Update docs/changelog/88943.yaml --- docs/changelog/88943.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/88943.yaml diff --git a/docs/changelog/88943.yaml b/docs/changelog/88943.yaml new file mode 100644 index 0000000000000..63dd57750ffb5 --- /dev/null +++ b/docs/changelog/88943.yaml @@ -0,0 +1,5 @@ +pr: 88943 +summary: "Frequent Items: use a bitset for deduplication" +area: Machine Learning +type: enhancement +issues: [] From a3ffc1ccae0346159a240ccb03d6bfc40148ea62 Mon Sep 17 00:00:00 2001 From: Hendrik Muhs Date: Mon, 1 Aug 2022 13:01:01 +0200 Subject: [PATCH 4/4] use ArrayUtil for array resizing, add test for itemsetbitset cardinality --- .../aggs/frequentitemsets/ItemSetBitSet.java | 20 ++++++------- .../frequentitemsets/ItemSetTraverser.java | 11 ++----- .../frequentitemsets/ItemSetBitSetTests.java | 29 +++++++++++++++++++ 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java index 031deec36b2a8..9a87fad024101 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSet.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.ml.aggs.frequentitemsets; +import org.apache.lucene.util.ArrayUtil; + import java.util.Arrays; /** @@ -14,6 +16,12 @@ * * Unfortunately other {@code BitSet} implementation, e.g. java.util, * lack a subset check. + * + * For this implementation I took the code from {@code BitSet}, removed + * unnecessary parts and added additional functionality like the subset check. + * Cardinality - the number of set bits == number of items - is used a lot. + * The original {@code BitSet} uses a scan, this implementation uses + * a counter for faster retrieval. */ class ItemSetBitSet implements Cloneable { @@ -40,7 +48,7 @@ class ItemSetBitSet implements Cloneable { } void reset(ItemSetBitSet bitSet) { - ensureCapacity(bitSet.wordsInUse); + words = ArrayUtil.grow(words, bitSet.wordsInUse); System.arraycopy(bitSet.words, 0, this.words, 0, bitSet.wordsInUse); this.cardinality = bitSet.cardinality; this.wordsInUse = bitSet.wordsInUse; @@ -208,14 +216,6 @@ private void initWords(int nbits) { words = new long[wordIndex(nbits - 1) + 1]; } - private void ensureCapacity(int wordsRequired) { - if (words.length < wordsRequired) { - // Allocate larger of doubled size or required size - int request = Math.max(2 * words.length, wordsRequired); - words = Arrays.copyOf(words, request); - } - } - private void recalculateWordsInUse() { // Traverse the bitset until a used word is found int i; @@ -228,7 +228,7 @@ private void recalculateWordsInUse() { private void expandTo(int wordIndex) { int wordsRequired = wordIndex + 1; if (wordsInUse < wordsRequired) { - ensureCapacity(wordsRequired); + words = ArrayUtil.grow(words, wordsRequired); wordsInUse = wordsRequired; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverser.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverser.java index a69d8c31a0116..41d8c43fd4ffd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverser.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetTraverser.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.aggs.frequentitemsets; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.IntsRef; import org.apache.lucene.util.LongsRef; @@ -163,17 +164,11 @@ public void prune() { private void growStacksIfNecessary() { if (itemIdStack.longs.length == itemIdStack.length) { - LongsRef resizedItemIdStack = new LongsRef(itemIdStack.length + SIZE_INCREMENT); - System.arraycopy(itemIdStack.longs, 0, resizedItemIdStack.longs, 0, itemIdStack.length); - resizedItemIdStack.length = itemIdStack.length; - itemIdStack = resizedItemIdStack; + itemIdStack.longs = ArrayUtil.grow(itemIdStack.longs, itemIdStack.length + SIZE_INCREMENT); } if (itemPositionsStack.ints.length == itemPositionsStack.length) { - IntsRef resizeditemPositionsStack = new IntsRef(itemPositionsStack.length + SIZE_INCREMENT); - System.arraycopy(itemPositionsStack.ints, 0, resizeditemPositionsStack.ints, 0, itemPositionsStack.length); - resizeditemPositionsStack.length = itemPositionsStack.length; - itemPositionsStack = resizeditemPositionsStack; + itemPositionsStack.ints = ArrayUtil.grow(itemPositionsStack.ints, itemPositionsStack.length + SIZE_INCREMENT); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSetTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSetTests.java index a0c599d1c6da7..b70775391f122 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSetTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/ItemSetBitSetTests.java @@ -19,6 +19,7 @@ public void testBasics() { set.set(5); set.set(65); + assertEquals(5, set.cardinality()); assertTrue(set.get(0)); assertFalse(set.get(1)); assertFalse(set.get(2)); @@ -31,6 +32,7 @@ public void testBasics() { set.clear(0); set.clear(65); set.clear(5); + assertEquals(2, set.cardinality()); assertFalse(set.get(0)); assertFalse(set.get(1)); @@ -50,11 +52,13 @@ public void testIsSubSet() { set1.set(5); set1.set(65); + assertEquals(5, set1.cardinality()); ItemSetBitSet set2 = new ItemSetBitSet(); set2.set(3); set2.set(200); set2.set(65); + assertEquals(3, set2.cardinality()); assertTrue(set2.isSubset(set1)); assertFalse(set1.isSubset(set2)); assertTrue(set1.isSubset(set1)); @@ -85,6 +89,7 @@ public void testClone() { set1.set(65); ItemSetBitSet set2 = (ItemSetBitSet) set1.clone(); + assertEquals(5, set2.cardinality()); assertTrue(set2.get(0)); assertFalse(set2.get(1)); @@ -144,6 +149,30 @@ public void testReset() { assertFalse(set2.get(99999999)); } + public void testCardinality() { + ItemSetBitSet set = new ItemSetBitSet(); + set.set(0); + set.set(3); + set.set(200); + set.set(5); + set.set(65); + + assertEquals(5, set.cardinality()); + set.clear(1); + assertEquals(5, set.cardinality()); + set.clear(200); + assertEquals(4, set.cardinality()); + set.set(204); + set.set(204); + set.set(204); + set.set(204); + assertEquals(5, set.cardinality()); + ItemSetBitSet set2 = new ItemSetBitSet(); + set.reset(set2); + assertEquals(0, set.cardinality()); + set.clear(999); + } + public void testHashCode() { ItemSetBitSet set1 = new ItemSetBitSet(); set1.set(0);