Skip to content

Commit cb20f90

Browse files
authored
feat: support embedding_bag converter (1D input) (#2395)
1 parent d649d12 commit cb20f90

File tree

3 files changed

+321
-3
lines changed

3 files changed

+321
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+45
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,51 @@ def aten_ops_embedding(
233233
)
234234

235235

236+
def embedding_bag_validator(node: Node) -> bool:
237+
mode = args_bounds_check(node.args, 4, 0)
238+
indices = node.args[1].meta.get("tensor_meta")
239+
if indices is None:
240+
return False
241+
return (
242+
bool(node.args[2].op == "get_attr")
243+
and (mode == 0 or mode == 1 or mode == 2)
244+
and len(indices.shape) == 1
245+
)
246+
247+
248+
@dynamo_tensorrt_converter(torch.ops.aten.embedding_bag.default, capability_validator=embedding_bag_validator) # type: ignore[misc]
249+
@dynamo_tensorrt_converter(torch.ops.aten._embedding_bag.default, capability_validator=embedding_bag_validator) # type: ignore[misc]
250+
@enforce_tensor_types(
251+
{
252+
0: (TRTTensor,),
253+
1: (TRTTensor,),
254+
2: (np.ndarray, torch.Tensor),
255+
}
256+
) # type: ignore[misc]
257+
def aten_ops_embedding_bag(
258+
ctx: ConversionContext,
259+
target: Target,
260+
args: Tuple[Argument, ...],
261+
kwargs: Dict[str, Argument],
262+
name: str,
263+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
264+
return impl.embedding.embedding_bag(
265+
ctx,
266+
target,
267+
SourceIR.ATEN,
268+
name,
269+
weight=args[0],
270+
indices=args[1],
271+
offsets=args[2],
272+
scale_grad_by_freq=args_bounds_check(args, 3, False),
273+
mode=args_bounds_check(args, 4, 0),
274+
sparse=args_bounds_check(args, 5, False),
275+
per_sample_weights=args_bounds_check(args, 6, None),
276+
include_last_offset=args_bounds_check(args, 7, False),
277+
# padding index is useful for training only
278+
)
279+
280+
236281
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar) # type: ignore[misc]
237282
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor) # type: ignore[misc]
238283
def aten_ops_fmod(

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

+135-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import Optional
1+
import functools
2+
from typing import Optional, Sequence, Tuple, Union
23

4+
import numpy as np
35
import torch
6+
import torch_tensorrt.dynamo.conversion.impl as impl
47
from torch.fx.node import Target
58
from torch_tensorrt.dynamo._SourceIR import SourceIR
69
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7-
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
10+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy
811
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
912
from torch_tensorrt.fx.types import TRTTensor
1013

@@ -40,5 +43,134 @@ def embedding(
4043

4144
# Implement embedding lookup with gather layer
4245
gather_layer = ctx.net.add_gather(embedding_tensor, indices_tensor, axis=0)
43-
set_layer_name(gather_layer, target, name + "_gather", source_ir)
46+
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)
4447
return gather_layer.get_output(0)
48+
49+
50+
def embedding_bag(
51+
ctx: ConversionContext,
52+
target: Target,
53+
source_ir: Optional[SourceIR],
54+
name: str,
55+
weight: TRTTensor,
56+
indices: TRTTensor,
57+
offsets: Union[torch.Tensor, np.ndarray, Sequence[int]],
58+
scale_grad_by_freq: bool,
59+
mode: int,
60+
sparse: bool,
61+
per_sample_weights: Optional[TRTTensor],
62+
include_last_offset: bool,
63+
) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
64+
"""
65+
This function is for calculating embedding bags.
66+
67+
In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N),
68+
it will be treated as B bags (sequences) each of fixed length N, and this will return
69+
B values aggregated in a way depending on the mode. `offsets` is ignored and required
70+
to be None in this case.
71+
72+
However, according to the schema, `offsets` is required for input with any dimensions.
73+
Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags.
74+
"""
75+
76+
# TODO: support 2D inputs
77+
# indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,))
78+
79+
if mode == 0: # sum
80+
reduce_op = functools.partial(
81+
impl.reduce.sum, ctx=ctx, target=target, source_ir=source_ir
82+
)
83+
reduce_name = "sum"
84+
elif mode == 1: # mean
85+
reduce_op = functools.partial(
86+
impl.reduce.mean, ctx=ctx, target=target, source_ir=source_ir
87+
)
88+
reduce_name = "mean"
89+
elif mode == 2: # max
90+
reduce_op = functools.partial(
91+
impl.reduce.max,
92+
ctx=ctx,
93+
target=target,
94+
source_ir=source_ir,
95+
return_indices=False,
96+
)
97+
reduce_name = "max"
98+
99+
# calculate embedding
100+
embed = embedding(
101+
ctx,
102+
target,
103+
source_ir,
104+
f"{name}_embedding",
105+
indices,
106+
weight,
107+
scale_grad_by_freq,
108+
sparse,
109+
)
110+
111+
# give weights to embedding
112+
if per_sample_weights is not None:
113+
assert (
114+
per_sample_weights.shape == indices.shape
115+
), f"`per_sample_weights` (shape: {per_sample_weights.shape}) must have exactly the same shape as indices/input (shape: {indices.shape})!"
116+
per_sample_weights = get_trt_tensor(
117+
ctx, per_sample_weights, f"{name}_per_sample_weights", np.float32
118+
)
119+
per_sample_weights = impl.shuffle.reshape(
120+
ctx,
121+
target,
122+
source_ir,
123+
f"{name}_reshape_per_sample_weights",
124+
per_sample_weights,
125+
(-1, 1),
126+
)
127+
embed = impl.elementwise.mul(
128+
ctx,
129+
target,
130+
source_ir,
131+
f"{name}_mul_per_sample_weights",
132+
embed,
133+
per_sample_weights,
134+
)
135+
136+
offsets = to_numpy(offsets)
137+
138+
if include_last_offset is False:
139+
# add the end index to offsets
140+
offsets = np.append(offsets, indices.shape[0])
141+
else:
142+
# modify the last index of offsets to the end index
143+
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
144+
# is equal to the number of bags + 1. The last element is the size of the input,
145+
# or the ending index position of the last bag (sequence).
146+
147+
offsets[-1] = indices.shape[0]
148+
149+
# separately reduce embeddings for different bags
150+
reduced_embed = []
151+
len_offsets = len(offsets)
152+
for i in range(len_offsets - 1):
153+
if offsets[i] < offsets[i + 1]:
154+
sliced_embed = impl.slice.slice_op(
155+
ctx,
156+
target,
157+
source_ir,
158+
f"{name}_slice_embed_{i}",
159+
embed,
160+
0,
161+
offsets[i],
162+
offsets[i + 1],
163+
1,
164+
)
165+
reduced_sliced_embed = reduce_op(
166+
name=f"{name}_{reduce_name}_{i}",
167+
input_val=sliced_embed,
168+
dim=0,
169+
keepdim=True,
170+
)
171+
reduced_embed.append(reduced_sliced_embed)
172+
173+
out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed, 0)
174+
# out = reduce_op(input_val=embed, dim=1, keepdim=False) # Note: This implementation doesn't work for N-dim
175+
176+
return out, None, None, None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import torch
2+
from parameterized import param, parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
5+
from .harness import DispatchTestCase
6+
7+
8+
class TestEmbeddingBagConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
# 1D input
12+
param(
13+
test_name="1d_indices_1",
14+
weight=torch.randn((10, 3), dtype=torch.float32),
15+
indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
16+
offsets=torch.tensor([0, 3], dtype=torch.int32),
17+
scale_grad_by_freq=False,
18+
mode=1,
19+
sparse=False,
20+
per_sample_weights=None,
21+
include_last_offset=True,
22+
padding_idx=-1,
23+
),
24+
param(
25+
test_name="1d_indices_2",
26+
weight=torch.randn((10, 3), dtype=torch.float32),
27+
indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
28+
offsets=torch.tensor([0, 5], dtype=torch.int32),
29+
scale_grad_by_freq=False,
30+
mode=0,
31+
sparse=False,
32+
per_sample_weights=torch.randn((6,)),
33+
include_last_offset=False,
34+
padding_idx=-1,
35+
),
36+
param(
37+
test_name="1d_indices_3",
38+
weight=torch.randn((10, 3), dtype=torch.float32),
39+
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
40+
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
41+
scale_grad_by_freq=False,
42+
mode=2,
43+
sparse=False,
44+
per_sample_weights=None,
45+
include_last_offset=False,
46+
padding_idx=-1,
47+
),
48+
# 2D input
49+
# param(
50+
# test_name="2d_indices_1",
51+
# weight=torch.randn((5, 10), dtype=torch.float32),
52+
# indices=torch.tensor([[3, 1], [4, 3]], dtype=torch.int32),
53+
# offsets=torch.tensor([0, 1], dtype=torch.int32),
54+
# scale_grad_by_freq=False,
55+
# mode=0,
56+
# sparse=False,
57+
# per_sample_weights=torch.randn((4,)),
58+
# include_last_offset=False,
59+
# padding_idx=-1,
60+
# ),
61+
# param(
62+
# test_name="2d_indices_3",
63+
# weight=torch.tensor([
64+
# [0.0, 0.0, 0.0],
65+
# [1.0, 1.0, 1.0],
66+
# [2.0, 2.0, 2.0],
67+
# [3.0, 3.0, 3.0],
68+
# [4.0, 4.0, 4.0],
69+
# [5.0, 5.0, 5.0],
70+
# ], dtype=torch.float32),
71+
# indices=torch.tensor([[0, 2, 1], [3, 5, 4]], dtype=torch.int32),
72+
# offsets=torch.tensor([0, 1], dtype=torch.int32),
73+
# scale_grad_by_freq=False,
74+
# mode=2,
75+
# sparse=False,
76+
# per_sample_weights=None,
77+
# include_last_offset=False,
78+
# padding_idx=-1,
79+
# ),
80+
# param(
81+
# test_name="2d_indices_2",
82+
# weight=torch.randn((5, 5), dtype=torch.float32),
83+
# indices=torch.tensor([[3, 1, 2], [4, 2, 3]], dtype=torch.int32),
84+
# offsets=torch.tensor([0, 2], dtype=torch.int32),
85+
# scale_grad_by_freq=False,
86+
# mode=1,
87+
# sparse=False,
88+
# per_sample_weights=None,
89+
# include_last_offset=False,
90+
# padding_idx=-1,
91+
# ),
92+
# param(
93+
# test_name="2d_indices_2",
94+
# weight=torch.randn((5, 10), dtype=torch.float32),
95+
# indices=torch.tensor([[3, 1, 2, 4], [4, 1, 3, 1]], dtype=torch.int32),
96+
# offsets=torch.tensor([0, 2], dtype=torch.int32),
97+
# scale_grad_by_freq=False,
98+
# mode=0,
99+
# sparse=False,
100+
# per_sample_weights=torch.randn((8,)),
101+
# include_last_offset=True,
102+
# padding_idx=-1,
103+
# ),
104+
]
105+
)
106+
def test_embedding_bag(
107+
self,
108+
test_name,
109+
weight,
110+
indices,
111+
offsets,
112+
scale_grad_by_freq,
113+
mode,
114+
sparse,
115+
per_sample_weights,
116+
include_last_offset,
117+
padding_idx,
118+
):
119+
class TestEmbeddingBag(torch.nn.Module):
120+
def forward(self, weight, indices):
121+
return torch.ops.aten._embedding_bag.default(
122+
weight,
123+
indices,
124+
offsets,
125+
scale_grad_by_freq,
126+
mode,
127+
sparse,
128+
per_sample_weights,
129+
include_last_offset,
130+
padding_idx,
131+
)[0]
132+
133+
self.run_test(
134+
TestEmbeddingBag(),
135+
inputs=[weight, indices],
136+
enable_passes=True,
137+
)
138+
139+
140+
if __name__ == "__main__":
141+
run_tests()

0 commit comments

Comments
 (0)