Skip to content

Commit abbc996

Browse files
authored
several optimization in cvedb and csv2cve files and fixes issue #413 (#477)
1 parent c168cc9 commit abbc996

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

cve_bin_tool/csv2cve.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,12 @@ def csv2cve(filename):
5353
LOGGER.error("Error: invalid CSV")
5454
return ERR_BADCSV
5555

56-
required_columns = ["vendor", "product", "version"]
57-
for column in required_columns:
58-
if column not in csvdata.fieldnames:
59-
LOGGER.error(f"Error: no {column} column found")
60-
return ERR_MISSINGCOLUMN
56+
required_columns = {"vendor", "product", "version"}
57+
csv_columns = set(csvdata.fieldnames)
58+
missing_columns = required_columns - csv_columns
59+
if missing_columns != set():
60+
LOGGER.error(f"Error: no {missing_columns} columns found")
61+
return ERR_MISSINGCOLUMN
6162

6263
# Initialize the NVD database
6364
cvedb = CVEDB()
@@ -70,8 +71,14 @@ def csv2cve(filename):
7071
)
7172
cves = cvedb.get_cves(row["vendor"], row["product"], row["version"])
7273
if cves:
73-
LOGGER.info("\n".join(sorted(cves.keys())))
74-
cveoutput.append(cves.keys())
74+
s = map(
75+
lambda key: " ".join(
76+
[row["vendor"], row["product"], row["version"], key, cves[key]]
77+
),
78+
sorted(cves.keys()),
79+
)
80+
LOGGER.info("\n".join(s))
81+
cveoutput.append(cves)
7582
else:
7683
LOGGER.debug("No CVEs found. Is the vendor/product info correct?")
7784
LOGGER.debug("")

cve_bin_tool/cvedb.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import urllib2 as request
2424

2525
from collections import namedtuple
26-
from .log import LOGGER
26+
from string import ascii_lowercase
27+
from cve_bin_tool.log import LOGGER
2728
from pkg_resources import parse_version
2829

2930
logging.basicConfig(level=logging.DEBUG)
@@ -268,10 +269,7 @@ def get_cves(self, vendor, product, version):
268269
query = """SELECT CVE_number FROM cve_range WHERE
269270
vendor=? AND product=? AND version=?"""
270271
cursor.execute(query, [vendor, product, version])
271-
# FIXME: this seems inefficient
272-
cve_list = []
273-
for cve in cursor:
274-
cve_list.append(cve[0])
272+
cve_list = list(map(lambda x: x[0], cursor.fetchall()))
275273

276274
# Check for any ranges
277275
query = """SELECT CVE_number, versionStartIncluding, versionStartExcluding, versionEndIncluding, versionEndExcluding FROM cve_range WHERE
@@ -297,6 +295,7 @@ def get_cves(self, vendor, product, version):
297295
versionEndExcluding = self.openssl_convert(versionEndExcluding)
298296

299297
parsed_version = parse_version(version)
298+
300299
# check the start range
301300
passes_start = False
302301
if (
@@ -336,21 +335,17 @@ def get_cves(self, vendor, product, version):
336335
):
337336
# then there is no end range so it passes
338337
passes_end = True
339-
340338
# if it fits into both ends of the range, add the cve number
341339
if passes_start and passes_end:
342340
cve_list.append(cve_number)
343341

344342
# Go through and get all the severities
345343
if cve_list:
346-
query = f'SELECT CVE_number, severity from cve_severity where CVE_number IN ({",".join("?" for i in cve_list)}) ORDER BY CVE_number ASC'
344+
query = f'SELECT CVE_number, severity from cve_severity where CVE_number IN ({",".join(["?"]*len(cve_list))}) ORDER BY CVE_number ASC'
347345
cursor.execute(query, cve_list)
348346
# Everything expects a data structure of cve[number] = severity so you can search through keys
349347
# and do other easy manipulations
350-
severity_list = dict()
351-
for cve_id, severity in cursor:
352-
severity_list[cve_id] = severity
353-
return severity_list
348+
return dict(cursor)
354349

355350
return cve_list
356351

@@ -361,10 +356,7 @@ def openssl_convert(self, version):
361356
return version
362357

363358
lastchar = version[len(version) - 1]
364-
letters = {
365-
letter: str(index)
366-
for index, letter in enumerate("abcdefghijklmnopqrstuvwxyz")
367-
}
359+
letters = dict(zip(ascii_lowercase, range(26)))
368360

369361
if lastchar in letters:
370362
version = f"{version[0 : len(version) - 1]}.{letters[lastchar]}"

0 commit comments

Comments
 (0)