19
19
20
20
include "mlir/IR/OpBase.td"
21
21
include "mlir/Dialect/Shape/IR/ShapeBase.td"
22
+ include "mlir/Interfaces/CastInterfaces.td"
22
23
include "mlir/Interfaces/LoopLikeInterface.td"
23
24
include "mlir/Interfaces/SideEffectInterfaces.td"
24
25
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -42,8 +43,8 @@ def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
42
43
intend to optimize.
43
44
}];
44
45
45
- let arguments = (ins);
46
46
let results = (outs Variadic<AnyType>);
47
+
47
48
let skipDefaultBuilders = 1;
48
49
let builders = [ OpBuilder<(ins "int64_t":$num_loops)> ];
49
50
@@ -53,14 +54,14 @@ def KrnlDefineLoopsOp : Op<Krnl_Dialect, "define_loops"> {
53
54
let extraClassDeclaration = [{
54
55
static StringRef getNumLoopsAttrName() { return "num_loops"; }
55
56
56
- // Helper function to extract the number of loops being defined.
57
- int64_t getNumLoops() {
58
- auto num_loops = (*this)->getAttrOfType<IntegerAttr>(getNumLoopsAttrName())
59
- .getValue()
60
- .getSExtValue();
61
- return num_loops;
62
- }
63
- }];
57
+ // Helper function to extract the number of loops being defined.
58
+ int64_t getNumLoops() {
59
+ auto num_loops = (*this)->getAttrOfType<IntegerAttr>(getNumLoopsAttrName())
60
+ .getValue()
61
+ .getSExtValue();
62
+ return num_loops;
63
+ }
64
+ }];
64
65
}
65
66
66
67
def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator,
@@ -88,7 +89,9 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator,
88
89
}];
89
90
90
91
let arguments = (ins Variadic<AnyType>);
92
+
91
93
let regions = (region SizedRegion<1>:$bodyRegion);
94
+
92
95
let skipDefaultBuilders = 1;
93
96
let builders = [
94
97
// Main builder.
@@ -106,7 +109,12 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator,
106
109
CArg<"ArrayRef<IndexExpr>">:$lbs, CArg<"ArrayRef<IndexExpr>">:$ubs,
107
110
CArg<"ValueRange">:$iterArgs,
108
111
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>">:$bodyBuilderFn)>
109
- ];
112
+ ];
113
+
114
+ let printer = [{ return ::print(p, *this); }];
115
+ let parser = [{ return ::parse$cppClass(parser, result); }];
116
+
117
+ let hasVerifier = 1;
110
118
111
119
let extraClassDeclaration = [{
112
120
// In krnl.iterate operation, operands are stored as such
@@ -127,11 +135,7 @@ def KrnlIterateOp : Op<Krnl_Dialect, "iterate", [ImplicitKrnlTerminator,
127
135
128
136
// Get name of the attribute for storing bound represented using affine maps.
129
137
static StringRef getBoundsAttrName() { return "bounds"; }
130
- }];
131
-
132
- let printer = [{ return ::print(p, *this); }];
133
- let parser = [{ return ::parse$cppClass(parser, result); }];
134
- let verifier = [{ return ::verify(*this); }];
138
+ }];
135
139
}
136
140
137
141
def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
@@ -148,9 +152,6 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
148
152
// No custom parsing/printing form.
149
153
let parser = ?;
150
154
let printer = ?;
151
-
152
- // Fully specified by traits.
153
- let verifier = ?;
154
155
}
155
156
156
157
def KrnlEntryPointOp : Op<Krnl_Dialect, "entry_point"> {
@@ -223,16 +224,16 @@ def KrnlGetRefOp : Op<Krnl_Dialect, "getref", [MemRefsNormalizable]> {
223
224
}]>,
224
225
];
225
226
227
+ let parser = ?;
228
+ let printer = ?;
229
+
226
230
let extraClassDeclaration = [{
227
231
/// Returns the symbolic operands (the ones in square brackets), which bind
228
232
/// to the symbols of the memref's layout map.
229
233
operand_range getDynamicSizes() {
230
234
return {operand_begin() + 2, operand_end()};
231
235
}
232
236
}];
233
-
234
- let parser = ?;
235
- let printer = ?;
236
237
}
237
238
238
239
def KrnlBlockOp : Op<Krnl_Dialect, "block"> {
@@ -243,9 +244,9 @@ def KrnlBlockOp : Op<Krnl_Dialect, "block"> {
243
244
means to block the for loop referred to by %i using a tile size of 4.
244
245
}];
245
246
246
- let arguments = (ins
247
- AnyType:$loop, I64Attr:$tile_size);
247
+ let arguments = (ins AnyType:$loop, I64Attr:$tile_size);
248
248
let results = (outs AnyType:$loop_block, AnyType:$loop_local);
249
+
249
250
let builders = [ OpBuilder<(ins "Value": $loop, "int64_t":$tile_size)> ];
250
251
let assemblyFormat = [{
251
252
$loop $tile_size attr-dict `:` functional-type($loop, results)
@@ -312,7 +313,7 @@ def KrnlPermuteOp : Op<Krnl_Dialect, "permute"> {
312
313
}];
313
314
314
315
let arguments = (ins Variadic<AnyType>:$loops, I64ArrayAttr:$map);
315
- let results = (outs);
316
+
316
317
let builders = [ OpBuilder<(ins "ValueRange": $loops, "ArrayRef<int64_t>":$map)> ];
317
318
let assemblyFormat = [{
318
319
`(` $loops `)` $map attr-dict `:` type($loops)
@@ -330,7 +331,7 @@ def KrnlUnrollOp : Op<Krnl_Dialect, "unroll"> {
330
331
}];
331
332
332
333
let arguments = (ins AnyType:$loop);
333
- let results = (outs);
334
+
334
335
let assemblyFormat = [{
335
336
$loop attr-dict `:` type($loop)
336
337
}];
@@ -507,6 +508,8 @@ def KrnlLoadOp : Op<Krnl_Dialect, "load",
507
508
$_state.types.push_back(memrefType.getElementType());
508
509
}]>];
509
510
511
+ let assemblyFormat = [{$memref `[` $indices `]` attr-dict `:` type($memref)}];
512
+
510
513
let extraClassDeclaration = [{
511
514
Value getMemRef() { return getOperand(0); }
512
515
void setMemRef(Value value) { setOperand(0, value); }
@@ -516,8 +519,6 @@ def KrnlLoadOp : Op<Krnl_Dialect, "load",
516
519
517
520
operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
518
521
}];
519
-
520
- let assemblyFormat = [{$memref `[` $indices `]` attr-dict `:` type($memref)}];
521
522
}
522
523
523
524
def KrnlStoreOp : Op<Krnl_Dialect, "store",
@@ -544,6 +545,10 @@ def KrnlStoreOp : Op<Krnl_Dialect, "store",
544
545
$_state.addOperands(memref);
545
546
}]>];
546
547
548
+ let assemblyFormat = [{
549
+ $value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
550
+ }];
551
+
547
552
let extraClassDeclaration = [{
548
553
Value getValueToStore() { return getOperand(0); }
549
554
@@ -557,10 +562,6 @@ def KrnlStoreOp : Op<Krnl_Dialect, "store",
557
562
return {operand_begin() + 2, operand_end()};
558
563
}
559
564
}];
560
-
561
- let assemblyFormat = [{
562
- $value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
563
- }];
564
565
}
565
566
566
567
def KrnlMovableOp : Op<Krnl_Dialect, "movable", [ImplicitKrnlTerminator]> {
@@ -576,8 +577,6 @@ def KrnlMovableOp : Op<Krnl_Dialect, "movable", [ImplicitKrnlTerminator]> {
576
577
are nested imperfectly between an "eager" and a "lazy" loop.
577
578
}];
578
579
579
- let arguments = (ins );
580
-
581
580
let regions = (region AnyRegion:$region);
582
581
583
582
let assemblyFormat = [{
@@ -600,6 +599,7 @@ def KrnlGetInductionVariableValueOp : Op<Krnl_Dialect, "get_induction_var_value"
600
599
601
600
let arguments = (ins Variadic<AnyType> : $loops);
602
601
let results = (outs Variadic<AnyType> : $ind_var_vals);
602
+
603
603
let builders = [ OpBuilder<(ins "ValueRange": $loops)>];
604
604
605
605
let assemblyFormat = [{
@@ -611,7 +611,7 @@ def KrnlGetInductionVariableValueOp : Op<Krnl_Dialect, "get_induction_var_value"
611
611
// =============================================================================
612
612
613
613
def KrnlVectorTypeCastOp : Op<Krnl_Dialect, "vector_type_cast", [NoSideEffect,
614
- MemRefsNormalizable, ViewLikeOpInterface]> {
614
+ MemRefsNormalizable, DeclareOpInterfaceMethods<CastOpInterface>, ViewLikeOpInterface]> {
615
615
let summary = "vector type cast operation";
616
616
let description = [{
617
617
The "vector_type_cast" operation converts a memref from an non-vector
@@ -627,30 +627,20 @@ def KrnlVectorTypeCastOp : Op<Krnl_Dialect, "vector_type_cast", [NoSideEffect,
627
627
let arguments = (ins AnyMemRef:$source);
628
628
let results = (outs AnyMemRef:$result);
629
629
630
- let parser = ? ;
631
- let printer = ? ;
630
+ let hasFolder = 1 ;
631
+ let builders = [ OpBuilder<(ins "Value": $source, "int64_t": $vectorLen)> ] ;
632
632
633
- let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }];
633
+ let assemblyFormat = [{
634
+ $source attr-dict `:` type($source) `to` type($result)
635
+ }];
634
636
635
637
let extraClassDeclaration = [{
636
- /// Return true if `a` and `b` are valid operand and result pairs for
637
- /// the operation.
638
- static bool areCastCompatible(Type a, Type b);
639
-
640
638
/// The result of a vector_type_cast is always a memref.
641
639
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
642
640
643
641
/// Return the view source.
644
642
Value getViewSource() { return source(); }
645
643
}];
646
-
647
- let hasFolder = 1;
648
- let builders = [ OpBuilder<(ins "Value": $source, "int64_t": $vectorLen)> ];
649
-
650
- let assemblyFormat = [{
651
- $source attr-dict `:` type($source) `to` type($result)
652
- }];
653
-
654
644
}
655
645
656
646
// =============================================================================
@@ -663,7 +653,6 @@ def KrnlSpecializedKernel : Op<Krnl_Dialect, "specialized_kernel",
663
653
}];
664
654
665
655
let arguments = (ins Variadic<AnyType> : $loops);
666
- let results = (outs );
667
656
668
657
let assemblyFormat = [{
669
658
`(` $loops `)` attr-dict `:` type($loops)
@@ -841,7 +830,7 @@ def KrnlMatMulOp : Op<Krnl_Dialect, "matmul", [AttrSizedOperandSegments,
841
830
"bool": $overcompute)>
842
831
];
843
832
844
- let verifier = [{ return ::verify(*this); }] ;
833
+ let hasVerifier = 1 ;
845
834
846
835
let assemblyFormat = [{
847
836
$A `[` $aMemStart `]` `,`
@@ -922,8 +911,9 @@ def KrnlCopyToBufferOp : Op<Krnl_Dialect, "copy_to_tile_buffer", [
922
911
"bool": $transpose)>
923
912
];
924
913
925
- let verifier = [{ return ::verify(*this); }];
926
- let assemblyFormat = [{
914
+ let hasVerifier = 1;
915
+
916
+ let assemblyFormat = [{
927
917
$buffer `,` $source `[` $starts `]` `,` $padValue attr-dict
928
918
`:` type($buffer) `,` type($source)
929
919
}];
@@ -955,8 +945,9 @@ def KrnlCopyFromBufferOp : Op<Krnl_Dialect, "copy_from_tile_buffer",
955
945
OpBuilder<(ins "Value": $buffer, "Value": $dest, "ValueRange": $starts)>
956
946
];
957
947
958
- let verifier = [{ return ::verify(*this); }];
959
- let assemblyFormat = [{
948
+ let hasVerifier = 1;
949
+
950
+ let assemblyFormat = [{
960
951
$buffer `,` $dest `[` $starts `]` attr-dict `:` type($buffer) `,` type($dest)
961
952
}];
962
953
}
@@ -984,6 +975,7 @@ def KrnlMemsetOp : Op<Krnl_Dialect, "memset", [MemRefsNormalizable,
984
975
}];
985
976
986
977
let arguments = (ins AnyMemRef:$dest, AnyType: $value);
978
+
987
979
let assemblyFormat = [{ $dest `,` $value attr-dict `:` type($dest) }];
988
980
}
989
981
@@ -1019,8 +1011,6 @@ def KrnlRandomNormalOp : Op<Krnl_Dialect, "random_normal",
1019
1011
AnyFloat:$mean,
1020
1012
AnyFloat:$scale,
1021
1013
AnyFloat:$seed);
1022
-
1023
- let results = (outs );
1024
1014
}
1025
1015
1026
1016
def KrnlFindIndexOp : Op<Krnl_Dialect, "find_index",
0 commit comments