Skip to content

Commit 0aec5fd

Browse files
author
Tanyo Kwok
committed
[MHLO] Init end-to-end unit tests
See RFC #999 Co-authored-by: Bairen Yi [email protected] Co-authored-by: Jiawei Wu [email protected] Co-authored-by: Tianyou Guo [email protected] Co-authored-by: Xu Yan [email protected] Co-authored-by: Ziheng Jiang [email protected]
1 parent c935795 commit 0aec5fd

File tree

9 files changed

+298
-4
lines changed

9 files changed

+298
-4
lines changed

e2e_testing/torchscript/main.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,35 @@
1515

1616
# Available test configs.
1717
from torch_mlir_e2e_test.torchscript.configs import (
18-
LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
18+
LazyTensorCoreTestConfig,
19+
LinalgOnTensorsBackendTestConfig,
20+
MhloBackendTestConfig,
21+
NativeTorchTestConfig,
22+
TorchScriptTestConfig,
23+
TosaBackendTestConfig,
24+
EagerModeTestConfig
1925
)
2026

2127
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
28+
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
2229
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
2330

24-
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
31+
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
2532

2633
# Import tests to register them in the global registry.
2734
from torch_mlir_e2e_test.test_suite import register_all_tests
2835
register_all_tests()
2936

3037
def _get_argparse():
31-
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core']
38+
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core']
3239
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
3340
parser.add_argument('-c', '--config',
3441
choices=config_choices,
3542
default='refbackend',
3643
help=f'''
3744
Meaning of options:
3845
"refbackend": run through torch-mlir's RefBackend.
46+
"mhlo": run through torch-mlir's default MHLO backend.
3947
"tosa": run through torch-mlir's default TOSA backend.
4048
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
4149
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
@@ -78,6 +86,9 @@ def main():
7886
if args.config == 'tosa':
7987
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
8088
xfail_set = all_test_unique_names - TOSA_PASS_SET
89+
if args.config == 'mhlo':
90+
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
91+
xfail_set = all_test_unique_names - MHLO_PASS_SET
8192
elif args.config == 'native_torch':
8293
config = NativeTorchTestConfig()
8394
xfail_set = {}

e2e_testing/torchscript/xfail_sets.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,121 @@
2121
"Matmul_vecmat"
2222
}
2323

24+
MHLO_PASS_SET = {
25+
"AvgPool2dStaticModule_basic",
26+
"ElementwiseCloneContiguousModule_basic",
27+
"ElementwiseCloneModule_basic",
28+
"ElementwiseBinaryStaticShapeModule_basic",
29+
"ReturnThreeTensorFloat32_basic",
30+
"BoolTensorReturnFalseModule_basic",
31+
"BoolTensorReturnTrueModule_basic",
32+
"BoolTensorReturnMixedModule_basic",
33+
"SqueezeModule_static",
34+
"TModuleRank1_basic",
35+
"TModuleRank0_basic",
36+
"ElementwiseToDtypeIdentityModule_basic",
37+
"View1DFoldModule_basic",
38+
"UnsafeView1DFoldModule_basic",
39+
"SqueezeDimModule_static",
40+
"SqueezeDimModule_identity",
41+
"SliceModule_basic",
42+
"SliceNegIdxModule_basic",
43+
"SliceOutOfLowerBoundStartIndexModule_basic",
44+
"SliceSizeTwoStepModule_basic",
45+
"SliceWholeTensorModule_basic",
46+
"ReturnTwoTensorF32I64_basic",
47+
"Matmul4dStatic_basic",
48+
"Matmul_dot",
49+
"Matmul_2d",
50+
"Matmul_matvec",
51+
"Matmul_vecmat",
52+
"MaxPool2dWithIndicesStaticModule_basic",
53+
"MmDagModule_basic",
54+
"MmModule_basic",
55+
"MmModule_chained",
56+
"MaxPool2dStaticModule_basic",
57+
"PermuteModule_basic",
58+
"PermuteNegativeIndexModule_basic",
59+
"ZerosModuleDefaultDtype_basic",
60+
"ZerosModuleInt2D_basic",
61+
"ZerosModuleInt3D_basic",
62+
"ZerosModuleFloat2D_basic",
63+
"ZerosModuleFloat3D_basic",
64+
"ZerosModuleFalsePinMemory_basic",
65+
"OnesModuleDefaultDtype_basic",
66+
"OnesModuleInt_basic",
67+
"OnesModuleFloat_basic",
68+
"OnesModuleFalsePinMemory_basic",
69+
"NewZerosModuleDefaultDtype_basic",
70+
"NewZerosModuleInt2D_basic",
71+
"NewZerosModuleInt3D_basic",
72+
"NewZerosModuleFloat2D_basic",
73+
"NewZerosModuleFloat3D_basic",
74+
"NewZerosModuleFalsePinMemory_basic",
75+
"NewOnesModuleDefaultDtype_basic",
76+
"NewOnesModuleInt2D_basic",
77+
"NewOnesModuleInt3D_basic",
78+
"NewOnesModuleFloat2D_basic",
79+
"NewOnesModuleFloat3D_basic",
80+
"NewOnesModuleFalsePinMemory_basic",
81+
"DropoutEvalIntModule_basic",
82+
"DropoutEvalFloatModule_basic",
83+
"ContiguousModule_basic",
84+
"DropoutModule_basic",
85+
"ViewCollapseModule_basic",
86+
"ViewCollapseInferredDimModule_basic",
87+
"ViewDynamicExpandCollapseModule_basic",
88+
"ViewDynamicExpandModule_basic",
89+
"ViewExpandModule_basic",
90+
"ViewExpandOnesModule_basic",
91+
"ViewExpandOnesBeforeAndAfterModule_basic",
92+
"ViewExpandOnesMiddleModule_basic",
93+
"ViewExpandCollapseModule_basic",
94+
"ViewExpandCollapseWithOnesModule_basic",
95+
"ViewExpandInferredDimModule_basic",
96+
"ViewNoChangeStaticModule_basic",
97+
"ViewNoChange1dModule_basic",
98+
"ViewNoChange2dModule_basic",
99+
"ViewNoChange3dModule_basic",
100+
"UnsafeViewExpandModule_basic",
101+
"ReduceMaxAllDims_basic",
102+
"ReduceMaxFloatModule_basic",
103+
"ReduceMaxSignedIntModule_basic",
104+
"ReduceMaxUnsignedIntModule_basic",
105+
"ReduceSumDimIntListFloatModule_basic",
106+
"ReduceSumDimIntListIntModule_basic",
107+
"ReduceSumFloatModule_basic",
108+
"ReduceSumSignedIntModule_basic",
109+
"ReduceSumUnsignedIntModule_basic",
110+
"RepeatModule_basic",
111+
"ReshapeAliasCollapseModule_basic",
112+
"ReshapeAliasExpandModule_basic",
113+
"ReshapeExpandModule_basic",
114+
"TestMultipleTensorReturn_basic",
115+
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
116+
"BaddbmmStaticModule_basic",
117+
"BaddbmmBroadcast1DInputModule_basic",
118+
"BaddbmmBroadcast2DInputModule_basic",
119+
"NarrowHorizontalTest2_basic",
120+
"NarrowHorizontalTest_basic",
121+
"NarrowVerticalTest2_basic",
122+
"NarrowVerticalTest_basic",
123+
"NumToTensorIntModule_basic",
124+
"NumpyTRank0Module_basic",
125+
"NumpyTRank1Module_basic",
126+
"NumpyTRank2Module_basic",
127+
"NumpyTRankNStaticModule_basic",
128+
"NumpyTRankNDynamicModule_basic",
129+
"TModuleRank2_basic",
130+
"TensorLiteralModule_basic",
131+
"TensorOpaqueLiteralModule_basic",
132+
"TransposeIntModule_basic",
133+
"TransposeIntNegDimsModule_basic",
134+
"Permute0RankModule_basic",
135+
"UnsafeViewCollapseModule_basic",
136+
"UnsafeViewDynamicExpandModule_basic",
137+
}
138+
24139
# Write the TOSA set as a "passing" set as it is very early in development
25140
# and very few tests work yet.
26141
TOSA_PASS_SET = {

lib/Conversion/Passes.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
#include "torch-mlir/Conversion/Passes.h"
1111

12+
#ifdef TORCH_MLIR_ENABLE_MHLO
13+
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
14+
#endif // TORCH_MLIR_ENABLE_MHLO
1215
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
1316
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
1417
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
@@ -25,4 +28,11 @@ namespace {
2528
#include "torch-mlir/Conversion/Passes.h.inc"
2629
} // end namespace
2730

28-
void mlir::torch::registerConversionPasses() { ::registerPasses(); }
31+
void mlir::torch::registerConversionPasses() {
32+
::registerPasses();
33+
#ifdef TORCH_MLIR_ENABLE_MHLO
34+
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
35+
return mlir::mhlo::createLegalizeHloToLinalgPass();
36+
});
37+
#endif // TORCH_MLIR_ENABLE_MHLO
38+
}

lib/Conversion/TorchToMhlo/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
1414
DEPENDS
1515
MhloDialect
1616
ChloDialect
17+
MhloToLinalg
18+
MLIRMhloPassIncGen
1719
TorchMLIRConversionPassIncGen
1820

1921
LINK_COMPONENTS
@@ -24,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
2426
MLIRPass
2527
MhloDialect
2628
ChloDialect
29+
MhloToLinalg
2730
TorchMLIRTorchDialect
2831
)
2932

python/torch_mlir_e2e_test/mhlo_backends/__init__.py

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import abc
7+
from typing import TypeVar
8+
9+
import torch
10+
11+
from torch_mlir.ir import Module
12+
13+
# A type shared between the result of `MhloBackend.compile` and the
14+
# input to `MhloBackend.load`. Each backend will likely have a
15+
# different definition of this type.
16+
CompiledArtifact = TypeVar('CompiledArtifact')
17+
18+
# A wrapper around a backend-specific loaded program representation
19+
# that uniformly translates the `x.method(...)` interface expected of
20+
# Torch modules into appropriate lower-level operations.
21+
Invoker = TypeVar('Invoker')
22+
23+
24+
class MhloBackend(abc.ABC):
25+
"""The interface to an MHLO backend.
26+
27+
Backends are recommended to raise meaningful exceptions in case of error,
28+
ideally with easy reproduction instructions.
29+
"""
30+
@abc.abstractmethod
31+
def compile(self, module: Module) -> CompiledArtifact:
32+
"""Compile the provided MLIR module into a compiled artifact.
33+
34+
The module adheres to the MHLO backend contract
35+
(see the VerifyMhloBackendContract pass).
36+
37+
The compiled artifact can be any type, but must be correctly
38+
interpreted by the `load` method.
39+
"""
40+
41+
@abc.abstractmethod
42+
def load(self, artifact: CompiledArtifact) -> Invoker:
43+
"""Load the compiled artifact into a uniformly invokable form.
44+
45+
The compiled artifact is the result of a previous call to `compile`.
46+
47+
See the description of `Invoker` for the requirements on the returned
48+
type.
49+
"""
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
from torch_mlir.ir import *
7+
from torch_mlir.passmanager import *
8+
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
9+
10+
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
11+
12+
from .abc import MhloBackend
13+
14+
__all__ = [
15+
"LinalgOnTensorsMhloBackend",
16+
]
17+
18+
class LinalgOnTensorsMhloBackend(MhloBackend):
19+
"""Main entry-point for the linalg-on-tensors based MHLO backend.
20+
21+
This currently uses the linalg-on-tensors RefBackend for actual execution.
22+
"""
23+
def __init__(self):
24+
super().__init__()
25+
self.refbackend = RefBackendLinalgOnTensorsBackend()
26+
27+
def compile(self, imported_module: Module):
28+
"""Compiles an imported module that satisfied the MHLO backend contract.
29+
30+
Args:
31+
imported_module: The MLIR module consisting of funcs in the MHLO
32+
dialect.
33+
Returns:
34+
An opaque, backend specific compiled artifact object that can be
35+
passed to `load`.
36+
"""
37+
run_pipeline_with_repro_report(
38+
imported_module,
39+
"func.func(hlo-legalize-to-linalg)",
40+
"Lowering MLIR-HLO to Linalg-on-Tensors")
41+
return self.refbackend.compile(imported_module)
42+
43+
def load(self, module):
44+
"""Loads a compiled artifact into the runtime."""
45+
return self.refbackend.load(module)

python/torch_mlir_e2e_test/torchscript/configs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@
77
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
88
from .native_torch import NativeTorchTestConfig
99
from .torchscript import TorchScriptTestConfig
10+
from .mhlo_backend import MhloBackendTestConfig
1011
from .tosa_backend import TosaBackendTestConfig
1112
from .eager_mode import EagerModeTestConfig
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import sys
7+
from typing import Any
8+
from io import StringIO
9+
import os
10+
import tempfile
11+
12+
import numpy as np
13+
import torch
14+
15+
from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend
16+
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
17+
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
18+
from .utils import (
19+
recursively_convert_to_numpy,
20+
recursively_convert_from_numpy,
21+
convert_torchscript_module_to_torch_backend_contract_mlir,
22+
)
23+
24+
25+
class MhloBackendTestConfig(TestConfig):
26+
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
27+
28+
This class handles all the common lowering that torch-mlir does before
29+
reaching the linalg-on-tensors abstraction level.
30+
"""
31+
def __init__(self, backend: MhloBackend):
32+
super().__init__()
33+
self.backend = backend
34+
35+
def compile(self, program: torch.nn.Module) -> Any:
36+
37+
module = convert_torchscript_module_to_torch_backend_contract_mlir(
38+
program)
39+
40+
run_pipeline_with_repro_report(
41+
module,
42+
"torch-backend-to-mhlo-backend-pipeline",
43+
"Lower Torch Backend IR -> MHLO Backend IR")
44+
45+
return self.backend.compile(module)
46+
47+
48+
49+
def run(self, artifact: Any, trace: Trace) -> Trace:
50+
backend_module = self.backend.load(artifact)
51+
result: Trace = []
52+
for item in trace:
53+
numpy_inputs = recursively_convert_to_numpy(item.inputs)
54+
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
55+
output = recursively_convert_from_numpy(outputs)
56+
result.append(
57+
TraceItem(symbol=item.symbol,
58+
inputs=item.inputs,
59+
output=output))
60+
return result

0 commit comments

Comments
 (0)