13
13
# limitations under the License.
14
14
15
15
from dataclasses import dataclass
16
- from typing import List , Optional
16
+ from typing import Callable , List , Optional
17
17
18
18
import torch
19
19
20
20
from ..utils import get_logger
21
21
from ..utils .torch_utils import unwrap_module
22
- from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
23
- from ._helpers import TransformerBlockRegistry
22
+ from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS , _ATTENTION_CLASSES , _FEEDFORWARD_CLASSES
23
+ from ._helpers import AttentionProcessorRegistry , TransformerBlockRegistry
24
24
from .hooks import HookRegistry , ModelHook
25
25
26
26
@@ -44,9 +44,50 @@ class LayerSkipConfig:
44
44
45
45
indices : List [int ]
46
46
fqn : str = "auto"
47
+ skip_attention : bool = True
48
+ skip_attention_scores : bool = False
49
+ skip_ff : bool = True
47
50
48
51
49
- class LayerSkipHook (ModelHook ):
52
+ class AttentionScoreSkipFunctionMode (torch .overrides .TorchFunctionMode ):
53
+ def __init__ (self ) -> None :
54
+ super ().__init__ ()
55
+
56
+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
57
+ if kwargs is None :
58
+ kwargs = {}
59
+ if func is torch .nn .functional .scaled_dot_product_attention :
60
+ value = kwargs .get ("value" , None )
61
+ if value is None :
62
+ value = args [2 ]
63
+ return value
64
+ return func (* args , ** kwargs )
65
+
66
+
67
+ class AttentionProcessorSkipHook (ModelHook ):
68
+ def __init__ (self , skip_processor_output_fn : Callable , skip_attention_scores : bool = False ):
69
+ self .skip_processor_output_fn = skip_processor_output_fn
70
+ self .skip_attention_scores = skip_attention_scores
71
+
72
+ def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
73
+ if self .skip_attention_scores :
74
+ with AttentionScoreSkipFunctionMode ():
75
+ return self .fn_ref .original_forward (* args , ** kwargs )
76
+ else :
77
+ return self .skip_processor_output_fn (module , * args , ** kwargs )
78
+
79
+
80
+ class FeedForwardSkipHook (ModelHook ):
81
+ def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
82
+ output = kwargs .get ("hidden_states" , None )
83
+ if output is None :
84
+ output = kwargs .get ("x" , None )
85
+ if output is None and len (args ) > 0 :
86
+ output = args [0 ]
87
+ return output
88
+
89
+
90
+ class TransformerBlockSkipHook (ModelHook ):
50
91
def initialize_hook (self , module ):
51
92
self ._metadata = TransformerBlockRegistry .get (unwrap_module (module ).__class__ )
52
93
return module
@@ -81,6 +122,9 @@ def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
81
122
def _apply_layer_skip_hook (module : torch .nn .Module , config : LayerSkipConfig , name : Optional [str ] = None ) -> None :
82
123
name = name or _LAYER_SKIP_HOOK
83
124
125
+ if config .skip_attention and config .skip_attention_scores :
126
+ raise ValueError ("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one." )
127
+
84
128
if config .fqn == "auto" :
85
129
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS :
86
130
if hasattr (module , identifier ):
@@ -101,10 +145,38 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
101
145
if len (config .indices ) == 0 :
102
146
raise ValueError ("Layer index list is empty. Please provide a non-empty list of layer indices to skip." )
103
147
148
+ blocks_found = False
104
149
for i , block in enumerate (transformer_blocks ):
105
150
if i not in config .indices :
106
151
continue
107
- logger .debug (f"Apply LayerSkipHook to '{ config .fqn } .{ i } '" )
108
- registry = HookRegistry .check_if_exists_or_initialize (block )
109
- hook = LayerSkipHook ()
110
- registry .register_hook (hook , name )
152
+ blocks_found = True
153
+ if config .skip_attention and config .skip_ff :
154
+ logger .debug (f"Applying TransformerBlockSkipHook to '{ config .fqn } .{ i } '" )
155
+ registry = HookRegistry .check_if_exists_or_initialize (block )
156
+ hook = TransformerBlockSkipHook ()
157
+ registry .register_hook (hook , name )
158
+ elif config .skip_attention or config .skip_attention_scores :
159
+ for submodule_name , submodule in block .named_modules ():
160
+ if isinstance (submodule , _ATTENTION_CLASSES ) and not submodule .is_cross_attention :
161
+ logger .debug (f"Applying AttentionProcessorSkipHook to '{ config .fqn } .{ i } .{ submodule_name } '" )
162
+ output_fn = AttentionProcessorRegistry .get (submodule .processor .__class__ ).skip_processor_output_fn
163
+ registry = HookRegistry .check_if_exists_or_initialize (submodule )
164
+ hook = AttentionProcessorSkipHook (output_fn , config .skip_attention_scores )
165
+ registry .register_hook (hook , name )
166
+ elif config .skip_ff :
167
+ for submodule_name , submodule in block .named_modules ():
168
+ if isinstance (submodule , _FEEDFORWARD_CLASSES ):
169
+ logger .debug (f"Applying FeedForwardSkipHook to '{ config .fqn } .{ i } .{ submodule_name } '" )
170
+ registry = HookRegistry .check_if_exists_or_initialize (submodule )
171
+ hook = FeedForwardSkipHook ()
172
+ registry .register_hook (hook , name )
173
+ else :
174
+ raise ValueError (
175
+ "At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
176
+ )
177
+
178
+ if not blocks_found :
179
+ raise ValueError (
180
+ f"Could not find any transformer blocks matching the provided indices { config .indices } and "
181
+ f"fully qualified name '{ config .fqn } '. Please check the indices and fqn for correctness."
182
+ )
0 commit comments