Skip to content

Commit a35a9ec

Browse files
committed
feat: Automatically generate QDP plugins
1 parent 43831dc commit a35a9ec

File tree

8 files changed

+563
-5
lines changed

8 files changed

+563
-5
lines changed
+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""
2+
.. _auto_generate_converters:
3+
4+
Automatically Generate a Plugin for a Custom Kernel
5+
===================================================================
6+
7+
We are going to demonstrate how to automatically generate a plugin for a custom kernel using Torch-TensorRT using
8+
the new Python based plugin system in TensorRT 10.7.
9+
10+
Torch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT
11+
does not know how to compile them in TensorRT. However, this comes at the cost of a graph break and will reduce the performance of the model.
12+
The easiest way to fix lack of support for ops is by adding a decomposition (see:
13+
`Writing lowering passes for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html>`_) - which defines the operator
14+
in terms of PyTorch ops that are supported in Torch-TensorRT or a converter (see:
15+
`Writing converters for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/dynamo_converters.html>`_) - which defines the operator in terms of TensorRT operators.
16+
17+
In some cases there isn't a great way to do either of these, perhaps because the operator is a custom kernel that is not part of standard PyTorch or
18+
TensorRT cannot support it natively.
19+
20+
For these cases, it is possible to use a TensorRT plugin to replace the operator **inside** the TensorRT engine, thereby avoiding
21+
the performance and resource overhead from a graph break.
22+
23+
Previously this involved a complex process in not only building a performant kernel but setting it up to run in TensorRT (see: `Using Custom Kernels within TensorRT Engines with Torch-TensorRT <https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/custom_kernel_plugins.html>`_).
24+
With TensorRT 10.7, there is a new Python native plugin system which greatly streamlines this process. This
25+
plugin system also allows Torch-TensorRT to automatically generate the necessary conversion code to convert the
26+
operation in PyTorch to TensorRT.
27+
"""
28+
29+
# %%
30+
# Writing Custom Operators in PyTorch
31+
# -----------------------------------------
32+
#
33+
# Pervious tutorials already cover creating custom operators in PyTorch which later get used with Torch-TensorRT.
34+
# Here we define a simple elementwise multiplication operator in Triton. This operator is then registered as a custom op in PyTorch.
35+
# with its host launch code as well as a "meta-kernel", A meta-kernel is a function that describes the shape and data type
36+
# transformations that the operator will perform. This meta-kernel is used by Dynamo and Torch-TensorRT, so it
37+
# is necessary to define.
38+
#
39+
40+
from typing import Tuple
41+
42+
import tensorrt_bindings.plugin as trtp
43+
import torch
44+
import torch_tensorrt
45+
import triton
46+
import triton.language as tl
47+
48+
49+
@triton.jit
50+
def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr):
51+
pid = tl.program_id(0)
52+
# Compute the range of elements that this thread block will work on
53+
block_start = pid * BLOCK_SIZE
54+
# Range of indices this thread will handle
55+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
56+
# Load elements from the X and Y tensors
57+
x_vals = tl.load(X + offsets)
58+
y_vals = tl.load(Y + offsets)
59+
# Perform the element-wise multiplication
60+
z_vals = x_vals * y_vals * a + b
61+
# Store the result in Z
62+
tl.store(Z + offsets, z_vals)
63+
64+
65+
@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc]
66+
def elementwise_scale_mul(
67+
X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
68+
) -> torch.Tensor:
69+
# Ensure the tensors are on the GPU
70+
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
71+
assert X.shape == Y.shape, "Tensors must have the same shape."
72+
73+
# Create output tensor
74+
Z = torch.empty_like(X)
75+
76+
# Define block size
77+
BLOCK_SIZE = 1024
78+
79+
# Grid of programs
80+
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)
81+
82+
# Launch the kernel with parameters a and b
83+
elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE)
84+
85+
return Z
86+
87+
88+
# %%
89+
# The meta kernel for an elementwise operation is just the shape and dtype of one of the inputs since we will not change the shape
90+
# in the course of the operation.
91+
92+
93+
@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul")
94+
def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
95+
return x
96+
97+
98+
# %%
99+
# Here we use automatic plugin creation feature in Torch-TensorRT which enables plugin registration using
100+
# TensorRT QDP APIs
101+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
102+
"torchtrt_ex::elementwise_scale_mul"
103+
)
104+
105+
106+
# # %%
107+
# # Generating the Converter
108+
# # -------------------------------------------------------------------
109+
# # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.
110+
# # As long as the namespace and names match, the following function will automatically generate the converter for the operation.
111+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
112+
"torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True
113+
)
114+
115+
116+
# # %%
117+
# # Above two commands can be replaced with the following single one line:
118+
# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True)
119+
120+
121+
# %%
122+
# Using our converter with a model
123+
# -------------------------------------------------------------------
124+
#
125+
# Now we can use our custom operator in a model and compile it with Torch-TensorRT.
126+
# We can see that the custom operator is used as one of the operations in the forward pass of the model.
127+
# The process of compiling the model at this point is identical to standard Torch-TensorRT usage.
128+
class MyModel(torch.nn.Module): # type: ignore[misc]
129+
def __init__(self):
130+
super().__init__()
131+
132+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
133+
z = torch.add(x, y)
134+
res = torch.ops.torchtrt_ex.elementwise_scale_mul.default(x, z, b=0.5)
135+
136+
return res
137+
138+
139+
my_model = MyModel().to("cuda")
140+
m = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float)
141+
n = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float)
142+
143+
with torch_tensorrt.logging.errors():
144+
model_trt = torch_tensorrt.compile(
145+
my_model, inputs=[m, n], debug=True, min_block_size=1
146+
)
147+
for i in range(300):
148+
res = model_trt(m, n)
149+
assert torch.allclose(res, my_model(m, n))
150+
151+
print("Ran with custom plugin!")
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from torch_tensorrt.dynamo.conversion.plugins._custom_op import custom_op
2+
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin
13
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import (
24
generate_plugin_converter,
35
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Callable, Optional
2+
3+
from torch.fx.node import Node
4+
from torch_tensorrt.dynamo._settings import CompilationSettings
5+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority
6+
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin
7+
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import (
8+
generate_plugin_converter,
9+
)
10+
11+
12+
def custom_op(
13+
op_name: str,
14+
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,
15+
priority: ConverterPriority = ConverterPriority.STANDARD,
16+
supports_dynamic_shapes: bool = False,
17+
):
18+
generate_plugin(op_name)
19+
generate_plugin_converter(
20+
op_name, capability_validator, priority, supports_dynamic_shapes
21+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import logging
2+
from types import FunctionType
3+
from typing import Tuple
4+
5+
import tensorrt_bindings.plugin as trtp
6+
import torch
7+
from sympy import lambdify
8+
from torch._dynamo.source import LocalSource
9+
from torch._subclasses.fake_tensor import FakeTensorMode
10+
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
11+
12+
_LOGGER: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
def mksym(shape_env, value, source, dynamic_dim):
16+
return shape_env.create_symintnode(
17+
shape_env.create_symbol(
18+
value,
19+
source=source,
20+
dynamic_dim=dynamic_dim,
21+
),
22+
hint=value,
23+
source=source,
24+
)
25+
26+
27+
def generate_plugin(plugin_name: str):
28+
namespace, name = plugin_name.split("::")
29+
30+
# retrieve the corresponding torch operation using the passed in string
31+
torch_op = getattr(getattr(torch.ops, namespace), name)
32+
33+
# helper function that generates the required signature based on the torch operation
34+
def generate_signature(torch_op):
35+
schema = torch_op._schemas[""]
36+
37+
arg_list = []
38+
39+
register_func_annotation = {}
40+
impl_func_annotation = {}
41+
42+
for arg in schema.arguments:
43+
arg_list.append(arg.name)
44+
45+
# TODO: Torch types need to be converted to python primitive types here
46+
# Some other types are not handled:
47+
# - torch._C.ListType.ofT(<type>)
48+
# - torch._C.TupleType.get()
49+
# - torch._C.DictType.get(<key_type>, <value_type>)
50+
# - torch._C.OptionalType.ofT(<type>)
51+
# - torch._C.DeviceObjType.get()
52+
# - torch._C.FunctionType.get()
53+
# - torch._C.ClassType
54+
55+
if arg.type.isSubtypeOf(torch._C.TensorType.get()):
56+
register_func_annotation[arg.name] = trtp.TensorDesc
57+
impl_func_annotation[arg.name] = trtp.Tensor
58+
elif arg.type.isSubtypeOf(torch._C.FloatType.get()):
59+
register_func_annotation[arg.name] = float
60+
impl_func_annotation[arg.name] = float
61+
elif arg.type.isSubtypeOf(torch._C.IntType.get()):
62+
register_func_annotation[arg.name] = int
63+
impl_func_annotation[arg.name] = int
64+
elif arg.type.isSubtypeOf(torch._C.Booltype.get()):
65+
register_func_annotation[arg.name] = bool
66+
impl_func_annotation[arg.name] = bool
67+
elif arg.type.isSubtypeOf(torch._C.Stringtype.get()):
68+
register_func_annotation[arg.name] = str
69+
impl_func_annotation[arg.name] = str
70+
else:
71+
raise ValueError("arg type is not handled")
72+
73+
input_signature = ", ".join(arg_list)
74+
75+
plugin_signature = f"def add_plugin_desc({input_signature}):"
76+
77+
plugin_impl_arg_list = arg_list
78+
plugin_impl_arg_list.append("outputs")
79+
plugin_impl_arg_list.append("stream")
80+
plugin_impl_input = ", ".join(plugin_impl_arg_list)
81+
plugin_impl_signature = f"def add_plugin_impl({plugin_impl_input}):"
82+
83+
register_func_annotation["return"] = Tuple[trtp.TensorDesc]
84+
85+
impl_func_annotation["outputs"] = Tuple[trtp.Tensor]
86+
impl_func_annotation["stream"] = int
87+
88+
return (
89+
input_signature,
90+
plugin_signature,
91+
plugin_impl_signature,
92+
register_func_annotation,
93+
impl_func_annotation,
94+
)
95+
96+
# Use the helper function to get the required signatures
97+
(
98+
input_signature,
99+
plugin_signature,
100+
plugin_impl_signature,
101+
register_func_annotation,
102+
impl_func_annotation,
103+
) = generate_signature(torch_op)
104+
105+
def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
106+
shape_env = ShapeEnv()
107+
fake_mode = FakeTensorMode(shape_env=shape_env)
108+
syms_args = []
109+
tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)]
110+
111+
for tensor_arg in tensor_args:
112+
113+
sample = {f"{i}": 5 for i in range(tensor_arg.ndim)}
114+
syms_arg = [
115+
mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC)
116+
for k, v in sample.items()
117+
]
118+
syms_args.append(syms_arg)
119+
120+
with FakeTensorMode() as fake_mode:
121+
fake_args = []
122+
for syms_arg in syms_args:
123+
fake_arg = torch.randn(syms_arg)
124+
fake_args.append(fake_arg)
125+
126+
output = torch_op(*fake_args, **kwargs)
127+
128+
# We assume that number of dimensions are the same in torch op
129+
shape_calc_fns = [None] * args[0].ndim
130+
for i in range(args[0].ndim):
131+
input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args]
132+
shape_calc_fns[i] = lambdify(
133+
tuple(input_node_expr), output.shape[i].node.expr, "math"
134+
)
135+
136+
out_desc = tensor_args[0].like()
137+
for i in range(out_desc.ndim):
138+
input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args]
139+
out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr)
140+
141+
return (out_desc,)
142+
143+
codegen_plugin = f"""
144+
{plugin_signature}
145+
return _generic_plugin_desc({input_signature})
146+
"""
147+
148+
_LOGGER.warning(f"Plugin registration function: \n{codegen_plugin}")
149+
150+
plugin_code = compile(codegen_plugin, "<string>", "exec")
151+
152+
globals()["_generic_plugin_desc"] = _generic_plugin_desc
153+
154+
plugin = FunctionType(
155+
plugin_code.co_consts[0],
156+
globals(),
157+
"plugin",
158+
)
159+
160+
# Function annotation is required for dynamic function to work in TensorRT.Plugin
161+
plugin.__annotations__ = register_func_annotation
162+
163+
trtp.register(plugin_name)(plugin)
164+
165+
def _generic_plugin_impl(outputs, stream, *args, **kwargs):
166+
tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)]
167+
print(args)
168+
non_tensor_args = [elem for elem in args if not isinstance(elem, trtp.Tensor)]
169+
in_tensors = [torch.as_tensor(i, device="cuda") for i in tensor_args]
170+
171+
dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs]
172+
173+
stream = torch.cuda.ExternalStream(stream)
174+
with torch.cuda.stream(stream):
175+
out_tensors = torch_op(*in_tensors, *non_tensor_args, **kwargs)
176+
if isinstance(out_tensors, torch.Tensor):
177+
out_tensors = (out_tensors,)
178+
[d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)]
179+
180+
plugin_impl_func = f"""
181+
{plugin_impl_signature}
182+
_generic_plugin_impl(outputs, stream, {input_signature})
183+
"""
184+
185+
_LOGGER.warning(f"Plugin implementation function: \n{plugin_impl_func}")
186+
187+
plugin_impl_code = compile(plugin_impl_func, "<string>", "exec")
188+
189+
globals()["_generic_plugin_impl"] = _generic_plugin_impl
190+
191+
plugin_impl = FunctionType(plugin_impl_code.co_consts[0], globals(), "plugin_impl")
192+
193+
plugin_impl.__annotations__ = impl_func_annotation
194+
195+
trtp.impl(plugin_name)(plugin_impl)

0 commit comments

Comments
 (0)