Skip to content

Commit f271e6a

Browse files
authored
Add verifiers for ToBuiltinTensorOp and FromBuiltinTensorOp (#1089)
This commit adds verifiers to the ops `ToBuiltinTensorOp` and `FromBuiltinTensorOp` that make sure that the input and output have the same shape and data type.
1 parent 31fd812 commit f271e6a

File tree

3 files changed

+77
-1
lines changed

3 files changed

+77
-1
lines changed

include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
4242
let assemblyFormat = [{
4343
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
4444
}];
45+
let hasVerifier = 1;
4546
}
4647

4748
def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tensor">
@@ -60,6 +61,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
6061
let assemblyFormat = [{
6162
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
6263
}];
64+
let hasVerifier = 1;
6365
}
6466

6567
def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [

lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,29 @@ using namespace mlir::torch;
2121
using namespace mlir::torch::TorchConversion;
2222
using namespace mlir::torch;
2323

24+
static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) {
25+
if (lhs.hasRank() != rhs.hasRank())
26+
return false;
27+
bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true;
28+
bool sameElementType = lhs.getElementType() == rhs.getElementType();
29+
return sameElementType && sameSize;
30+
}
31+
2432
//===----------------------------------------------------------------------===//
2533
// ToBuiltinTensorOp
2634
//===----------------------------------------------------------------------===//
2735

36+
LogicalResult ToBuiltinTensorOp::verify() {
37+
auto resultType = getResult().getType().cast<TensorType>();
38+
auto operandType =
39+
getOperand().getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
40+
if (!haveSameSizeAndElementType(resultType, operandType)) {
41+
return emitError()
42+
<< "operand and result must have the same size and dtype";
43+
}
44+
return success();
45+
}
46+
2847
LogicalResult ToBuiltinTensorOp::inferReturnTypes(
2948
MLIRContext *context, Optional<Location> location, ValueRange operands,
3049
DictionaryAttr attributes, RegionRange regions,
@@ -37,6 +56,25 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
3756
return success();
3857
}
3958

59+
//===----------------------------------------------------------------------===//
60+
// FromBuiltinTensorOp
61+
//===----------------------------------------------------------------------===//
62+
63+
LogicalResult FromBuiltinTensorOp::verify() {
64+
auto resultType =
65+
getResult().getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
66+
auto operandType = getOperand().getType().cast<TensorType>();
67+
if (!haveSameSizeAndElementType(resultType, operandType)) {
68+
return emitError()
69+
<< "operand and result must have the same size and dtype";
70+
}
71+
return success();
72+
}
73+
74+
//===----------------------------------------------------------------------===//
75+
// FromI64Op
76+
//===----------------------------------------------------------------------===//
77+
4078
OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
4179
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
4280
if (attr) {
@@ -46,6 +84,10 @@ OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
4684
}
4785
}
4886

87+
//===----------------------------------------------------------------------===//
88+
// ToI64Op
89+
//===----------------------------------------------------------------------===//
90+
4991
OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
5092
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
5193
if (attr) {

test/Dialect/TorchConversion/ops.mlir

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s
1+
// RUN: torch-mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
22

33
// CHECK-LABEL: func.func @builtin_tensor_interop(
44
func.func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
@@ -14,3 +14,35 @@ func.func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi8>, %
1414
%4 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xi8>
1515
return
1616
}
17+
18+
// -----
19+
20+
func.func @to_builtin_tensor_invalid_size(%arg0: !torch.vtensor<[3,?],si8>) {
21+
// expected-error @+1 {{operand and result must have the same size and dtype}}
22+
%1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,?],si8> -> tensor<?x?xi8>
23+
return
24+
}
25+
26+
// -----
27+
28+
func.func @to_builtin_tensor_invalid_dtype(%arg0: !torch.vtensor<*,si8>) {
29+
// expected-error @+1 {{operand and result must have the same size and dtype}}
30+
%1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<*,si8> -> tensor<*xi64>
31+
return
32+
}
33+
34+
// -----
35+
36+
func.func @from_builtin_tensor_invalid_size(%arg0: tensor<3x?xi8>) {
37+
// expected-error @+1 {{operand and result must have the same size and dtype}}
38+
%1 = torch_c.from_builtin_tensor %arg0 : tensor<3x?xi8> -> !torch.vtensor<[?,?],si8>
39+
return
40+
}
41+
42+
// -----
43+
44+
func.func @from_builtin_tensor_invalid_dtype(%arg0: tensor<*xi8>) {
45+
// expected-error @+1 {{operand and result must have the same size and dtype}}
46+
%1 = torch_c.from_builtin_tensor %arg0 : tensor<*xi8> -> !torch.vtensor<*,si64>
47+
return
48+
}

0 commit comments

Comments
 (0)