Skip to content

Commit 52fb808

Browse files
ExecutionSession constructor without entry function name (llvm#1163)
* ExecutionSession constructor without entry function name Signed-off-by: Tung D. Le <[email protected]> * clang-format Signed-off-by: Tung D. Le <[email protected]> * Do not use run_main_graph in ExecutionSession Signed-off-by: Tung D. Le <[email protected]> * Revise Signed-off-by: Tung D. Le <[email protected]> * Add omEntryPointName to pass Windows CI Signed-off-by: Tung D. Le <[email protected]> * Address review comments Signed-off-by: Tung D. Le <[email protected]> * Use run_main_graph as the default entry point Signed-off-by: Tung D. Le <[email protected]> * undo numerical.def Signed-off-by: Tung D. Le <[email protected]> * Change messages Signed-off-by: Tung D. Le <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
1 parent be32233 commit 52fb808

17 files changed

+37
-31
lines changed

docs/UsingPyRuntime.md

+2-4
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,10 @@ The complete interface to ExecutionSession can be seen in the sources mentioned
3434
using the constructor and run method is enough to perform inferences.
3535

3636
```python
37-
def __init__(self, path: str, entry_point: str):
37+
def __init__(self, path: str):
3838
"""
3939
Args:
4040
path: relative or absolute path to your .so model.
41-
entry_point: function generated by onnx-mlir to call inferences.
42-
Use 'run_main_graph'.
4341
"""
4442

4543
def run(self, input: List[ndarray]) -> List[ndarray]:
@@ -72,7 +70,7 @@ import numpy as np
7270
from PyRuntime import ExecutionSession
7371

7472
model = 'model.so' # LeNet from ONNX Zoo compiled with onnx-mlir
75-
session = ExecutionSession(model, "run_main_graph")
73+
session = ExecutionSession(model)
7674
print("input signature in json", session.input_signature())
7775
print("output signature in json",session.output_signature())
7876
input = np.full((1, 1, 28, 28), 1, np.dtype(np.float32))

docs/mnist_example/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ The runtime use an `ExecutionSession` object to hold a specific model and entry
198198
``` Python
199199
# Load the model mnist.so compiled with onnx-mlir.
200200
model = 'mnist.so'
201-
session = ExecutionSession(model, "run_main_graph")
201+
session = ExecutionSession(model)
202202
# Print the models input/output signature, for display.
203203
# If there are problems with the signature functions, they can be simply commented out.
204204
print("input signature in json", session.input_signature())

docs/mnist_example/mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# Load the model mnist.so compiled with onnx-mlir.
55
model = './mnist.so'
6-
session = ExecutionSession(model, "run_main_graph")
6+
session = ExecutionSession(model)
77
# Print the models input/output signature, for display.
88
# Signature functions for info only, commented out if they cause problems.
99
print("input signature in json", session.input_signature())

src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -1194,14 +1194,19 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
11941194
KrnlEntryPointOp::getEntryPointFuncAttrName())
11951195
.getLeafReference()
11961196
.getValue();
1197-
auto dynEntryPointName = "run_" + staticEntryPointFuncName;
1198-
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
1199-
"dynamic entry point name is not unique");
1197+
1198+
// When there is only a single entry point function in a model, use
1199+
// "run_main_graph" as the default name.
1200+
// TODO(tung): support multiple entry point functions.
1201+
std::string entryPointName = "run_main_graph";
1202+
assert(module.lookupSymbol(entryPointName) == nullptr &&
1203+
"Only support a single entry point function.");
1204+
12001205
rewriter.eraseOp(op);
12011206
auto dynEntryPointFuncTy =
12021207
LLVM::LLVMFunctionType::get(opaquePtrTy, {opaquePtrTy}, false);
12031208
auto dynamicEntryPointFunc = rewriter.create<LLVM::LLVMFuncOp>(
1204-
loc, dynEntryPointName.str(), dynEntryPointFuncTy);
1209+
loc, entryPointName, dynEntryPointFuncTy);
12051210
auto &entryPointEntryBlock =
12061211
createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc, loc);
12071212
rewriter.setInsertionPointToStart(&entryPointEntryBlock);

