Skip to content

Commit 6595826

Browse files
authored
[ML] Fix off-by-one error in usurped categories (#1122)
The mistake was that m_CategoriesByCount has indexes into m_Categories, not category IDs (which are one more than the indices so that they start at 1 rather than 0). Fixes #1121
1 parent a7f071a commit 6595826

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

include/model/CTokenListDataCategorizerBase.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ class MODEL_EXPORT CTokenListDataCategorizerBase : public CDataCategorizer {
324324
TTokenListCategoryVec m_Categories;
325325

326326
//! List of match count/index into category vector in descending order of
327-
//! match count
327+
//! match count. Note that the second element is an index into m_Categories,
328+
//! not a category ID.
328329
TSizeSizePrVec m_CategoriesByCount;
329330

330331
//! Used for looking up tokens to a unique ID

lib/model/CTokenListDataCategorizerBase.cc

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ bool CTokenListDataCategorizerBase::createReverseSearch(int categoryId,
179179
std::string& part2,
180180
std::size_t& maxMatchingLength,
181181
bool& wasCached) {
182+
wasCached = false;
183+
maxMatchingLength = 0;
184+
182185
if (m_ReverseSearchCreator == nullptr) {
183186
LOG_ERROR(<< "Cannot create reverse search - no reverse search creator");
184187

@@ -207,8 +210,8 @@ bool CTokenListDataCategorizerBase::createReverseSearch(int categoryId,
207210
maxMatchingLength = category.maxMatchingStringLen();
208211

209212
// If we can retrieve cached reverse search terms we'll save a lot of time
210-
if (category.cachedReverseSearch(part1, part2) == true) {
211-
wasCached = true;
213+
wasCached = category.cachedReverseSearch(part1, part2);
214+
if (wasCached) {
212215
return true;
213216
}
214217

@@ -634,27 +637,28 @@ CDataCategorizer::TIntVec CTokenListDataCategorizerBase::usurpedCategories(int c
634637
}
635638
auto iter = std::find_if(m_CategoriesByCount.begin(), m_CategoriesByCount.end(),
636639
[categoryId](const TSizeSizePr& pr) {
637-
return pr.second == static_cast<std::size_t>(categoryId);
640+
return pr.second ==
641+
static_cast<std::size_t>(categoryId - 1);
638642
});
639643
if (iter == m_CategoriesByCount.end()) {
640644
LOG_WARN(<< "Could not find category definition for category: " << categoryId);
641645
return usurped;
642646
}
643-
++iter;
647+
644648
const CTokenListCategory& category{m_Categories[categoryId - 1]};
645-
for (; iter != m_CategoriesByCount.end(); ++iter) {
646-
const CTokenListCategory& lessFrequentCategory{
647-
m_Categories[static_cast<int>(iter->second) - 1]};
649+
for (++iter; iter != m_CategoriesByCount.end(); ++iter) {
650+
const CTokenListCategory& lessFrequentCategory{m_Categories[iter->second]};
648651
bool matchesSearch{category.maxMatchingStringLen() >=
649652
lessFrequentCategory.maxMatchingStringLen() &&
650653
category.isMissingCommonTokenWeightZero(
651654
lessFrequentCategory.commonUniqueTokenIds()) &&
652655
category.containsCommonInOrderTokensInOrder(
653656
lessFrequentCategory.baseTokenIds())};
654657
if (matchesSearch) {
655-
usurped.emplace_back(static_cast<int>(iter->second));
658+
usurped.emplace_back(1 + static_cast<int>(iter->second));
656659
}
657660
}
661+
std::sort(usurped.begin(), usurped.end());
658662
return usurped;
659663
}
660664

lib/model/unittest/CTokenListDataCategorizerTest.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
* you may not use this file except in compliance with the Elastic License.
55
*/
66

7+
#include <core/CContainerPrinter.h>
78
#include <core/CLogger.h>
89
#include <core/CRapidXmlParser.h>
910
#include <core/CRapidXmlStatePersistInserter.h>
@@ -542,13 +543,13 @@ BOOST_FIXTURE_TEST_CASE(testUsurpedCategories, CTestFixture) {
542543
500));
543544

544545
BOOST_REQUIRE_EQUAL(2, categorizer.numMatches(1));
545-
std::vector<int> expected{2, 3, 4, 5, 6};
546-
std::vector<int> actual = categorizer.usurpedCategories(1);
547546

548-
BOOST_REQUIRE_EQUAL(expected.size(), actual.size());
549-
for (std::size_t i = 0; i < actual.size(); i++) {
550-
BOOST_REQUIRE_EQUAL(expected[i], actual[i]);
551-
}
547+
using TIntVec = std::vector<int>;
548+
TIntVec expected{2, 3, 4, 5, 6, 7};
549+
TIntVec actual{categorizer.usurpedCategories(1)};
550+
551+
BOOST_REQUIRE_EQUAL(ml::core::CContainerPrinter::print(expected),
552+
ml::core::CContainerPrinter::print(actual));
552553
checkMemoryUsageInstrumentation(categorizer);
553554
}
554555

0 commit comments

Comments
 (0)