Skip to content

Commit f6e7e0c

Browse files
committed
Add type checking for lhs and rhs
1 parent f5253d4 commit f6e7e0c

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -825,14 +825,16 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
825825
if (!lhs.hasValue() || !rhs.hasValue())
826826
return failure();
827827

828+
if (!lhs.getValue().getType().isa<Torch::IntType>() ||
829+
!rhs.getValue().getType().isa<Torch::IntType>())
830+
return failure();
831+
828832
Value mul =
829833
rewriter.create<AtenMulIntOp>(op->getLoc(), rhs.getValue(), op.alpha());
830-
if (!mul.getType().isa<Torch::IntType>())
831-
return failure();
834+
832835
Value add =
833836
rewriter.create<AtenAddIntOp>(op->getLoc(), lhs.getValue(), mul);
834-
if (!add.getType().isa<Torch::IntType>())
835-
return failure();
837+
836838
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
837839
op, op.self().getType(), add);
838840
return success();

0 commit comments

Comments
 (0)