@@ -2507,8 +2507,9 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
2507
2507
2508
2508
for (Operation *op : dynamicSizeProducers.back ()) {
2509
2509
if (op->getNumResults () == 1 &&
2510
- isa<IndexType>(op->getResult (0 ).getType ()))
2510
+ isa<IndexType>(op->getResult (0 ).getType ())) {
2511
2511
continue ;
2512
+ }
2512
2513
2513
2514
DiagnosedSilenceableFailure diag =
2514
2515
emitSilenceableError () << " expected sizes to be produced by ops "
@@ -2525,11 +2526,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
2525
2526
auto scalableSizes = getScalableSizes ();
2526
2527
for (auto [i, op] : llvm::enumerate (targets)) {
2527
2528
auto tilingInterface = dyn_cast<TilingInterface>(op);
2528
- auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
2529
- if (!tilingInterface || !dpsInterface) {
2529
+ if (!tilingInterface) {
2530
2530
DiagnosedSilenceableFailure diag =
2531
- emitSilenceableError () << " only ops implementing TilingInterface and "
2532
- " DestinationStyleOpInterface are supported" ;
2531
+ emitSilenceableError ()
2532
+ << " only ops implementing TilingInterface are supported" ;
2533
2533
diag.attachNote (op->getLoc ()) << " target op" ;
2534
2534
return diag;
2535
2535
}
@@ -2578,10 +2578,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
2578
2578
if (failed (maybeTilingResult))
2579
2579
return DiagnosedSilenceableFailure::definiteFailure ();
2580
2580
2581
- if (dpsInterface.hasBufferSemantics ())
2582
- rewriter.eraseOp (op);
2583
- else
2584
- rewriter.replaceOp (op, maybeTilingResult->loops .front ()->getResults ());
2581
+ rewriter.replaceOp (op, maybeTilingResult->replacements );
2585
2582
2586
2583
tiled.append (maybeTilingResult->tiledOps );
2587
2584
for (const auto &en2 : llvm::enumerate (maybeTilingResult->loops ))
@@ -2895,204 +2892,6 @@ LogicalResult TileToForallOp::verify() {
2895
2892
return success ();
2896
2893
}
2897
2894
2898
- // ===----------------------------------------------------------------------===//
2899
- // TileToScfForOp
2900
- // ===----------------------------------------------------------------------===//
2901
-
2902
- void transform::TileToScfForOp::build (OpBuilder &builder,
2903
- OperationState &result, Value target,
2904
- ArrayRef<OpFoldResult> mixedTileSizes,
2905
- ArrayRef<int64_t > interchange) {
2906
- SmallVector<int64_t > staticTileSizes;
2907
- SmallVector<Value> dynamicTileSizes;
2908
- dispatchIndexOpFoldResults (mixedTileSizes, dynamicTileSizes, staticTileSizes);
2909
- // Call the default builder which sets up the proper operands segment sizes
2910
- // attributes for multiple variadic operands. In the absence of this,
2911
- // horrible bugs ensue.
2912
- auto staticTileSizesAttr = builder.getDenseI64ArrayAttr (staticTileSizes);
2913
- int64_t numExpectedLoops =
2914
- staticTileSizes.size () - llvm::count (staticTileSizes, 0 );
2915
- SmallVector<Type> resultTypes (
2916
- numExpectedLoops, transform::AnyOpType::get (builder.getContext ()));
2917
- build (builder, result,
2918
- /* tiled_linalg_op=*/ target.getType (),
2919
- /* loops=*/ resultTypes,
2920
- /* target=*/ target,
2921
- /* dynamic_sizes=*/ dynamicTileSizes,
2922
- /* static_sizes=*/ staticTileSizesAttr,
2923
- /* interchange=*/ builder.getDenseI64ArrayAttr (interchange));
2924
- }
2925
-
2926
- DiagnosedSilenceableFailure
2927
- transform::TileToScfForOp::apply (transform::TransformRewriter &rewriter,
2928
- TransformResults &transformResults,
2929
- TransformState &state) {
2930
- ArrayRef<int64_t > tileSizes = getStaticSizes ();
2931
-
2932
- SmallVector<Operation *> targets =
2933
- llvm::to_vector (state.getPayloadOps (getTarget ()));
2934
- SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2935
- dynamicSizeProducers.reserve (getDynamicSizes ().size ());
2936
- for (Value dynamicSizeProducerHandle : getDynamicSizes ()) {
2937
- dynamicSizeProducers.push_back (
2938
- llvm::to_vector (state.getPayloadOps (dynamicSizeProducerHandle)));
2939
-
2940
- if (dynamicSizeProducers.back ().size () != targets.size ()) {
2941
- DiagnosedSilenceableFailure diag =
2942
- emitSilenceableError ()
2943
- << " expected as many dynamic size-producing operations ("
2944
- << dynamicSizeProducers.back ().size () << " ) as target ops ("
2945
- << targets.size () << " )" ;
2946
- diag.attachNote (dynamicSizeProducerHandle.getLoc ()) << " for this handle" ;
2947
- return diag;
2948
- }
2949
-
2950
- for (Operation *op : dynamicSizeProducers.back ()) {
2951
- if (op->getNumResults () == 1 &&
2952
- isa<IndexType>(op->getResult (0 ).getType ()))
2953
- continue ;
2954
- DiagnosedSilenceableFailure diag =
2955
- emitSilenceableError () << " expected sizes to be produced by ops "
2956
- " with a single index-type result" ;
2957
- diag.attachNote (op->getLoc ()) << " size producer op" ;
2958
- diag.attachNote (dynamicSizeProducerHandle.getLoc ()) << " for this handle" ;
2959
- return diag;
2960
- }
2961
- }
2962
-
2963
- SmallVector<Operation *> tiled;
2964
- SmallVector<SmallVector<Operation *, 4 >, 4 > loops;
2965
- loops.resize (getLoops ().size ());
2966
- for (auto en : llvm::enumerate (targets)) {
2967
- auto tilingInterfaceOp = dyn_cast<TilingInterface>(en.value ());
2968
- if (!tilingInterfaceOp) {
2969
- DiagnosedSilenceableFailure diag =
2970
- emitSilenceableError () << " only TilingInterface ops are supported" ;
2971
- diag.attachNote (en.value ()->getLoc ()) << " target op" ;
2972
- return diag;
2973
- }
2974
-
2975
- scf::SCFTilingOptions tilingOptions;
2976
- unsigned index = en.index ();
2977
- if (!tileSizes.empty ()) {
2978
- tilingOptions.setTileSizeComputationFunction (
2979
- [&, index ](OpBuilder &b, Operation *) {
2980
- SmallVector<Value, 4 > sizes;
2981
- sizes.reserve (tileSizes.size ());
2982
- unsigned dynamicIdx = 0 ;
2983
- for (OpFoldResult ofr : getMixedSizes ()) {
2984
- if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
2985
- sizes.push_back (b.create <arith::ConstantIndexOp>(
2986
- getLoc (), cast<IntegerAttr>(attr).getInt ()));
2987
- } else {
2988
- sizes.push_back (
2989
- dynamicSizeProducers[dynamicIdx++][index ]->getResult (0 ));
2990
- }
2991
- }
2992
- return sizes;
2993
- });
2994
- }
2995
-
2996
- tilingOptions.setInterchange (getInterchange ());
2997
- FailureOr<scf::SCFTilingResult> tilingResult =
2998
- tileUsingSCFForOp (rewriter, tilingInterfaceOp, tilingOptions);
2999
- if (failed (tilingResult))
3000
- return DiagnosedSilenceableFailure::definiteFailure ();
3001
-
3002
- rewriter.replaceOp (tilingInterfaceOp, tilingResult->replacements );
3003
-
3004
- tiled.append (tilingResult->tiledOps );
3005
- for (const auto &en2 : llvm::enumerate (tilingResult->loops ))
3006
- loops[en2.index ()].push_back (en2.value ());
3007
- }
3008
-
3009
- transformResults.set (cast<OpResult>(getTiledLinalgOp ()), tiled);
3010
- for (const auto &en : llvm::enumerate (loops))
3011
- transformResults.set (cast<OpResult>(getLoops ()[en.index ()]), en.value ());
3012
-
3013
- return DiagnosedSilenceableFailure::success ();
3014
- }
3015
-
3016
- SmallVector<OpFoldResult> transform::TileToScfForOp::getMixedSizes () {
3017
- ValueRange dynamic = getDynamicSizes ();
3018
- ArrayRef<int64_t > tileSizes = getStaticSizes ();
3019
- SmallVector<OpFoldResult> results;
3020
- results.reserve (tileSizes.size ());
3021
- unsigned dynamicPos = 0 ;
3022
- Builder builder (getContext ());
3023
- for (int64_t size : tileSizes) {
3024
- if (size == ShapedType::kDynamic ) {
3025
- results.push_back (dynamic[dynamicPos++]);
3026
- } else {
3027
- results.push_back (builder.getIndexAttr (size));
3028
- }
3029
- }
3030
- return results;
3031
- }
3032
-
3033
- ParseResult transform::TileToScfForOp::parse (OpAsmParser &parser,
3034
- OperationState &result) {
3035
- OpAsmParser::UnresolvedOperand target;
3036
- SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
3037
- DenseI64ArrayAttr staticSizes;
3038
- FunctionType trailingType;
3039
- llvm::SMLoc typeLoc;
3040
- if (parser.parseOperand (target) ||
3041
- parseDynamicIndexList (parser, dynamicSizes, staticSizes) ||
3042
- parseOptionalInterchange (parser, result) ||
3043
- parser.parseOptionalAttrDict (result.attributes ) ||
3044
- parser.getCurrentLocation (&typeLoc) ||
3045
- parser.parseColonType (trailingType)) {
3046
- return ParseResult::failure ();
3047
- }
3048
-
3049
- result.addAttribute (getStaticSizesAttrName (result.name ), staticSizes);
3050
- size_t numExpectedLoops =
3051
- staticSizes.size () - llvm::count (staticSizes.asArrayRef (), 0 );
3052
-
3053
- unsigned numExpectedInputTypes = 1 + dynamicSizes.size ();
3054
- if (trailingType.getNumInputs () != numExpectedInputTypes) {
3055
- return parser.emitError (typeLoc)
3056
- << " expected " << numExpectedInputTypes << " operand types, got "
3057
- << trailingType.getNumInputs ();
3058
- }
3059
-
3060
- unsigned numExpectedOutputTypes = 1 + numExpectedLoops;
3061
- if (trailingType.getNumResults () != numExpectedOutputTypes) {
3062
- return parser.emitError (typeLoc)
3063
- << " expected " << numExpectedOutputTypes << " result types, got "
3064
- << trailingType.getNumResults ();
3065
- }
3066
-
3067
- if (parser.resolveOperand (target, trailingType.getInput (0 ),
3068
- result.operands ) ||
3069
- parser.resolveOperands (dynamicSizes,
3070
- trailingType.getInputs ().drop_front (), typeLoc,
3071
- result.operands ) ||
3072
- parser.addTypesToList (trailingType.getResults (), result.types )) {
3073
- return failure ();
3074
- }
3075
- return success ();
3076
- }
3077
-
3078
- void TileToScfForOp::print (OpAsmPrinter &p) {
3079
- p << ' ' << getTarget ();
3080
- printDynamicIndexList (p, getOperation (), getDynamicSizes (), getStaticSizes ());
3081
- printOptionalInterchange (p, getInterchange ());
3082
- p.printOptionalAttrDict (getOperation ()->getAttrs (), getAttributeNames ());
3083
- p << " : " ;
3084
- p.printFunctionalType (getOperation ());
3085
- }
3086
-
3087
- void transform::TileToScfForOp::getEffects (
3088
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3089
- consumesHandle (getTarget (), effects);
3090
- onlyReadsHandle (getDynamicSizes (), effects);
3091
- producesHandle (getTiledLinalgOp (), effects);
3092
- producesHandle (getLoops (), effects);
3093
- modifiesPayload (effects);
3094
- }
3095
-
3096
2895
// ===----------------------------------------------------------------------===//
3097
2896
// VectorizeOp
3098
2897
// ===----------------------------------------------------------------------===//
0 commit comments