Skip to content

Commit be32233

Browse files
Exec session updates (llvm#1173)
use ONTensorUniquePtr and add run with OMTensorList Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent a43f9b4 commit be32233

12 files changed

+61
-62
lines changed

src/Runtime/ExecutionSession.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ ExecutionSession::ExecutionSession(
6363
}
6464
}
6565

66-
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>>
67-
ExecutionSession::run(
68-
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>> ins) {
66+
std::vector<OMTensorUniquePtr> ExecutionSession::run(
67+
std::vector<OMTensorUniquePtr> ins) {
6968

7069
std::vector<OMTensor *> omts;
7170
for (const auto &inOmt : ins)
@@ -74,15 +73,21 @@ ExecutionSession::run(
7473

7574
auto *wrappedOutput = _entryPointFunc(wrappedInput);
7675

77-
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>> outs;
76+
std::vector<OMTensorUniquePtr> outs;
7877

7978
for (int64_t i = 0; i < omTensorListGetSize(wrappedOutput); i++) {
80-
outs.emplace_back(std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
79+
outs.emplace_back(OMTensorUniquePtr(
8180
omTensorListGetOmtByIndex(wrappedOutput, i), omTensorDestroy));
8281
}
8382
return outs;
8483
}
8584

85+
// Run using public interface. Explicit calls are needed to free tensor & tensor
86+
// lists.
87+
OMTensorList *ExecutionSession::run(OMTensorList *input) {
88+
return _entryPointFunc(input);
89+
}
90+
8691
std::string ExecutionSession::inputSignature() { return _inputSignatureFunc(); }
8792

8893
std::string ExecutionSession::outputSignature() {

src/Runtime/ExecutionSession.hpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,18 @@ namespace onnx_mlir {
2727
typedef OMTensorList *(*entryPointFuncType)(OMTensorList *);
2828
typedef const char *(*signatureFuncType)();
2929

30+
using OMTensorUniquePtr = std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>;
31+
3032
class ExecutionSession {
3133
public:
3234
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
3335

3436
// Use custom deleter since forward declared OMTensor hides destructor
35-
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>> run(
36-
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>>);
37+
std::vector<OMTensorUniquePtr> run(std::vector<OMTensorUniquePtr>);
38+
39+
// Run using public interface. Explicit calls are needed to free tensor &
40+
// tensor lists.
41+
OMTensorList *run(OMTensorList *input);
3742

3843
// Get input and output signature as a Json string. For example for nminst:
3944
// `[ { "type" : "f32" , "dims" : [1 , 1 , 28 , 28] , "name" : "image" } ]`

test/backend-cpp/ModelBuilder.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ bool ModelBuilder::compileTest(const CompilerOptionList &compileOptions) {
5858
compileModule(modRef, ctx, sharedLibBaseName, onnx_mlir::EmitLib) == 0);
5959
}
6060

61-
bool ModelBuilder::runAndVerifyTest(std::vector<OMTensorPtr> &inputs,
62-
std::vector<OMTensorPtr> &expectedOutputs,
61+
bool ModelBuilder::runAndVerifyTest(std::vector<OMTensorUniquePtr> &inputs,
62+
std::vector<OMTensorUniquePtr> &expectedOutputs,
6363
std::function<bool(OMTensor *, OMTensor *)> verifyFunction) {
6464
assert(!inputs.empty() && "Expecting valid inputs");
6565

@@ -72,8 +72,8 @@ bool ModelBuilder::runAndVerifyTest(std::vector<OMTensorPtr> &inputs,
7272

7373
// Verify the result(s).
7474
for (size_t i = 0; i < outputs.size(); ++i) {
75-
OMTensorPtr &output = outputs.at(i);
76-
OMTensorPtr &expectedOutput = expectedOutputs.at(i);
75+
OMTensorUniquePtr &output = outputs.at(i);
76+
OMTensorUniquePtr &expectedOutput = expectedOutputs.at(i);
7777
if (!verifyFunction(output.get(), expectedOutput.get()))
7878
return false;
7979
}

test/backend-cpp/ModelBuilder.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
using namespace std;
3030
using namespace mlir;
31-
using OMTensorPtr = unique_ptr<OMTensor, decltype(&omTensorDestroy)>;
3231
namespace BackendCppTests {
3332

3433
// Helper class containing useful functions for creating, compiling and running
@@ -65,8 +64,8 @@ class ModelBuilder {
6564
// Run the model and verify the result(s). The \p verifyFunction parameter
6665
// is used to pass in the function object used to verify the correctness of
6766
// the test result.
68-
bool runAndVerifyTest(std::vector<OMTensorPtr> &inputs,
69-
std::vector<OMTensorPtr> &expectedOutputs,
67+
bool runAndVerifyTest(std::vector<onnx_mlir::OMTensorUniquePtr> &inputs,
68+
std::vector<onnx_mlir::OMTensorUniquePtr> &expectedOutputs,
7069
std::function<bool(OMTensor *, OMTensor *)> verifyFunction);
7170

7271
void reset();

test/backend-cpp/TestCategoryMapper.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ class CategoryMapperTester {
6464
}
6565

6666
// Run the test and verify the result.
67-
std::vector<OMTensorPtr> inputOMTs, expectedOutputOMTs;
68-
auto inputOMT = OMTensorPtr(
67+
std::vector<onnx_mlir::OMTensorUniquePtr> inputOMTs, expectedOutputOMTs;
68+
auto inputOMT = onnx_mlir::OMTensorUniquePtr(
6969
omTensorCreate(static_cast<void *>(const_cast<int64_t *>(input.data())),
7070
inputShape, 1 /*rank*/, ONNX_TYPE_INT64),
7171
omTensorDestroy);
72-
auto expectedOutputOMT = OMTensorPtr(
72+
auto expectedOutputOMT = onnx_mlir::OMTensorUniquePtr(
7373
omTensorCreate(static_cast<void *>(
7474
const_cast<const char **>(expectedOutput.data())),
7575
inputShape, 1 /*rank*/, ONNX_TYPE_STRING),

test/numerical/TestConv.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
144144
onnx_mlir::ExecutionSession sess(
145145
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
146146

147-
std::vector<unique_ptr<OMTensor, decltype(&omTensorDestroy)>> inputs;
148-
auto xOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
147+
std::vector<OMTensorUniquePtr> inputs;
148+
auto xOmt = OMTensorUniquePtr(
149149
omTensorCreateWithRandomData<float>({N, C, H, W}), omTensorDestroy);
150150
inputs.emplace_back(move(xOmt));
151-
auto wOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
151+
auto wOmt = OMTensorUniquePtr(
152152
omTensorCreateWithRandomData<float>({C, C, kH, kW}), omTensorDestroy);
153153
inputs.emplace_back(move(wOmt));
154154

test/numerical/TestGRU.cpp

+6-9
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ bool isOMGRUTheSameAsNaiveImplFor(const int direction, const int S, const int B,
4949
onnx_mlir::ExecutionSession sess(
5050
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
5151

52-
std::vector<unique_ptr<OMTensor, decltype(&omTensorDestroy)>> inputs;
53-
auto xOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
52+
std::vector<OMTensorUniquePtr> inputs;
53+
auto xOmt = OMTensorUniquePtr(
5454
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(xShape), 0, 1),
5555
omTensorDestroy);
5656
inputs.emplace_back(move(xOmt));
57-
auto hOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
57+
auto hOmt = OMTensorUniquePtr(
5858
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(hShape), 0, 1),
5959
omTensorDestroy);
6060
inputs.emplace_back(move(hOmt));
@@ -72,12 +72,9 @@ bool isOMGRUTheSameAsNaiveImplFor(const int direction, const int S, const int B,
7272
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
7373
auto &input = inputs.at(0);
7474
auto &initialH = inputs.at(1);
75-
auto weight =
76-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(wOmt, omTensorDestroy);
77-
auto recurr =
78-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(rOmt, omTensorDestroy);
79-
auto bias =
80-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(bOmt, omTensorDestroy);
75+
auto weight = OMTensorUniquePtr(wOmt, omTensorDestroy);
76+
auto recurr = OMTensorUniquePtr(rOmt, omTensorDestroy);
77+
auto bias = OMTensorUniquePtr(bOmt, omTensorDestroy);
8178

8279
// Initialize refYh and refYc.
8380
for (int64_t d = 0; d < D; d++)

test/numerical/TestGemm.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ bool isOMGemmTheSameAsNaiveImplFor(const int I, const int J, const int K,
7979
onnx_mlir::ExecutionSession sess(
8080
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
8181

82-
std::vector<unique_ptr<OMTensor, decltype(&omTensorDestroy)>> inputs;
83-
auto aOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
82+
std::vector<OMTensorUniquePtr> inputs;
83+
auto aOmt = OMTensorUniquePtr(
8484
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(aShape)),
8585
omTensorDestroy);
8686
inputs.emplace_back(move(aOmt));
87-
auto bOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
87+
auto bOmt = OMTensorUniquePtr(
8888
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(bShape)),
8989
omTensorDestroy);
9090
inputs.emplace_back(move(bOmt));
91-
auto cOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
91+
auto cOmt = OMTensorUniquePtr(
9292
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(cShape)),
9393
omTensorDestroy);
9494
inputs.emplace_back(move(cOmt));

test/numerical/TestLSTM.cpp

+8-12
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,16 @@ bool isOMLSTMTheSameAsNaiveImplFor(const int direction, const int S,
5151
onnx_mlir::ExecutionSession sess(
5252
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
5353

54-
std::vector<unique_ptr<OMTensor, decltype(&omTensorDestroy)>> inputs;
55-
auto xOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
54+
std::vector<OMTensorUniquePtr> inputs;
55+
auto xOmt = OMTensorUniquePtr(
5656
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(xShape), 0, 1),
5757
omTensorDestroy);
5858
inputs.emplace_back(move(xOmt));
59-
auto hOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
59+
auto hOmt = OMTensorUniquePtr(
6060
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(hShape), 0, 1),
6161
omTensorDestroy);
6262
inputs.emplace_back(move(hOmt));
63-
auto cOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
63+
auto cOmt = OMTensorUniquePtr(
6464
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(cShape), 0, 1),
6565
omTensorDestroy);
6666
inputs.emplace_back(move(cOmt));
@@ -77,14 +77,10 @@ bool isOMLSTMTheSameAsNaiveImplFor(const int direction, const int S,
7777
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
7878
// Ht = ot (.) h(Ct)
7979

80-
auto weight =
81-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(wOmt, omTensorDestroy);
82-
auto recurr =
83-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(rOmt, omTensorDestroy);
84-
auto bias =
85-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(bOmt, omTensorDestroy);
86-
auto peepholes =
87-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(pOmt, omTensorDestroy);
80+
auto weight = OMTensorUniquePtr(wOmt, omTensorDestroy);
81+
auto recurr = OMTensorUniquePtr(rOmt, omTensorDestroy);
82+
auto bias = OMTensorUniquePtr(bOmt, omTensorDestroy);
83+
auto peepholes = OMTensorUniquePtr(pOmt, omTensorDestroy);
8884

8985
auto &input = inputs.at(0);
9086
auto &initialH = inputs.at(1);

test/numerical/TestLoop.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,21 @@ bool isOMLoopTheSameAsNaiveImplFor(std::string moduleIR,
9090
onnx_mlir::ExecutionSession sess(
9191
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
9292

93-
std::vector<unique_ptr<OMTensor, decltype(&omTensorDestroy)>> inputs;
94-
auto tripCountTensor = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
93+
std::vector<OMTensorUniquePtr> inputs;
94+
auto tripCountTensor = OMTensorUniquePtr(
9595
omTensorCreateEmpty(nullptr, 0, OM_DATA_TYPE::ONNX_TYPE_INT64),
9696
omTensorDestroy);
9797
omTensorGetElem<int64_t>(tripCountTensor.get(), {}) = tripCount;
9898
inputs.emplace_back(move(tripCountTensor));
9999

100-
auto condTensor = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
100+
auto condTensor = OMTensorUniquePtr(
101101
omTensorCreateEmpty(nullptr, 0, OM_DATA_TYPE::ONNX_TYPE_BOOL),
102102
omTensorDestroy);
103103
omTensorGetElem<bool>(condTensor.get(), {}) = true;
104104
inputs.emplace_back(move(condTensor));
105105

106106
auto *yInitShape = new int64_t[1]{1};
107-
auto yInitTensor = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
107+
auto yInitTensor = OMTensorUniquePtr(
108108
omTensorCreateEmpty(&yInitShape[0], 1, OM_DATA_TYPE::ONNX_TYPE_INT64),
109109
omTensorDestroy);
110110
omTensorGetElem<int64_t>(yInitTensor.get(), {0}) = yInit;
@@ -113,7 +113,7 @@ bool isOMLoopTheSameAsNaiveImplFor(std::string moduleIR,
113113
auto outputs = sess.run(move(inputs));
114114

115115
auto *yRefInitShape = new int64_t[1]{1};
116-
auto vFinalRef = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
116+
auto vFinalRef = OMTensorUniquePtr(
117117
omTensorCreateEmpty(&yRefInitShape[0], 1, OM_DATA_TYPE::ONNX_TYPE_INT64),
118118
omTensorDestroy);
119119

test/numerical/TestMatMul2D.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ bool isOMMatmulTheSameAsNaiveImplFor(const int I, const int J, const int K) {
4242
onnx_mlir::ExecutionSession sess(
4343
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
4444

45-
std::vector<unique_ptr<OMTensor, decltype(&omTensorDestroy)>> inputs;
46-
auto aOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
45+
std::vector<OMTensorUniquePtr> inputs;
46+
auto aOmt = OMTensorUniquePtr(
4747
omTensorCreateWithRandomData<float>({I, K}), omTensorDestroy);
4848
inputs.emplace_back(move(aOmt));
49-
auto bOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
49+
auto bOmt = OMTensorUniquePtr(
5050
omTensorCreateWithRandomData<float>({K, J}), omTensorDestroy);
5151
inputs.emplace_back(move(bOmt));
5252

test/numerical/TestRNN.cpp

+6-9
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ bool isOMRNNTheSameAsNaiveImplFor(const int direction, const int S, const int B,
5050
onnx_mlir::ExecutionSession sess(
5151
getSharedLibName(SHARED_LIB_BASE.str()), "run_main_graph");
5252

53-
std::vector<unique_ptr<OMTensor, decltype(&omTensorDestroy)>> inputs;
54-
auto xOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
53+
std::vector<OMTensorUniquePtr> inputs;
54+
auto xOmt = OMTensorUniquePtr(
5555
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(xShape), 0, 1),
5656
omTensorDestroy);
5757
inputs.emplace_back(move(xOmt));
58-
auto hOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
58+
auto hOmt = OMTensorUniquePtr(
5959
omTensorCreateWithRandomData<float>(llvm::makeArrayRef(hShape), 0, 1),
6060
omTensorDestroy);
6161
inputs.emplace_back(move(hOmt));
@@ -68,12 +68,9 @@ bool isOMRNNTheSameAsNaiveImplFor(const int direction, const int S, const int B,
6868
auto &input = inputs.at(0);
6969
auto &initialH = inputs.at(1);
7070

71-
auto weight =
72-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(wOmt, omTensorDestroy);
73-
auto recurr =
74-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(rOmt, omTensorDestroy);
75-
auto bias =
76-
unique_ptr<OMTensor, decltype(&omTensorDestroy)>(bOmt, omTensorDestroy);
71+
auto weight = OMTensorUniquePtr(wOmt, omTensorDestroy);
72+
auto recurr = OMTensorUniquePtr(rOmt, omTensorDestroy);
73+
auto bias = OMTensorUniquePtr(bOmt, omTensorDestroy);
7774

7875
// Initialize refYh.
7976
for (int64_t d = 0; d < D; d++)

0 commit comments

Comments
 (0)