Skip to content

Support broadcast index put #3421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions examples/dynamo/torch_compile_pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import torch_tensorrt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image

DEVICE = "cuda:0"

model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)


model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16).eval()
model.to(DEVICE).to(torch.float16)
# model.forward = model.forward.to(torch.float16).eval()

processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(DEVICE) # to(DEVICE) # .to(torch.float16).to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]

# model.config.token_healing = False

with torch.inference_mode():
pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
pyt_generation_out = pyt_generation[0][input_len:]
pyt_decoded = processor.decode(pyt_generation_out, skip_special_tokens=True)
print("=============================")
print("pyt_generation whole text:")
print(pyt_generation)
print("=============================")
print("=============================")
print("PyTorch generated text:")
print(pyt_decoded)
print("=============================")

with torch_tensorrt.logging.debug():
torch._dynamo.mark_dynamic(model_inputs["input_ids"], 1, min=2, max=256)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float16},
"disable_tf32": True,
"min_block_size": 1,
# "use_explicit_typing": True,
# "use_fp32_acc": True,
"debug": True,
# "use_aot_joint_export":False,
},
)

with torch.inference_mode():
trt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
trt_generation_out = trt_generation[0][input_len:]
trt_decoded = processor.decode(trt_generation_out, skip_special_tokens=True)
print(trt_generation)
print("TensorRT generated text:")
print(trt_decoded)
179 changes: 179 additions & 0 deletions examples/dynamo/torch_export_pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import torch
import torch_tensorrt
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
from transformers.image_utils import load_image

# 1. Model
DEVICE = torch.device("cuda:0")
model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = (
PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16
)
.eval()
.to(DEVICE)
)
processor = PaliGemmaProcessor.from_pretrained(model_id)

prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]

# 2. PyTorch
with torch.inference_mode():
pyt_generation = model.generate(
**model_inputs, max_new_tokens=100, do_sample=False
) # , use_cache=False)
# The newly generated tokens after the input tokens.
pyt_generation = pyt_generation[0][input_len:]
pyt_decoded = processor.decode(pyt_generation, skip_special_tokens=True)
print("=============================")
print("PyTorch generated text:")
print(pyt_decoded)
print("=============================")

# (a) Dummy inputs
batch_size = 1
dummy_input_ids = model_inputs["input_ids"]
dummy_attention_mask = model_inputs["attention_mask"]
dummy_pixel_values = model_inputs["pixel_values"]

dummy_inputs = {
"input_ids": dummy_input_ids,
"attention_mask": dummy_attention_mask,
"pixel_values": dummy_pixel_values,
}

# (b) Dynamic shape
BATCH = torch.export.Dim("batch", min=1, max=2)
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=1024)
dynamic_shapes = {
"input_ids": {0: BATCH, 1: SEQ_LEN},
"attention_mask": {0: BATCH, 1: SEQ_LEN},
"pixel_values": {0: BATCH},
}
# (c) ExportedProgram
# torch.export.export(
# model,
# args=(),
# kwargs=dummy_inputs,
# dynamic_shapes=dynamic_shapes,
# strict=False,
# )


import torch
import torch.utils._pytree as pytree
import transformers


def flatten_hybridcache(hc: transformers.cache_utils.HybridCache):
"""
1) Collects all tensors inside HybridCache into a list.
2) Stores non-tensor values in the context (dictionary).
"""
# 1. Tensors: is_sliding, entire key_cache, entire value_cache
flat_tensors = []
flat_tensors.append(hc.is_sliding) # shape: [num_hidden_layers], bool
flat_tensors.extend(hc.key_cache) # List[Tensor]
flat_tensors.extend(hc.value_cache) # List[Tensor]

# 2. Store non-tensor fields in the context
context = {
"max_cache_len": hc.max_cache_len,
"max_batch_size": hc.max_batch_size,
"head_dim": hc.head_dim,
"dtype": hc.dtype,
"num_key_value_heads": hc.num_key_value_heads,
"num_layers": len(
hc.key_cache
), # = len(hc.value_cache) = config.num_hidden_layers
}

return flat_tensors, context


def unflatten_hybridcache(flat_tensors, context):
"""
Restores a HybridCache object from the (flat_tensors, context) produced by flatten_hybridcache.
"""
num_layers = context["num_layers"]

# 1. Parse flat_tensors
# - First element is is_sliding
# - Next num_layers elements: key_cache
# - Next num_layers elements: value_cache
is_sliding = flat_tensors[0]
key_cache = flat_tensors[1 : 1 + num_layers]
value_cache = flat_tensors[1 + num_layers : 1 + 2 * num_layers]

# 2. Create an empty HybridCache object using __new__ (without calling __init__)
hc = transformers.cache_utils.HybridCache.__new__(
transformers.cache_utils.HybridCache
)

# 3. Manually set required fields
hc.max_cache_len = context["max_cache_len"]
hc.max_batch_size = context["max_batch_size"]
hc.head_dim = context["head_dim"]
hc.dtype = context["dtype"]
hc.num_key_value_heads = context["num_key_value_heads"]
hc.is_sliding = is_sliding
hc.key_cache = list(key_cache)
hc.value_cache = list(value_cache)

return hc


# Register with pytree
pytree.register_pytree_node(
transformers.cache_utils.HybridCache, flatten_hybridcache, unflatten_hybridcache
)

# from torch.export._trace import _export
# exported_program = _export(
# model,
# args=(),
# kwargs=dummy_inputs,
# dynamic_shapes=dynamic_shapes,
# strict=False,
# allow_complex_guards_as_runtime_asserts=True,
# )

# torch.export._draft_export.draft_export
import torch.export._draft_export

exported_program = torch.export._draft_export.draft_export(
model,
args=(),
kwargs=dummy_inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
# allow_complex_guards_as_runtime_asserts=True,
)


trt_model = torch_tensorrt.dynamo.compile(
exported_program[0],
inputs=dummy_inputs,
enabled_precisions={torch.float32},
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_explicit_typing=True,
use_fp32_acc=True,
)

# Execute generation using TensorRT model
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
with torch.inference_mode():
trt_generation = trt_model.generate(
**model_inputs, max_new_tokens=100, do_sample=False
)
trt_generation = trt_generation[0][input_len:]
trt_decoded = processor.decode(trt_generation, skip_special_tokens=True)
print("TensorRT generated text:")
print(trt_decoded)
32 changes: 4 additions & 28 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,40 +813,16 @@ def aten_ops_select(
def index_put_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
accumulate_valid = False
else:
accumulate_valid = True

# Retrieve input tensor's meta information
input_meta = node.args[0].meta.get("tensor_meta")
if not input_meta:
_LOGGER.warning(
"Meta information of input is missing. Unable to validate if broadcasting is needed, falling back to PyTorch operation."
)
if None in node.args[1]:
_LOGGER.debug("We do not support None index yet.")
return False

input_shape = input_meta.shape
input_num_dims = len(input_shape)

# Check if broadcasting is valid
indices_num_dims = len(node.args[1])
if indices_num_dims == input_num_dims:
broadcast_valid = True
else:
_LOGGER.debug(
"We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions."
)
broadcast_valid = False

# Return validation result
return accumulate_valid and broadcast_valid
return True


@dynamo_tensorrt_converter(
torch.ops.aten.index_put.default,
capability_validator=index_put_validator,
# capability_validator=index_put_validator,
)
@enforce_tensor_types(
{
Expand Down
45 changes: 31 additions & 14 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Sequence, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -23,8 +24,6 @@
)
from torch_tensorrt.fx.types import TRTTensor

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -480,28 +479,46 @@ def index_put_converter(
values: TRTTensor,
accumulate: bool = False,
) -> TRTTensor:
# Reshape indices to add an extra dimension if necessary (indices is a Tuple of ITensors)
reshaped_indices = []
for i, each_input in enumerate(indices):
if not isinstance(each_input, TRTTensor):
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
each_input = impl.shuffle.reshape(
for i, each_idx in enumerate(indices):
idx_trt = get_trt_tensor(ctx, each_idx, f"{name}_idx_{i}")
idx_trt = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_{i}",
each_input,
(-1, 1), # Reshape to (N, 1)
f"{name}_reshape_idx_{i}",
idx_trt,
shape=(-1, 1),
)
reshaped_indices.append(each_input)
reshaped_indices.append(idx_trt)

# Concatenate along the second dimension (columns)
# Concat -> (N, K)
indices_cat = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat", reshaped_indices, dim=1
ctx, target, source_ir, f"{name}_cat_indices", reshaped_indices, dim=1
)

source_shape = tuple(input_tensor.shape)
k = len(indices)
leftover_dims = source_shape[k:]

index_shapes_py = [tuple(idx.shape) for idx in reshaped_indices]
N = index_shapes_py[0][0]
sub_tensor_shape = (N,) + leftover_dims

broadcasted_values = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_values",
values,
sub_tensor_shape,
)

scatter_layer = ctx.net.add_scatter(
input_tensor, indices_cat, values, trt.ScatterMode.ND
input_tensor,
indices_cat,
broadcasted_values,
trt.ScatterMode.ND,
)
scatter_layer.axis = 0
set_layer_name(scatter_layer, target, f"{name}_scatter_layer", source_ir)
Expand Down
14 changes: 13 additions & 1 deletion py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,19 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
shape_env.var_to_val
)
assert var_range, var_val
min_val, max_val = int(var_range.lower), int(var_range.upper)
min_val_, max_val_ = (
var_range.lower,
var_range.upper,
) # int(var_range.lower), int(var_range.upper)

if isinstance(var_range.lower, torch.utils._sympy.numbers.NegativeIntInfinity):
min_val_ = 1

# if isinstance(var_range.upper, torch.utils._sympy.numbers.IntInfinity):
# max_val_ = 2048
min_val = int(min_val_)
max_val = int(max_val_)

# Torchdynamo 0/1 specialization outlier
min_val = 1 if min_val == 2 else min_val
min_max_opt = {}
Expand Down
Loading
Loading