Skip to content

Commit c4a915f

Browse files
committed
Fix typo
1 parent 671dab7 commit c4a915f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/models/activations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
7979
self.approximate = approximate
8080

8181
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82-
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
82+
return F.gelu(gate, approximate=self.approximate)
8383

8484
def forward(self, hidden_states):
8585
hidden_states = self.proj(hidden_states)
@@ -102,7 +102,7 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
102102
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
103103

104104
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
105-
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
105+
return F.gelu(gate)
106106

107107
def forward(self, hidden_states, *args, **kwargs):
108108
if len(args) > 0 or kwargs.get("scale", None) is not None:

0 commit comments

Comments
 (0)