src/Runtime/ExecutionSession.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ namespace onnx_mlir {
2525
const std::string ExecutionSession::_inputSignatureName = "omInputSignature";
2626
const std::string ExecutionSession::_outputSignatureName = "omOutputSignature";
2727

28+
ExecutionSession::ExecutionSession(std::string sharedLibPath)
29+
: ExecutionSession::ExecutionSession(sharedLibPath, "") {}
30+
2831
ExecutionSession::ExecutionSession(
2932
std::string sharedLibPath, std::string entryPointName) {
3033

@@ -36,6 +39,11 @@ ExecutionSession::ExecutionSession(
3639
throw std::runtime_error(errStr.str());
3740
}
3841

42+
// When entry point name is not given, use the default "run_main_graph".
43+
// TODO(tung): support multiple entry point functions.
44+
if (entryPointName.empty())
45+
entryPointName = "run_main_graph";
46+
3947
_entryPointFunc = reinterpret_cast<entryPointFuncType>(
4048
_sharedLibraryHandle.getAddressOfSymbol(entryPointName.c_str()));
4149
if (!_entryPointFunc) {

src/Runtime/ExecutionSession.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ using OMTensorUniquePtr = std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>;
3131

3232
class ExecutionSession {
3333
public:
34+
ExecutionSession(std::string sharedLibPath);
3435
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
3536

3637
// Use custom deleter since forward declared OMTensor hides destructor

src/Runtime/PyExecutionSession.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ namespace onnx_mlir {
2525

2626
class PyExecutionSession : public onnx_mlir::ExecutionSession {
2727
public:
28+
PyExecutionSession(std::string sharedLibPath)
29+
: onnx_mlir::ExecutionSession(sharedLibPath) {}
2830
PyExecutionSession(std::string sharedLibPath, std::string entryPointName)
29-
: onnx_mlir::ExecutionSession(sharedLibPath, entryPointName){};
31+
: onnx_mlir::ExecutionSession(sharedLibPath, entryPointName) {}
3032

3133
std::vector<py::array> pyRun(const std::vector<py::array> &inputsPyArray);
3234

@@ -37,6 +39,7 @@ class PyExecutionSession : public onnx_mlir::ExecutionSession {
3739

3840
PYBIND11_MODULE(PyRuntime, m) {
3941
py::class_<onnx_mlir::PyExecutionSession>(m, "ExecutionSession")
42+
.def(py::init<const std::string &>())
4043
.def(py::init<const std::string &, const std::string &>())
4144
.def("run", &onnx_mlir::PyExecutionSession::pyRun)
4245
.def("input_signature", &onnx_mlir::PyExecutionSession::pyInputSignature)

test/backend/inference_backend.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,6 @@ def JniExecutionSession(jar_name, inputs):
972972
class EndiannessAwareExecutionSession(object):
973973
def __init__(self, model):
974974
self.model = model
975-
self.entry_point = "run_main_graph"
976975
self.exec_name = None
977976
# Compiling the model in advance if not testing constants, so that
978977
# the model is compiled once and used multiple times.
@@ -1060,7 +1059,7 @@ def run(self, inputs, **kwargs):
10601059
inputs = self.turn_model_input_to_constant(inputs)
10611060
self.exec_name = compile_model(self.model, args.emit)
10621061
if args.emit == "lib":
1063-
session = ExecutionSession(self.exec_name, self.entry_point)
1062+
session = ExecutionSession(self.exec_name)
10641063
outputs = session.run(inputs)
10651064
# print('input='+str(inputs), file=sys.stderr)
10661065
# print('output='+str(outputs), file=sys.stderr)
@@ -1079,7 +1078,7 @@ def run(self, inputs, **kwargs):
10791078
"Cannot deduce desired output endianness, using native endianness by default."
10801079
)
10811080
if args.emit == "lib":
1082-
session = ExecutionSession(self.exec_name, self.entry_point)
1081+
session = ExecutionSession(self.exec_name)
10831082
outputs = session.run(inputs)
10841083
elif args.emit == "jni":
10851084
outputs = JniExecutionSession(self.exec_name, inputs)

test/backend/signature_backend.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,13 @@ def run(test_self, device): # type: (Any, Text) -> None
166166
class SignatureExecutionSession(object):
167167
def __init__(self, model):
168168
self.model = model
169-
self.entry_point = "run_main_graph"
170169
self.exec_name = compile_model(self.model, args.emit)
171170

172171
def run(self, **kwargs):
173172
sys.path.append(RUNTIME_DIR)
174173
from PyRuntime import ExecutionSession
175174

176-
session = ExecutionSession(self.exec_name, self.entry_point)
175+
session = ExecutionSession(self.exec_name)
177176
output = session.input_signature()
178177
return output
179178

test/numerical/TestConv.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,7 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
141141
/*output conv param*/ NOut, COut, HOut, WOut))
142142
return false;
143143

144-
onnx_mlir::ExecutionSession sess(
145-
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
144+
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));
146145

147146
std::vector<OMTensorUniquePtr> inputs;
148147
auto xOmt = OMTensorUniquePtr(

test/numerical/TestGRU.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ bool isOMGRUTheSameAsNaiveImplFor(const int direction, const int S, const int B,
4646
/* GRU param out*/
4747
D, xShape, hShape, wOmt, rOmt, bOmt))
4848
return false;
49-
onnx_mlir::ExecutionSession sess(
50-
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
49+
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));
5150

5251
std::vector<OMTensorUniquePtr> inputs;
5352
auto xOmt = OMTensorUniquePtr(

test/numerical/TestGemm.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ bool isOMGemmTheSameAsNaiveImplFor(const int I, const int J, const int K,
7676
aShape, bShape, cShape))
7777
return false;
7878

79-
onnx_mlir::ExecutionSession sess(
80-
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
79+
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));
8180

8281
std::vector<OMTensorUniquePtr> inputs;
8382
auto aOmt = OMTensorUniquePtr(

test/numerical/TestLSTM.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ bool isOMLSTMTheSameAsNaiveImplFor(const int direction, const int S,
4848
D, xShape, hShape, cShape, wOmt, rOmt, bOmt, pOmt))
4949
return false;
5050

51-
onnx_mlir::ExecutionSession sess(
52-
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
51+
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));
5352

5453
std::vector<OMTensorUniquePtr> inputs;
5554
auto xOmt = OMTensorUniquePtr(

test/numerical/TestLoop.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ bool isOMLoopTheSameAsNaiveImplFor(std::string moduleIR,
8787
auto module = mlir::parseSourceString(moduleIR, &ctx);
8888
OwningModuleRef moduleRef(std::move(module));
8989
compileModule(moduleRef, ctx, SHARED_LIB_BASE.str(), onnx_mlir::EmitLib);
90-
onnx_mlir::ExecutionSession sess(
91-
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
90+
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));
9291

9392
std::vector<OMTensorUniquePtr> inputs;
9493
auto tripCountTensor = OMTensorUniquePtr(

test/numerical/TestMatMul2D.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ bool isOMMatmulTheSameAsNaiveImplFor(const int I, const int J, const int K) {
3939
I, J, K))
4040
return false;
4141

42-
onnx_mlir::ExecutionSession sess(
43-
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
42+
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));
4443

4544
std::vector<OMTensorUniquePtr> inputs;
4645
auto aOmt = OMTensorUniquePtr(

test/numerical/TestRNN.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ bool isOMRNNTheSameAsNaiveImplFor(const int direction, const int S, const int B,
4747
D, xShape, hShape, wOmt, rOmt, bOmt))
4848
return false;
4949

50-
onnx_mlir::ExecutionSession sess(
51-
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
50+
onnx_mlir::ExecutionSession sess(getSharedLibName(SHARED_LIB_BASE.str()));
5251

5352
std::vector<OMTensorUniquePtr> inputs;
5453
auto xOmt = OMTensorUniquePtr(

utils/RunONNXModel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def main():
306306
# Use the generated shared library to create an execution session.
307307
print("Loading the compiled model ...")
308308
start = time.perf_counter()
309-
sess = ExecutionSession(shared_lib_path, "run_main_graph")
309+
sess = ExecutionSession(shared_lib_path)
310310
end = time.perf_counter()
311311
print(" took ", end - start, " seconds.\n")
312312

0 commit comments

Comments
 (0)