Skip to content

Commit 175f191

Browse files
author
Chengzhe Xu
committed
chore: linting
1 parent 2bafe0e commit 175f191

File tree

5 files changed

+68
-69
lines changed

5 files changed

+68
-69
lines changed

examples/dynamo/torch_export_pg.py

+49-45
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import torch
22
import torch_tensorrt
3-
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
3+
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
44
from transformers.image_utils import load_image
55

6-
76
# 1. Model
87
DEVICE = torch.device("cuda:0")
98
model_id = "google/paligemma2-3b-pt-224"
109
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
1110
image = load_image(url)
1211

13-
model = PaliGemmaForConditionalGeneration.from_pretrained(
14-
model_id, torch_dtype=torch.float16
15-
).eval().to(DEVICE)
12+
model = (
13+
PaliGemmaForConditionalGeneration.from_pretrained(
14+
model_id, torch_dtype=torch.float16
15+
)
16+
.eval()
17+
.to(DEVICE)
18+
)
1619
processor = PaliGemmaProcessor.from_pretrained(model_id)
1720

1821
prompt = ""
@@ -21,19 +24,21 @@
2124

2225
# 2. PyTorch
2326
with torch.inference_mode():
24-
pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) #, use_cache=False)
25-
# 입력 토큰 이후의 새로 생성된 토큰만 취합니다.
27+
pyt_generation = model.generate(
28+
**model_inputs, max_new_tokens=100, do_sample=False
29+
) # , use_cache=False)
30+
# The newly generated tokens after the input tokens.
2631
pyt_generation = pyt_generation[0][input_len:]
2732
pyt_decoded = processor.decode(pyt_generation, skip_special_tokens=True)
2833
print("=============================")
2934
print("PyTorch generated text:")
3035
print(pyt_decoded)
3136
print("=============================")
3237

33-
# (a) Dummy inputs
38+
# (a) Dummy inputs
3439
batch_size = 1
35-
dummy_input_ids = model_inputs["input_ids"]
36-
dummy_attention_mask = model_inputs["attention_mask"]
40+
dummy_input_ids = model_inputs["input_ids"]
41+
dummy_attention_mask = model_inputs["attention_mask"]
3742
dummy_pixel_values = model_inputs["pixel_values"]
3843

3944
dummy_inputs = {
@@ -42,15 +47,15 @@
4247
"pixel_values": dummy_pixel_values,
4348
}
4449

45-
# (b) Dynamic shape
50+
# (b) Dynamic shape
4651
BATCH = torch.export.Dim("batch", min=1, max=2)
4752
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=1024)
4853
dynamic_shapes = {
4954
"input_ids": {0: BATCH, 1: SEQ_LEN},
5055
"attention_mask": {0: BATCH, 1: SEQ_LEN},
5156
"pixel_values": {0: BATCH},
5257
}
53-
# (c) ExportedProgram
58+
# (c) ExportedProgram
5459
# torch.export.export(
5560
# model,
5661
# args=(),
@@ -64,50 +69,53 @@
6469
import torch.utils._pytree as pytree
6570
import transformers
6671

72+
6773
def flatten_hybridcache(hc: transformers.cache_utils.HybridCache):
6874
"""
69-
1) HybridCache 내부의 '텐서'들을 리스트로 모은다.
70-
2) 텐서가 아닌 값들은 context(dict)에 담는다.
75+
1) Collects all tensors inside HybridCache into a list.
76+
2) Stores non-tensor values in the context (dictionary).
7177
"""
72-
# 1. 텐서로 취급할 것들: is_sliding, key_cache 전체, value_cache 전체
78+
# 1. Tensors: is_sliding, entire key_cache, entire value_cache
7379
flat_tensors = []
74-
flat_tensors.append(hc.is_sliding) # shape: [num_hidden_layers], bool
75-
flat_tensors.extend(hc.key_cache) # List[Tensor]
76-
flat_tensors.extend(hc.value_cache) # List[Tensor]
80+
flat_tensors.append(hc.is_sliding) # shape: [num_hidden_layers], bool
81+
flat_tensors.extend(hc.key_cache) # List[Tensor]
82+
flat_tensors.extend(hc.value_cache) # List[Tensor]
7783

78-
# 2. 텐서가 아닌 필드는 context로 저장
84+
# 2. Store non-tensor fields in the context
7985
context = {
8086
"max_cache_len": hc.max_cache_len,
8187
"max_batch_size": hc.max_batch_size,
8288
"head_dim": hc.head_dim,
8389
"dtype": hc.dtype,
8490
"num_key_value_heads": hc.num_key_value_heads,
85-
# unflatten 시에 key_cache / value_cache를 몇 개씩 떼어낼지 알아야 하므로
86-
"num_layers": len(hc.key_cache), # = len(hc.value_cache) = config.num_hidden_layers
91+
"num_layers": len(
92+
hc.key_cache
93+
), # = len(hc.value_cache) = config.num_hidden_layers
8794
}
8895

