diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index e8d9429f6204..79f4601ae6eb 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -16,13 +16,14 @@ """ import importlib.util +import inspect import operator as op import os import sys -from collections import OrderedDict +from collections import OrderedDict, defaultdict from itertools import chain from types import ModuleType -from typing import Any, Union +from typing import Any, Tuple, Union from huggingface_hub.utils import is_jinja_available # noqa: F401 from packaging.version import Version, parse @@ -54,12 +55,36 @@ _is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) -def _is_package_available(pkg_name: str): +def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]: pkg_exists = importlib.util.find_spec(pkg_name) is not None pkg_version = "N/A" if pkg_exists: try: + package_map = importlib_metadata.packages_distributions() + except Exception as e: + package_map = defaultdict(list) + if isinstance(e, AttributeError): + try: + # Fallback for Python < 3.10 + for dist in importlib_metadata.distributions(): + _top_level_declared = (dist.read_text("top_level.txt") or "").split() + _infered_opt_names = { + f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or []) + } - {None} + _top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names) + for pkg in _top_level_declared or _top_level_inferred: + package_map[pkg].append(dist.metadata["Name"]) + except Exception as _: + pass + + try: + if get_dist_name and pkg_name in package_map and package_map[pkg_name]: + if len(package_map[pkg_name]) > 1: + logger.warning( + f"Multiple distributions found for package {pkg_name}. Picked distribution: {package_map[pkg_name][0]}" + ) + pkg_name = package_map[pkg_name][0] pkg_version = importlib_metadata.version(pkg_name) logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") except (ImportError, importlib_metadata.PackageNotFoundError): @@ -189,15 +214,7 @@ def _is_package_available(pkg_name: str): _gguf_available, _gguf_version = _is_package_available("gguf") _torchao_available, _torchao_version = _is_package_available("torchao") _bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") -_torchao_available, _torchao_version = _is_package_available("torchao") - -_optimum_quanto_available = importlib.util.find_spec("optimum") is not None -if _optimum_quanto_available: - try: - _optimum_quanto_version = importlib_metadata.version("optimum_quanto") - logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}") - except importlib_metadata.PackageNotFoundError: - _optimum_quanto_available = False +_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True) def is_torch_available():