Skip to content

Commit ec9bfa9

Browse files
skotapatihlky
andauthored
Remove mps workaround for fp16 GELU, which is now supported natively (#10133)
* Remove mps workaround for fp16 GELU, which is now supported natively --------- Co-authored-by: hlky <[email protected]>
1 parent bdbaea8 commit ec9bfa9

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/diffusers/models/activations.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch import nn
1919

2020
from ..utils import deprecate
21-
from ..utils.import_utils import is_torch_npu_available
21+
from ..utils.import_utils import is_torch_npu_available, is_torch_version
2222

2323

2424
if is_torch_npu_available():
@@ -79,10 +79,10 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
7979
self.approximate = approximate
8080

8181
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82-
if gate.device.type != "mps":
83-
return F.gelu(gate, approximate=self.approximate)
84-
# mps: gelu is not implemented for float16
85-
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
82+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
83+
# fp16 gelu not supported on mps before torch 2.0
84+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
85+
return F.gelu(gate, approximate=self.approximate)
8686

8787
def forward(self, hidden_states):
8888
hidden_states = self.proj(hidden_states)
@@ -105,10 +105,10 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
105105
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
106106

107107
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108-
if gate.device.type != "mps":
109-
return F.gelu(gate)
110-
# mps: gelu is not implemented for float16
111-
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
108+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
109+
# fp16 gelu not supported on mps before torch 2.0
110+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
111+
return F.gelu(gate)
112112

113113
def forward(self, hidden_states, *args, **kwargs):
114114
if len(args) > 0 or kwargs.get("scale", None) is not None:

0 commit comments

Comments
 (0)