From 8f86e56829a5d32a7c0850f93e1f796ab420cb9c Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Tue, 4 Feb 2025 12:18:11 +0530 Subject: [PATCH 1/2] Update pipeline_utils.py Added Self in from_pretrained method so inference will correctly recognize pipeline --- src/diffusers/pipelines/pipeline_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0c1371c7556f..7f2a0598fcbc 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Self, Union, get_args, get_origin import numpy as np import PIL.Image @@ -513,7 +513,7 @@ def dtype(self) -> torch.dtype: @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self: r""" Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights. From 39c7424e1a841d65673623b399a75d38d1f6119a Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 4 Feb 2025 08:27:16 +0000 Subject: [PATCH 2/2] Use typing_extensions --- src/diffusers/pipelines/pipeline_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 7f2a0598fcbc..c4593b3e698b 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Self, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin import numpy as np import PIL.Image @@ -41,6 +41,7 @@ from packaging import version from requests.exceptions import HTTPError from tqdm.auto import tqdm +from typing_extensions import Self from .. import __version__ from ..configuration_utils import ConfigMixin