From 671dab71e872c76220802f46e0cd0b0c7a9048e6 Mon Sep 17 00:00:00 2001 From: Siddharth Kotapati Date: Thu, 5 Dec 2024 10:12:06 -0800 Subject: [PATCH 1/5] Remove mps workaround for fp16 GELU, which is now supported natively --- src/diffusers/models/activations.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index f4318fc3cd39..334b1a536204 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -79,9 +79,6 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) def forward(self, hidden_states): @@ -105,9 +102,6 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states, *args, **kwargs): From c4a915f17d4893de5a4f907a551d0553fbf0c444 Mon Sep 17 00:00:00 2001 From: Siddharth Kotapati Date: Thu, 5 Dec 2024 10:20:12 -0800 Subject: [PATCH 2/5] Fix typo --- src/diffusers/models/activations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 334b1a536204..2aa409e2b61e 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -79,7 +79,7 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: - return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + return F.gelu(gate, approximate=self.approximate) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) @@ -102,7 +102,7 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + return F.gelu(gate) def forward(self, hidden_states, *args, **kwargs): if len(args) > 0 or kwargs.get("scale", None) is not None: From d0feeae2667c43cc94666988587ab984d78f5e5a Mon Sep 17 00:00:00 2001 From: Siddharth Kotapati Date: Fri, 6 Dec 2024 10:55:17 -0800 Subject: [PATCH 3/5] Add fp32 fallback for torch<2.0 --- src/diffusers/models/activations.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 2aa409e2b61e..66c6881139d9 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -79,6 +79,9 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type == "mps" and torch.__version__ <'2.0.0': + # fp16 gelu not supported on mps before torch 2.0 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) return F.gelu(gate, approximate=self.approximate) def forward(self, hidden_states): @@ -102,6 +105,9 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type == "mps" and torch.__version__ <'2.0.0': + # fp16 gelu not supported on mps before torch 2.0 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) return F.gelu(gate) def forward(self, hidden_states, *args, **kwargs): From 68f5c3cba6ed5a6e43e41f5b7658357831e7ab1e Mon Sep 17 00:00:00 2001 From: Siddharth Kotapati Date: Fri, 6 Dec 2024 10:56:07 -0800 Subject: [PATCH 4/5] Minor formatting change --- src/diffusers/models/activations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 66c6881139d9..8ca8edc1fba9 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -79,7 +79,7 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type == "mps" and torch.__version__ <'2.0.0': + if gate.device.type == "mps" and torch.__version__ < '2.0.0': # fp16 gelu not supported on mps before torch 2.0 return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) return F.gelu(gate, approximate=self.approximate) @@ -105,7 +105,7 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type == "mps" and torch.__version__ <'2.0.0': + if gate.device.type == "mps" and torch.__version__ < '2.0.0': # fp16 gelu not supported on mps before torch 2.0 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) return F.gelu(gate) From 562eb4bc0349377714cfb527566d45520c2b9a60 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 12 Dec 2024 09:13:37 +0000 Subject: [PATCH 5/5] use is_torch_version --- src/diffusers/models/activations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 8ca8edc1fba9..c1d4f0b46e15 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -18,7 +18,7 @@ from torch import nn from ..utils import deprecate -from ..utils.import_utils import is_torch_npu_available +from ..utils.import_utils import is_torch_npu_available, is_torch_version if is_torch_npu_available(): @@ -79,7 +79,7 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type == "mps" and torch.__version__ < '2.0.0': + if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): # fp16 gelu not supported on mps before torch 2.0 return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) return F.gelu(gate, approximate=self.approximate) @@ -105,7 +105,7 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type == "mps" and torch.__version__ < '2.0.0': + if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): # fp16 gelu not supported on mps before torch 2.0 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) return F.gelu(gate)