diff --git a/cve_bin_tool/cli.py b/cve_bin_tool/cli.py index f1bfd7342e..a4c78d40dd 100644 --- a/cve_bin_tool/cli.py +++ b/cve_bin_tool/cli.py @@ -48,8 +48,8 @@ curl_source, epss_source, gad_source, + new_osv_source, nvd_source, - osv_source, purl2cpe_source, redhat_source, ) @@ -782,7 +782,7 @@ def main(argv=None): enabled_sources = [] if "OSV" not in disabled_sources: - source_osv = osv_source.OSV_Source(incremental_update=incremental_db_update) + source_osv = new_osv_source.OSVDataSource() enabled_sources.append(source_osv) if "GAD" not in disabled_sources: diff --git a/cve_bin_tool/data_sources/new_osv_source.py b/cve_bin_tool/data_sources/new_osv_source.py new file mode 100644 index 0000000000..0d71450d4b --- /dev/null +++ b/cve_bin_tool/data_sources/new_osv_source.py @@ -0,0 +1,249 @@ +import asyncio +import io +import json +import os +import pathlib +import zipfile + +import aiohttp +from cvss import CVSS3 +from google.auth.credentials import AnonymousCredentials # type: ignore[import-untyped] +from google.cloud import storage # type: ignore[import-untyped] + +from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT, Data_Source +from cve_bin_tool.log import LOGGER + + +class OSVDataSource(Data_Source): + """Reimplementation of the OSV data source.""" + + def __init__(self, bucket_name=None, max_parallel_downloads=5): + self.source_name = "OSV" + self._client = storage.Client(credentials=AnonymousCredentials(), project=None) + self.ecosystems_fetched = set() + self.download_url = [] + self._semaphore = asyncio.Semaphore(max_parallel_downloads) + self.osv_path = str(pathlib.Path(DISK_LOCATION_DEFAULT) / "osv") + + self.bucket_name = bucket_name or "osv-vulnerabilities" + + async def update_ecosystem_info(self) -> None: + """Fetch OSV ecosystem information and prepare download links.""" + LOGGER.info("OSV: Started fetching ecosystem info...") + blobs = self._client.list_blobs(self.bucket_name) + for blob in blobs: + if blob.name.endswith("all.zip"): + try: + ecosystem_name = blob.name.split("/")[-2] + url = f"https://storage.googleapis.com/{self.bucket_name}/{ecosystem_name}/all.zip" + self.download_url.append(url) + LOGGER.debug(f"OSV: Download link for {ecosystem_name} added.") + # exclude ecosystem versions from appending, e.g. Debian:10 should not be included + if ecosystem_name.find(":") >= 0: + ecosystem_name = ecosystem_name.split(":")[0] + self.ecosystems_fetched.add(ecosystem_name) + except (ValueError, IndexError): + pass + + async def fetch_single(self, url: str, download_to: str, session): + """ + Fetches a single file while preventing downloading more than $max_parallel_downloads files simultaneously. + """ + async with self._semaphore: + try: + async with session.get(url, timeout=120) as response: + if response.status == 200: + try: + content: bytes = await response.read() + content_size_mb = len(content) / (1024 * 1024) + + # If the file is more than 512 MB, download it to disk + if content_size_mb > 512: + _fname = url.split("/")[-2] + filename: str = f"{_fname}.zip" + location = os.path.join(download_to, filename) + with open(location, "wb") as file: + file.write(content) + LOGGER.debug(f"OSV: Fetched {url}.") + else: + in_memory_file: io.BytesIO = io.BytesIO(content) + zip_file = zipfile.ZipFile(in_memory_file) + zip_file.extractall(download_to) + del in_memory_file + del zip_file + LOGGER.debug(f"OSV: Fetched and unzipped {url}.") + + del content + del content_size_mb + except ( + ValueError, + IndexError, + aiohttp.ClientPayloadError, + ) as e: + LOGGER.warning(f"OSV: Unable to fetch {url}: {str(e)}") + else: + LOGGER.warning(f"OSV: Unable to fetch {url}.") + except (TimeoutError, asyncio.TimeoutError): + LOGGER.warning(f"OSV: Timeout error while fetching {url}.") + + async def fetch_all(self): + """Concurrently fetch all zip files from OSV.""" + LOGGER.info("OSV: Started fetching OSV CVE files...") + async with aiohttp.ClientSession() as session: + tasks = [ + self.fetch_single(url, self.osv_path, session) + for url in self.download_url + ] + await asyncio.gather(*tasks) + + async def extract_all(self): + """Extract and delete all files in the OSV cache directory.""" + # .glob("zip") returns an iterator, so it is memory efficient to process files in the loop + LOGGER.info("OSV: Started extracting zip files...") + for file in pathlib.Path(self.osv_path).glob("*.zip"): + try: + LOGGER.debug(f"OSV: Extracting {file}") + with zipfile.ZipFile(file, "r") as zip_ref: + zip_ref.extractall(self.osv_path) + except zipfile.BadZipFile: + LOGGER.warning(f"OSV: Error while extracting {file}.") + finally: + os.remove(file) + await asyncio.sleep(0.5) + + @staticmethod + def get_formatted_data_from_json(content: dict): + cve_id, severity, vector = ( + content.get("id"), + content.get("severity", None), + None, + ) + + severity: dict | None # type: ignore + if severity and "CVSS_V3" in [x["type"] for x in severity]: + try: + # Ensure the CVSS vector is valid + if severity[0]["score"].endswith("/"): + cvss_data = CVSS3(severity[0]["score"][:-1]) + LOGGER.debug(f"{cve_id} : Correcting malformed CVSS3 vector.") + else: + cvss_data = CVSS3(severity[0]["score"]) + # Now extract CVSS attributes + version = "3" + severity = cvss_data.severities()[0] + score = cvss_data.scores()[0] + vector = cvss_data.clean_vector() + + except Exception as e: + LOGGER.debug(e) + LOGGER.debug(f"{cve_id} : {severity}") + vector = None + + cve = { + "ID": cve_id, + "severity": severity if vector is not None else "unknown", + "description": content.get("summary", "unknown"), + "score": score if vector is not None else "unknown", # noqa + "CVSS_version": version if vector is not None else "unknown", # noqa + "CVSS_vector": vector if vector is not None else "unknown", + "last_modified": ( + content["modified"] + if content.get("modified", None) + else content["published"] + ), + } + + affected = None + + for package_data in content.get("affected", []): + package = package_data.get("package", {}) + if not package: + continue + + product = package.get("name") + vendor = "unknown" # OSV Schema does not provide vendor names for packages + + if product.startswith("github.com/"): + vendor = product.split("/")[-2] + product = product.split("/")[-1] + + _affected = { + "cve_id": cve_id, + "vendor": vendor, + "product": product, + "version": "*", + "versionStartIncluding": "", + "versionStartExcluding": "", + "versionEndIncluding": "", + "versionEndExcluding": "", + } + + events = None + for ranges in package_data.get("ranges", []): + if ranges["type"] == "SEMVER": + events = ranges["events"] + + if events is None and "versions" in package_data: + versions = package_data["versions"] + + if not versions: + continue + + version_affected = _affected.copy() + + version_affected["versionStartIncluding"] = versions[0] + version_affected["versionEndIncluding"] = versions[-1] + + affected = version_affected + elif events is not None: + introduced = None + fixed = None + + for event in events: + if event.get("introduced", None): + introduced = event.get("introduced") + if event.get("fixed", None): + fixed = event.get("fixed") + + if fixed is not None: + range_affected = _affected.copy() + + range_affected["versionStartIncluding"] = introduced + range_affected["versionEndExcluding"] = fixed + + fixed = None + affected = range_affected + + return cve, affected + + def process_data_from_disk(self): + """Read data from disk and yield each instance in the required format.""" + for file in pathlib.Path(self.osv_path).glob("*.json"): + with open(file) as opened_file: + content = opened_file.read() + + json_data: dict = json.loads(content) # type: ignore + data = self.get_formatted_data_from_json(json_data) + # Delete unused json data before the garbage collector does. + del json_data + del content + + yield data + + async def get_cve_data(self): + """Returns OSV CVE data to insert into the database.""" + await self.update_ecosystem_info() + await self.fetch_all() + await self.extract_all() + + # No need to keep links after download, as there may be a lot of them. + del self.download_url + + severity_data, affected_data = [], [] + for cve, affected in self.process_data_from_disk(): + if cve: + severity_data.append(cve) + if affected: + affected_data.append(affected) + + return (severity_data, affected_data), self.source_name diff --git a/requirements.csv b/requirements.csv index 3ad34e857e..dff46f5cec 100644 --- a/requirements.csv +++ b/requirements.csv @@ -15,7 +15,7 @@ sissaschool_not_in_db,xmlschema python_not_in_db,importlib_metadata python,requests python,urllib3 -google,gsutil +google,google-cloud-storage skontar,cvss python_not_in_db,packaging python_not_in_db,importlib_resources diff --git a/requirements.txt b/requirements.txt index e6d8e62c47..45279ce6ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ cvss defusedxml distro filetype>=1.2.0 -gsutil +google-cloud-storage importlib_metadata>=3.6; python_version < "3.10" importlib_resources; python_version < "3.9" jinja2>=2.11.3 diff --git a/test/test_source_osv.py b/test/test_source_osv.py index 9bb105b0ce..44643ab37d 100644 --- a/test/test_source_osv.py +++ b/test/test_source_osv.py @@ -2,24 +2,22 @@ # SPDX-License-Identifier: GPL-3.0-or-later -import io import shutil import tempfile -import zipfile from pathlib import Path from test.utils import EXTERNAL_SYSTEM import aiohttp import pytest -from cve_bin_tool.data_sources import osv_source +from cve_bin_tool.data_sources import new_osv_source from cve_bin_tool.util import make_http_requests class TestSourceOSV: @classmethod def setup_class(cls): - cls.osv = osv_source.OSV_Source() + cls.osv = new_osv_source.OSVDataSource() cls.osv.cachedir = tempfile.mkdtemp(prefix="cvedb-") cls.osv.osv_path = str(Path(cls.osv.cachedir) / "osv") @@ -168,37 +166,18 @@ def teardown_class(cls): @pytest.mark.asyncio @pytest.mark.skipif(not EXTERNAL_SYSTEM(), reason="Needs network connection.") async def test_update_ecosystems(self): - await self.osv.update_ecosystems() - ecosystems_txt = make_http_requests( "text", url=self.ecosystems_url, timeout=300 ).strip("\n") expected_ecosystems = set(ecosystems_txt.split("\n")) - # Because ecosystems.txt does not contain the complete list, this must be - # manually fixed up. - expected_ecosystems.add("DWF") - expected_ecosystems.add("JavaScript") - - # Assert that there are no missing ecosystems - assert all(x in self.osv.ecosystems for x in expected_ecosystems) - # Assert that there are no extra ecosystems - assert all(x in expected_ecosystems for x in self.osv.ecosystems) - - @pytest.mark.asyncio - @pytest.mark.skipif(not EXTERNAL_SYSTEM(), reason="Needs network connection.") - @pytest.mark.parametrize("ecosystem_url", [url for url in cve_file_data]) - async def test_get_ecosystem_00(self, ecosystem_url): - connector = aiohttp.TCPConnector(limit_per_host=19) - async with aiohttp.ClientSession( - connector=connector, trust_env=True - ) as session: - content = await self.osv.get_ecosystem(ecosystem_url, session) - - cve_data = self.cve_file_data[ecosystem_url] + await self.osv.update_ecosystem_info() - assert content["id"] == cve_data["id"] - assert content["published"] == cve_data["published"] + # there may be more ecosystems fetched than provided in ecosystems.txt + assert all( + ecosystem in self.osv.ecosystems_fetched + for ecosystem in expected_ecosystems + ) @pytest.mark.asyncio @pytest.mark.skipif(not EXTERNAL_SYSTEM(), reason="Needs network connection.") @@ -209,19 +188,24 @@ async def test_get_ecosystem_01(self): async with aiohttp.ClientSession( connector=connector, trust_env=True ) as session: - content = await self.osv.get_ecosystem(eco_url, session, mode="bytes") + await self.osv.fetch_single(eco_url, self.osv.osv_path, session) - z = zipfile.ZipFile(io.BytesIO(content)) + p = Path(self.osv.osv_path).glob("**/*") + files = [x.name for x in p if x.is_file()] # Shouldn't be any files as DWF is no longer a valid ecosystems - assert len(z.namelist()) == 0 + assert len(files) == 0 @pytest.mark.asyncio @pytest.mark.skipif(not EXTERNAL_SYSTEM(), reason="Needs network connection.") async def test_fetch_cves(self): - self.osv.ecosystems = ["PyPI"] + ecosystem_name = "PyPI" + self.osv.ecosystems_fetched = [ecosystem_name] + self.osv.download_url = [ + f"https://storage.googleapis.com/osv-vulnerabilities/{ecosystem_name}/all.zip" + ] - await self.osv.fetch_cves() + await self.osv.fetch_all() p = Path(self.osv.osv_path).glob("**/*") files = [x.name for x in p if x.is_file()] @@ -234,7 +218,11 @@ async def test_fetch_cves(self): @pytest.mark.parametrize("cve_entries", [[x] for _, x in cve_file_data.items()]) def test_format_data(self, cve_entries): - severity_data, affected_data = self.osv.format_data(cve_entries) + severity_data, affected_data = [], [] + for cve_entry in cve_entries: + severity, affected = self.osv.get_formatted_data_from_json(cve_entry) + severity_data.append(severity) + affected_data.append(affected) severity_data = severity_data[0]