Skip to content

Commit e52e886

Browse files
authored
build: update llvm tag to 00d648b (#1307)
- Update MHLO commit to build with LLVM commit hash 00d648b - Update TorchToMhlo code to work with Stablehlo - Re-enabled two failing TOSA tests, thus resolving Github Issue #1231
1 parent 51ef1b1 commit e52e886

File tree

9 files changed

+31
-28
lines changed

9 files changed

+31
-28
lines changed

e2e_testing/xfail_sets.py

+2
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
"ElementwiseBinaryModule_basic",
164164
"ElementwiseSigmoidModule_basic",
165165
"ElementwiseExpModule_basic",
166+
"ElementwiseReluModule_basic",
166167
"ElementwiseFloorModule_basic",
167168
"ElementwiseLogModule_basic",
168169
"ElementwiseBinaryStaticShapeModule_basic",
@@ -237,6 +238,7 @@
237238
"ElementwiseFlattenBroadcastModule_basic",
238239
"SquareModule_basic",
239240
"MaxPool2dStaticModule_basic",
241+
"ResNet18StaticModule_basic",
240242
"NativeLayerNormModule4D_basic",
241243
"LayerNormNormalizeOverAllDimsModule_basic",
242244
"PermuteModule_basic",

externals/llvm-project

Submodule llvm-project updated 5399 files

externals/mlir-hlo

Submodule mlir-hlo updated 201 files

lib/Conversion/TorchToMhlo/Basic.cpp

+16-15
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
#include "../PassDetail.h"
1313
#include "./MhloLegalizeUtils.h"
1414
#include "./PopulatePatterns.h"
15-
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
1615
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
16+
#include "mlir-hlo/utils/hlo_utils.h"
1717
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1818
#include "mlir/Dialect/Tensor/IR/Tensor.h"
19+
#include "stablehlo/dialect/ChloOps.h"
1920
#include "torch-mlir/Conversion/Utils/Utils.h"
2021
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
2122
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@@ -291,33 +292,33 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
291292
// TODO: what is the PyTorch default type promotion?
292293
rhs = mhlo::promoteType(rewriter, rhs, lhsTy);
293294

294-
mhlo::ComparisonTypeAttr compareTypeAttr;
295-
mhlo::ComparisonDirectionAttr compareDirectionAttr;
295+
chlo::ComparisonTypeAttr compareTypeAttr;
296+
chlo::ComparisonDirectionAttr compareDirectionAttr;
296297

297298
if (lhsElemTy.isa<mlir::FloatType>()) {
298-
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
299-
op->getContext(), mhlo::ComparisonType::FLOAT);
299+
compareTypeAttr = chlo::ComparisonTypeAttr::get(
300+
op->getContext(), chlo::ComparisonType::FLOAT);
300301
} else if (lhsElemTy.isa<mlir::IntegerType>()) {
301-
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
302-
op->getContext(), mhlo::ComparisonType::SIGNED);
302+
compareTypeAttr = chlo::ComparisonTypeAttr::get(
303+
op->getContext(), chlo::ComparisonType::SIGNED);
303304
}
304305

