@@ -65,6 +65,104 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
65
65
return true ;
66
66
}})
67
67
.pattern(
68
+ {" aten::unflatten.int(Tensor self, int dim, int[] sizes) -> (Tensor)" ,
69
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
70
+ auto in = args[0 ].ITensorOrFreeze (ctx);
71
+ auto dim = args[1 ].unwrapToInt ();
72
+ auto in_shape = util::toVec (in->getDimensions ());
73
+ std::vector<int64_t > new_shape;
74
+ nvinfer1::ITensor* shape_tensor;
75
+ if (ctx->input_is_dynamic ) {
76
+ /*
77
+ * In case the dim is negative
78
+ * If the dim in negative range is larger than in_shape,
79
+ * then it should run into index out of bound error as expected
80
+ */
81
+ if (dim < 0 ) {
82
+ dim = in_shape.size () + dim;
83
+ }
84
+ std::cout << " Dynamic shape case" << std::endl;
85
+ LOG_DEBUG (" Using dynamic version of reshape layer" );
86
+ if (args[2 ].isITensorList ()) {
87
+ std::cout << " isTensorList case" << std::endl;
88
+ LOG_DEBUG (" Shape tensor is an ITensorList" );
89
+ auto expand_shape = args[2 ].unwrapToITensorList ();
90
+ auto shape_layer = ctx->net ->addShape (*in);
91
+ TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
92
+ auto shape_1d_tensor = shape_layer->getOutput (0 );
93
+
94
+ std::vector<int > before_dim_indices_vector (dim);
95
+ std::iota (before_dim_indices_vector.begin (), before_dim_indices_vector.end (), 0 );
96
+
97
+ nvinfer1::ITensor* before_dim_gather_out = nullptr ;
98
+ if (before_dim_indices_vector.size ()) {
99
+ at::Tensor before_dim_indices = torch::tensor (before_dim_indices_vector).to (torch::kI32 );
100
+ auto before_dim_indices_out = converters::tensor_to_const (ctx, before_dim_indices);
101
+ auto before_dim_gather_layer = ctx->net ->addGather (*shape_1d_tensor, *before_dim_indices_out, 0 );
102
+ TORCHTRT_CHECK (before_dim_gather_layer, " Unable to create gather layer from node: " << *n);
103
+ before_dim_gather_out = before_dim_gather_layer->getOutput (0 );
104
+ }
105
+
106
+ std::vector<int > after_dim_indices_vector (in_shape.size () - (dim + 1 ));
107
+ std::iota (after_dim_indices_vector.begin (), after_dim_indices_vector.end (), dim + 1 );
108
+
109
+ nvinfer1::ITensor* after_dim_gather_out = nullptr ;
110
+ if (after_dim_indices_vector.size ()) {
111
+ at::Tensor after_dim_indices = torch::tensor (after_dim_indices_vector).to (torch::kI32 );
112
+ auto after_dim_indices_out = converters::tensor_to_const (ctx, after_dim_indices);
113
+ auto after_dim_gather_layer = ctx->net ->addGather (*shape_1d_tensor, *after_dim_indices_out, 0 );
114
+ TORCHTRT_CHECK (after_dim_gather_layer, " Unable to create gather layer from node: " << *n);
115
+ after_dim_gather_out = after_dim_gather_layer->getOutput (0 );
116
+ }
117
+
118
+ std::vector<nvinfer1::ITensor*> shape_tensors;
119
+ if (before_dim_gather_out) {
120
+ shape_tensors.push_back (before_dim_gather_out);
121
+ }
122
+ for (auto new_shape_tensor : expand_shape) {
123
+ shape_tensors.push_back (new_shape_tensor);
124
+ }
125
+ if (after_dim_gather_out) {
126
+ shape_tensors.push_back (after_dim_gather_out);
127
+ }
128
+
129
+ auto shape_cat_layer = ctx->net ->addConcatenation (shape_tensors.data (), shape_tensors.size ());
130
+ TORCHTRT_CHECK (shape_cat_layer, " Unable to create cat layer from node: " << *n);
131
+ shape_tensor = shape_cat_layer->getOutput (0 );
132
+ LOG_DEBUG (" Shape tensor shape: " << shape_tensor->getDimensions ());
133
+ } else if (args[2 ].isIntList ()) {
134
+ auto shape_vec = args[2 ].unwrapToIntList ().vec ();
135
+ // New shape
136
+ new_shape.insert (new_shape.end (), in_shape.begin (), in_shape.begin () + dim);
137
+ new_shape.insert (new_shape.end (), shape_vec.begin (), shape_vec.end ());
138
+ new_shape.insert (new_shape.end (), in_shape.begin () + dim + 1 , in_shape.end ());
139
+
140
+ shape_tensor = tensor_to_const (ctx, torch::tensor (new_shape).to (torch::kI32 ));
141
+ } else {
142
+ LOG_ERROR (
143
+ " Invalid IValue type of " << args[2 ].IValue ()->type ()
144
+ << " detected for shape tensor from node: " << *n);
145
+ }
146
+ } else {
147
+ new_shape =
148
+ torch::unflatten (torch::rand (in_shape), dim, args[2 ].unwrapToIntList ().vec ()).sizes ().vec ();
149
+ }
150
+ auto shuffle = ctx->net ->addShuffle (*in);
151
+ shuffle->setName (util::node_info (n).c_str ());
152
+ TORCHTRT_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
153
+
154
+ if (ctx->input_is_dynamic ) {
155
+ shuffle->setInput (1 , *shape_tensor);
156
+ } else {
157
+ shuffle->setReshapeDimensions (util::toDims (new_shape));
158
+ }
159
+
160
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle->getOutput (0 ));
161
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
162
+
163
+ return true ;
164
+ }})
165
+ .pattern(
68
166
{" aten::reshape(Tensor self, int[] shape) -> (Tensor)" ,
69
167
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
70
168
auto in = args[0 ].ITensorOrFreeze (ctx);
0 commit comments