@@ -301,21 +301,20 @@ LogicalResult ClassTypeOp::verify() {
301
301
// PrimLoopOp
302
302
// ===----------------------------------------------------------------------===//
303
303
304
- OperandRange
305
- PrimLoopOp::getEntrySuccessorOperands (std::optional<unsigned int > index) {
306
- assert (index .has_value () && index .value () == 0 );
304
+ OperandRange PrimLoopOp::getEntrySuccessorOperands (RegionBranchPoint point) {
305
+ assert (point == getRegion ());
307
306
return getIterArgsInit ();
308
307
}
309
308
310
309
void PrimLoopOp::getSuccessorRegions (
311
- std::optional< unsigned > index , SmallVectorImpl<RegionSuccessor> ®ions) {
312
-
313
- if (!index . has_value ()) {
314
- regions.emplace_back (&getRegion (), getRegion () .getArguments ().slice (1 ));
310
+ RegionBranchPoint point , SmallVectorImpl<RegionSuccessor> ®ions) {
311
+ Region ®ion = getRegion ();
312
+ if (!point. getRegionOrNull ()) {
313
+ regions.emplace_back (®ion, region .getArguments ().slice (1 ));
315
314
return ;
316
315
}
317
- assert (* index == 0 );
318
- regions.emplace_back (&getRegion (), getRegion () .getArguments ().slice (1 ));
316
+ assert (point == region );
317
+ regions.emplace_back (®ion, region .getArguments ().slice (1 ));
319
318
regions.emplace_back (getResults ());
320
319
}
321
320
@@ -328,8 +327,8 @@ bool PrimLoopOp::isForLike() {
328
327
// PrimLoopConditionOp
329
328
// ===----------------------------------------------------------------------===//
330
329
331
- MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands (
332
- std::optional< unsigned > index ) {
330
+ MutableOperandRange
331
+ PrimLoopConditionOp::getMutableSuccessorOperands (RegionBranchPoint point ) {
333
332
// Pass all operands except the condition to the successor which is the
334
333
// parent loop op.
335
334
return getIterArgsMutable ();
@@ -378,10 +377,10 @@ void PrimIfOp::print(OpAsmPrinter &p) {
378
377
p.printOptionalAttrDict ((*this )->getAttrs ());
379
378
}
380
379
381
- void PrimIfOp::getSuccessorRegions (std::optional< unsigned > index ,
380
+ void PrimIfOp::getSuccessorRegions (RegionBranchPoint point ,
382
381
SmallVectorImpl<RegionSuccessor> ®ions) {
383
382
// The `then` and the `else` region branch back to the parent operation.
384
- if (index . has_value ()) {
383
+ if (point. getRegionOrNull ()) {
385
384
regions.push_back (RegionSuccessor (getResults ()));
386
385
return ;
387
386
}
@@ -1595,7 +1594,9 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
1595
1594
MLIRContext *context, std::optional<Location> location, ValueRange operands,
1596
1595
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1597
1596
SmallVectorImpl<Type> &inferredReturnTypes) {
1598
- auto attr = attributes.get (" value" ).dyn_cast_or_null <ElementsAttr>();
1597
+ auto attr = properties.as <Properties *>()
1598
+ ->getValue ()
1599
+ .dyn_cast_or_null <ElementsAttr>();
1599
1600
if (!attr)
1600
1601
return failure ();
1601
1602
RankedTensorType tensorType = attr.getType ().cast <RankedTensorType>();
@@ -1635,7 +1636,9 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
1635
1636
MLIRContext *context, std::optional<Location> location, ValueRange operands,
1636
1637
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1637
1638
SmallVectorImpl<Type> &inferredReturnTypes) {
1638
- auto attr = attributes.get (" value" ).dyn_cast_or_null <ElementsAttr>();
1639
+ auto attr = properties.as <Properties *>()
1640
+ ->getValue ()
1641
+ .dyn_cast_or_null <ElementsAttr>();
1639
1642
if (!attr)
1640
1643
return failure ();
1641
1644
RankedTensorType tensorType = attr.getType ().cast <RankedTensorType>();
@@ -2768,43 +2771,43 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
2768
2771
2769
2772
template <typename CalculateOp>
2770
2773
static void
2771
- getSuccessorRegionsForCalculateOp (CalculateOp op, std::optional< unsigned > index ,
2774
+ getSuccessorRegionsForCalculateOp (CalculateOp op, RegionBranchPoint point ,
2772
2775
SmallVectorImpl<RegionSuccessor> ®ions) {
2773
- if (!index . has_value ()) {
2776
+ if (!point. getRegionOrNull ()) {
2774
2777
// First thing the op does is branch into the calculation.
2775
2778
regions.emplace_back (&op.getCalculation ());
2776
2779
return ;
2777
2780
}
2778
- if (* index == 0 ) {
2781
+ if (point == op. getBody () ) {
2779
2782
// Body returns control to the outer op, passing through results.
2780
2783
regions.emplace_back (op.getResults ());
2781
2784
return ;
2782
2785
}
2783
- assert (* index == 1 );
2786
+ assert (point == op. getCalculation () );
2784
2787
// Calculation branches to the body.
2785
2788
regions.emplace_back (&op.getBody ());
2786
2789
}
2787
2790
2788
2791
void ShapeCalculateOp::getSuccessorRegions (
2789
- std::optional< unsigned > index , SmallVectorImpl<RegionSuccessor> ®ions) {
2790
- getSuccessorRegionsForCalculateOp (*this , index , regions);
2792
+ RegionBranchPoint point , SmallVectorImpl<RegionSuccessor> ®ions) {
2793
+ getSuccessorRegionsForCalculateOp (*this , point , regions);
2791
2794
}
2792
2795
2793
2796
// ===----------------------------------------------------------------------===//
2794
2797
// DtypeCalculateOp
2795
2798
// ===----------------------------------------------------------------------===//
2796
2799
2797
2800
void DtypeCalculateOp::getSuccessorRegions (
2798
- std::optional< unsigned > index , SmallVectorImpl<RegionSuccessor> ®ions) {
2799
- getSuccessorRegionsForCalculateOp (*this , index , regions);
2801
+ RegionBranchPoint point , SmallVectorImpl<RegionSuccessor> ®ions) {
2802
+ getSuccessorRegionsForCalculateOp (*this , point , regions);
2800
2803
}
2801
2804
2802
2805
// ===----------------------------------------------------------------------===//
2803
2806
// ShapeCalculateYieldShapesOp
2804
2807
// ===----------------------------------------------------------------------===//
2805
2808
2806
2809
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands (
2807
- std::optional< unsigned > index ) {
2810
+ RegionBranchPoint point ) {
2808
2811
// The shape operands don't get forwarded to the body.
2809
2812
// MutableOperandRange always has an owning operation, even if empty, so
2810
2813
// create a 0-length range.
@@ -2823,7 +2826,7 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
2823
2826
// ===----------------------------------------------------------------------===//
2824
2827
2825
2828
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands (
2826
- std::optional< unsigned > index ) {
2829
+ RegionBranchPoint point ) {
2827
2830
// The dtype operands don't get forwarded to the body.
2828
2831
// MutableOperandRange always has an owning operation, even if empty, so
2829
2832
// create a 0-length range.
0 commit comments