Skip to content

Commit 5381375

Browse files
committed
[refactoring/test] Refactor some code and add test cases
This commit refators some code which fixes some bugs we had previously. Test cases are added.
1 parent e809c83 commit 5381375

File tree

4 files changed

+220
-35
lines changed

4 files changed

+220
-35
lines changed

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py

+27-29
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,8 @@ def generate_plugin(plugin_name: str):
3333
# helper function that generates the required signature based on the torch operation
3434
def generate_signature(torch_op):
3535
schema = torch_op._schemas[""]
36-
tensor_args = []
37-
arg_list = []
3836

39-
args = []
40-
kwargs = []
37+
arg_list = []
4138

4239
register_func_annotation = {}
4340
impl_func_annotation = {}
@@ -56,7 +53,6 @@ def generate_signature(torch_op):
5653
# - torch._C.ClassType
5754

5855
if arg.type.isSubtypeOf(torch._C.TensorType.get()):
59-
tensor_args.append(arg)
6056
register_func_annotation[arg.name] = trtp.TensorDesc
6157
impl_func_annotation[arg.name] = trtp.Tensor
6258
elif arg.type.isSubtypeOf(torch._C.FloatType.get()):
@@ -74,40 +70,32 @@ def generate_signature(torch_op):
7470
else:
7571
raise ValueError("arg type is not handled")
7672

77-
if arg.default_value is None:
78-
args.append(arg.name)
79-
else:
80-
kwargs.append(f"{arg.name} = {arg.default_value}")
81-
8273
input_signature = ", ".join(arg_list)
74+
8375
plugin_signature = f"def add_plugin_desc({input_signature}):"
84-
args_input = ", ".join(args)
85-
kwargs_input = ", ".join(kwargs)
8676

8777
plugin_impl_arg_list = arg_list
8878
plugin_impl_arg_list.append("outputs")
8979
plugin_impl_arg_list.append("stream")
9080
plugin_impl_input = ", ".join(plugin_impl_arg_list)
91-
plugin_impl_signagture = f"def add_plugin_impl({plugin_impl_input}):"
81+
plugin_impl_signature = f"def add_plugin_impl({plugin_impl_input}):"
9282

9383
register_func_annotation["return"] = Tuple[trtp.TensorDesc]
9484

9585
impl_func_annotation["outputs"] = Tuple[trtp.Tensor]
9686
impl_func_annotation["stream"] = int
9787

9888
return (
99-
args_input,
100-
kwargs_input,
89+
input_signature,
10190
plugin_signature,
102-
plugin_impl_signagture,
91+
plugin_impl_signature,
10392
register_func_annotation,
10493
impl_func_annotation,
10594
)
10695

10796
# Use the helper function to get the required signatures
10897
(
109-
args_input,
110-
kwargs_input,
98+
input_signature,
11199
plugin_signature,
112100
plugin_impl_signature,
113101
register_func_annotation,
@@ -118,8 +106,11 @@ def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
118106
shape_env = ShapeEnv()
119107
fake_mode = FakeTensorMode(shape_env=shape_env)
120108
syms_args = []
121-
for arg in args:
122-
sample = {f"{i}": 5 for i in range(arg.ndim)}
109+
tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)]
110+
111+
for tensor_arg in tensor_args:
112+
113+
sample = {f"{i}": 5 for i in range(tensor_arg.ndim)}
123114
syms_arg = [
124115
mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC)
125116
for k, v in sample.items()
@@ -142,16 +133,16 @@ def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
142133
tuple(input_node_expr), output.shape[i].node.expr, "math"
143134
)
144135

145-
out_desc = args[0].like()
136+
out_desc = tensor_args[0].like()
146137
for i in range(out_desc.ndim):
147-
input_shape_expr = [arg.shape_expr[i] for arg in args]
138+
input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args]
148139
out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr)
149140

150141
return (out_desc,)
151142

152143
codegen_plugin = f"""
153144
{plugin_signature}
154-
return _generic_plugin_desc({args_input}, {kwargs_input})
145+
return _generic_plugin_desc({input_signature})
155146
"""
156147

157148
_LOGGER.warning(f"Plugin registration function: \n{codegen_plugin}")
@@ -160,26 +151,35 @@ def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
160151

161152
globals()["_generic_plugin_desc"] = _generic_plugin_desc
162153

163-
plugin = FunctionType(plugin_code.co_consts[0], globals(), "plugin")
154+
plugin = FunctionType(
155+
plugin_code.co_consts[0],
156+
globals(),
157+
"plugin",
158+
)
164159

