Skip to content

Commit 2e7a8b8

Browse files
authored
fix: enhance test_checkers to cover new situations (#4942)
* fix: enhance test_checkers to cover new situations #4940 * fix: enhance test_checkers to cover new situations
1 parent e8a0f7f commit 2e7a8b8

File tree

3 files changed

+42
-23
lines changed

3 files changed

+42
-23
lines changed

cve_bin_tool/checkers/sqlite.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,6 @@ def get_versions(self, lines, filename):
9797
# https://www.sqlite.org/c3ref/c_source_id.html
9898
if mapping[1][:-4] in lines:
9999
# overwrite version with the version found by sha mapping
100-
version_info.versions.append(mapping[0])
100+
version_info.versions.add(mapping[0])
101101

102102
return version_info

cve_bin_tool/util.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -215,22 +215,22 @@ class VersionMatchInfo:
215215
attributes:
216216
matched_filename: bool
217217
matched_contains: bool
218-
versions: list[str]
218+
versions: set[str]
219219
"""
220220

221221
matched_filename: bool
222222
matched_contains: bool
223-
versions: list[str]
223+
versions: set[str]
224224

225225
def __init__(
226226
self,
227227
matched_filename: bool = False,
228228
matched_contains: bool = False,
229-
versions: list[str] = [],
229+
versions: set[str] | None = None,
230230
):
231231
self.matched_filename = matched_filename
232232
self.matched_contains = matched_contains
233-
self.versions = versions or []
233+
self.versions = versions if versions is not None else set()
234234

235235

236236
class VersionInfo(NamedTuple):
@@ -280,24 +280,22 @@ def __missing__(self, key: str) -> list[CVE] | set[str]:
280280

281281
def regex_find(
282282
lines: str, version_patterns: list[Pattern[str]], ignore: list[Pattern[str]]
283-
) -> list[str]:
283+
) -> set[str]:
284284
"""Search a set of lines to find all matches for the given regex"""
285-
versions = list()
285+
versions = set()
286286

287287
for pattern in version_patterns:
288288
version_matches = pattern.finditer(lines)
289289
for match in version_matches:
290290
# before collecting a potential version number, ensure the version string isn't on the ignore list
291-
if not check_ignored(match.string, ignore):
291+
matched_text = match.group(0)
292+
if not check_ignored(matched_text, ignore):
292293
version = match.group(1).strip()
293294
version = version.replace("_", ".").replace("-", ".")
294-
versions.append(version)
295+
versions.add(version)
295296

296-
# if we searched and found no matches, just return a one-element list containing "UNKNOWN"
297-
if not versions:
298-
versions.append("UNKNOWN")
299-
300-
return versions
297+
# if we searched and found no matches, just return a one-element set containing "UNKNOWN"
298+
return versions if versions else {"UNKNOWN"}
301299

302300

303301
def check_ignored(possible_version_string: str, ignore: list[Pattern[str]]) -> bool:

test/test_checkers.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,33 @@ class MyFakeChecker(Checker):
145145
VENDOR_PRODUCT = [("my", "checker")]
146146
IGNORE_PATTERNS = [r"mychecker-5.6"]
147147

148-
string = "Some other lines. \n Ignore this version pattern mychecker-5.6. \n Consider this version pattern mychecker-5.8. \n Some more lines."
149-
lines = string.split("\n")
150-
checker = MyFakeChecker()
151-
152-
result1 = checker.get_versions(lines[1], "")
153-
assert result1.versions[0] == "UNKNOWN"
154-
155-
result2 = checker.get_versions(lines[2], "")
156-
assert result2.versions[0] == "5.8"
148+
class TestVersionParsing:
149+
"""Unit tests for MyFakeChecker class, validating version parsing and detection logic."""
150+
151+
@pytest.mark.parametrize(
152+
"description, input_text, expected_versions",
153+
[
154+
(
155+
"Multiple valid versions",
156+
"mychecker-5.8\nmychecker-6.0",
157+
{"5.8", "6.0"},
158+
),
159+
("All versions ignored", "mychecker-5.6\nmychecker-5.6", {"UNKNOWN"}),
160+
(
161+
"Mixed valid and ignored versions",
162+
"mychecker-5.8\nmychecker-5.6\nmychecker-6.0",
163+
{"5.8", "6.0"},
164+
),
165+
("No version pattern", "random text", set()),
166+
(
167+
"Duplicate versions",
168+
"mychecker-5.8\nmychecker-5.8\nmychecker-6.0",
169+
{"5.8", "6.0"},
170+
),
171+
],
172+
)
173+
def test_checker_version(self, description, input_text, expected_versions):
174+
"""Test get_versions logic using MyFakeChecker class, ensuring multiple versions are processed correctly."""
175+
checker = TestCheckerVersionParser.MyFakeChecker()
176+
result = checker.get_versions(input_text, "")
177+
assert result.versions == expected_versions, f"{description} failed"

0 commit comments

Comments
 (0)