Skip to content

Commit 0160e51

Browse files
authored
Adds local_files_only bool to prevent forced online connection (#3486)
1 parent 194b0a4 commit 0160e51

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -727,8 +727,8 @@ def _copy_layers(hf_layers, pt_layers):
727727
return hf_model
728728

729729

730-
def convert_ldm_clip_checkpoint(checkpoint):
731-
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
730+
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
731+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
732732

733733
keys = list(checkpoint.keys())
734734

@@ -992,6 +992,7 @@ def download_from_original_stable_diffusion_ckpt(
992992
controlnet: Optional[bool] = None,
993993
load_safety_checker: bool = True,
994994
pipeline_class: DiffusionPipeline = None,
995+
local_files_only=False
995996
) -> DiffusionPipeline:
996997
"""
997998
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@@ -1037,6 +1038,8 @@ def download_from_original_stable_diffusion_ckpt(
10371038
Whether to load the safety checker or not. Defaults to `True`.
10381039
pipeline_class (`str`, *optional*, defaults to `None`):
10391040
The pipeline class to use. Pass `None` to determine automatically.
1041+
local_files_only (`bool`, *optional*, defaults to `False`):
1042+
Whether or not to only look at local files (i.e., do not try to download the model).
10401043
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
10411044
"""
10421045

@@ -1292,7 +1295,7 @@ def download_from_original_stable_diffusion_ckpt(
12921295
feature_extractor=feature_extractor,
12931296
)
12941297
elif model_type == "FrozenCLIPEmbedder":
1295-
text_model = convert_ldm_clip_checkpoint(checkpoint)
1298+
text_model = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
12961299
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
12971300

12981301
if load_safety_checker:

0 commit comments

Comments
 (0)