165160
# Function annotation is required for dynamic function to work in TensorRT.Plugin
166161
plugin.__annotations__ = register_func_annotation
167162

168163
trtp.register(plugin_name)(plugin)
169164

170165
def _generic_plugin_impl(outputs, stream, *args, **kwargs):
171-
in_tensors = [torch.as_tensor(i, device="cuda") for i in args]
166+
tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)]
167+
print(args)
168+
non_tensor_args = [elem for elem in args if not isinstance(elem, trtp.Tensor)]
169+
in_tensors = [torch.as_tensor(i, device="cuda") for i in tensor_args]
172170

173171
dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs]
174172

175173
stream = torch.cuda.ExternalStream(stream)
176174
with torch.cuda.stream(stream):
177-
out_tensors = torch_op(*in_tensors, **kwargs)
175+
out_tensors = torch_op(*in_tensors, *non_tensor_args, **kwargs)
176+
if isinstance(out_tensors, torch.Tensor):
177+
out_tensors = (out_tensors,)
178178
[d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)]
179179

180180
plugin_impl_func = f"""
181181
{plugin_impl_signature}
182-
_generic_plugin_impl(outputs, stream, {args_input}, {kwargs_input})
182+
_generic_plugin_impl(outputs, stream, {input_signature})
183183
"""
184184

185185
_LOGGER.warning(f"Plugin implementation function: \n{plugin_impl_func}")
@@ -193,5 +193,3 @@ def _generic_plugin_impl(outputs, stream, *args, **kwargs):
193193
plugin_impl.__annotations__ = impl_func_annotation
194194

195195
trtp.impl(plugin_name)(plugin_impl)
196-
197-
return plugin

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
33

44
import numpy as np
5+
import tensorrt as trt
56

67
# Seems like a bug in TensorRT
78
import tensorrt_bindings.plugin as trtp
@@ -18,8 +19,6 @@
1819
)
1920
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
2021

21-
import tensorrt as trt
22-
2322
_LOGGER: logging.Logger = logging.getLogger(__name__)
2423

2524

