@@ -200,6 +200,61 @@ auto element_wise_registrations TRTORCH_UNUSED =
200
200
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
201
201
return true ;
202
202
}})
203
+ .pattern({" aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)" ,
204
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
205
+ // Should implement other - alpha * self
206
+ auto self = args[0 ].ITensorOrFreeze (ctx);
207
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
208
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
209
+ auto scalar = args[2 ].unwrapToScalar ().to <float >();
210
+
211
+ if (1 != scalar) {
212
+ auto scaleW = Weights (ctx, scalar);
213
+ auto unuse = Weights ();
214
+ // IScaleLayer assert shift, scale and power to have
215
+ // the same dtype
216
+ auto scaleLayer = ctx->net ->addScale (
217
+ *self, nvinfer1::ScaleMode::kUNIFORM , unuse.data , scaleW.data , unuse.data );
218
+ TRTORCH_CHECK (scaleLayer, " Unable to create scale layer from node: " << *n);
219
+ self = scaleLayer->getOutput (0 );
220
+ }
221
+
222
+ auto rsub =
223
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , other, self, util::node_info (n));
224
+ TRTORCH_CHECK (rsub, " Unable to create rsub layer from node: " << *n);
225
+
226
+ rsub->setName (util::node_info (n).c_str ());
227
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], rsub->getOutput (0 ));
228
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
229
+ return true ;
230
+ }})
231
+ .pattern({" aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)" ,
232
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
233
+ // Should implement other - alpha * self
234
+ auto self = args[0 ].ITensorOrFreeze (ctx);
235
+ auto other = args[1 ].ITensorOrFreeze (ctx);
236
+ auto scalar = args[2 ].unwrapToScalar ().to <float >();
237
+
238
+ if (1 != scalar) {
239
+ auto scaleW = Weights (ctx, scalar);
240
+ auto unuse = Weights ();
241
+ // IScaleLayer assert shift, scale and power to have
242
+ // the same dtype
243
+ auto scaleLayer = ctx->net ->addScale (
244
+ *self, nvinfer1::ScaleMode::kUNIFORM , unuse.data , scaleW.data , unuse.data );
245
+ TRTORCH_CHECK (scaleLayer, " Unable to create scale layer from node: " << *n);
246
+ self = scaleLayer->getOutput (0 );
247
+ }
248
+
249
+ auto rsub =
250
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , other, self, util::node_info (n));
251
+ TRTORCH_CHECK (rsub, " Unable to create rsub layer from node: " << *n);
252
+
253
+ rsub->setName (util::node_info (n).c_str ());
254
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], rsub->getOutput (0 ));
255
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
256
+ return true ;
257
+ }})
203
258
.pattern({" aten::div.Tensor(Tensor self, Tensor other) -> Tensor" ,
204
259
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
205
260
// Should implement self / other
@@ -352,6 +407,63 @@ auto element_wise_registrations TRTORCH_UNUSED =
352
407
pow ->setName (util::node_info (n).c_str ());
353
408
auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], pow ->getOutput (0 ));
354
409
410
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
411
+ return true ;
412
+ }})
413
+ .pattern({" aten::floor_divide(Tensor self, Tensor other) -> (Tensor)" ,
414
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
415
+ // TODO: Remove with functionalization
416
+ auto self = args[0 ].ITensorOrFreeze (ctx);
417
+ auto other = args[1 ].ITensorOrFreeze (ctx);
418
+ auto floor_divide =
419
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV , self, other, util::node_info (n));
420
+ TRTORCH_CHECK (floor_divide, " Unable to create floor_divide layer from node: " << *n);
421
+
422
+ floor_divide->setName (util::node_info (n).c_str ());
423
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], floor_divide->getOutput (0 ));
424
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
425
+ return true ;
426
+ }})
427
+ .pattern({" aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
428
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
429
+ // TODO: Remove with functionalization
430
+ auto self = args[0 ].ITensorOrFreeze (ctx);
431
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
432
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
433
+ auto floor_divide =
434
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV , self, other, util::node_info (n));
435
+ TRTORCH_CHECK (floor_divide, " Unable to create floor_divide layer from node: " << *n);
436
+
437
+ floor_divide->setName (util::node_info (n).c_str ());
438
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], floor_divide->getOutput (0 ));
439
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
440
+ return true ;
441
+ }})
442
+ .pattern({" aten::max.other(Tensor self, Tensor other) -> (Tensor)" ,
443
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
444
+ // TODO: Remove with functionalization
445
+ auto self = args[0 ].ITensorOrFreeze (ctx);
446
+ auto other = args[1 ].ITensorOrFreeze (ctx);
447
+ auto max =
448
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kMAX , self, other, util::node_info (n));
449
+ TRTORCH_CHECK (max, " Unable to create max layer from node: " << *n);
450
+
451
+ max->setName (util::node_info (n).c_str ());
452
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], max->getOutput (0 ));
453
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
454
+ return true ;
455
+ }})
456
+ .pattern({" aten::min.other(Tensor self, Tensor other) -> (Tensor)" ,
457
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
458
+ // TODO: Remove with functionalization
459
+ auto self = args[0 ].ITensorOrFreeze (ctx);
460
+ auto other = args[1 ].ITensorOrFreeze (ctx);
461
+ auto min =
462
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kMIN , self, other, util::node_info (n));
463
+ TRTORCH_CHECK (min, " Unable to create min layer from node: " << *n);
464
+
465
+ min->setName (util::node_info (n).c_str ());
466
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], min->getOutput (0 ));
355
467
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
356
468
return true ;
357
469
}});
0 commit comments