Skip to content

Commit 158c5c4

Browse files
authored
Add provider_options to OnnxRuntimeModel (#10661)
1 parent 4157177 commit 158c5c4

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/diffusers/pipelines/onnx_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __call__(self, **kwargs):
6161
return self.model.run(None, inputs)
6262

6363
@staticmethod
64-
def load_model(path: Union[str, Path], provider=None, sess_options=None):
64+
def load_model(path: Union[str, Path], provider=None, sess_options=None, provider_options=None):
6565
"""
6666
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
6767
@@ -75,7 +75,9 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None):
7575
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
7676
provider = "CPUExecutionProvider"
7777

78-
return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
78+
return ort.InferenceSession(
79+
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
80+
)
7981

8082
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
8183
"""

0 commit comments

Comments
 (0)