Skip to content

Commit a00a0d4

Browse files
stellaraccidenttatwaichongeric-k256
authored
Integrate llvm-project and mlir-hlo. (#2454)
Corresponding commits: * mlir-hlo: 16886a108eff5197f816ca0f1950cc5ff1b078d9 * stablehlo: 77a59815a82b34f7b08ed2d42a711d9920682d0e * llvm-project: 4acc3ff * Adapt to ByteCodeOpInterface changes. * Adapt to RegionBranchPoint changes: https://reviews.llvm.org/D159116 * Adapt inferReturnTypes to get the value from properties. * Adapt invalid.mlir to properties syntax * [TOSA] Align with custom assembly format change. * [TOSA] handle change of axis to int32 type * [TOSA] Restore improper convert to i32 Landing with Windows broken (it cannot be fixed because of the way the mlir-hlo dep is inserted). Will followup with an untangling. --------- Co-authored-by: TatWai Chong <[email protected]> Co-authored-by: Eric Kunze <[email protected]>
1 parent 106b585 commit a00a0d4

File tree

11 files changed

+199
-192
lines changed

11 files changed

+199
-192
lines changed

externals/llvm-project

Submodule llvm-project updated 9061 files

externals/mlir-hlo

Submodule mlir-hlo updated 3994 files

include/torch-mlir/Dialect/Torch/IR/TorchOps.h

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
1111
#define TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H
1212

13+
#include "mlir/Bytecode/BytecodeOpInterface.h"
1314
#include "mlir/IR/BuiltinTypes.h"
1415
#include "mlir/IR/Matchers.h"
1516
#include "mlir/IR/OpDefinition.h"

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -2346,7 +2346,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
23462346
op.getLoc(),
23472347
RankedTensorType::get(makeShapeLLVMCompatible(toReduceShape),
23482348
inputType.getElementType()),
2349-
sumDiv, rewriter.getI64IntegerAttr(i));
2349+
sumDiv, rewriter.getI32IntegerAttr(i));
23502350
}
23512351

23522352
return rewriter.create<tosa::ReshapeOp>(
@@ -3214,7 +3214,7 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
32143214
prunedShape.push_back(en.value());
32153215
}
32163216

3217-
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim);
3217+
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
32183218
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);
32193219

32203220
Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
@@ -4787,7 +4787,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
47874787
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
47884788

47894789
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
4790-
rewriter, loc, outType, builtinTensors, rewriter.getI64IntegerAttr(dim));
4790+
rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim));
47914791
rewriter.replaceOp(op, result.getResult());
47924792
return success();
47934793
}

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
382382
rewriter, op->getLoc(),
383383
GetTypeFromTensorShape(indicesMatrixReducesumShape,
384384
indicesType.getElementType()),
385-
flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1));
385+
flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1));
386386

387387
// And reshape to [N, W]
388388
// %7 = "tosa.reshape"(%6) {new_shape = [1, 8]} : (tensor<8x1xi32>) ->
@@ -648,7 +648,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
648648
rewriter, op->getLoc(),
649649
GetTypeFromTensorShape(indicesMatrixReducesumShape,
650650
indicesType.getElementType()),
651-
flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1));
651+
flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1));
652652

