Skip to content

Commit 84d345c

Browse files
authored
build: update llvm tag to 2dde4ba (#1229)
Summary of changes: - Tensor dialect now sets `emitAccessorPrefix` to prefixed, thus requring updates to methods that retrieve arguments [https://reviews.llvm.org/D131361] - Update MHLO to build with LLVM commit hash 2dde4ba - Replace `AbsOp` with `AbsFOp` [https://reviews.llvm.org/D131325] - Replace deprecated `getValue()` with `value()` [https://reviews.llvm.org/D131349] - Remove `AnalysisState::defaultInitialize()` [https://reviews.llvm.org/D131746] - Update MHLO MLIR tests to use the updated assembly format - Disabled two failing TOSA tests (Github Issue link: #1231)
1 parent 3b3cb99 commit 84d345c

File tree

18 files changed

+75
-85
lines changed

18 files changed

+75
-85
lines changed

e2e_testing/torchscript/xfail_sets.py

-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"ElementwiseBinaryModule_basic",
2929
"ElementwiseSigmoidModule_basic",
3030
"ElementwiseExpModule_basic",
31-
"ElementwiseReluModule_basic",
3231
"ElementwiseFloorModule_basic",
3332
"ElementwiseLogModule_basic",
3433
"ElementwiseBinaryStaticShapeModule_basic",
@@ -103,7 +102,6 @@
103102
"ElementwiseFlattenBroadcastModule_basic",
104103
"SquareModule_basic",
105104
"MaxPool2dStaticModule_basic",
106-
"ResNet18StaticModule_basic",
107105
"NativeLayerNormModule4D_basic",
108106
"LayerNormNormalizeOverAllDimsModule_basic",
109107
"PermuteModule_basic",

externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
516516
for (OpOperand *opOperand : op.getInputOperands()) {
517517
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
518518
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
519-
? tensorCastOp.source()
519+
? tensorCastOp.getSource()
520520
: opOperand->get());
521521
}
522522
// Init tensors may fold, in which case the resultType must also change.

externals/llvm-project

Submodule llvm-project updated 3202 files

externals/mlir-hlo

Submodule mlir-hlo updated 96 files

lib/Conversion/TorchToLinalg/DataMovement.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
576576
rewriter
577577
.create<tensor::CollapseShapeOp>(loc, intermediateResultType,
578578
castedInput, inputAssociations)
579-
.result();
579+
.getResult();
580580
}
581581

582582
if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) {
@@ -588,7 +588,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
588588
expandedInput.has_value() ? expandedInput.value()
589589
: castedInput,
590590
outputAssociations)
591-
.result();
591+
.getResult();
592592
}
593593

594594
Value result = collapsedInput.has_value() ? collapsedInput.value()

lib/Conversion/TorchToLinalg/Reduction.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
239239
Value elem = payloadArgs[0];
240240
Value result = payloadArgs[1];
241241
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
242-
auto abs = b.create<math::AbsOp>(loc, self);
242+
auto abs = b.create<math::AbsFOp>(loc, self);
243243
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
244244
Value ord = convertScalarToDtype(b, loc, adaptor.ord(), resultElementType);
245245
auto pow = b.create<math::PowFOp>(loc, abs, ord);

lib/Conversion/TorchToLinalg/Uncategorized.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
210210
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
211211
}
212212
if (isa<AtenAbsOp>(op))
213-
return b.create<math::AbsOp>(loc, payloadArgs[0]);
213+
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
214214
if (isa<AtenSigmoidOp>(op)) {
215215
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
216216
b, converter, payloadArgs[0], op);

lib/Conversion/TorchToMhlo/Basic.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -1063,9 +1063,6 @@ LogicalResult ConvertAtenOp<ValsemVariantAtenUniformOp>::matchAndRewrite(
10631063
op.getLoc(),
10641064
rewriter.getFloatAttr(inputTy.getElementType(), toDoubleValue));
10651065

1066-
auto outType = getTypeConverter()
1067-
->convertType(op.getType())
1068-
.template dyn_cast<TensorType>();
10691066
rewriter.replaceOpWithNewOp<mhlo::RngOp>(
10701067
op, inputTy, fromTensor, toTensor, mhloShape, mhlo::RngDistribution::UNIFORM);
10711068
return success();

