@@ -198,7 +198,7 @@ auto expand_registrations TRTORCH_UNUSED =
198
198
RegisterNodeConversionPatterns ()
199
199
.pattern({" aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))" ,
200
200
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
201
- auto in = args[0 ].ITensor ( );
201
+ auto in = args[0 ].ITensorOrFreeze (ctx );
202
202
auto input_dims = in->getDimensions ();
203
203
auto expanded_size = args[1 ].unwrapToIntList ();
204
204
auto expandedDims = util::toDims (expanded_size);
@@ -213,9 +213,9 @@ auto expand_registrations TRTORCH_UNUSED =
213
213
}})
214
214
.pattern({" aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))" ,
215
215
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
216
- auto in = args[0 ].ITensor ( );
216
+ auto in = args[0 ].ITensorOrFreeze (ctx );
217
217
auto input_dims = in->getDimensions ();
218
- auto targetTensor = args[1 ].ITensor ( );
218
+ auto targetTensor = args[1 ].ITensorOrFreeze (ctx );
219
219
auto targetDims = targetTensor->getDimensions ();
220
220
LOG_DEBUG (" (expand_as layer) Expand input from " << input_dims << " to " << targetDims);
221
221
if (ctx->input_is_dynamic ) {
@@ -227,7 +227,7 @@ auto expand_registrations TRTORCH_UNUSED =
227
227
}})
228
228
.pattern({" aten::repeat(Tensor self, int[] repeats) -> (Tensor)" ,
229
229
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
230
- auto in = args[0 ].ITensor ( );
230
+ auto in = args[0 ].ITensorOrFreeze (ctx );
231
231
auto input_dims = in->getDimensions ();
232
232
auto repeats = args[1 ].unwrapToIntList ().vec ();
233
233
int repeats_rank = repeats.size ();
0 commit comments