1
1
import torch
2
2
import torch_tensorrt
3
- from transformers import PaliGemmaProcessor , PaliGemmaForConditionalGeneration
3
+ from transformers import PaliGemmaForConditionalGeneration , PaliGemmaProcessor
4
4
from transformers .image_utils import load_image
5
5
6
-
7
6
# 1. Model
8
7
DEVICE = torch .device ("cuda:0" )
9
8
model_id = "google/paligemma2-3b-pt-224"
10
9
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
11
10
image = load_image (url )
12
11
13
- model = PaliGemmaForConditionalGeneration .from_pretrained (
14
- model_id , torch_dtype = torch .float16
15
- ).eval ().to (DEVICE )
12
+ model = (
13
+ PaliGemmaForConditionalGeneration .from_pretrained (
14
+ model_id , torch_dtype = torch .float16
15
+ )
16
+ .eval ()
17
+ .to (DEVICE )
18
+ )
16
19
processor = PaliGemmaProcessor .from_pretrained (model_id )
17
20
18
21
prompt = ""
21
24
22
25
# 2. PyTorch
23
26
with torch .inference_mode ():
24
- pyt_generation = model .generate (** model_inputs , max_new_tokens = 100 , do_sample = False ) #, use_cache=False)
25
- # 입력 토큰 이후의 새로 생성된 토큰만 취합니다.
27
+ pyt_generation = model .generate (
28
+ ** model_inputs , max_new_tokens = 100 , do_sample = False
29
+ ) # , use_cache=False)
30
+ # The newly generated tokens after the input tokens.
26
31
pyt_generation = pyt_generation [0 ][input_len :]
27
32
pyt_decoded = processor .decode (pyt_generation , skip_special_tokens = True )
28
33
print ("=============================" )
29
34
print ("PyTorch generated text:" )
30
35
print (pyt_decoded )
31
36
print ("=============================" )
32
37
33
- # (a) Dummy inputs
38
+ # (a) Dummy inputs
34
39
batch_size = 1
35
- dummy_input_ids = model_inputs ["input_ids" ]
36
- dummy_attention_mask = model_inputs ["attention_mask" ]
40
+ dummy_input_ids = model_inputs ["input_ids" ]
41
+ dummy_attention_mask = model_inputs ["attention_mask" ]
37
42
dummy_pixel_values = model_inputs ["pixel_values" ]
38
43
39
44
dummy_inputs = {
42
47
"pixel_values" : dummy_pixel_values ,
43
48
}
44
49
45
- # (b) Dynamic shape
50
+ # (b) Dynamic shape
46
51
BATCH = torch .export .Dim ("batch" , min = 1 , max = 2 )
47
52
SEQ_LEN = torch .export .Dim ("seq_len" , min = 1 , max = 1024 )
48
53
dynamic_shapes = {
49
54
"input_ids" : {0 : BATCH , 1 : SEQ_LEN },
50
55
"attention_mask" : {0 : BATCH , 1 : SEQ_LEN },
51
56
"pixel_values" : {0 : BATCH },
52
57
}
53
- # (c) ExportedProgram
58
+ # (c) ExportedProgram
54
59
# torch.export.export(
55
60
# model,
56
61
# args=(),
64
69
import torch .utils ._pytree as pytree
65
70
import transformers
66
71
72
+
67
73
def flatten_hybridcache (hc : transformers .cache_utils .HybridCache ):
68
74
"""
69
- 1) HybridCache 내부의 '텐서'들을 리스트로 모은다 .
70
- 2) 텐서가 아닌 값들은 context(dict)에 담는다 .
75
+ 1) Collects all tensors inside HybridCache into a list .
76
+ 2) Stores non-tensor values in the context (dictionary) .
71
77
"""
72
- # 1. 텐서로 취급할 것들 : is_sliding, key_cache 전체, value_cache 전체
78
+ # 1. Tensors : is_sliding, entire key_cache, entire value_cache
73
79
flat_tensors = []
74
- flat_tensors .append (hc .is_sliding ) # shape: [num_hidden_layers], bool
75
- flat_tensors .extend (hc .key_cache ) # List[Tensor]
76
- flat_tensors .extend (hc .value_cache ) # List[Tensor]
80
+ flat_tensors .append (hc .is_sliding ) # shape: [num_hidden_layers], bool
81
+ flat_tensors .extend (hc .key_cache ) # List[Tensor]
82
+ flat_tensors .extend (hc .value_cache ) # List[Tensor]
77
83
78
- # 2. 텐서가 아닌 필드는 context로 저장
84
+ # 2. Store non-tensor fields in the context
79
85
context = {
80
86
"max_cache_len" : hc .max_cache_len ,
81
87
"max_batch_size" : hc .max_batch_size ,
82
88
"head_dim" : hc .head_dim ,
83
89
"dtype" : hc .dtype ,
84
90
"num_key_value_heads" : hc .num_key_value_heads ,
85
- # unflatten 시에 key_cache / value_cache를 몇 개씩 떼어낼지 알아야 하므로
86
- "num_layers" : len (hc .key_cache ), # = len(hc.value_cache) = config.num_hidden_layers
91
+ "num_layers" : len (
92
+ hc .key_cache
93
+ ), # = len(hc.value_cache) = config.num_hidden_layers
87
94
}
88
95
89
96
return flat_tensors , context
90
97
91
98
92
99
def unflatten_hybridcache (flat_tensors , context ):
93
100
"""
94
- flatten_hybridcache에서 분리한 (flat_tensors, context)를 받아
95
- 다시 HybridCache 객체로 복원하는 함수.
101
+ Restores a HybridCache object from the (flat_tensors, context) produced by flatten_hybridcache.
96
102
"""
97
103
num_layers = context ["num_layers" ]
98
104
99
- # 1. flat_tensors 파싱
100
- # - 첫 번째 요소가 is_sliding
101
- # - 그 다음 num_layers개 : key_cache
102
- # - 그 다음 num_layers개 : value_cache
105
+ # 1. Parse flat_tensors
106
+ # - First element is is_sliding
107
+ # - Next num_layers elements : key_cache
108
+ # - Next num_layers elements : value_cache
103
109
is_sliding = flat_tensors [0 ]
104
110
key_cache = flat_tensors [1 : 1 + num_layers ]
105
- value_cache = flat_tensors [1 + num_layers : 1 + 2 * num_layers ]
111
+ value_cache = flat_tensors [1 + num_layers : 1 + 2 * num_layers ]
106
112
107
- # 2. __new__로 빈 HybridCache 객체 생성 (생성자 __init__은 호출 안 함)
108
- hc = transformers .cache_utils .HybridCache .__new__ (transformers .cache_utils .HybridCache )
113
+ # 2. Create an empty HybridCache object using __new__ (without calling __init__)
114
+ hc = transformers .cache_utils .HybridCache .__new__ (
115
+ transformers .cache_utils .HybridCache
116
+ )
109
117
110
- # 3. 필요한 필드를 직접 셋팅
118
+ # 3. Manually set required fields
111
119
hc .max_cache_len = context ["max_cache_len" ]
112
120
hc .max_batch_size = context ["max_batch_size" ]
113
121
hc .head_dim = context ["head_dim" ]
@@ -119,14 +127,13 @@ def unflatten_hybridcache(flat_tensors, context):
119
127
120
128
return hc
121
129
122
- # pytree 등록
130
+
131
+ # Register with pytree
123
132
pytree .register_pytree_node (
124
- transformers .cache_utils .HybridCache ,
125
- flatten_hybridcache ,
126
- unflatten_hybridcache
133
+ transformers .cache_utils .HybridCache , flatten_hybridcache , unflatten_hybridcache
127
134
)
128
135
129
- # from torch.export._trace import _export
136
+ # from torch.export._trace import _export
130
137
# exported_program = _export(
131
138
# model,
132
139
# args=(),
@@ -138,6 +145,7 @@ def unflatten_hybridcache(flat_tensors, context):
138
145
139
146
# torch.export._draft_export.draft_export
140
147
import torch .export ._draft_export
148
+
141
149
exported_program = torch .export ._draft_export .draft_export (
142
150
model ,
143
151
args = (),
@@ -156,20 +164,16 @@ def unflatten_hybridcache(flat_tensors, context):
156
164
device = DEVICE ,
157
165
disable_tf32 = True ,
158
166
use_explicit_typing = True ,
159
- use_fp32_acc = True , # FP32 누적을 사용해 정확도를 보존합니다.
167
+ use_fp32_acc = True ,
160
168
)
161
169
162
- # ----------------------------
163
- # 5. TensorRT 모델로 생성 수행
164
- # ----------------------------
165
- # (원래의 모델 입력을 GPU로 이동시킨 후 generate() 호출)
170
+ # Execute generation using TensorRT model
166
171
model_inputs = {k : v .to (DEVICE ) for k , v in model_inputs .items ()}
167
172
with torch .inference_mode ():
168
- trt_generation = trt_model .generate (** model_inputs , max_new_tokens = 100 , do_sample = False )
173
+ trt_generation = trt_model .generate (
174
+ ** model_inputs , max_new_tokens = 100 , do_sample = False
175
+ )
169
176
trt_generation = trt_generation [0 ][input_len :]
170
177
trt_decoded = processor .decode (trt_generation , skip_special_tokens = True )
171
178
print ("TensorRT generated text:" )
172
179
print (trt_decoded )
173
-
174
-
175
- # pytree._register_pytree_node(transformers.modeling_outputs.MaskedLMOutput, lambda x: ([x.loss, x.logits], None), lambda values, _: transformers.modeling_outputs.MaskedLMOutput(loss=values[0], logits=values[1]))
0 commit comments