lib/Conversion/TorchToMhlo/Linear.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
531531
std::copy(outputPadding.begin(), outputPadding.end(),
532532
edgePaddingHighVec.begin() + 2);
533533
Value paddingValue =
534-
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).getValue();
534+
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
535535
paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy);
536536
mlir::DenseIntElementsAttr edgePaddingLow =
537537
rewriter.getI64VectorAttr(edgePaddingLowVec);

lib/Conversion/TorchToMhlo/Reduction.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
8787
if (!initValue) return llvm::None;
8888
Value initIndex;
8989
if (mlir::mhlo::kMhloDimSizeBits == 32) {
90-
initIndex =
91-
mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).getValue();
90+
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
9291
} else {
93-
initIndex =
94-
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
92+
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
9593
}
9694

9795
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(

lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp

+3-6
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,16 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
9494
/// unsafe
9595
class InlineGlobalSlotsAnalysisState : public AnalysisState {
9696
public:
97-
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {}
97+
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
98+
setSafe();
99+
}
98100

99101
bool isUninitialized() const override {
100102
// We are an optimistic analysis, so we are always default initialized to
101103
// the optimistic "assumed safe" state.
102104
return false;
103105
}
104106

105-
ChangeResult defaultInitialize() override {
106-
// We are an optimistic analysis, so the default state is always "safe".
107-
return setSafe();
108-
}
109-
110107
void print(raw_ostream &os) const override {
111108
os << "InlineGlobalSlotsAnalysisState(" << (isSafe ? "safe" : "unsafe")
112109
<< ")";

test/Conversion/TorchToMhlo/basic.mlir

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
88
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
99
// CHECK: %[[VAL_2:.*]] = torch.constant.none
10-
// CHECK: %[[VAL_3:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
10+
// CHECK: %[[VAL_3:.*]] = mhlo.copy %[[VAL_1]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
1111
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
1212
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
1313
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@@ -47,7 +47,7 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
4747
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
4848
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
4949
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64>
50-
// CHECK: %[[T3:.*]] = "mhlo.reshape"(%[[T2]]) : (tensor<1xi64>) -> tensor<i64>
50+
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
5151
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
5252
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
5353
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
@@ -229,16 +229,16 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
229229
// CHECK: %true = torch.constant.bool true
230230
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
231231
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64>
232-
// CHECK: %[[VAL_6:.*]] = "mhlo.dynamic_reshape"(%[[VAL_1]], %[[VAL_5]]) : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
232+
// CHECK: %[[VAL_6:.*]] = mhlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
233233
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32>
234234
// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32>
235235
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
236236
// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
237-
// CHECK: %[[VAL_13:.*]] = "mhlo.dynamic_reshape"(%[[VAL_9]], %[[VAL_12]]) : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
237+
// CHECK: %[[VAL_13:.*]] = mhlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
238238
// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
239-
// CHECK: %[[VAL_15:.*]] = "mhlo.dynamic_reshape"(%[[VAL_10]], %[[VAL_14]]) : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
239+
// CHECK: %[[VAL_15:.*]] = mhlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
240240
// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
241-
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_reshape"(%[[VAL_11]], %[[VAL_16]]) : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
241+
// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
242242
// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
243243
// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
244244
// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>

test/Conversion/TorchToMhlo/dropout.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64
1111
// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64
1212
// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64>
13-
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf64>) -> tensor<f64>
13+
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf64>) -> tensor<f64>
1414
// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor<?x?xf32>) -> tensor<?x?xf64>
1515
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor<?x?xf64>
1616
// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64
@@ -33,7 +33,7 @@
3333
// CHECK: shape.assuming_yield %[[T19]] : tensor<?x?xf32>
3434
// CHECK: }
3535
// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32>
36-
// CHECK: %[[T21:.*]] = "mhlo.reshape"(%[[T20]]) : (tensor<1xf32>) -> tensor<f32>
36+
// CHECK: %[[T21:.*]] = mhlo.reshape %[[T20]] : (tensor<1xf32>) -> tensor<f32>
3737
// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor<?x?xf32> -> tensor<2xindex>
3838
// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
3939
// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor<?x?xf32>
@@ -44,4 +44,4 @@ func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %ar
4444
%bool_true = torch.constant.bool true
4545
%result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>
4646
return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>
47-
}
47+
}

0 commit comments

Comments
 (0)