Skip to content

Commit 671dab7

Browse files
committed
Remove mps workaround for fp16 GELU, which is now supported natively
1 parent 6394d90 commit 671dab7

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

src/diffusers/models/activations.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
7979
self.approximate = approximate
8080

8181
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82-
if gate.device.type != "mps":
83-
return F.gelu(gate, approximate=self.approximate)
84-
# mps: gelu is not implemented for float16
8582
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
8683

8784
def forward(self, hidden_states):
@@ -105,9 +102,6 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
105102
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
106103

107104
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108-
if gate.device.type != "mps":
109-
return F.gelu(gate)
110-
# mps: gelu is not implemented for float16
111105
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
112106

113107
def forward(self, hidden_states, *args, **kwargs):

0 commit comments

Comments
 (0)