Skip to content

Commit 2bafe0e

Browse files
author
Chengzhe Xu
committed
feat: support broadcasting index_put
1 parent 4f0bb6f commit 2bafe0e

File tree

6 files changed

+325
-46
lines changed

6 files changed

+325
-46
lines changed

examples/dynamo/torch_compile_pg.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import torch_tensorrt
3+
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
4+
from transformers.image_utils import load_image
5+
6+
DEVICE = "cuda:0"
7+
8+
model_id = "google/paligemma2-3b-pt-224"
9+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
10+
image = load_image(url)
11+
12+
13+
model = PaliGemmaForConditionalGeneration.from_pretrained(
14+
model_id, torch_dtype=torch.float16).eval()
15+
model.to(DEVICE).to(torch.float16)
16+
# model.forward = model.forward.to(torch.float16).eval()
17+
18+
processor = PaliGemmaProcessor.from_pretrained(model_id)
19+
prompt = ""
20+
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(DEVICE) # to(DEVICE) # .to(torch.float16).to(DEVICE)
21+
input_len = model_inputs["input_ids"].shape[-1]
22+
23+
# model.config.token_healing = False
24+
25+
with torch.inference_mode():
26+
pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
27+
pyt_generation_out = pyt_generation[0][input_len:]
28+
pyt_decoded = processor.decode(pyt_generation_out, skip_special_tokens=True)
29+
print("=============================")
30+
print("pyt_generation whole text:")
31+
print(pyt_generation)
32+
print("=============================")
33+
print("=============================")
34+
print("PyTorch generated text:")
35+
print(pyt_decoded)
36+
print("=============================")
37+
38+
with torch_tensorrt.logging.debug():
39+
torch._dynamo.mark_dynamic(model_inputs["input_ids"], 1, min=2, max=256)
40+
model.forward = torch.compile(
41+
model.forward,
42+
backend="tensorrt",
43+
dynamic=None,
44+
options={
45+
"enabled_precisions": {torch.float16},
46+
"disable_tf32": True,
47+
"min_block_size": 1,
48+
# "use_explicit_typing": True,
49+
# "use_fp32_acc": True,
50+
"debug": True,
51+
# "use_aot_joint_export":False,
52+
},
53+
)
54+
55+
with torch.inference_mode():
56+
trt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
57+
trt_generation_out = trt_generation[0][input_len:]
58+
trt_decoded = processor.decode(trt_generation_out, skip_special_tokens=True)
59+
print(trt_generation)
60+
print("TensorRT generated text:")
61+
print(trt_decoded)