@@ -58,13 +57,25 @@ def custom_kernel_converter(
5857
# Assuming TensorRT preserves kwargs order like PyTorch does
5958
non_tensor_inputs = plugin.input_attrs
6059

60+
kwargs = {}
61+
62+
for arg in torch_schema.arguments:
63+
if arg.default_value is not None:
64+
kwargs[arg.name] = arg.default_value
65+
6166
non_tensor_args = args[len(tensor_inputs) :]
6267
non_tensor_kwargs = dict(zip(list(non_tensor_inputs.keys()), non_tensor_args))
63-
for k, v in non_tensor_kwargs.items():
68+
69+
for k, v in kwargs.items():
70+
if k in non_tensor_kwargs:
71+
kwargs[k] = non_tensor_kwargs[k]
72+
73+
for k, v in kwargs.items():
6474
if isinstance(v, torch.fx.immutable_collections.immutable_list):
65-
non_tensor_kwargs[k] = np.array(v)
75+
kwargs[k] = np.array(v)
76+
6677

67-
layer = ctx.net.add_plugin(plugin(*itensor_args, **non_tensor_kwargs))
78+
layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs))
6879
assert layer, f"{namespace}::{name} plugin layer was not able to be created"
6980
_LOGGER.debug(
7081
f"Adding generated plugin for {namespace}::{name} to tensorrt network"
@@ -91,7 +102,7 @@ def generate_plugin_converter(
91102
supports_dynamic_shapes: bool = False,
92103
) -> DynamoConverterImplSignature:
93104
plugin_ns, plugin_name = plugin_id.split("::")
94-
return _generate_plugin_converter(
105+
return _generate_plugin_converter(
95106
plugin_ns,
96107
plugin_name,
97108
capability_validator=capability_validator,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Tuple
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch_tensorrt
6+
import triton
7+
import triton.language as tl
8+
from parameterized import parameterized
9+
from torch.testing._internal.common_utils import run_tests
10+
11+
from .harness import DispatchTestCase
12+
13+
14+
@triton.jit
15+
def elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr):
16+
# Program ID determines the block of data each thread will process
17+
pid = tl.program_id(0)
18+
# Compute the range of elements that this thread block will work on
19+
block_start = pid * BLOCK_SIZE
20+
# Range of indices this thread will handle
21+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
22+
# Load elements from the X and Y tensors
23+
x_vals = tl.load(X + offsets)
24+
y_vals = tl.load(Y + offsets)
25+
# Perform the element-wise multiplication
26+
z_vals = x_vals * y_vals
27+
# Store the result in Z
28+
tl.store(Z + offsets, z_vals)
29+
30+
31+
@torch.library.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) # type: ignore[misc]
32+
def elementwise_mul(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
33+
# Ensure the tensors are on the GPU
34+
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
35+
assert X.shape == Y.shape, "Tensors must have the same shape."
36+
37+
# Create output tensor
38+
Z = torch.empty_like(X)
39+
40+
# Define block size
41+
BLOCK_SIZE = 1024
42+
43+
# Grid of programs
44+
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
45+
46+
# Launch the kernel
47+
elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE)
48+
49+
return Z
50+
51+
52+
@torch.library.register_fake("torchtrt_ex::elementwise_mul")
53+
def elementwise_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
54+
return x
55+
56+
57+
torch_tensorrt.dynamo.conversion.plugins.custom_op(
58+
"torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True
59+
)
60+
61+
62+
class TestAutomaticPlugin(DispatchTestCase):
63+
@parameterized.expand(
64+
[
65+
((64, 64), torch.float),
66+
]
67+
)
68+
def test_mul_plugin_float(self, input_shape, dtype):
69+
class elementwise_mul(nn.Module):
70+
def forward(self, lhs, rhs):
71+
return torch.ops.torchtrt_ex.elementwise_mul.default(lhs, rhs)
72+
73+
inputs = [
74+
torch.randint(0, 5, input_shape, device="cuda", dtype=dtype),
75+
torch.randint(0, 5, input_shape, device="cuda", dtype=dtype),
76+
]
77+
78+
self.run_test(elementwise_mul(), inputs)
79+
80+
81+
if __name__ == "__main__":
82+
run_tests()
83+
84+
# Example Usage
85+
# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
86+
# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float)
87+
88+
# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B)
89+
90+
# print("C (Addition):", C)
91+
# print("D (Multiplication):", D)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import Tuple
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch_tensorrt
6+
import triton
7+
import triton.language as tl
8+
from parameterized import parameterized
9+
from torch.testing._internal.common_utils import run_tests
10+
11+
from .harness import DispatchTestCase
12+
13+
14+
@triton.jit
15+
def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr):
16+
pid = tl.program_id(0)
17+
# Compute the range of elements that this thread block will work on
18+
block_start = pid * BLOCK_SIZE
19+
# Range of indices this thread will handle
20+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
21+
# Load elements from the X and Y tensors
22+
x_vals = tl.load(X + offsets)
23+
y_vals = tl.load(Y + offsets)
24+
# Perform the element-wise multiplication
25+
z_vals = x_vals * y_vals * a + b
26+
# Store the result in Z
27+
tl.store(Z + offsets, z_vals)
28+
29+
30+
@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc]
31+
def elementwise_scale_mul(
32+
X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
33+
) -> torch.Tensor:
34+
# Ensure the tensors are on the GPU
35+
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
36+
assert X.shape == Y.shape, "Tensors must have the same shape."
37+
38+
# Create output tensor
39+
Z = torch.empty_like(X)
40+
41+
# Define block size
42+
BLOCK_SIZE = 1024
43+
44+
# Grid of programs
45+
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
46+
47+
# Launch the kernel with parameters a and b
48+
elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE)
49+
50+
return Z
51+
52+
53+
@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul")
54+
def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
55+
return x
56+
57+
58+
torch_tensorrt.dynamo.conversion.plugins.custom_op(
59+
"torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True
60+
)
61+
62+
63+
class TestAutomaticPlugin(DispatchTestCase):
64+
@parameterized.expand(
65+
[
66+
((64, 64), torch.float),
67+
]
68+
)
69+
def test_scale_mul_plugin_float(self, input_shape, dtype):
70+
class elementwise_scale_mul(nn.Module):
71+
def forward(self, lhs, rhs):
72+
return torch.ops.torchtrt_ex.elementwise_scale_mul.default(
73+
lhs, rhs, b=1, a=0
74+
)
75+
76+
inputs = [
77+
torch.randint(0, 5, input_shape, device="cuda", dtype=dtype),
78+
torch.randint(0, 5, input_shape, device="cuda", dtype=dtype),
79+
]
80+
81+
self.run_test(elementwise_scale_mul(), inputs)
82+
83+
84+
if __name__ == "__main__":
85+
run_tests()

0 commit comments

Comments
 (0)