@@ -21,7 +21,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
21
21
auto padding = args[1 ].unwrapToIntList ().vec ();
22
22
int64_t padSize = padding.size ();
23
23
auto value = args[2 ].unwrapToScalar ().to <float >();
24
-
24
+ at::Tensor value_tensor = torch::tensor (value, util::TRTDataTypeToScalarType (in->getType ()));
25
+ auto valueTensor = tensor_to_const (ctx, value_tensor);
25
26
TORCHTRT_CHECK (padSize % 2 == 0 , " Length of pad must be even but instead it equals " << padSize);
26
27
27
28
int64_t l_pad = padSize / 2 ;
@@ -55,10 +56,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
55
56
auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
56
57
auto shape_gather_out = ctx->net ->addShape (*left_gather_out)->getOutput (0 );
57
58
fill_layer->setInput (0 , *shape_gather_out);
58
- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
59
- auto valueTensor = tensor_to_const (ctx, value_tensor);
60
59
fill_layer->setInput (1 , *valueTensor);
61
- at::Tensor delta_tensor = torch::zeros (inRank);
60
+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
62
61
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
63
62
fill_layer->setInput (2 , *deltaTensor);
64
63
auto padTensor = fill_layer->getOutput (0 );
@@ -69,10 +68,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
69
68
} else {
70
69
inDims.d [axis] = padding[padding_index];
71
70
auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
72
- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
73
- auto valueTensor = tensor_to_const (ctx, value_tensor);
74
71
fill_layer->setInput (1 , *valueTensor);
75
- at::Tensor delta_tensor = torch::zeros (inRank);
72
+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
76
73
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
77
74
fill_layer->setInput (2 , *deltaTensor);
78
75
auto padTensor = fill_layer->getOutput (0 );
@@ -112,10 +109,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
112
109
auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
113
110
auto shape_gather_out = ctx->net ->addShape (*right_gather_out)->getOutput (0 );
114
111
fill_layer->setInput (0 , *shape_gather_out);
115
- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
116
- auto valueTensor = tensor_to_const (ctx, value_tensor);
117
112
fill_layer->setInput (1 , *valueTensor);
118
- at::Tensor delta_tensor = torch::zeros (inRank);
113
+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
119
114
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
120
115
fill_layer->setInput (2 , *deltaTensor);
121
116
auto padTensor = fill_layer->getOutput (0 );
@@ -126,10 +121,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
126
121
} else {
127
122
inDims.d [axis] = padding[padding_index + 1 ];
128
123
auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
129
- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
130
- auto valueTensor = tensor_to_const (ctx, value_tensor);
131
124
fill_layer->setInput (1 , *valueTensor);
132
- at::Tensor delta_tensor = torch::zeros (inRank);
125
+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
133
126
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
134
127
fill_layer->setInput (2 , *deltaTensor);
135
128
auto padTensor = fill_layer->getOutput (0 );
0 commit comments