8996
return flat_tensors, context
9097

9198

9299
def unflatten_hybridcache(flat_tensors, context):
93100
"""
94-
flatten_hybridcache에서 분리한 (flat_tensors, context)를 받아
95-
다시 HybridCache 객체로 복원하는 함수.
101+
Restores a HybridCache object from the (flat_tensors, context) produced by flatten_hybridcache.
96102
"""
97103
num_layers = context["num_layers"]
98104

99-
# 1. flat_tensors 파싱
100-
# - 첫 번째 요소가 is_sliding
101-
# - 그 다음 num_layers개: key_cache
102-
# - 그 다음 num_layers개: value_cache
105+
# 1. Parse flat_tensors
106+
# - First element is is_sliding
107+
# - Next num_layers elements: key_cache
108+
# - Next num_layers elements: value_cache
103109
is_sliding = flat_tensors[0]
104110
key_cache = flat_tensors[1 : 1 + num_layers]
105-
value_cache = flat_tensors[1 + num_layers : 1 + 2*num_layers]
111+
value_cache = flat_tensors[1 + num_layers : 1 + 2 * num_layers]
106112

107-
# 2. __new__로 빈 HybridCache 객체 생성 (생성자 __init__은 호출 안 함)
108-
hc = transformers.cache_utils.HybridCache.__new__(transformers.cache_utils.HybridCache)
113+
# 2. Create an empty HybridCache object using __new__ (without calling __init__)
114+
hc = transformers.cache_utils.HybridCache.__new__(
115+
transformers.cache_utils.HybridCache
116+
)
109117

110-
# 3. 필요한 필드를 직접 셋팅
118+
# 3. Manually set required fields
111119
hc.max_cache_len = context["max_cache_len"]
112120
hc.max_batch_size = context["max_batch_size"]
113121
hc.head_dim = context["head_dim"]
@@ -119,14 +127,13 @@ def unflatten_hybridcache(flat_tensors, context):
119127

120128
return hc
121129

122-
# pytree 등록
130+
131+
# Register with pytree
123132
pytree.register_pytree_node(
124-
transformers.cache_utils.HybridCache,
125-
flatten_hybridcache,
126-
unflatten_hybridcache
133+
transformers.cache_utils.HybridCache, flatten_hybridcache, unflatten_hybridcache
127134
)
128135

129-
# from torch.export._trace import _export
136+
# from torch.export._trace import _export
130137
# exported_program = _export(
131138
# model,
132139
# args=(),
@@ -138,6 +145,7 @@ def unflatten_hybridcache(flat_tensors, context):
138145

139146
# torch.export._draft_export.draft_export
140147
import torch.export._draft_export
148+
141149
exported_program = torch.export._draft_export.draft_export(
142150
model,
143151
args=(),
@@ -156,20 +164,16 @@ def unflatten_hybridcache(flat_tensors, context):
156164
device=DEVICE,
157165
disable_tf32=True,
158166
use_explicit_typing=True,
159-
use_fp32_acc=True, # FP32 누적을 사용해 정확도를 보존합니다.
167+
use_fp32_acc=True,
160168
)
161169

162-
# ----------------------------
163-
# 5. TensorRT 모델로 생성 수행
164-
# ----------------------------
165-
# (원래의 모델 입력을 GPU로 이동시킨 후 generate() 호출)
170+
# Execute generation using TensorRT model
166171
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
167172
with torch.inference_mode():
168-
trt_generation = trt_model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
173+
trt_generation = trt_model.generate(
174+
**model_inputs, max_new_tokens=100, do_sample=False
175+
)
169176
trt_generation = trt_generation[0][input_len:]
170177
trt_decoded = processor.decode(trt_generation, skip_special_tokens=True)
171178
print("TensorRT generated text:")
172179
print(trt_decoded)
173-
174-
175-
# pytree._register_pytree_node(transformers.modeling_outputs.MaskedLMOutput, lambda x: ([x.loss, x.logits], None), lambda values, _: transformers.modeling_outputs.MaskedLMOutput(loss=values[0], logits=values[1]))

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -813,15 +813,11 @@ def aten_ops_select(
813813
def index_put_validator(
814814
node: Node, settings: Optional[CompilationSettings] = None
815815
) -> bool:
816-
817816
if None in node.args[1]:
818-
_LOGGER.debug(
819-
"We do not support None index yet."
820-
)
817+
_LOGGER.debug("We do not support None index yet.")
821818
return False
822819
else:
823820
return True
824-
825821

826822

827823
@dynamo_tensorrt_converter(

py/torch_tensorrt/dynamo/conversion/impl/select.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Sequence, Union
33

44
import numpy as np
5+
import tensorrt as trt
56
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -13,7 +14,6 @@
1314
get_positive_dim,
1415
get_trt_tensor,
1516
to_numpy,
16-
create_constant,
1717
)
1818
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
1919
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
@@ -24,8 +24,6 @@
2424
)
2525
from torch_tensorrt.fx.types import TRTTensor
2626

