diff --git a/include/model/CTokenListDataCategorizerBase.h b/include/model/CTokenListDataCategorizerBase.h index e53ffe7af9..1b295b041f 100644 --- a/include/model/CTokenListDataCategorizerBase.h +++ b/include/model/CTokenListDataCategorizerBase.h @@ -324,7 +324,8 @@ class MODEL_EXPORT CTokenListDataCategorizerBase : public CDataCategorizer { TTokenListCategoryVec m_Categories; //! List of match count/index into category vector in descending order of - //! match count + //! match count. Note that the second element is an index into m_Categories, + //! not a category ID. TSizeSizePrVec m_CategoriesByCount; //! Used for looking up tokens to a unique ID diff --git a/lib/model/CTokenListDataCategorizerBase.cc b/lib/model/CTokenListDataCategorizerBase.cc index 780737ff92..fdfe0d99ad 100644 --- a/lib/model/CTokenListDataCategorizerBase.cc +++ b/lib/model/CTokenListDataCategorizerBase.cc @@ -179,6 +179,9 @@ bool CTokenListDataCategorizerBase::createReverseSearch(int categoryId, std::string& part2, std::size_t& maxMatchingLength, bool& wasCached) { + wasCached = false; + maxMatchingLength = 0; + if (m_ReverseSearchCreator == nullptr) { LOG_ERROR(<< "Cannot create reverse search - no reverse search creator"); @@ -207,8 +210,8 @@ bool CTokenListDataCategorizerBase::createReverseSearch(int categoryId, maxMatchingLength = category.maxMatchingStringLen(); // If we can retrieve cached reverse search terms we'll save a lot of time - if (category.cachedReverseSearch(part1, part2) == true) { - wasCached = true; + wasCached = category.cachedReverseSearch(part1, part2); + if (wasCached) { return true; } @@ -634,17 +637,17 @@ CDataCategorizer::TIntVec CTokenListDataCategorizerBase::usurpedCategories(int c } auto iter = std::find_if(m_CategoriesByCount.begin(), m_CategoriesByCount.end(), [categoryId](const TSizeSizePr& pr) { - return pr.second == static_cast(categoryId); + return pr.second == + static_cast(categoryId - 1); }); if (iter == m_CategoriesByCount.end()) { LOG_WARN(<< "Could not find category definition for category: " << categoryId); return usurped; } - ++iter; + const CTokenListCategory& category{m_Categories[categoryId - 1]}; - for (; iter != m_CategoriesByCount.end(); ++iter) { - const CTokenListCategory& lessFrequentCategory{ - m_Categories[static_cast(iter->second) - 1]}; + for (++iter; iter != m_CategoriesByCount.end(); ++iter) { + const CTokenListCategory& lessFrequentCategory{m_Categories[iter->second]}; bool matchesSearch{category.maxMatchingStringLen() >= lessFrequentCategory.maxMatchingStringLen() && category.isMissingCommonTokenWeightZero( @@ -652,9 +655,10 @@ CDataCategorizer::TIntVec CTokenListDataCategorizerBase::usurpedCategories(int c category.containsCommonInOrderTokensInOrder( lessFrequentCategory.baseTokenIds())}; if (matchesSearch) { - usurped.emplace_back(static_cast(iter->second)); + usurped.emplace_back(1 + static_cast(iter->second)); } } + std::sort(usurped.begin(), usurped.end()); return usurped; } diff --git a/lib/model/unittest/CTokenListDataCategorizerTest.cc b/lib/model/unittest/CTokenListDataCategorizerTest.cc index 5dcd77d16e..664366206d 100644 --- a/lib/model/unittest/CTokenListDataCategorizerTest.cc +++ b/lib/model/unittest/CTokenListDataCategorizerTest.cc @@ -4,6 +4,7 @@ * you may not use this file except in compliance with the Elastic License. */ +#include #include #include #include @@ -542,13 +543,13 @@ BOOST_FIXTURE_TEST_CASE(testUsurpedCategories, CTestFixture) { 500)); BOOST_REQUIRE_EQUAL(2, categorizer.numMatches(1)); - std::vector expected{2, 3, 4, 5, 6}; - std::vector actual = categorizer.usurpedCategories(1); - BOOST_REQUIRE_EQUAL(expected.size(), actual.size()); - for (std::size_t i = 0; i < actual.size(); i++) { - BOOST_REQUIRE_EQUAL(expected[i], actual[i]); - } + using TIntVec = std::vector; + TIntVec expected{2, 3, 4, 5, 6, 7}; + TIntVec actual{categorizer.usurpedCategories(1)}; + + BOOST_REQUIRE_EQUAL(ml::core::CContainerPrinter::print(expected), + ml::core::CContainerPrinter::print(actual)); checkMemoryUsageInstrumentation(categorizer); }