@@ -21,10 +21,29 @@ using namespace mlir::torch;
21
21
using namespace mlir ::torch::TorchConversion;
22
22
using namespace mlir ::torch;
23
23
24
+ static bool haveSameSizeAndElementType (TensorType lhs, TensorType rhs) {
25
+ if (lhs.hasRank () != rhs.hasRank ())
26
+ return false ;
27
+ bool sameSize = lhs.hasRank () ? lhs.getShape ().equals (rhs.getShape ()) : true ;
28
+ bool sameElementType = lhs.getElementType () == rhs.getElementType ();
29
+ return sameElementType && sameSize;
30
+ }
31
+
24
32
// ===----------------------------------------------------------------------===//
25
33
// ToBuiltinTensorOp
26
34
// ===----------------------------------------------------------------------===//
27
35
36
+ LogicalResult ToBuiltinTensorOp::verify () {
37
+ auto resultType = getResult ().getType ().cast <TensorType>();
38
+ auto operandType =
39
+ getOperand ().getType ().cast <Torch::ValueTensorType>().toBuiltinTensor ();
40
+ if (!haveSameSizeAndElementType (resultType, operandType)) {
41
+ return emitError ()
42
+ << " operand and result must have the same size and dtype" ;
43
+ }
44
+ return success ();
45
+ }
46
+
28
47
LogicalResult ToBuiltinTensorOp::inferReturnTypes (
29
48
MLIRContext *context, Optional<Location> location, ValueRange operands,
30
49
DictionaryAttr attributes, RegionRange regions,
@@ -37,6 +56,25 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
37
56
return success ();
38
57
}
39
58
59
+ // ===----------------------------------------------------------------------===//
60
+ // FromBuiltinTensorOp
61
+ // ===----------------------------------------------------------------------===//
62
+
63
+ LogicalResult FromBuiltinTensorOp::verify () {
64
+ auto resultType =
65
+ getResult ().getType ().cast <Torch::ValueTensorType>().toBuiltinTensor ();
66
+ auto operandType = getOperand ().getType ().cast <TensorType>();
67
+ if (!haveSameSizeAndElementType (resultType, operandType)) {
68
+ return emitError ()
69
+ << " operand and result must have the same size and dtype" ;
70
+ }
71
+ return success ();
72
+ }
73
+
74
+ // ===----------------------------------------------------------------------===//
75
+ // FromI64Op
76
+ // ===----------------------------------------------------------------------===//
77
+
40
78
OpFoldResult FromI64Op::fold (llvm::ArrayRef<mlir::Attribute> operands) {
41
79
auto attr = operands[0 ].dyn_cast_or_null <mlir::IntegerAttr>();
42
80
if (attr) {
@@ -46,6 +84,10 @@ OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
46
84
}
47
85
}
48
86
87
+ // ===----------------------------------------------------------------------===//
88
+ // ToI64Op
89
+ // ===----------------------------------------------------------------------===//
90
+
49
91
OpFoldResult ToI64Op::fold (llvm::ArrayRef<mlir::Attribute> operands) {
50
92
auto attr = operands[0 ].dyn_cast_or_null <mlir::IntegerAttr>();
51
93
if (attr) {
0 commit comments