@@ -51,16 +51,16 @@ bool isOMLSTMTheSameAsNaiveImplFor(const int direction, const int S,
51
51
onnx_mlir::ExecutionSession sess (
52
52
getSharedLibName (SHARED_LIB_BASE.str ()), " run_main_graph" );
53
53
54
- std::vector<unique_ptr<OMTensor, decltype (&omTensorDestroy)> > inputs;
55
- auto xOmt = unique_ptr<OMTensor, decltype (&omTensorDestroy)> (
54
+ std::vector<OMTensorUniquePtr > inputs;
55
+ auto xOmt = OMTensorUniquePtr (
56
56
omTensorCreateWithRandomData<float >(llvm::makeArrayRef (xShape), 0 , 1 ),
57
57
omTensorDestroy);
58
58
inputs.emplace_back (move (xOmt));
59
- auto hOmt = unique_ptr<OMTensor, decltype (&omTensorDestroy)> (
59
+ auto hOmt = OMTensorUniquePtr (
60
60
omTensorCreateWithRandomData<float >(llvm::makeArrayRef (hShape), 0 , 1 ),
61
61
omTensorDestroy);
62
62
inputs.emplace_back (move (hOmt));
63
- auto cOmt = unique_ptr<OMTensor, decltype (&omTensorDestroy)> (
63
+ auto cOmt = OMTensorUniquePtr (
64
64
omTensorCreateWithRandomData<float >(llvm::makeArrayRef (cShape), 0 , 1 ),
65
65
omTensorDestroy);
66
66
inputs.emplace_back (move (cOmt));
@@ -77,14 +77,10 @@ bool isOMLSTMTheSameAsNaiveImplFor(const int direction, const int S,
77
77
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
78
78
// Ht = ot (.) h(Ct)
79
79
80
- auto weight =
81
- unique_ptr<OMTensor, decltype (&omTensorDestroy)>(wOmt, omTensorDestroy);
82
- auto recurr =
83
- unique_ptr<OMTensor, decltype (&omTensorDestroy)>(rOmt, omTensorDestroy);
84
- auto bias =
85
- unique_ptr<OMTensor, decltype (&omTensorDestroy)>(bOmt, omTensorDestroy);
86
- auto peepholes =
87
- unique_ptr<OMTensor, decltype (&omTensorDestroy)>(pOmt, omTensorDestroy);
80
+ auto weight = OMTensorUniquePtr (wOmt, omTensorDestroy);
81
+ auto recurr = OMTensorUniquePtr (rOmt, omTensorDestroy);
82
+ auto bias = OMTensorUniquePtr (bOmt, omTensorDestroy);
83
+ auto peepholes = OMTensorUniquePtr (pOmt, omTensorDestroy);
88
84
89
85
auto &input = inputs.at (0 );
90
86
auto &initialH = inputs.at (1 );
0 commit comments