Skip to content

Commit ce06f6e

Browse files
authored
Merge pull request #2006 from pytorch/swin_unet
feat: Implement dynamic shape support for floordiv, NumToTensor, layer_norm
2 parents f7b03f4 + c37eeec commit ce06f6e

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

core/conversion/converters/impl/layer_norm.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
2020

2121
/* Layer_Norm normalizes over last N dimensions.
2222
normalizaed_shape could be (C,H,W), (H,W), or (W). */
23-
auto normalized_shape = args[1].unwrapToIntList();
24-
auto normalized_shape_vec = util::toVec(util::toDims(normalized_shape));
23+
// This could be an IntList or ITensorList. We only need the size of this list.
24+
auto normalized_shape = args[1].IValue()->toList();
2525

2626
// Unwrap eps.
2727
auto eps = args[4].unwrapToDouble();
@@ -30,7 +30,7 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
3030

3131
// Set up axis_ask for E[x].
3232
uint32_t axis_mask = 0;
33-
for (size_t i = 0; i < normalized_shape_vec.size(); i++) {
33+
for (size_t i = 0; i < normalized_shape.size(); i++) {
3434
axis_mask |= 1 << (shape.size() - i - 1);
3535
}
3636
LOG_DEBUG("Axis Mask for E[x]" << std::bitset<32>(axis_mask));

core/conversion/evaluators/aten.cpp

+29-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "torch/csrc/jit/ir/ir.h"
1010
#include "torch/torch.h"
1111

12+
#include "core/conversion/converters/converter_util.h"
1213
#include "core/conversion/evaluators/eval_macros.h"
1314
#include "core/conversion/evaluators/eval_util.h"
1415
#include "core/conversion/evaluators/evaluators.h"
@@ -298,20 +299,22 @@ auto aten_registrations TORCHTRT_UNUSED =
298299
} else {
299300
auto dim = args.at(n->input(1)).unwrapToInt();
300301
if (tensor_var.isITensor()) {
301-
if (ctx->input_is_dynamic) {
302+
auto tensor = tensor_var.ITensor();
303+
auto dims = util::toVec(tensor->getDimensions());
304+
auto nbDims = tensor->getDimensions().nbDims;
305+
if (dim < 0) {
306+
dim += nbDims;
307+
}
308+
// Check if selected dimension size is -1 else return static size
309+
if (ctx->input_is_dynamic && dims[dim] == -1) {
302310
if (ctx->settings.allow_shape_tensors) {
303311
return dynamic_size_layer(ctx, n, args);
304312
} else {
305313
LOG_WARNING(
306314
"There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors");
307315
}
308316
}
309-
auto tensor = tensor_var.ITensor();
310-
auto dims = util::toVec(tensor->getDimensions());
311-
auto nbDims = tensor->getDimensions().nbDims;
312-
if (dim < 0) {
313-
dim += nbDims;
314-
}
317+
315318
return dims[dim];
316319
} else if (tensor_var.IValue()->isTensor()) {
317320
auto tensor = tensor_var.unwrapToTensor();
@@ -677,6 +680,25 @@ auto aten_registrations TORCHTRT_UNUSED =
677680
.evaluator(
678681
{c10::Symbol::fromQualString("aten::floordiv"),
679682
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
683+
// Dynamic version of aten::floordiv
684+
if (args.at(n->input(0)).isITensor()) {
685+
if (args.at(n->input(1)).IValue()->isInt()) {
686+
auto int_tensor = scalar_to_tensor(args.at(n->input(1)).IValue()->toInt());
687+
auto int_itensor = converters::tensor_to_const(ctx, int_tensor, util::node_info(n) + "_constant");
688+
auto elementwise_layer = converters::add_elementwise(
689+
ctx,
690+
nvinfer1::ElementWiseOperation::kFLOOR_DIV,
691+
args.at(n->input(0)).ITensor(),
692+
int_itensor,
693+
util::node_info(n));
694+
auto output_tensor = elementwise_layer->getOutput(0);
695+
auto tensor_holder = TensorContainer();
696+
tensor_holder.hold_tensor(output_tensor);
697+
auto output_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
698+
return output_ivalue;
699+
}
700+
}
701+
// Static version
680702
if (args.at(n->input(0)).IValue()->isInt()) {
681703
auto a = args.at(n->input(0)).unwrapToInt();
682704
auto b = args.at(n->input(1)).unwrapToInt();

core/conversion/evaluators/prim.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ auto prim_registrations =
3232
.evaluator(
3333
{torch::jit::prim::NumToTensor,
3434
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
35+
// Dynamic version receives an ITensor here so pass that as output directly.
36+
if (args.at(n->input(0)).isITensor()) {
37+
return args.at(n->input(0)).ITensor();
38+
}
3539
return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
3640
}})
3741
.evaluator(

0 commit comments

Comments
 (0)