Skip to content

Commit a4882c6

Browse files
committed
fix: support expand/repeat with IValue type input
Signed-off-by: inocsin <[email protected]>
1 parent aec4e1a commit a4882c6

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

Diff for: core/conversion/converters/impl/expand.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ auto expand_registrations TRTORCH_UNUSED =
198198
RegisterNodeConversionPatterns()
199199
.pattern({"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
200200
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
201-
auto in = args[0].ITensor();
201+
auto in = args[0].ITensorOrFreeze(ctx);
202202
auto input_dims = in->getDimensions();
203203
auto expanded_size = args[1].unwrapToIntList();
204204
auto expandedDims = util::toDims(expanded_size);
@@ -213,9 +213,9 @@ auto expand_registrations TRTORCH_UNUSED =
213213
}})
214214
.pattern({"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
215215
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
216-
auto in = args[0].ITensor();
216+
auto in = args[0].ITensorOrFreeze(ctx);
217217
auto input_dims = in->getDimensions();
218-
auto targetTensor = args[1].ITensor();
218+
auto targetTensor = args[1].ITensorOrFreeze(ctx);
219219
auto targetDims = targetTensor->getDimensions();
220220
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
221221
if (ctx->input_is_dynamic) {
@@ -227,7 +227,7 @@ auto expand_registrations TRTORCH_UNUSED =
227227
}})
228228
.pattern({"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
229229
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
230-
auto in = args[0].ITensor();
230+
auto in = args[0].ITensorOrFreeze(ctx);
231231
auto input_dims = in->getDimensions();
232232
auto repeats = args[1].unwrapToIntList().vec();
233233
int repeats_rank = repeats.size();

0 commit comments

Comments
 (0)