Skip to content

Commit 47adab9

Browse files
committed
Fix response factory cleanup
1 parent c42afe1 commit 47adab9

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

src/infer_request.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class InferRequest {
9696
InferenceTrace& GetTrace();
9797
uint32_t ReleaseFlags();
9898
void SetReleaseFlags(const uint32_t& flags);
99+
intptr_t GetResponseFactoryAddress() { return response_factory_address_; }
99100

100101
#ifdef TRITON_PB_STUB
101102
std::shared_ptr<InferResponse> Exec(const bool is_decoupled);

src/python_be.cc

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,17 @@ ModelInstanceState::ResponseSendDecoupled(
10891089
ResponseSendMessage* send_message_payload =
10901090
reinterpret_cast<ResponseSendMessage*>(send_message.data_.get());
10911091
std::unique_ptr<PbString> error_message;
1092+
ScopedDefer response_factory_deleter([send_message_payload] {
1093+
if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
1094+
TRITONBACKEND_ResponseFactory* response_factory =
1095+
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
1096+
send_message_payload->response_factory_address);
1097+
std::unique_ptr<
1098+
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
1099+
lresponse_factory(reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
1100+
response_factory));
1101+
}
1102+
});
10921103
ScopedDefer _([send_message_payload] {
10931104
{
10941105
bi::scoped_lock<bi::interprocess_mutex> guard{send_message_payload->mu};
@@ -1214,13 +1225,6 @@ ModelInstanceState::ResponseSendDecoupled(
12141225
SetErrorForResponseSendMessage(
12151226
send_message_payload, WrapTritonErrorInSharedPtr(error), error_message);
12161227
}
1217-
1218-
if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
1219-
std::unique_ptr<
1220-
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
1221-
lresponse_factory(
1222-
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(response_factory));
1223-
}
12241228
}
12251229

12261230
TRITONSERVER_Error*
@@ -1291,6 +1295,15 @@ ModelInstanceState::ProcessRequests(
12911295

12921296
if (response_batch_shm_ptr->has_error) {
12931297
if (response_batch_shm_ptr->is_error_set) {
1298+
for (uint32_t r = 0; r < request_count; r++) {
1299+
TRITONBACKEND_ResponseFactory* response_factory =
1300+
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
1301+
pb_infer_requests[r]->GetResponseFactoryAddress());
1302+
std::unique_ptr<
1303+
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
1304+
lresponse_factory(reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
1305+
response_factory));
1306+
}
12941307
auto error = PbString::LoadFromSharedMemory(
12951308
Stub()->ShmPool(), response_batch_shm_ptr->error);
12961309
return TRITONSERVER_ErrorNew(
@@ -1357,6 +1370,16 @@ ModelInstanceState::ProcessRequests(
13571370
(*responses)[r] = nullptr;
13581371
continue;
13591372
}
1373+
{
1374+
TRITONBACKEND_ResponseFactory* response_factory =
1375+
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
1376+
pb_infer_requests[r]->GetResponseFactoryAddress());
1377+
std::unique_ptr<
1378+
TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter>
1379+
lresponse_factory(
1380+
reinterpret_cast<TRITONBACKEND_ResponseFactory*>(
1381+
response_factory));
1382+
}
13601383
infer_response = InferResponse::LoadFromSharedMemory(
13611384
Stub()->ShmPool(), response_shm_handle[r],
13621385
false /* open_cuda_handle */);

src/response_sender.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class ResponseSender {
4343
const std::set<std::string>& requested_output_names,
4444
std::unique_ptr<SharedMemoryManager>& shm_pool,
4545
const std::shared_ptr<PbCancel>& pb_cancel);
46+
intptr_t ResponseFactory() { return response_factory_address_; }
4647
~ResponseSender();
4748
void Send(std::shared_ptr<InferResponse> response, const uint32_t flags);
4849
bool IsCancelled();

0 commit comments

Comments
 (0)