18
18
from torch import nn
19
19
20
20
from ..utils import deprecate
21
- from ..utils .import_utils import is_torch_npu_available
21
+ from ..utils .import_utils import is_torch_npu_available , is_torch_version
22
22
23
23
24
24
if is_torch_npu_available ():
@@ -79,10 +79,10 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
79
79
self .approximate = approximate
80
80
81
81
def gelu (self , gate : torch .Tensor ) -> torch .Tensor :
82
- if gate .device .type != "mps" :
83
- return F . gelu ( gate , approximate = self . approximate )
84
- # mps: gelu is not implemented for float16
85
- return F .gelu (gate . to ( dtype = torch . float32 ) , approximate = self .approximate ). to ( dtype = gate . dtype )
82
+ if gate .device .type == "mps" and is_torch_version ( "<" , "2.0.0" ) :
83
+ # fp16 gelu not supported on mps before torch 2.0
84
+ return F . gelu ( gate . to ( dtype = torch . float32 ), approximate = self . approximate ). to ( dtype = gate . dtype )
85
+ return F .gelu (gate , approximate = self .approximate )
86
86
87
87
def forward (self , hidden_states ):
88
88
hidden_states = self .proj (hidden_states )
@@ -105,10 +105,10 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
105
105
self .proj = nn .Linear (dim_in , dim_out * 2 , bias = bias )
106
106
107
107
def gelu (self , gate : torch .Tensor ) -> torch .Tensor :
108
- if gate .device .type != "mps" :
109
- return F . gelu ( gate )
110
- # mps: gelu is not implemented for float16
111
- return F .gelu (gate . to ( dtype = torch . float32 )). to ( dtype = gate . dtype )
108
+ if gate .device .type == "mps" and is_torch_version ( "<" , "2.0.0" ) :
109
+ # fp16 gelu not supported on mps before torch 2.0
110
+ return F . gelu ( gate . to ( dtype = torch . float32 )). to ( dtype = gate . dtype )
111
+ return F .gelu (gate )
112
112
113
113
def forward (self , hidden_states , * args , ** kwargs ):
114
114
if len (args ) > 0 or kwargs .get ("scale" , None ) is not None :
0 commit comments