27-
import tensorrt as trt
28-
2927
_LOGGER: logging.Logger = logging.getLogger(__name__)
3028

3129

@@ -499,12 +497,12 @@ def index_put_converter(
499497
ctx, target, source_ir, f"{name}_cat_indices", reshaped_indices, dim=1
500498
)
501499

502-
source_shape = tuple(input_tensor.shape)
500+
source_shape = tuple(input_tensor.shape)
503501
k = len(indices)
504-
leftover_dims = source_shape[k:]
502+
leftover_dims = source_shape[k:]
505503

506-
index_shapes_py = [tuple(idx.shape) for idx in reshaped_indices]
507-
N = index_shapes_py[0][0]
504+
index_shapes_py = [tuple(idx.shape) for idx in reshaped_indices]
505+
N = index_shapes_py[0][0]
508506
sub_tensor_shape = (N,) + leftover_dims
509507

510508
broadcasted_values = impl.slice.expand(

py/torch_tensorrt/dynamo/utils.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -378,16 +378,18 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
378378
shape_env.var_to_val
379379
)
380380
assert var_range, var_val
381-
min_val_, max_val_ = var_range.lower, var_range.upper # int(var_range.lower), int(var_range.upper)
381+
min_val_, max_val_ = (
382+
var_range.lower,
383+
var_range.upper,
384+
) # int(var_range.lower), int(var_range.upper)
382385

383386
if isinstance(var_range.lower, torch.utils._sympy.numbers.NegativeIntInfinity):
384-
min_val_ = 1
387+
min_val_ = 1
385388

386389
# if isinstance(var_range.upper, torch.utils._sympy.numbers.IntInfinity):
387-
# max_val_ = 2048
388-
390+
# max_val_ = 2048
389391
min_val = int(min_val_)
390-
max_val = int(max_val_)
392+
max_val = int(max_val_)
391393

392394
# Torchdynamo 0/1 specialization outlier
393395
min_val = 1 if min_val == 2 else min_val

tests/py/dynamo/conversion/test_index_put_aten.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -136,26 +136,25 @@ class TestIndexPutConverter(DispatchTestCase):
136136
),
137137
param(
138138
test_name="3d_indices_float_broadcast_index",
139-
source_tensor=torch.zeros([3, 3, 3], dtype = torch.int32),
139+
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
140140
indices_tensor=(
141-
torch.tensor([0,1], dtype=torch.int32),
142-
torch.tensor([0,1], dtype=torch.int32),
141+
torch.tensor([0, 1], dtype=torch.int32),
142+
torch.tensor([0, 1], dtype=torch.int32),
143143
),
144-
value_tensor=torch.tensor([10], dtype = torch.int32),
144+
value_tensor=torch.tensor([10], dtype=torch.int32),
145145
),
146146
param(
147147
test_name="3d_indices_broadcast_1dim",
148148
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
149149
indices_tensor=(torch.tensor([1], dtype=torch.int32),),
150150
value_tensor=torch.tensor([7], dtype=torch.int32),
151151
),
152-
param(
152+
param(
153153
test_name="2d_indices_broadcast_1dim",
154154
source_tensor=torch.zeros([4, 4], dtype=torch.int32),
155155
indices_tensor=(torch.tensor([1, 3], dtype=torch.int32),),
156156
value_tensor=torch.tensor([5], dtype=torch.int32),
157157
),
158-
159158
# 예시 4) 4D 소스, 2D 인덱스 → 마지막 2차원 전체 브로드캐스트
160159
param(
161160
test_name="4d_indices_broadcast_2dim",
@@ -169,7 +168,7 @@ class TestIndexPutConverter(DispatchTestCase):
169168
# param(
170169
# test_name="4d_indices_none_none_single_idx",
171170
# source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.int32),
172-
# # None이 들어가면 현재 코드에서 문제가 발생할 수 있음
171+
# # indexing with None is WIP.
173172
# indices_tensor=(None, None, torch.tensor([2], dtype=torch.int32)),
174173
# value_tensor=torch.tensor(
175174
# [[[10, 20, 30],

0 commit comments

Comments
 (0)