Skip to content

fix: osv data source memory consumption #4956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cve_bin_tool/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
curl_source,
epss_source,
gad_source,
new_osv_source,
nvd_source,
osv_source,
purl2cpe_source,
redhat_source,
)
Expand Down Expand Up @@ -782,7 +782,7 @@ def main(argv=None):
enabled_sources = []

if "OSV" not in disabled_sources:
source_osv = osv_source.OSV_Source(incremental_update=incremental_db_update)
source_osv = new_osv_source.OSVDataSource()
enabled_sources.append(source_osv)

if "GAD" not in disabled_sources:
Expand Down
249 changes: 249 additions & 0 deletions cve_bin_tool/data_sources/new_osv_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import asyncio
import io
import json
import os
import pathlib
import zipfile

import aiohttp
from cvss import CVSS3
from google.auth.credentials import AnonymousCredentials # type: ignore[import-untyped]
from google.cloud import storage # type: ignore[import-untyped]

from cve_bin_tool.data_sources import DISK_LOCATION_DEFAULT, Data_Source
from cve_bin_tool.log import LOGGER


class OSVDataSource(Data_Source):
"""Reimplementation of the OSV data source."""

def __init__(self, bucket_name=None, max_parallel_downloads=5):
self.source_name = "OSV"
self._client = storage.Client(credentials=AnonymousCredentials(), project=None)
self.ecosystems_fetched = set()
self.download_url = []
self._semaphore = asyncio.Semaphore(max_parallel_downloads)
self.osv_path = str(pathlib.Path(DISK_LOCATION_DEFAULT) / "osv")

self.bucket_name = bucket_name or "osv-vulnerabilities"

async def update_ecosystem_info(self) -> None:
"""Fetch OSV ecosystem information and prepare download links."""
LOGGER.info("OSV: Started fetching ecosystem info...")
blobs = self._client.list_blobs(self.bucket_name)
for blob in blobs:
if blob.name.endswith("all.zip"):
try:
ecosystem_name = blob.name.split("/")[-2]
url = f"https://storage.googleapis.com/{self.bucket_name}/{ecosystem_name}/all.zip"
self.download_url.append(url)
LOGGER.debug(f"OSV: Download link for {ecosystem_name} added.")
# exclude ecosystem versions from appending, e.g. Debian:10 should not be included
if ecosystem_name.find(":") >= 0:
ecosystem_name = ecosystem_name.split(":")[0]
self.ecosystems_fetched.add(ecosystem_name)
except (ValueError, IndexError):
pass

async def fetch_single(self, url: str, download_to: str, session):
"""
Fetches a single file while preventing downloading more than $max_parallel_downloads files simultaneously.
"""
async with self._semaphore:
try:
async with session.get(url, timeout=120) as response:
if response.status == 200:
try:
content: bytes = await response.read()
content_size_mb = len(content) / (1024 * 1024)

# If the file is more than 512 MB, download it to disk
if content_size_mb > 512:
_fname = url.split("/")[-2]
filename: str = f"{_fname}.zip"
location = os.path.join(download_to, filename)
with open(location, "wb") as file:
file.write(content)
LOGGER.debug(f"OSV: Fetched {url}.")
else:
in_memory_file: io.BytesIO = io.BytesIO(content)
zip_file = zipfile.ZipFile(in_memory_file)
zip_file.extractall(download_to)
del in_memory_file
del zip_file
LOGGER.debug(f"OSV: Fetched and unzipped {url}.")

del content
del content_size_mb
except (
ValueError,
IndexError,
aiohttp.ClientPayloadError,
) as e:
LOGGER.warning(f"OSV: Unable to fetch {url}: {str(e)}")
else:
LOGGER.warning(f"OSV: Unable to fetch {url}.")
except (TimeoutError, asyncio.TimeoutError):
LOGGER.warning(f"OSV: Timeout error while fetching {url}.")

async def fetch_all(self):
"""Concurrently fetch all zip files from OSV."""
LOGGER.info("OSV: Started fetching OSV CVE files...")
async with aiohttp.ClientSession() as session:
tasks = [
self.fetch_single(url, self.osv_path, session)
for url in self.download_url
]
await asyncio.gather(*tasks)