examples/dynamo/torch_export_pg.py

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import torch
2+
import torch_tensorrt
3+
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
4+
from transformers.image_utils import load_image
5+
6+
7+
# 1. Model
8+
DEVICE = torch.device("cuda:0")
9+
model_id = "google/paligemma2-3b-pt-224"
10+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
11+
image = load_image(url)
12+
13+
model = PaliGemmaForConditionalGeneration.from_pretrained(
14+
model_id, torch_dtype=torch.float16
15+
).eval().to(DEVICE)
16+
processor = PaliGemmaProcessor.from_pretrained(model_id)
17+
18+
prompt = ""
19+
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE)
20+
input_len = model_inputs["input_ids"].shape[-1]
21+
22+
# 2. PyTorch
23+
with torch.inference_mode():
24+
pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) #, use_cache=False)
25+
# 입력 토큰 이후의 새로 생성된 토큰만 취합니다.
26+
pyt_generation = pyt_generation[0][input_len:]
27+
pyt_decoded = processor.decode(pyt_generation, skip_special_tokens=True)
28+
print("=============================")
29+
print("PyTorch generated text:")
30+
print(pyt_decoded)
31+
print("=============================")
32+
33+
# (a) Dummy inputs
34+
batch_size = 1
35+
dummy_input_ids = model_inputs["input_ids"]
36+
dummy_attention_mask = model_inputs["attention_mask"]
37+
dummy_pixel_values = model_inputs["pixel_values"]
38+
39+
dummy_inputs = {
40+
"input_ids": dummy_input_ids,
41+
"attention_mask": dummy_attention_mask,
42+
"pixel_values": dummy_pixel_values,
43+
}
44+
45+
# (b) Dynamic shape
46+
BATCH = torch.export.Dim("batch", min=1, max=2)
47+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=1024)
48+
dynamic_shapes = {
49+
"input_ids": {0: BATCH, 1: SEQ_LEN},
50+
"attention_mask": {0: BATCH, 1: SEQ_LEN},
51+
"pixel_values": {0: BATCH},
52+
}
53+
# (c) ExportedProgram
54+
# torch.export.export(
55+
# model,
56+
# args=(),
57+
# kwargs=dummy_inputs,
58+
# dynamic_shapes=dynamic_shapes,
59+
# strict=False,
60+
# )
61+
62+
63+
import torch
64+
import torch.utils._pytree as pytree
65+
import transformers
66+
67+
def flatten_hybridcache(hc: transformers.cache_utils.HybridCache):
68+
"""
69+
1) HybridCache 내부의 '텐서'들을 리스트로 모은다.
70+
2) 텐서가 아닌 값들은 context(dict)에 담는다.
71+
"""
72+
# 1. 텐서로 취급할 것들: is_sliding, key_cache 전체, value_cache 전체
73+
flat_tensors = []
74+
flat_tensors.append(hc.is_sliding) # shape: [num_hidden_layers], bool
75+
flat_tensors.extend(hc.key_cache) # List[Tensor]
76+
flat_tensors.extend(hc.value_cache) # List[Tensor]
77+
78+
# 2. 텐서가 아닌 필드는 context로 저장
79+
context = {
80+
"max_cache_len": hc.max_cache_len,
81+
"max_batch_size": hc.max_batch_size,
82+
"head_dim": hc.head_dim,
83+
"dtype": hc.dtype,
84+
"num_key_value_heads": hc.num_key_value_heads,
85+
# unflatten 시에 key_cache / value_cache를 몇 개씩 떼어낼지 알아야 하므로
86+
"num_layers": len(hc.key_cache), # = len(hc.value_cache) = config.num_hidden_layers
87+
}
88+
89+
return flat_tensors, context
90+
91+
92+
def unflatten_hybridcache(flat_tensors, context):
93+
"""
94+
flatten_hybridcache에서 분리한 (flat_tensors, context)를 받아
95+
다시 HybridCache 객체로 복원하는 함수.
96+
"""
97+
num_layers = context["num_layers"]
98+
99+
# 1. flat_tensors 파싱
100+
# - 첫 번째 요소가 is_sliding
101+
# - 그 다음 num_layers개: key_cache
102+
# - 그 다음 num_layers개: value_cache
103+
is_sliding = flat_tensors[0]
104+
key_cache = flat_tensors[1 : 1 + num_layers]
105+
value_cache = flat_tensors[1 + num_layers : 1 + 2*num_layers]
106+
107+
# 2. __new__로 빈 HybridCache 객체 생성 (생성자 __init__은 호출 안 함)
108+
hc = transformers.cache_utils.HybridCache.__new__(transformers.cache_utils.HybridCache)
109+
110+
# 3. 필요한 필드를 직접 셋팅
111+
hc.max_cache_len = context["max_cache_len"]
112+
hc.max_batch_size = context["max_batch_size"]
113+
hc.head_dim = context["head_dim"]
114+
hc.dtype = context["dtype"]
115+
hc.num_key_value_heads = context["num_key_value_heads"]
116+
hc.is_sliding = is_sliding
117+
hc.key_cache = list(key_cache)
118+
hc.value_cache = list(value_cache)
119+
120+
return hc
121+
122+
# pytree 등록
123+
pytree.register_pytree_node(
124+
transformers.cache_utils.HybridCache,
125+
flatten_hybridcache,
126+
unflatten_hybridcache
127+
)
128+
129+
# from torch.export._trace import _export
130+
# exported_program = _export(
131+
# model,
132+
# args=(),
133+
# kwargs=dummy_inputs,
134+
# dynamic_shapes=dynamic_shapes,
135+
# strict=False,
136+
# allow_complex_guards_as_runtime_asserts=True,
137+
# )
138+
139+
# torch.export._draft_export.draft_export
140+
import torch.export._draft_export
141+
exported_program = torch.export._draft_export.draft_export(
142+
model,
143+
args=(),
144+
kwargs=dummy_inputs,
145+
dynamic_shapes=dynamic_shapes,
146+
strict=False,
147+
# allow_complex_guards_as_runtime_asserts=True,
148+
)
149+
150+
151+
trt_model = torch_tensorrt.dynamo.compile(
152+
exported_program[0],
153+
inputs=dummy_inputs,
154+
enabled_precisions={torch.float32},
155+
truncate_double=True,
156+
device=DEVICE,
157+
disable_tf32=True,
158+
use_explicit_typing=True,
159+
use_fp32_acc=True, # FP32 누적을 사용해 정확도를 보존합니다.
160+
)
161+
162+
# ----------------------------
163+
# 5. TensorRT 모델로 생성 수행
164+
# ----------------------------
165+
# (원래의 모델 입력을 GPU로 이동시킨 후 generate() 호출)
166+
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
167+
with torch.inference_mode():
168+
trt_generation = trt_model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
169+
trt_generation = trt_generation[0][input_len:]
170+
trt_decoded = processor.decode(trt_generation, skip_special_tokens=True)
171+
print("TensorRT generated text:")
172+
print(trt_decoded)
173+
174+
175+
# pytree._register_pytree_node(transformers.modeling_outputs.MaskedLMOutput, lambda x: ([x.loss, x.logits], None), lambda values, _: transformers.modeling_outputs.MaskedLMOutput(loss=values[0], logits=values[1]))

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+7-27
Original file line numberDiff line numberDiff line change
@@ -813,40 +813,20 @@ def aten_ops_select(
813813
def index_put_validator(
814814
node: Node, settings: Optional[CompilationSettings] = None
815815
) -> bool:
816-
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
817-
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
818-
accumulate_valid = False
819-
else:
820-
accumulate_valid = True
821-
822-
# Retrieve input tensor's meta information
823-
input_meta = node.args[0].meta.get("tensor_meta")
824-
if not input_meta:
825-
_LOGGER.warning(
826-
"Meta information of input is missing. Unable to validate if broadcasting is needed, falling back to PyTorch operation."
816+
817+
if None in node.args[1]:
818+
_LOGGER.debug(
819+
"We do not support None index yet."
827820
)
828821
return False
829-
830-
input_shape = input_meta.shape
831-
input_num_dims = len(input_shape)
832-
833-
# Check if broadcasting is valid
834-
indices_num_dims = len(node.args[1])
835-
if indices_num_dims == input_num_dims:
836-
broadcast_valid = True
837822
else:
838-
_LOGGER.debug(
839-
"We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions."
840-
)
841-
broadcast_valid = False
842-
843-
# Return validation result
844-
return accumulate_valid and broadcast_valid
823+
return True
824+
845825

846826

847827
@dynamo_tensorrt_converter(
848828
torch.ops.aten.index_put.default,
849-
capability_validator=index_put_validator,
829+
# capability_validator=index_put_validator,
850830
)
851831
@enforce_tensor_types(
852832
{

py/torch_tensorrt/dynamo/conversion/impl/select.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_positive_dim,
1414
get_trt_tensor,
1515
to_numpy,
16+
create_constant,
1617
)
1718
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
1819
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
@@ -480,28 +481,46 @@ def index_put_converter(
480481
values: TRTTensor,
481482
accumulate: bool = False,
482483
) -> TRTTensor:
483-
# Reshape indices to add an extra dimension if necessary (indices is a Tuple of ITensors)
484484
reshaped_indices = []
485-
for i, each_input in enumerate(indices):
486-
if not isinstance(each_input, TRTTensor):
487-
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
488-
each_input = impl.shuffle.reshape(
485+
for i, each_idx in enumerate(indices):
486+
idx_trt = get_trt_tensor(ctx, each_idx, f"{name}_idx_{i}")
487+
idx_trt = impl.shuffle.reshape(
489488
ctx,
490489
target,
491490
source_ir,
492-
f"{name}_reshape_{i}",
493-
each_input,
494-
(-1, 1), # Reshape to (N, 1)
491+
f"{name}_reshape_idx_{i}",
492+
idx_trt,
493+
shape=(-1, 1),
495494
)
496-
reshaped_indices.append(each_input)
495+
reshaped_indices.append(idx_trt)
497496

498-
# Concatenate along the second dimension (columns)
497+
# Concat -> (N, K)
499498
indices_cat = impl.cat.cat(
500-
ctx, target, source_ir, f"{name}_cat", reshaped_indices, dim=1
499+
ctx, target, source_ir, f"{name}_cat_indices", reshaped_indices, dim=1
500+
)
501+
502+
source_shape = tuple(input_tensor.shape)
503+
k = len(indices)
504+
leftover_dims = source_shape[k:]
505+
506+
index_shapes_py = [tuple(idx.shape) for idx in reshaped_indices]
507+
N = index_shapes_py[0][0]
508+
sub_tensor_shape = (N,) + leftover_dims
509+
510+
broadcasted_values = impl.slice.expand(
511+
ctx,
512+
target,
513+
source_ir,
514+
f"{name}_expand_values",
515+
values,
516+
sub_tensor_shape,
501517
)
502518

503519
scatter_layer = ctx.net.add_scatter(
504-
input_tensor, indices_cat, values, trt.ScatterMode.ND
520+
input_tensor,
521+
indices_cat,
522+
broadcasted_values,
523+
trt.ScatterMode.ND,
505524
)
506525
scatter_layer.axis = 0
507526
set_layer_name(scatter_layer, target, f"{name}_scatter_layer", source_ir)

py/torch_tensorrt/dynamo/utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,17 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
378378
shape_env.var_to_val
379379
)
380380
assert var_range, var_val
381-
min_val, max_val = int(var_range.lower), int(var_range.upper)
381+
min_val_, max_val_ = var_range.lower, var_range.upper # int(var_range.lower), int(var_range.upper)
382+
383+
if isinstance(var_range.lower, torch.utils._sympy.numbers.NegativeIntInfinity):
384+
min_val_ = 1
385+
386+
# if isinstance(var_range.upper, torch.utils._sympy.numbers.IntInfinity):
387+
# max_val_ = 2048
388+
389+
min_val = int(min_val_)
390+
max_val = int(max_val_)
391+
382392
# Torchdynamo 0/1 specialization outlier
383393
min_val = 1 if min_val == 2 else min_val
384394
min_max_opt = {}

0 commit comments

Comments
 (0)