Skip to content

Commit d0feeae

Browse files
committed
Add fp32 fallback for torch<2.0
1 parent c4a915f commit d0feeae

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/diffusers/models/activations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
7979
self.approximate = approximate
8080

8181
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82+
if gate.device.type == "mps" and torch.__version__ <'2.0.0':
83+
# fp16 gelu not supported on mps before torch 2.0
84+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
8285
return F.gelu(gate, approximate=self.approximate)
8386

8487
def forward(self, hidden_states):
@@ -102,6 +105,9 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
102105
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
103106

104107
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108+
if gate.device.type == "mps" and torch.__version__ <'2.0.0':
109+
# fp16 gelu not supported on mps before torch 2.0
110+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
105111
return F.gelu(gate)
106112

107113
def forward(self, hidden_states, *args, **kwargs):

0 commit comments

Comments
 (0)