async def extract_all(self):
"""Extract and delete all files in the OSV cache directory."""
# .glob("zip") returns an iterator, so it is memory efficient to process files in the loop
LOGGER.info("OSV: Started extracting zip files...")
for file in pathlib.Path(self.osv_path).glob("*.zip"):
try:
LOGGER.debug(f"OSV: Extracting {file}")
with zipfile.ZipFile(file, "r") as zip_ref:
zip_ref.extractall(self.osv_path)
except zipfile.BadZipFile:
LOGGER.warning(f"OSV: Error while extracting {file}.")
finally:
os.remove(file)
await asyncio.sleep(0.5)

@staticmethod
def get_formatted_data_from_json(content: dict):
cve_id, severity, vector = (
content.get("id"),
content.get("severity", None),
None,
)

severity: dict | None # type: ignore
if severity and "CVSS_V3" in [x["type"] for x in severity]:
try:
# Ensure the CVSS vector is valid
if severity[0]["score"].endswith("/"):
cvss_data = CVSS3(severity[0]["score"][:-1])
LOGGER.debug(f"{cve_id} : Correcting malformed CVSS3 vector.")
else:
cvss_data = CVSS3(severity[0]["score"])
# Now extract CVSS attributes
version = "3"
severity = cvss_data.severities()[0]
score = cvss_data.scores()[0]
vector = cvss_data.clean_vector()

except Exception as e:
LOGGER.debug(e)
LOGGER.debug(f"{cve_id} : {severity}")
vector = None

cve = {
"ID": cve_id,
"severity": severity if vector is not None else "unknown",
"description": content.get("summary", "unknown"),
"score": score if vector is not None else "unknown", # noqa
"CVSS_version": version if vector is not None else "unknown", # noqa
"CVSS_vector": vector if vector is not None else "unknown",
"last_modified": (
content["modified"]
if content.get("modified", None)
else content["published"]
),
}

affected = None

for package_data in content.get("affected", []):
package = package_data.get("package", {})
if not package:
continue

product = package.get("name")
vendor = "unknown" # OSV Schema does not provide vendor names for packages

if product.startswith("github.com/"):
vendor = product.split("/")[-2]
product = product.split("/")[-1]

_affected = {
"cve_id": cve_id,
"vendor": vendor,
"product": product,
"version": "*",
"versionStartIncluding": "",
"versionStartExcluding": "",
"versionEndIncluding": "",
"versionEndExcluding": "",
}

events = None
for ranges in package_data.get("ranges", []):
if ranges["type"] == "SEMVER":
events = ranges["events"]

if events is None and "versions" in package_data:
versions = package_data["versions"]

if not versions:
continue

version_affected = _affected.copy()

version_affected["versionStartIncluding"] = versions[0]
version_affected["versionEndIncluding"] = versions[-1]

affected = version_affected
elif events is not None:
introduced = None
fixed = None

for event in events:
if event.get("introduced", None):
introduced = event.get("introduced")
if event.get("fixed", None):
fixed = event.get("fixed")

if fixed is not None:
range_affected = _affected.copy()

range_affected["versionStartIncluding"] = introduced
range_affected["versionEndExcluding"] = fixed

fixed = None
affected = range_affected

return cve, affected

def process_data_from_disk(self):
"""Read data from disk and yield each instance in the required format."""
for file in pathlib.Path(self.osv_path).glob("*.json"):
with open(file) as opened_file:
content = opened_file.read()

json_data: dict = json.loads(content) # type: ignore
data = self.get_formatted_data_from_json(json_data)
# Delete unused json data before the garbage collector does.
del json_data
del content

yield data

async def get_cve_data(self):
"""Returns OSV CVE data to insert into the database."""
await self.update_ecosystem_info()
await self.fetch_all()
await self.extract_all()

# No need to keep links after download, as there may be a lot of them.
del self.download_url

severity_data, affected_data = [], []
for cve, affected in self.process_data_from_disk():
if cve:
severity_data.append(cve)
if affected:
affected_data.append(affected)

return (severity_data, affected_data), self.source_name
2 changes: 1 addition & 1 deletion requirements.csv
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ sissaschool_not_in_db,xmlschema
python_not_in_db,importlib_metadata
python,requests
python,urllib3
google,gsutil
google,google-cloud-storage
skontar,cvss
python_not_in_db,packaging
python_not_in_db,importlib_resources
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ cvss
defusedxml
distro
filetype>=1.2.0
gsutil
google-cloud-storage
importlib_metadata>=3.6; python_version < "3.10"
importlib_resources; python_version < "3.9"
jinja2>=2.11.3
Expand Down
Loading
Loading