Skip to content

Commit f59df3b

Browse files
ishan-modiyiyixuxu
andauthored
[Refactor] Minor Improvement for import utils (#11161)
* update * update * addressed PR comments * update --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent a00c73a commit f59df3b

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

src/diffusers/utils/import_utils.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
"""
1717

1818
import importlib.util
19+
import inspect
1920
import operator as op
2021
import os
2122
import sys
22-
from collections import OrderedDict
23+
from collections import OrderedDict, defaultdict
2324
from itertools import chain
2425
from types import ModuleType
25-
from typing import Any, Union
26+
from typing import Any, Tuple, Union
2627

2728
from huggingface_hub.utils import is_jinja_available # noqa: F401
2829
from packaging.version import Version, parse
@@ -54,12 +55,36 @@
5455
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
5556

5657

57-
def _is_package_available(pkg_name: str):
58+
def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]:
5859
pkg_exists = importlib.util.find_spec(pkg_name) is not None
5960
pkg_version = "N/A"
6061

6162
if pkg_exists:
6263
try:
64+
package_map = importlib_metadata.packages_distributions()
65+
except Exception as e:
66+
package_map = defaultdict(list)
67+
if isinstance(e, AttributeError):
68+
try:
69+
# Fallback for Python < 3.10
70+
for dist in importlib_metadata.distributions():
71+
_top_level_declared = (dist.read_text("top_level.txt") or "").split()
72+
_infered_opt_names = {
73+
f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or [])
74+
} - {None}
75+
_top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names)
76+
for pkg in _top_level_declared or _top_level_inferred:
77+
package_map[pkg].append(dist.metadata["Name"])
78+
except Exception as _:
79+
pass
80+
81+
try:
82+
if get_dist_name and pkg_name in package_map and package_map[pkg_name]:
83+
if len(package_map[pkg_name]) > 1:
84+
logger.warning(
85+
f"Multiple distributions found for package {pkg_name}. Picked distribution: {package_map[pkg_name][0]}"
86+
)
87+
pkg_name = package_map[pkg_name][0]
6388
pkg_version = importlib_metadata.version(pkg_name)
6489
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
6590
except (ImportError, importlib_metadata.PackageNotFoundError):
@@ -189,15 +214,7 @@ def _is_package_available(pkg_name: str):
189214
_gguf_available, _gguf_version = _is_package_available("gguf")
190215
_torchao_available, _torchao_version = _is_package_available("torchao")
191216
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
192-
_torchao_available, _torchao_version = _is_package_available("torchao")
193-
194-
_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
195-
if _optimum_quanto_available:
196-
try:
197-
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
198-
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
199-
except importlib_metadata.PackageNotFoundError:
200-
_optimum_quanto_available = False
217+
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
201218

202219

203220
def is_torch_available():

0 commit comments

Comments
 (0)