305306
if (std::is_same<AtenOpT, AtenLtTensorOp>() ||
306307
std::is_same<AtenOpT, AtenLtScalarOp>()) {
307-
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
308-
op->getContext(), mhlo::ComparisonDirection::LT);
308+
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
309+
op->getContext(), chlo::ComparisonDirection::LT);
309310
} else if (std::is_same<AtenOpT, AtenGtTensorOp>() ||
310311
std::is_same<AtenOpT, AtenGtScalarOp>()) {
311-
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
312-
op->getContext(), mhlo::ComparisonDirection::GT);
312+
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
313+
op->getContext(), chlo::ComparisonDirection::GT);
313314
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
314315
std::is_same<AtenOpT, AtenEqScalarOp>()) {
315-
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
316-
op->getContext(), mhlo::ComparisonDirection::EQ);
316+
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
317+
op->getContext(), chlo::ComparisonDirection::EQ);
317318
} else if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
318319
std::is_same<AtenOpT, AtenNeScalarOp>()) {
319-
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
320-
op->getContext(), mhlo::ComparisonDirection::NE);
320+
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
321+
op->getContext(), chlo::ComparisonDirection::NE);
321322
}
322323
DenseIntElementsAttr bcastDimensions;
323324
rewriter.replaceOpWithNewOp<chlo::BroadcastCompareOp>(

lib/Conversion/TorchToMhlo/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
1313

1414
DEPENDS
1515
MhloDialect
16-
ChloDialect
1716
MhloToLinalg
1817
MLIRMhloPassIncGen
1918
TorchMLIRConversionPassIncGen
@@ -22,11 +21,12 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
2221
Core
2322

2423
LINK_LIBS PUBLIC
24+
ChloOps
2525
MLIRIR
2626
MLIRPass
2727
MhloDialect
28-
ChloDialect
2928
MhloToLinalg
29+
StablehloBase
3030
TorchMLIRTorchDialect
3131
)
3232

lib/Conversion/TorchToMhlo/Linear.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
#include "../PassDetail.h"
1313
#include "./MhloLegalizeUtils.h"
1414
#include "./PopulatePatterns.h"
15-
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
1615
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
1716
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1817
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18+
#include "stablehlo/dialect/ChloOps.h"
1919
#include "torch-mlir/Conversion/Utils/Utils.h"
2020
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
2121
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

lib/Conversion/TorchToMhlo/Pooling.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
#include "../PassDetail.h"
1313
#include "./MhloLegalizeUtils.h"
1414
#include "./PopulatePatterns.h"
15-
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
1615
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
1716
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1817
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18+
#include "stablehlo/dialect/ChloOps.h"
1919
#include "torch-mlir/Conversion/Utils/Utils.h"
2020
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
2121
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

lib/Conversion/TorchToMhlo/TorchToMhlo.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111

1212
#include "../PassDetail.h"
1313
#include "./PopulatePatterns.h"
14-
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
1514
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
1615
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1716
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1817
#include "mlir/Dialect/Traits.h"
1918
#include "mlir/IR/Matchers.h"
2019
#include "mlir/Transforms/DialectConversion.h"
20+
#include "stablehlo/dialect/ChloOps.h"
2121
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
2222
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
2323
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

test/Conversion/TorchToMhlo/elementwise.mlir

+6-6
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
372372
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
373373
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
374374
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
375-
// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
375+
// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
376376
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
377377
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1>
378378
func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
@@ -387,7 +387,7 @@ func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
387387
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
388388
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
389389
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
390-
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
390+
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
391391
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
392392
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
393393
func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
@@ -401,7 +401,7 @@ func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
401401
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
402402
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
403403
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
404-
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
404+
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction LT>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
405405
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
406406
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
407407
func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
@@ -415,7 +415,7 @@ func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
415415
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
416416
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
417417
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
418-
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction EQ>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
418+
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction EQ>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
419419
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
420420
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
421421
func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
@@ -429,7 +429,7 @@ func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.
429429
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
430430
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
431431
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
432-
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction NE>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
432+
// CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction NE>} : (tensor<?x?xf32>, tensor<64xf32>) -> tensor<?x?xi1>
433433
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
434434
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
435435
func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> {
@@ -553,7 +553,7 @@ func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
553553
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
554554
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
555555
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
556-
// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #mhlo<comparison_type FLOAT>, comparison_direction = #mhlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
556+
// CHECK: %[[T5:.*]] = chlo.broadcast_compare %[[T0]], %[[T4]] {compare_type = #chlo<comparison_type FLOAT>, comparison_direction = #chlo<comparison_direction GT>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
557557
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
558558
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],i1>
559559
func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,?],i1> {

0 commit comments

Comments
 (0)