-
Notifications
You must be signed in to change notification settings - Fork 362
feat: rmsnorm lowering #3440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: rmsnorm lowering #3440
Conversation
py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Outdated
Show resolved
Hide resolved
shape_calc_fns = [None] * output.ndim | ||
|
||
for i in range(output.ndim): | ||
input_node_expr = input_node_expr = list( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo here?
shape_calc_fns = [None] * args[0].ndim | ||
for i in range(args[0].ndim): | ||
input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args] | ||
shape_calc_fns = [None] * output.ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like fake_mode
above was defined twice.
tensor_inputs = plugin.input_tensor_names | ||
tensor_args = args[0 : len(tensor_inputs)] | ||
|
||
random_id = random.randint(0, 10000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use UUID (as short as possible) instead of random for now
torch_tensorrt.dynamo.conversion.plugins.custom_op( | ||
"flashinfer::rmsnorm", supports_dynamic_shapes=True | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: After merge, extend this example to include an aot_impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There needs to be a modification to the plugin_converter generation that figure out if you can use aot_impl
RMSNORM lowering pass
Checklist: