|
| 1 | +import logging |
| 2 | +from types import FunctionType |
| 3 | +from typing import Tuple |
| 4 | + |
| 5 | +import tensorrt_bindings.plugin as trtp |
| 6 | +import torch |
| 7 | +from sympy import lambdify |
| 8 | +from torch._dynamo.source import LocalSource |
| 9 | +from torch._subclasses.fake_tensor import FakeTensorMode |
| 10 | +from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv |
| 11 | + |
| 12 | +_LOGGER: logging.Logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +def mksym(shape_env, value, source, dynamic_dim): |
| 16 | + return shape_env.create_symintnode( |
| 17 | + shape_env.create_symbol( |
| 18 | + value, |
| 19 | + source=source, |
| 20 | + dynamic_dim=dynamic_dim, |
| 21 | + ), |
| 22 | + hint=value, |
| 23 | + source=source, |
| 24 | + ) |
| 25 | + |
| 26 | + |
| 27 | +def generate_plugin(plugin_name: str): |
| 28 | + namespace, name = plugin_name.split("::") |
| 29 | + |
| 30 | + # retrieve the corresponding torch operation using the passed in string |
| 31 | + torch_op = getattr(getattr(torch.ops, namespace), name) |
| 32 | + |
| 33 | + # helper function that generates the required signature based on the torch operation |
| 34 | + def generate_signature(torch_op): |
| 35 | + schema = torch_op._schemas[""] |
| 36 | + |
| 37 | + arg_list = [] |
| 38 | + |
| 39 | + register_func_annotation = {} |
| 40 | + impl_func_annotation = {} |
| 41 | + |
| 42 | + for arg in schema.arguments: |
| 43 | + arg_list.append(arg.name) |
| 44 | + |
| 45 | + # TODO: Torch types need to be converted to python primitive types here |
| 46 | + # Some other types are not handled: |
| 47 | + # - torch._C.ListType.ofT(<type>) |
| 48 | + # - torch._C.TupleType.get() |
| 49 | + # - torch._C.DictType.get(<key_type>, <value_type>) |
| 50 | + # - torch._C.OptionalType.ofT(<type>) |
| 51 | + # - torch._C.DeviceObjType.get() |
| 52 | + # - torch._C.FunctionType.get() |
| 53 | + # - torch._C.ClassType |
| 54 | + |
| 55 | + if arg.type.isSubtypeOf(torch._C.TensorType.get()): |
| 56 | + register_func_annotation[arg.name] = trtp.TensorDesc |
| 57 | + impl_func_annotation[arg.name] = trtp.Tensor |
| 58 | + elif arg.type.isSubtypeOf(torch._C.FloatType.get()): |
| 59 | + register_func_annotation[arg.name] = float |
| 60 | + impl_func_annotation[arg.name] = float |
| 61 | + elif arg.type.isSubtypeOf(torch._C.IntType.get()): |
| 62 | + register_func_annotation[arg.name] = int |
| 63 | + impl_func_annotation[arg.name] = int |
| 64 | + elif arg.type.isSubtypeOf(torch._C.Booltype.get()): |
| 65 | + register_func_annotation[arg.name] = bool |
| 66 | + impl_func_annotation[arg.name] = bool |
| 67 | + elif arg.type.isSubtypeOf(torch._C.Stringtype.get()): |
| 68 | + register_func_annotation[arg.name] = str |
| 69 | + impl_func_annotation[arg.name] = str |
| 70 | + else: |
| 71 | + raise ValueError("arg type is not handled") |
| 72 | + |
| 73 | + input_signature = ", ".join(arg_list) |
| 74 | + |
| 75 | + plugin_signature = f"def add_plugin_desc({input_signature}):" |
| 76 | + |
| 77 | + plugin_impl_arg_list = arg_list |
| 78 | + plugin_impl_arg_list.append("outputs") |
| 79 | + plugin_impl_arg_list.append("stream") |
| 80 | + plugin_impl_input = ", ".join(plugin_impl_arg_list) |
| 81 | + plugin_impl_signature = f"def add_plugin_impl({plugin_impl_input}):" |
| 82 | + |
| 83 | + register_func_annotation["return"] = Tuple[trtp.TensorDesc] |
| 84 | + |
| 85 | + impl_func_annotation["outputs"] = Tuple[trtp.Tensor] |
| 86 | + impl_func_annotation["stream"] = int |
| 87 | + |
| 88 | + return ( |
| 89 | + input_signature, |
| 90 | + plugin_signature, |
| 91 | + plugin_impl_signature, |
| 92 | + register_func_annotation, |
| 93 | + impl_func_annotation, |
| 94 | + ) |
| 95 | + |
| 96 | + # Use the helper function to get the required signatures |
| 97 | + ( |
| 98 | + input_signature, |
| 99 | + plugin_signature, |
| 100 | + plugin_impl_signature, |
| 101 | + register_func_annotation, |
| 102 | + impl_func_annotation, |
| 103 | + ) = generate_signature(torch_op) |
| 104 | + |
| 105 | + def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]: |
| 106 | + shape_env = ShapeEnv() |
| 107 | + fake_mode = FakeTensorMode(shape_env=shape_env) |
| 108 | + syms_args = [] |
| 109 | + tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)] |
| 110 | + |
| 111 | + for tensor_arg in tensor_args: |
| 112 | + |
| 113 | + sample = {f"{i}": 5 for i in range(tensor_arg.ndim)} |
| 114 | + syms_arg = [ |
| 115 | + mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) |
| 116 | + for k, v in sample.items() |
| 117 | + ] |
| 118 | + syms_args.append(syms_arg) |
| 119 | + |
| 120 | + with FakeTensorMode() as fake_mode: |
| 121 | + fake_args = [] |
| 122 | + for syms_arg in syms_args: |
| 123 | + fake_arg = torch.randn(syms_arg) |
| 124 | + fake_args.append(fake_arg) |
| 125 | + |
| 126 | + output = torch_op(*fake_args, **kwargs) |
| 127 | + |
| 128 | + # We assume that number of dimensions are the same in torch op |
| 129 | + shape_calc_fns = [None] * args[0].ndim |
| 130 | + for i in range(args[0].ndim): |
| 131 | + input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args] |
| 132 | + shape_calc_fns[i] = lambdify( |
| 133 | + tuple(input_node_expr), output.shape[i].node.expr, "math" |
| 134 | + ) |
| 135 | + |
| 136 | + out_desc = tensor_args[0].like() |
| 137 | + for i in range(out_desc.ndim): |
| 138 | + input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args] |
| 139 | + out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) |
| 140 | + |
| 141 | + return (out_desc,) |
| 142 | + |
| 143 | + codegen_plugin = f""" |
| 144 | +{plugin_signature} |
| 145 | + return _generic_plugin_desc({input_signature}) |
| 146 | + """ |
| 147 | + |
| 148 | + _LOGGER.warning(f"Plugin registration function: \n{codegen_plugin}") |
| 149 | + |
| 150 | + plugin_code = compile(codegen_plugin, "<string>", "exec") |
| 151 | + |
| 152 | + globals()["_generic_plugin_desc"] = _generic_plugin_desc |
| 153 | + |
| 154 | + plugin = FunctionType( |
| 155 | + plugin_code.co_consts[0], |
| 156 | + globals(), |
| 157 | + "plugin", |
| 158 | + ) |
| 159 | + |
| 160 | + # Function annotation is required for dynamic function to work in TensorRT.Plugin |
| 161 | + plugin.__annotations__ = register_func_annotation |
| 162 | + |
| 163 | + trtp.register(plugin_name)(plugin) |
| 164 | + |
| 165 | + def _generic_plugin_impl(outputs, stream, *args, **kwargs): |
| 166 | + tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)] |
| 167 | + print(args) |
| 168 | + non_tensor_args = [elem for elem in args if not isinstance(elem, trtp.Tensor)] |
| 169 | + in_tensors = [torch.as_tensor(i, device="cuda") for i in tensor_args] |
| 170 | + |
| 171 | + dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] |
| 172 | + |
| 173 | + stream = torch.cuda.ExternalStream(stream) |
| 174 | + with torch.cuda.stream(stream): |
| 175 | + out_tensors = torch_op(*in_tensors, *non_tensor_args, **kwargs) |
| 176 | + if isinstance(out_tensors, torch.Tensor): |
| 177 | + out_tensors = (out_tensors,) |
| 178 | + [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] |
| 179 | + |
| 180 | + plugin_impl_func = f""" |
| 181 | +{plugin_impl_signature} |
| 182 | + _generic_plugin_impl(outputs, stream, {input_signature}) |
| 183 | + """ |
| 184 | + |
| 185 | + _LOGGER.warning(f"Plugin implementation function: \n{plugin_impl_func}") |
| 186 | + |
| 187 | + plugin_impl_code = compile(plugin_impl_func, "<string>", "exec") |
| 188 | + |
| 189 | + globals()["_generic_plugin_impl"] = _generic_plugin_impl |
| 190 | + |
| 191 | + plugin_impl = FunctionType(plugin_impl_code.co_consts[0], globals(), "plugin_impl") |
| 192 | + |
| 193 | + plugin_impl.__annotations__ = impl_func_annotation |
| 194 | + |
| 195 | + trtp.impl(plugin_name)(plugin_impl) |
0 commit comments