1
+ import torch
2
+ import torch_tensorrt
3
+ from transformers import PaliGemmaProcessor , PaliGemmaForConditionalGeneration
4
+ from transformers .image_utils import load_image
5
+
6
+
7
+ # 1. Model
8
+ DEVICE = torch .device ("cuda:0" )
9
+ model_id = "google/paligemma2-3b-pt-224"
10
+ url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
11
+ image = load_image (url )
12
+
13
+ model = PaliGemmaForConditionalGeneration .from_pretrained (
14
+ model_id , torch_dtype = torch .float16
15
+ ).eval ().to (DEVICE )
16
+ processor = PaliGemmaProcessor .from_pretrained (model_id )
17
+
18
+ prompt = ""
19
+ model_inputs = processor (text = prompt , images = image , return_tensors = "pt" ).to (DEVICE )
20
+ input_len = model_inputs ["input_ids" ].shape [- 1 ]
21
+
22
+ # 2. PyTorch
23
+ with torch .inference_mode ():
24
+ pyt_generation = model .generate (** model_inputs , max_new_tokens = 100 , do_sample = False ) #, use_cache=False)
25
+ # 입력 토큰 이후의 새로 생성된 토큰만 취합니다.
26
+ pyt_generation = pyt_generation [0 ][input_len :]
27
+ pyt_decoded = processor .decode (pyt_generation , skip_special_tokens = True )
28
+ print ("=============================" )
29
+ print ("PyTorch generated text:" )
30
+ print (pyt_decoded )
31
+ print ("=============================" )
32
+
33
+ # (a) Dummy inputs
34
+ batch_size = 1
35
+ dummy_input_ids = model_inputs ["input_ids" ]
36
+ dummy_attention_mask = model_inputs ["attention_mask" ]
37
+ dummy_pixel_values = model_inputs ["pixel_values" ]
38
+
39
+ dummy_inputs = {
40
+ "input_ids" : dummy_input_ids ,
41
+ "attention_mask" : dummy_attention_mask ,
42
+ "pixel_values" : dummy_pixel_values ,
43
+ }
44
+
45
+ # (b) Dynamic shape
46
+ BATCH = torch .export .Dim ("batch" , min = 1 , max = 2 )
47
+ SEQ_LEN = torch .export .Dim ("seq_len" , min = 1 , max = 1024 )
48
+ dynamic_shapes = {
49
+ "input_ids" : {0 : BATCH , 1 : SEQ_LEN },
50
+ "attention_mask" : {0 : BATCH , 1 : SEQ_LEN },
51
+ "pixel_values" : {0 : BATCH },
52
+ }
53
+ # (c) ExportedProgram
54
+ # torch.export.export(
55
+ # model,
56
+ # args=(),
57
+ # kwargs=dummy_inputs,
58
+ # dynamic_shapes=dynamic_shapes,
59
+ # strict=False,
60
+ # )
61
+
62
+
63
+ import torch
64
+ import torch .utils ._pytree as pytree
65
+ import transformers
66
+
67
+ def flatten_hybridcache (hc : transformers .cache_utils .HybridCache ):
68
+ """
69
+ 1) HybridCache 내부의 '텐서'들을 리스트로 모은다.
70
+ 2) 텐서가 아닌 값들은 context(dict)에 담는다.
71
+ """
72
+ # 1. 텐서로 취급할 것들: is_sliding, key_cache 전체, value_cache 전체
73
+ flat_tensors = []
74
+ flat_tensors .append (hc .is_sliding ) # shape: [num_hidden_layers], bool
75
+ flat_tensors .extend (hc .key_cache ) # List[Tensor]
76
+ flat_tensors .extend (hc .value_cache ) # List[Tensor]
77
+
78
+ # 2. 텐서가 아닌 필드는 context로 저장
79
+ context = {
80
+ "max_cache_len" : hc .max_cache_len ,
81
+ "max_batch_size" : hc .max_batch_size ,
82
+ "head_dim" : hc .head_dim ,
83
+ "dtype" : hc .dtype ,
84
+ "num_key_value_heads" : hc .num_key_value_heads ,
85
+ # unflatten 시에 key_cache / value_cache를 몇 개씩 떼어낼지 알아야 하므로
86
+ "num_layers" : len (hc .key_cache ), # = len(hc.value_cache) = config.num_hidden_layers
87
+ }
88
+
89
+ return flat_tensors , context
90
+
91
+
92
+ def unflatten_hybridcache (flat_tensors , context ):
93
+ """
94
+ flatten_hybridcache에서 분리한 (flat_tensors, context)를 받아
95
+ 다시 HybridCache 객체로 복원하는 함수.
96
+ """
97
+ num_layers = context ["num_layers" ]
98
+
99
+ # 1. flat_tensors 파싱
100
+ # - 첫 번째 요소가 is_sliding
101
+ # - 그 다음 num_layers개: key_cache
102
+ # - 그 다음 num_layers개: value_cache
103
+ is_sliding = flat_tensors [0 ]
104
+ key_cache = flat_tensors [1 : 1 + num_layers ]
105
+ value_cache = flat_tensors [1 + num_layers : 1 + 2 * num_layers ]
106
+
107
+ # 2. __new__로 빈 HybridCache 객체 생성 (생성자 __init__은 호출 안 함)
108
+ hc = transformers .cache_utils .HybridCache .__new__ (transformers .cache_utils .HybridCache )
109
+
110
+ # 3. 필요한 필드를 직접 셋팅
111
+ hc .max_cache_len = context ["max_cache_len" ]
112
+ hc .max_batch_size = context ["max_batch_size" ]
113
+ hc .head_dim = context ["head_dim" ]
114
+ hc .dtype = context ["dtype" ]
115
+ hc .num_key_value_heads = context ["num_key_value_heads" ]
116
+ hc .is_sliding = is_sliding
117
+ hc .key_cache = list (key_cache )
118
+ hc .value_cache = list (value_cache )
119
+
120
+ return hc
121
+
122
+ # pytree 등록
123
+ pytree .register_pytree_node (
124
+ transformers .cache_utils .HybridCache ,
125
+ flatten_hybridcache ,
126
+ unflatten_hybridcache
127
+ )
128
+
129
+ # from torch.export._trace import _export
130
+ # exported_program = _export(
131
+ # model,
132
+ # args=(),
133
+ # kwargs=dummy_inputs,
134
+ # dynamic_shapes=dynamic_shapes,
135
+ # strict=False,
136
+ # allow_complex_guards_as_runtime_asserts=True,
137
+ # )
138
+
139
+ # torch.export._draft_export.draft_export
140
+ import torch .export ._draft_export
141
+ exported_program = torch .export ._draft_export .draft_export (
142
+ model ,
143
+ args = (),
144
+ kwargs = dummy_inputs ,
145
+ dynamic_shapes = dynamic_shapes ,
146
+ strict = False ,
147
+ # allow_complex_guards_as_runtime_asserts=True,
148
+ )
149
+
150
+
151
+ trt_model = torch_tensorrt .dynamo .compile (
152
+ exported_program [0 ],
153
+ inputs = dummy_inputs ,
154
+ enabled_precisions = {torch .float32 },
155
+ truncate_double = True ,
156
+ device = DEVICE ,
157
+ disable_tf32 = True ,
158
+ use_explicit_typing = True ,
159
+ use_fp32_acc = True , # FP32 누적을 사용해 정확도를 보존합니다.
160
+ )
161
+
162
+ # ----------------------------
163
+ # 5. TensorRT 모델로 생성 수행
164
+ # ----------------------------
165
+ # (원래의 모델 입력을 GPU로 이동시킨 후 generate() 호출)
166
+ model_inputs = {k : v .to (DEVICE ) for k , v in model_inputs .items ()}
167
+ with torch .inference_mode ():
168
+ trt_generation = trt_model .generate (** model_inputs , max_new_tokens = 100 , do_sample = False )
169
+ trt_generation = trt_generation [0 ][input_len :]
170
+ trt_decoded = processor .decode (trt_generation , skip_special_tokens = True )
171
+ print ("TensorRT generated text:" )
172
+ print (trt_decoded )
173
+
174
+
175
+ # pytree._register_pytree_node(transformers.modeling_outputs.MaskedLMOutput, lambda x: ([x.loss, x.logits], None), lambda values, _: transformers.modeling_outputs.MaskedLMOutput(loss=values[0], logits=values[1]))
0 commit comments