|
16 | 16 | """
|
17 | 17 |
|
18 | 18 | import importlib.util
|
| 19 | +import inspect |
19 | 20 | import operator as op
|
20 | 21 | import os
|
21 | 22 | import sys
|
22 |
| -from collections import OrderedDict |
| 23 | +from collections import OrderedDict, defaultdict |
23 | 24 | from itertools import chain
|
24 | 25 | from types import ModuleType
|
25 |
| -from typing import Any, Union |
| 26 | +from typing import Any, Tuple, Union |
26 | 27 |
|
27 | 28 | from huggingface_hub.utils import is_jinja_available # noqa: F401
|
28 | 29 | from packaging.version import Version, parse
|
|
54 | 55 | _is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
|
55 | 56 |
|
56 | 57 |
|
57 |
| -def _is_package_available(pkg_name: str): |
| 58 | +def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]: |
58 | 59 | pkg_exists = importlib.util.find_spec(pkg_name) is not None
|
59 | 60 | pkg_version = "N/A"
|
60 | 61 |
|
61 | 62 | if pkg_exists:
|
62 | 63 | try:
|
| 64 | + package_map = importlib_metadata.packages_distributions() |
| 65 | + except Exception as e: |
| 66 | + package_map = defaultdict(list) |
| 67 | + if isinstance(e, AttributeError): |
| 68 | + try: |
| 69 | + # Fallback for Python < 3.10 |
| 70 | + for dist in importlib_metadata.distributions(): |
| 71 | + _top_level_declared = (dist.read_text("top_level.txt") or "").split() |
| 72 | + _infered_opt_names = { |
| 73 | + f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or []) |
| 74 | + } - {None} |
| 75 | + _top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names) |
| 76 | + for pkg in _top_level_declared or _top_level_inferred: |
| 77 | + package_map[pkg].append(dist.metadata["Name"]) |
| 78 | + except Exception as _: |
| 79 | + pass |
| 80 | + |
| 81 | + try: |
| 82 | + if get_dist_name and pkg_name in package_map and package_map[pkg_name]: |
| 83 | + if len(package_map[pkg_name]) > 1: |
| 84 | + logger.warning( |
| 85 | + f"Multiple distributions found for package {pkg_name}. Picked distribution: {package_map[pkg_name][0]}" |
| 86 | + ) |
| 87 | + pkg_name = package_map[pkg_name][0] |
63 | 88 | pkg_version = importlib_metadata.version(pkg_name)
|
64 | 89 | logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
|
65 | 90 | except (ImportError, importlib_metadata.PackageNotFoundError):
|
@@ -189,15 +214,7 @@ def _is_package_available(pkg_name: str):
|
189 | 214 | _gguf_available, _gguf_version = _is_package_available("gguf")
|
190 | 215 | _torchao_available, _torchao_version = _is_package_available("torchao")
|
191 | 216 | _bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
192 |
| -_torchao_available, _torchao_version = _is_package_available("torchao") |
193 |
| - |
194 |
| -_optimum_quanto_available = importlib.util.find_spec("optimum") is not None |
195 |
| -if _optimum_quanto_available: |
196 |
| - try: |
197 |
| - _optimum_quanto_version = importlib_metadata.version("optimum_quanto") |
198 |
| - logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}") |
199 |
| - except importlib_metadata.PackageNotFoundError: |
200 |
| - _optimum_quanto_available = False |
| 217 | +_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True) |
201 | 218 |
|
202 | 219 |
|
203 | 220 | def is_torch_available():
|
|
0 commit comments