Skip to content

Commit 40528e6

Browse files
committed
Expose IGridSampleLayer
1 parent fe0d8e0 commit 40528e6

File tree

5 files changed

+107
-0
lines changed

5 files changed

+107
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+16
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,22 @@ def aten_ops_fmod(
152152
return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
153153

154154

155+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.out)
156+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_backward.out)
157+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out)
158+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out)
159+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out)
160+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out)
161+
def aten_ops_grid(
162+
network: TRTNetwork,
163+
target: Target,
164+
args: Tuple[Argument, ...],
165+
kwargs: Dict[str, Argument],
166+
name: str,
167+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
168+
return impl.grid.grid(network, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])
169+
170+
155171
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
156172
def aten_ops_relu(
157173
network: TRTNetwork,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+26
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,32 @@
1818

1919
_LOGGER: logging.Logger = logging.getLogger(__name__)
2020

21+
#nearesr, linear, cubc
22+
class GridSamplerInterpolation:
23+
def __init__(self):
24+
self.interpolator_mode = None
25+
def __call__(self, interpolator_int):
26+
if(interpolator_int == 0) :
27+
self.interpolator_mode = trt.InterpolationMode.NEAREST
28+
elif(interpolator_int == 1) :
29+
self.interpolator_mode = trt.InterpolationMode.LINEAR
30+
elif(interpolator_int == 2) :
31+
self.interpolator_mode = trt.InterpolationMode.CUBIC
32+
return self.interpolator_mode
33+
34+
35+
#zeros, border, reflection
36+
class GridSamplerPadding:
37+
def __init__(self):
38+
self.padding_mode = None
39+
def __call__(self, padding_int):
40+
if(padding_int == 0) :
41+
self.padding_mode = trt.SampleMode.kFILL
42+
elif(padding_int == 1) :
43+
self.padding_mode = trt.SampleMode.kCLAMP
44+
elif(padding_int == 2) :
45+
self.padding_mode = trt.SampleMode.kREFLECT
46+
return self.padding_mode
2147

2248
def get_node_name(node: torch.fx.Node) -> str:
2349
# nn_module_stack preserves the call stack of pytorch nn.modules

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
conv,
88
elementwise,
99
embedding,
10+
grid,
1011
matmul,
1112
normalization,
1213
permutation,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
7+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
10+
def grid(
11+
network: TRTNetwork,
12+
target: Target,
13+
source_ir: Optional[SourceIR],
14+
name: str,
15+
input: TRTTensor,
16+
grid: TRTTensor,
17+
interpolation_mode: int,
18+
padding_mode: int,
19+
align_corners: bool,
20+
) -> TRTTensor:
21+
grid_layer = network.add_grid_sample(input, grid)
22+
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
23+
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
24+
grid_layer.align_corners = align_corners
25+
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
26+
return grid_layer.get_output(0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
from parameterized import parameterized
7+
from .harness import DispatchTestCase
8+
9+
class TestGridConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
13+
("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
14+
("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
15+
("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
16+
("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
17+
("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
18+
("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
19+
("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
20+
("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
21+
]
22+
)
23+
def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
24+
class TestModule(nn.Module):
25+
def forward(self, x):
26+
input = torch.randn(10).reshape(input_shape)
27+
grid = torch.randint(-1, 1, dim_shape)
28+
return nn.functional.grid(input, grid, interpolation, sample)
29+
30+
inputs = [torch.randn(1, 10)]
31+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})
32+
33+
34+
35+
36+
37+
38+

0 commit comments

Comments
 (0)