Skip to content

Commit de81be2

Browse files
authored
feat: support 1d ITensor offsets for embedding_bag converter (#2677)
1 parent ff8c872 commit de81be2

File tree

6 files changed

+708
-128
lines changed

6 files changed

+708
-128
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -231,26 +231,7 @@ def aten_ops_cat(
231231
)
232232

233233

234-
def embedding_param_validator(embedding_node: Node) -> bool:
235-
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
236-
sparse = args_bounds_check(embedding_node.args, 4)
237-
238-
if scale_grad_by_freq is not None:
239-
_LOGGER.debug(
240-
f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}."
241-
)
242-
return False
243-
244-
if sparse is not None:
245-
_LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.")
246-
return False
247-
248-
return True
249-
250-
251-
@dynamo_tensorrt_converter(
252-
torch.ops.aten.embedding.default, capability_validator=embedding_param_validator
253-
)
234+
@dynamo_tensorrt_converter(torch.ops.aten.embedding.default)
254235
def aten_ops_embedding(
255236
ctx: ConversionContext,
256237
target: Target,
@@ -265,22 +246,19 @@ def aten_ops_embedding(
265246
name,
266247
input=args[1],
267248
weight=args[0],
268-
# args[2] is the padding index, which is useful for training only
269-
scale_grad_by_freq=args_bounds_check(args, 3),
270-
sparse=args_bounds_check(args, 4),
271249
)
272250

273251

274252
def embedding_bag_validator(node: Node) -> bool:
275-
mode = args_bounds_check(node.args, 4, 0)
276-
indices = node.args[1].meta.get("tensor_meta")
253+
if not one_user_validator(node):
254+
return False
255+
meta = node.args[1].meta
256+
indices = meta.get("tensor_meta")
257+
if indices is None:
258+
indices = meta.get("val")
277259
if indices is None:
278260
return False
279-
return (
280-
bool(node.args[2].op == "get_attr")
281-
and (mode == 0 or mode == 1 or mode == 2)
282-
and len(indices.shape) == 1
283-
)
261+
return len(indices.shape) == 1 # currently only support 1D indices
284262

285263

286264
@dynamo_tensorrt_converter(
@@ -293,7 +271,6 @@ def embedding_bag_validator(node: Node) -> bool:
293271
{
294272
0: (TRTTensor,),
295273
1: (TRTTensor,),
296-
2: (np.ndarray, torch.Tensor),
297274
}
298275
)
299276
def aten_ops_embedding_bag(
@@ -311,12 +288,9 @@ def aten_ops_embedding_bag(
311288
weight=args[0],
312289
indices=args[1],
313290
offsets=args[2],
314-
scale_grad_by_freq=args_bounds_check(args, 3, False),
315291
mode=args_bounds_check(args, 4, 0),
316-
sparse=args_bounds_check(args, 5, False),
317292
per_sample_weights=args_bounds_check(args, 6, None),
318293
include_last_offset=args_bounds_check(args, 7, False),
319-
# padding index is useful for training only
320294
)
321295

322296

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import torch
8+
import torch_tensorrt.dynamo.conversion.impl as impl
89
from torch import SymBool, SymFloat, SymInt
910
from torch.fx.node import Argument, Target
1011
from torch_tensorrt import _enums
@@ -530,3 +531,111 @@ def flatten_dims(
530531
new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])
531532

532533
return new_shape
534+
535+
536+
def append(
537+
ctx: ConversionContext,
538+
target: Target,
539+
source_ir: Optional[SourceIR],
540+
name: str,
541+
original_tensor: TRTTensor,
542+
new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
543+
dim: int = 0,
544+
) -> TRTTensor:
545+
"""
546+
Append a new value to the last of the original tensor along the specified dimension (default 0).
547+
For example, if the original tensor is [1, 2, 3], the new value is 4, and the dim is 0,
548+
the new tensor will be [1, 2, 3, 4].
549+
550+
Args:
551+
ctx (ConversionContext): A ConversionContext containing the TensorRT network
552+
target (Target): Target of calling node
553+
source_ir (Optional[SourceIR]): SourceIR of calling converter
554+
name (str): Name of the calling layer
555+
original_tensor (TRTTensor): A TRTTensor to append the new value to
556+
new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to append
557+
dim (int, optional): Dimention to append the new value. Defaults to 0.
558+
559+
Returns:
560+
TRTTensor: A new TRTTensor that is the result of appending the new value to the original tensor
561+
"""
562+
if isinstance(new_value, (int, float)):
563+
new_value = np.array([new_value])
564+
new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype)
565+
566+
return impl.cat.cat(
567+
ctx,
568+
target,
569+
source_ir,
570+
f"{name}_concat",
571+
[original_tensor, new_value],
572+
get_positive_dim(dim, len(original_tensor.shape)),
573+
)
574+
575+
576+
def set_item(
577+
ctx: ConversionContext,
578+
target: Target,
579+
source_ir: Optional[SourceIR],
580+
name: str,
581+
original_tensor: TRTTensor,
582+
index: int,
583+
new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
584+
) -> TRTTensor:
585+
"""
586+
Set a new value to the original tensor at the specified index. For example,
587+
if the original tensor is [1, 2, 3], the new value is 4, and the index is 1,
588+
the new tensor will be [1, 4, 3].
589+
If the index is out of bound, the new value will be appended to the end.
590+
591+
Args:
592+
ctx (ConversionContext): A ConversionContext containing the TensorRT network
593+
target (Target): Target of calling node
594+
source_ir (Optional[SourceIR]): SourceIR of calling converter
595+
name (str): Name of the calling layer
596+
original_tensor (TRTTensor): A TRTTensor to set the new value to
597+
index (int): The index to set the new value
598+
new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to set
599+
600+
Returns:
601+
TRTTensor: A new TRTTensor that is the result of setting the new value to the original tensor
602+
"""
603+
if isinstance(new_value, (int, float)):
604+
new_value = np.array([new_value])
605+
new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype)
606+
607+
len_original_tensor = original_tensor.shape[0]
608+
index = get_positive_dim(index, len_original_tensor)
609+
610+
front_tensor = impl.slice.slice_op(
611+
ctx,
612+
target,
613+
source_ir,
614+
f"{name}_slice_front",
615+
original_tensor,
616+
dim=0,
617+
start=0,
618+
stop=index,
619+
step=1,
620+
)
621+
rear_tensor = impl.slice.slice_op(
622+
ctx,
623+
target,
624+
source_ir,
625+
f"{name}_slice_rear",
626+
original_tensor,
627+
dim=0,
628+
start=index + 1,
629+
stop=len_original_tensor,
630+
step=1,
631+
)
632+
633+
ans = impl.cat.cat(
634+
ctx,
635+
target,
636+
source_ir,
637+
f"{name}_concat",
638+
[front_tensor, new_value, rear_tensor],
639+
0,
640+
)
641+
return ans

0 commit comments

Comments
 (0)