@@ -800,6 +800,39 @@ auto select_registrations TORCHTRT_UNUSED =
800
800
801
801
layer->setName (util::node_info (n).c_str ());
802
802
803
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], layer->getOutput (0 ));
804
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
805
+ return true ;
806
+ }})
807
+ .pattern(
808
+ {" aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> (Tensor)" ,
809
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
810
+ auto condition = args[0 ].ITensorOrFreeze (ctx);
811
+ auto condition_nbDims = condition->getDimensions ().nbDims ;
812
+ auto self = args[1 ].ITensorOrFreeze (ctx);
813
+ auto x_nbDims = self->getDimensions ().nbDims ;
814
+
815
+ // Get maximum rank of all input tensors
816
+ auto max_nbDims = std::max (condition_nbDims, x_nbDims);
817
+
818
+ // TensorRT requires all inputs to Select layers to have the same rank, so for each
819
+ // tensor input, ensure that its rank is equal to the maximum number of dimensions
820
+ // If not, left-pad the tensor dimension with 1s until the max rank is achieved
821
+ condition =
822
+ addPadding (ctx, n, condition, max_nbDims, /* bool trailing =*/ false , /* bool use_zeros =*/ false );
823
+ self = addPadding (ctx, n, self, max_nbDims, /* bool trailing =*/ false , /* bool use_zeros =*/ false );
824
+
825
+ // Create a scalar tensor of rank max_nbDims from scalar other
826
+ auto scalar_value = args[2 ].unwrapToScalar ();
827
+ std::vector<int64_t > dims_vec (max_nbDims, 1 );
828
+ auto self_dtype = util::TRTDataTypeToScalarType (self->getType ());
829
+ auto constant_tensor = torch::full (dims_vec, scalar_value, {torch::dtype (self_dtype)});
830
+ auto constant_itensor = converters::tensor_to_const (ctx, constant_tensor);
831
+
832
+ auto layer = ctx->net ->addSelect (*condition, *self, *constant_itensor);
833
+ TORCHTRT_CHECK (layer, " Unable to create select layer for aten::where.ScalarOther" );
834
+ layer->setName (util::node_info (n).c_str ());
835
+
803
836
auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], layer->getOutput (0 ));
804
837
LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
805
838
return true ;
0 commit comments