653653
// And reshape to [N, W]
654654
// [[1],[2],[3]] -> [[1,2,3]]
@@ -717,7 +717,7 @@ std::optional<Value> convertReduceOpCommon(
717717
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
718718
if (axis_val < 0)
719719
axis_val += input_rank;
720-
auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
720+
auto axis_attr = rewriter.getI32IntegerAttr(axis_val);
721721

722722
shape_vec[axis_val] = 1;
723723
RankedTensorType reduce_type = RankedTensorType::get(

lib/Dialect/Torch/IR/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ add_mlir_library(TorchMLIRTorchDialect
1616
Core
1717

1818
LINK_LIBS PUBLIC
19+
MLIRBytecodeOpInterface
20+
MLIRBytecodeReader
21+
MLIRBytecodeWriter
1922
MLIRFuncDialect
2023
MLIRIR
2124
MLIRSupport

lib/Dialect/Torch/IR/TorchOps.cpp

+28-25
Original file line numberDiff line numberDiff line change
@@ -301,21 +301,20 @@ LogicalResult ClassTypeOp::verify() {
301301
// PrimLoopOp
302302
//===----------------------------------------------------------------------===//
303303

304-
OperandRange
305-
PrimLoopOp::getEntrySuccessorOperands(std::optional<unsigned int> index) {
306-
assert(index.has_value() && index.value() == 0);
304+
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
305+
assert(point == getRegion());
307306
return getIterArgsInit();
308307
}
309308

310309
void PrimLoopOp::getSuccessorRegions(
311-
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
312-
313-
if (!index.has_value()) {
314-
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
310+
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
311+
Region &region = getRegion();
312+
if (!point.getRegionOrNull()) {
313+
regions.emplace_back(&region, region.getArguments().slice(1));
315314
return;
316315
}
317-
assert(*index == 0);
318-
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
316+
assert(point == region);
317+
regions.emplace_back(&region, region.getArguments().slice(1));
319318
regions.emplace_back(getResults());
320319
}
321320

@@ -328,8 +327,8 @@ bool PrimLoopOp::isForLike() {
328327
// PrimLoopConditionOp
329328
//===----------------------------------------------------------------------===//
330329

331-
MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands(
332-
std::optional<unsigned> index) {
330+
MutableOperandRange
331+
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
333332
// Pass all operands except the condition to the successor which is the
334333
// parent loop op.
335334
return getIterArgsMutable();
@@ -378,10 +377,10 @@ void PrimIfOp::print(OpAsmPrinter &p) {
378377
p.printOptionalAttrDict((*this)->getAttrs());
379378
}
380379

381-
void PrimIfOp::getSuccessorRegions(std::optional<unsigned> index,
380+
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
382381
SmallVectorImpl<RegionSuccessor> &regions) {
383382
// The `then` and the `else` region branch back to the parent operation.
384-
if (index.has_value()) {
383+
if (point.getRegionOrNull()) {
385384
regions.push_back(RegionSuccessor(getResults()));
386385
return;
387386
}
@@ -1595,7 +1594,9 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
15951594
MLIRContext *context, std::optional<Location> location, ValueRange operands,
15961595
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
15971596
SmallVectorImpl<Type> &inferredReturnTypes) {
1598-
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
1597+
auto attr = properties.as<Properties *>()
1598+
->getValue()
1599+
.dyn_cast_or_null<ElementsAttr>();
15991600
if (!attr)
16001601
return failure();
16011602
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
@@ -1635,7 +1636,9 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
16351636
MLIRContext *context, std::optional<Location> location, ValueRange operands,
16361637
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
16371638
SmallVectorImpl<Type> &inferredReturnTypes) {
1638-
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
1639+
auto attr = properties.as<Properties *>()
1640+
->getValue()
1641+
.dyn_cast_or_null<ElementsAttr>();
16391642
if (!attr)
16401643
return failure();
16411644
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
@@ -2768,43 +2771,43 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
27682771

27692772
template <typename CalculateOp>
27702773
static void
2771-
getSuccessorRegionsForCalculateOp(CalculateOp op, std::optional<unsigned> index,
2774+
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
27722775
SmallVectorImpl<RegionSuccessor> &regions) {
2773-
if (!index.has_value()) {
2776+
if (!point.getRegionOrNull()) {
27742777
// First thing the op does is branch into the calculation.
27752778
regions.emplace_back(&op.getCalculation());
27762779
return;
27772780
}
2778-
if (*index == 0) {
2781+
if (point == op.getBody()) {
27792782
// Body returns control to the outer op, passing through results.
27802783
regions.emplace_back(op.getResults());
27812784
return;
27822785
}
2783-
assert(*index == 1);
2786+
assert(point == op.getCalculation());
27842787
// Calculation branches to the body.
27852788
regions.emplace_back(&op.getBody());
27862789
}
27872790

27882791
void ShapeCalculateOp::getSuccessorRegions(
2789-
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
2790-
getSuccessorRegionsForCalculateOp(*this, index, regions);
2792+
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2793+
getSuccessorRegionsForCalculateOp(*this, point, regions);
27912794
}
27922795

27932796
//===----------------------------------------------------------------------===//
27942797
// DtypeCalculateOp
27952798
//===----------------------------------------------------------------------===//
27962799

27972800
void DtypeCalculateOp::getSuccessorRegions(
2798-
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
2799-
getSuccessorRegionsForCalculateOp(*this, index, regions);
2801+
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2802+
getSuccessorRegionsForCalculateOp(*this, point, regions);
28002803
}
28012804

28022805
//===----------------------------------------------------------------------===//
28032806
// ShapeCalculateYieldShapesOp
28042807
//===----------------------------------------------------------------------===//
28052808

28062809
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
2807-
std::optional<unsigned> index) {
2810+
RegionBranchPoint point) {
28082811
// The shape operands don't get forwarded to the body.
28092812
// MutableOperandRange always has an owning operation, even if empty, so
28102813
// create a 0-length range.
@@ -2823,7 +2826,7 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
28232826
//===----------------------------------------------------------------------===//
28242827

28252828
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
2826-
std::optional<unsigned> index) {
2829+
RegionBranchPoint point) {
28272830
// The dtype operands don't get forwarded to the body.
28282831
// MutableOperandRange always has an owning operation, even if empty, so
28292832
// create a 0-length range.

0 commit comments

Comments
 (0)