From 4221e848962acde53e1b73532ca428ee7c7efba9 Mon Sep 17 00:00:00 2001 From: Kris Hung Date: Tue, 8 Oct 2024 12:59:28 -0700 Subject: [PATCH] Add back 24.05 response sending path to fix performance (#381) * Add back 24.05 response sender path * Improve perf * Fix cleanup * Review comments * Fix up * Fix up * Fix response factory cleanup * Fix segfault * Fix error handling * Remove extra logs * Fix up, add comments * Address comment * Fix up --------- Co-authored-by: Iman Tabrizian --- src/infer_request.cc | 2 +- src/infer_request.h | 1 + src/ipc_message.cc | 23 +++ src/ipc_message.h | 9 + src/pb_stub.cc | 146 +++++++++++--- src/pb_stub.h | 5 +- src/pb_utils.h | 3 + src/python_be.cc | 438 +++++++++++++++++++++++++++++++++++------ src/python_be.h | 21 +- src/response_sender.cc | 23 ++- src/response_sender.h | 5 +- 11 files changed, 579 insertions(+), 97 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 8a95b524..e5733662 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -484,7 +484,7 @@ InferRequest::Exec(const bool is_decoupled) { bi::scoped_lock lock{ *(ipc_message->ResponseMutex())}; - stub->SendIPCMessage(ipc_message); + stub->SendIPCUtilsMessage(ipc_message); ipc_message->ResponseCondition()->wait(lock); } diff --git a/src/infer_request.h b/src/infer_request.h index c67e2fb0..f368d692 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -96,6 +96,7 @@ class InferRequest { InferenceTrace& GetTrace(); uint32_t ReleaseFlags(); void SetReleaseFlags(const uint32_t& flags); + intptr_t GetResponseFactoryAddress() { return response_factory_address_; } #ifdef TRITON_PB_STUB std::shared_ptr Exec(const bool is_decoupled); diff --git a/src/ipc_message.cc b/src/ipc_message.cc index ea1dc5b0..2fa13ba3 100644 --- a/src/ipc_message.cc +++ b/src/ipc_message.cc @@ -56,6 +56,21 @@ IPCMessage::Create( new IPCMessage(ipc_message_shm, response_mutex_shm, response_cond_shm)); } +std::unique_ptr +IPCMessage::Create( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& message_handle) +{ + return std::unique_ptr( + new IPCMessage(ipc_message_shm, message_handle)); +} + +AllocatedSharedMemory& +IPCMessage::GetAllocatedSharedMemory() +{ + return ipc_message_shm_; +} + std::unique_ptr IPCMessage::LoadFromSharedMemory( std::unique_ptr& shm_pool, @@ -133,4 +148,12 @@ IPCMessage::IPCMessage( ipc_message_handle_ = ipc_message_shm_.handle_; } +IPCMessage::IPCMessage( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& handle) +{ + ipc_message_handle_ = handle; + ipc_message_shm_ptr_ = ipc_message_shm; +} + }}}; // namespace triton::backend::python diff --git a/src/ipc_message.h b/src/ipc_message.h index 8e762b8f..c3d1472e 100644 --- a/src/ipc_message.h +++ b/src/ipc_message.h @@ -97,6 +97,10 @@ class IPCMessage { static std::unique_ptr Create( const std::unique_ptr& shm_pool, bool inline_response); + + static std::unique_ptr Create( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& message_handle); static std::unique_ptr LoadFromSharedMemory( std::unique_ptr& shm_pool, bi::managed_external_buffer::handle_t message_handle); @@ -108,6 +112,7 @@ class IPCMessage { bi::interprocess_mutex* ResponseMutex(); bi::managed_external_buffer::handle_t& Args(); bi::managed_external_buffer::handle_t ShmHandle(); + AllocatedSharedMemory& GetAllocatedSharedMemory(); private: AllocatedSharedMemory ipc_message_shm_; @@ -129,6 +134,10 @@ class IPCMessage { AllocatedSharedMemory& ipc_message_shm, AllocatedSharedMemory& response_mutex_shm, AllocatedSharedMemory& response_cond_shm); + + IPCMessage( + IPCMessageShm* ipc_message_shm, + bi::managed_external_buffer::handle_t& handle); }; }}}; // namespace triton::backend::python diff --git a/src/pb_stub.cc b/src/pb_stub.cc index 007e7f29..a26719d2 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -653,27 +653,20 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) { py::list py_request_list = LoadRequestsFromSharedMemory(request_batch_shm_ptr); - std::unique_ptr execute_response = - IPCMessage::Create(shm_pool_, false /* Inline response */); - execute_response->Command() = PYTHONSTUB_ExecuteResponse; + std::unique_ptr execute_response; - AllocatedSharedMemory response_batch = - shm_pool_->Construct(); - ResponseBatch* response_batch_shm_ptr = - reinterpret_cast(response_batch.data_.get()); - execute_response->Args() = response_batch.handle_; + std::optional> response_batch; bool has_exception = false; std::string error_string; std::unique_ptr error_string_shm; + std::string err_message; ScopedDefer execute_finalize([this] { stub_message_queue_->Pop(); }); ScopedDefer _( [this, &execute_response] { SendIPCMessage(execute_response); }); - + py::object execute_return; + py::object coroutine_return; try { - response_batch_shm_ptr->has_error = false; - response_batch_shm_ptr->is_error_set = false; - if (!py::hasattr(model_instance_, "execute")) { std::string message = "Python model " + model_context_.PythonModelPath() + " does not implement `execute` method."; @@ -683,8 +676,7 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) { NVTX_RANGE(nvtx_, "PyExecute " + name_); - py::object execute_return = - model_instance_.attr("execute")(py_request_list); + execute_return = model_instance_.attr("execute")(py_request_list); bool is_coroutine = py::module::import("asyncio") .attr("iscoroutine")(execute_return) @@ -694,12 +686,14 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) // Do not wait for async decoupled execute to return. RunCoroutine(execute_return, true /* in_background */); } else { - py::object coroutine_return = + coroutine_return = RunCoroutine(execute_return, false /* in_background */); - ProcessReturnedResponses(py_request_list, coroutine_return); + ProcessReturnedResponses( + py_request_list, coroutine_return, response_batch); } } else { - ProcessReturnedResponses(py_request_list, execute_return); + ProcessReturnedResponses( + py_request_list, execute_return, response_batch); } } } @@ -713,16 +707,36 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) } if (has_exception) { - std::string err_message = - std::string( - "Failed to process the request(s) for model '" + name_ + - "', message: ") + - error_string; + err_message = std::string( + "Failed to process the request(s) for model '" + name_ + + "', message: ") + + error_string; LOG_ERROR << err_message.c_str(); + if (!response_batch) { + response_batch = shm_pool_->Construct( + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + } + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + + // The backend will clean up the response factory if there is an error in + // the response batch. For decoupled mode, it is necessary to handle cases + // where the response sender should have already cleaned up, ensuring the + // backend does not delete the response factory again during error handling. + if (IsDecoupled()) { + for (py::handle py_request : py_request_list) { + InferRequest* request = py_request.cast(); + if (request->GetResponseSender()->IsClosed()) { + response_batch_shm_ptr->is_response_factory_deleted = true; + } + } + } + response_batch_shm_ptr->has_error = true; error_string_shm = PbString::Create(shm_pool_, err_message); response_batch_shm_ptr->error = error_string_shm->ShmHandle(); response_batch_shm_ptr->is_error_set = true; + response_batch_shm_ptr->batch_size = 0; // Once the error is sent to the backend, the backend is supposed to close // all response factories if not already closed, so closing all response // senders if not already closed to prevent the model from sending more @@ -731,12 +745,47 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) InferRequest* request = py_request.cast(); request->GetResponseSender()->Close(); } + } else { + if (!response_batch) { + response_batch = shm_pool_->Construct( + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch_shm_ptr->batch_size = 0; + } + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch_shm_ptr->has_error = false; + response_batch_shm_ptr->is_error_set = false; + } + + execute_response = IPCMessage::Create( + reinterpret_cast(response_batch.value().data_.get()), + response_batch.value().handle_); + execute_response->Args() = + response_batch.value().handle_ + sizeof(IPCMessageShm); + execute_response->InlineResponse() = false; + execute_response->Command() = PYTHONSTUB_ExecuteResponse; + _.Complete(); + execute_finalize.Complete(); +} + +void +Stub::ProcessResponse(InferResponse* response) +{ + response->SaveToSharedMemory(shm_pool_, false /* copy_gpu */); + + for (auto& output_tensor : response->OutputTensors()) { + if (!output_tensor->IsCPU()) { + gpu_tensors_.push_back(output_tensor); + } } } void Stub::ProcessReturnedResponses( - py::list py_requests, py::object py_responses_obj) + py::list py_requests, py::object py_responses_obj, + std::optional>& response_batch) { // Return if there is nothing to process. if (py::isinstance(py_responses_obj)) { @@ -784,12 +833,55 @@ Stub::ProcessReturnedResponses( "return list, found type '" + std::string(py::str(py_responses[i].get_type())) + "'."); } - std::shared_ptr response = - py_responses[i].cast>(); - request->GetResponseSender()->Send( - response, TRITONSERVER_RESPONSE_COMPLETE_FINAL); + + InferResponse* response = py_responses[i].cast(); + try { + request->GetResponseSender()->UpdateStateAndCounters( + response, TRITONSERVER_RESPONSE_COMPLETE_FINAL); + } + catch (const PythonBackendException& pb_exception) { + // Handle the exception here to catch the error when there's a response + // returned from `execute()`. + if (request->GetResponseSender()->IsClosed()) { + response_batch = std::move(shm_pool_->Construct( + sizeof(ResponseBatch) + sizeof(IPCMessageShm))); + ResponseBatch* response_batch_shm_ptr = + reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + response_batch_shm_ptr->batch_size = 0; + response_batch_shm_ptr->is_response_factory_deleted = true; + } + throw pb_exception; + } + } + } + // Return all the created responses using response_batch. The reason + // that both of the paths are available is that sending the responses + // using response_batch is faster than using `response_sender`. + response_batch = std::move(shm_pool_->Construct( + sizeof(IPCMessageShm) + + requests_size * sizeof(bi::managed_external_buffer::handle_t) + + sizeof(ResponseBatch))); + ResponseBatch* response_batch_shm_ptr = reinterpret_cast( + response_batch.value().data_.get() + sizeof(IPCMessageShm)); + + bi::managed_external_buffer::handle_t* responses_shm_handle = + reinterpret_cast( + response_batch.value().data_.get() + sizeof(ResponseBatch) + + sizeof(IPCMessageShm)); + for (size_t i = 0; i < responses_size; i++) { + // Check the return type of execute function. + InferRequest* infer_request = py_requests[i].cast(); + InferResponse* infer_response = py_responses[i].cast(); + if (!py::isinstance(py_responses[i])) { + infer_response->PruneOutputTensors(infer_request->RequestedOutputNames()); + ProcessResponse(infer_response); + responses_shm_handle[i] = infer_response->ShmHandle(); + } else { + responses_shm_handle[i] = 0; } } + response_batch_shm_ptr->batch_size = requests_size; } py::object diff --git a/src/pb_stub.h b/src/pb_stub.h index 9ed74d9a..7d76ec9a 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -254,7 +254,10 @@ class Stub { void ProcessRequests(RequestBatch* request_batch_shm_ptr); void ProcessReturnedResponses( - py::list py_requests, py::object py_responses_obj); + py::list py_requests, py::object py_responses_obj, + std::optional>& response_batch); + + void ProcessResponse(InferResponse* response); py::object GetAsyncEventLoop(); diff --git a/src/pb_utils.h b/src/pb_utils.h index e68cfb0f..aacf6b49 100644 --- a/src/pb_utils.h +++ b/src/pb_utils.h @@ -167,6 +167,9 @@ struct ResponseBatch : SendMessageBase { bool is_error_set; uint32_t response_size; + + // Indicates whether the response factory has been deleted or not. + bool is_response_factory_deleted = false; }; enum LogLevel { kInfo = 0, kWarning, kError, kVerbose }; diff --git a/src/python_be.cc b/src/python_be.cc index 761abdbf..bdf7b95f 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -153,6 +153,23 @@ ModelInstanceState::SetErrorForResponseSendMessage( } } +bool +ModelInstanceState::IsStubProcessAlive() +{ + boost::posix_time::ptime timeout = + boost::get_system_time() + boost::posix_time::seconds(1); + bi::scoped_lock lock(*Stub()->HealthMutex(), timeout); + + // Check if lock has been acquired. + if (lock) { + return Stub()->IpcControl()->stub_health; + } else { + // If It failed to obtain the lock, it means that the stub has been + // stuck or exited while holding the health mutex lock. + return false; + } +} + TRITONSERVER_Error* ModelInstanceState::SaveRequestsToSharedMemory( TRITONBACKEND_Request** requests, const uint32_t request_count, @@ -290,7 +307,7 @@ ModelInstanceState::SaveRequestsToSharedMemory( request, &request_timeout)); std::unique_ptr infer_request; - TRITONBACKEND_ResponseFactory* factory_ptr; + TRITONBACKEND_ResponseFactory* factory_ptr = nullptr; RETURN_IF_ERROR(TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request)); infer_request = std::make_unique( @@ -322,8 +339,6 @@ ModelInstanceState::LaunchStubProcess() thread_pool_ = std::make_unique( model_state->StateForBackend()->thread_pool_size); - queue_monitor_thread_ = true; - queue_monitor_ = std::thread(&ModelInstanceState::MessageQueueMonitor, this); request_executor_ = std::make_unique( Stub()->ShmPool(), model_state->TritonServer()); @@ -685,44 +700,6 @@ ModelInstanceState::ExecuteBLSRequest( } } -void -ModelInstanceState::MessageQueueMonitor() -{ - while (queue_monitor_thread_) { - bi::managed_external_buffer::handle_t handle = - Stub()->ParentMessageQueue()->Pop(); - if (handle == DUMMY_MESSAGE) { - break; - } - std::unique_ptr message = - IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), handle); - - // Need to notify the model instance thread that the execute response has - // been received. - if (message->Command() == PYTHONSTUB_ExecuteResponse) { - std::lock_guard guard{mu_}; - received_message_ = std::move(message); - cv_.notify_one(); - } else if (message->Command() == PYTHONSTUB_ResponseSend) { - std::shared_ptr response_send_message = std::move(message); - std::packaged_task task([this, response_send_message] { - ResponseSendDecoupled(response_send_message); - }); - boost::asio::post(*thread_pool_, std::move(task)); - } else if ( - message->Command() == PYTHONSTUB_InferExecRequest || - message->Command() == PYTHONSTUB_InferStreamExecRequest) { - std::shared_ptr bls_execute = std::move(message); - std::packaged_task task([this, bls_execute] { - ExecuteBLSRequest( - bls_execute, - (bls_execute->Command() == PYTHONSTUB_InferStreamExecRequest)); - }); - boost::asio::post(*thread_pool_, std::move(task)); - } - } -} - void ModelInstanceState::StubToParentMQMonitor() { @@ -769,6 +746,25 @@ ModelInstanceState::StubToParentMQMonitor() ProcessModelControlRequest(message); break; } + case PYTHONSTUB_ResponseSend: { + std::shared_ptr response_send_message = std::move(message); + std::packaged_task task([this, response_send_message] { + ResponseSendDecoupled(response_send_message); + }); + boost::asio::post(*thread_pool_, std::move(task)); + break; + } + case PYTHONSTUB_InferExecRequest: + case PYTHONSTUB_InferStreamExecRequest: { + std::shared_ptr bls_execute = std::move(message); + std::packaged_task task([this, bls_execute] { + ExecuteBLSRequest( + bls_execute, + (bls_execute->Command() == PYTHONSTUB_InferStreamExecRequest)); + }); + boost::asio::post(*thread_pool_, std::move(task)); + break; + } default: { LOG_MESSAGE( TRITONSERVER_LOG_ERROR, "Unexpected message type received."); @@ -1030,6 +1026,100 @@ ModelInstanceState::ProcessModelControlRequest( }); } +TRITONSERVER_Error* +ModelInstanceState::SendMessageToStub( + bi::managed_external_buffer::handle_t message) +{ + bool success = false; + while (!success) { + uint64_t timeout_miliseconds = 1000; + { + boost::posix_time::ptime timeout = + boost::get_system_time() + + boost::posix_time::milliseconds(timeout_miliseconds); + + bi::scoped_lock lock( + *(Stub()->HealthMutex()), timeout); + + // Check if lock has been acquired. + if (lock) { + Stub()->IpcControl()->stub_health = false; + } else { + // If it failed to obtain the lock, it means that the stub has been + // stuck or exited while holding the health mutex lock. + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Failed to obtain the health mutex."); + } + } + + Stub()->StubMessageQueue()->Push( + message, timeout_miliseconds /* duration ms */, success); + + if (!success && !IsStubProcessAlive()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Stub process is not healthy."); + } + } + + return nullptr; // success +} + +void +ModelInstanceState::SendMessageAndReceiveResponse( + bi::managed_external_buffer::handle_t message, + bi::managed_external_buffer::handle_t& response, + std::shared_ptr>& responses, + TRITONBACKEND_Request** requests, const uint32_t request_count) +{ + auto error = SendMessageToStub(message); + if (error != nullptr) { + RespondErrorToAllRequests( + TRITONSERVER_ErrorMessage(error), responses, requests, request_count); + + return; + } + + bi::managed_external_buffer::handle_t response_message; + error = Stub()->ReceiveMessageFromStub(response_message); + if (error != nullptr) { + RespondErrorToAllRequests( + TRITONSERVER_ErrorMessage(error), responses, requests, request_count); + + return; + } + + response = response_message; +} + +void +ModelInstanceState::RespondErrorToAllRequests( + const char* message, + std::shared_ptr>& responses, + TRITONBACKEND_Request** requests, const uint32_t request_count) +{ + for (uint32_t r = 0; r < request_count; ++r) { + if ((*responses)[r] == nullptr) + continue; + + std::string err_message = + std::string( + "Failed to process the request(s) for model instance '" + Name() + + "', message: ") + + message; + + TRITONSERVER_Error* err = + TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, err_message.c_str()); + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + (*responses)[r], TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), + "failed sending response"); + + (*responses)[r] = nullptr; + TRITONSERVER_ErrorDelete(err); + } +} + + void ModelInstanceState::StartMonitor() { @@ -1060,6 +1150,17 @@ ModelInstanceState::ResponseSendDecoupled( ResponseSendMessage* send_message_payload = reinterpret_cast(send_message.data_.get()); std::unique_ptr error_message; + ScopedDefer response_factory_deleter([send_message_payload] { + if (send_message_payload->flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + TRITONBACKEND_ResponseFactory* response_factory = + reinterpret_cast( + send_message_payload->response_factory_address); + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + lresponse_factory(reinterpret_cast( + response_factory)); + } + }); ScopedDefer _([send_message_payload] { { bi::scoped_lock guard{send_message_payload->mu}; @@ -1228,31 +1329,48 @@ ModelInstanceState::ProcessRequests( IPCMessage::Create(Stub()->ShmPool(), false /*inline_response*/)); ipc_message->Command() = PYTHONSTUB_CommandType::PYTHONSTUB_ExecuteRequest; ipc_message->Args() = request_batch.handle_; - received_message_ = nullptr; - ScopedDefer _([this] { + + ScopedDefer execute_finalize([this] { // Push a dummy message to signal the thread to terminate. Stub()->StubMessageQueue()->Push(DUMMY_MESSAGE); }); + std::unique_ptr response; { - std::unique_lock guard{mu_}; Stub()->StubMessageQueue()->Push(ipc_message->ShmHandle()); - cv_.wait(guard, [this] { return received_message_ != nullptr; }); + bi::managed_external_buffer::handle_t response_message; + RETURN_IF_ERROR(Stub()->ReceiveMessageFromStub(response_message)); + response = + IPCMessage::LoadFromSharedMemory(Stub()->ShmPool(), response_message); } - - AllocatedSharedMemory response_batch = - Stub()->ShmPool()->Load(received_message_->Args()); - received_message_.reset(); + char* ipc_message_shm = + reinterpret_cast(response->GetAllocatedSharedMemory().data_.get()); + ResponseBatch* response_batch_shm_ptr = + reinterpret_cast(ipc_message_shm + sizeof(IPCMessageShm)); uint64_t compute_end_ns = 0; SET_TIMESTAMP(compute_end_ns); reporter.SetComputeEndNs(compute_end_ns); reporter.SetBatchStatistics(total_batch_size); - if (response_batch.data_->has_error) { - if (response_batch.data_->is_error_set) { + if (response_batch_shm_ptr->has_error) { + // Clean up the response factory if an error occurred. The + // `is_response_factory_deleted` flag indicates whether the response factory + // has been deleted for some corner cases. + if (!response_batch_shm_ptr->is_response_factory_deleted) { + for (uint32_t r = 0; r < request_count; r++) { + TRITONBACKEND_ResponseFactory* response_factory = + reinterpret_cast( + pb_infer_requests[r]->GetResponseFactoryAddress()); + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + lresponse_factory(reinterpret_cast( + response_factory)); + } + } + if (response_batch_shm_ptr->is_error_set) { auto error = PbString::LoadFromSharedMemory( - Stub()->ShmPool(), response_batch.data_->error); + Stub()->ShmPool(), response_batch_shm_ptr->error); return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, error->String().c_str()); } @@ -1261,6 +1379,218 @@ ModelInstanceState::ProcessRequests( TRITONSERVER_ERROR_INTERNAL, "Failed to process the requests."); } + if (response_batch_shm_ptr->batch_size > 0) { + bi::managed_external_buffer::handle_t* response_shm_handle = + reinterpret_cast( + ipc_message_shm + sizeof(ResponseBatch) + sizeof(IPCMessageShm)); + + std::shared_ptr> responses( + new std::vector()); + responses->reserve(request_count); + for (size_t i = 0; i < request_count; i++) { + // It is possible to have multiple responses batched together in a single + // response batch shm, where some of the responses are None due to the + // usage of response sender, so only create a TRITONBACKEND_Response + // object for the valid responses. + if (response_shm_handle[i] == 0) { + responses->emplace_back(nullptr); + } else { + TRITONBACKEND_Response* response; + auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); + if (err == nullptr) { + responses->emplace_back(response); + } else { + responses->emplace_back(nullptr); + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response"); + TRITONSERVER_ErrorDelete(err); + } + } + } + + std::vector requires_deferred_callback; + + bool has_gpu_output = false; + std::vector> shm_responses; + std::vector, void*>>> + gpu_output_buffers(request_count); + GPUBuffersHelper gpu_buffer_helper; + + for (uint32_t r = 0; r < request_count; ++r) { + NVTX_RANGE(nvtx_, "LoadingResponse " + Name()); + requires_deferred_callback.push_back(false); + if (response_shm_handle[r] == 0) { + continue; + } + TRITONBACKEND_Response* response = (*responses)[r]; + TRITONBACKEND_Request* request = requests[r]; + uint32_t requested_output_count = 0; + + shm_responses.emplace_back(nullptr); + std::unique_ptr& infer_response = shm_responses.back(); + try { + if (pb_infer_requests[r]->ReleaseFlags() == + TRITONSERVER_REQUEST_RELEASE_RESCHEDULE) { + // For rescheduled requests, we do not need to send a response. + LOG_IF_ERROR( + TRITONBACKEND_ResponseDelete((*responses)[r]), + "failed to delete response"); + (*responses)[r] = nullptr; + continue; + } + { + TRITONBACKEND_ResponseFactory* response_factory = + reinterpret_cast( + pb_infer_requests[r]->GetResponseFactoryAddress()); + std::unique_ptr< + TRITONBACKEND_ResponseFactory, backend::ResponseFactoryDeleter> + lresponse_factory( + reinterpret_cast( + response_factory)); + } + infer_response = InferResponse::LoadFromSharedMemory( + Stub()->ShmPool(), response_shm_handle[r], + false /* open_cuda_handle */); + if (infer_response->HasError()) { + TRITONSERVER_Error* err = TRITONSERVER_ErrorNew( + infer_response->Error()->Code(), + infer_response->Error()->Message().c_str()); + + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + (*responses)[r], TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), + "failed sending response"); + TRITONSERVER_ErrorDelete(err); + (*responses)[r] = nullptr; + + // Reset the release flags for the request. + pb_infer_requests[r]->SetReleaseFlags( + TRITONSERVER_REQUEST_RELEASE_ALL); + + // If has_error is true, we do not look at the response tensors. + continue; + } + } + catch (const PythonBackendException& pb_exception) { + TRITONSERVER_Error* err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, pb_exception.what()); + LOG_IF_ERROR( + TRITONBACKEND_ResponseSend( + (*responses)[r], TRITONSERVER_RESPONSE_COMPLETE_FINAL, err), + "failed sending response"); + TRITONSERVER_ErrorDelete(err); + (*responses)[r] = nullptr; + + // Reset the release flags for the request. + pb_infer_requests[r]->SetReleaseFlags(TRITONSERVER_REQUEST_RELEASE_ALL); + + continue; + } + + GUARDED_RESPOND_IF_ERROR( + responses, r, + TRITONBACKEND_RequestOutputCount(request, &requested_output_count)); + std::set requested_output_names; + for (size_t j = 0; j < requested_output_count; ++j) { + const char* output_name; + GUARDED_RESPOND_IF_ERROR( + responses, r, + TRITONBACKEND_RequestOutputName(request, j, &output_name)); + requested_output_names.insert(output_name); + } + + bool require_deferred_callback = false; + +#ifdef TRITON_ENABLE_GPU + for (auto& output_tensor : infer_response->OutputTensors()) { + if (output_tensor->MemoryType() == TRITONSERVER_MEMORY_GPU) { + // Attempt to use the cuda shared memory pool for GPU tensor. + ShareCUDAMemoryPool(output_tensor->MemoryTypeId()); + } + } +#endif // TRITON_ENABLE_GPU + + gpu_output_buffers[r] = + std::vector, void*>>{}; + infer_response->Send( + response, CudaStream(), require_deferred_callback, + TRITONSERVER_RESPONSE_COMPLETE_FINAL, Stub()->ShmPool(), + gpu_buffer_helper, gpu_output_buffers[r], requested_output_names); + + requires_deferred_callback[r] = require_deferred_callback; + + if (requires_deferred_callback[r]) { + has_gpu_output = true; + } + } + + execute_finalize.Complete(); + + // If the output tensor is in GPU, there will be a second round trip + // required for filling the GPU buffers provided by the main process. + if (has_gpu_output) { + ipc_message->Command() = + PYTHONSTUB_CommandType::PYTHONSTUB_LoadGPUBuffers; + gpu_buffer_helper.Complete(Stub()->ShmPool()); + ipc_message->Args() = gpu_buffer_helper.ShmHandle(); + bi::managed_external_buffer::handle_t response_message; + SendMessageAndReceiveResponse( + ipc_message->ShmHandle(), response_message, responses, requests, 0); + + bool cuda_copy = false; + + uint32_t response_index = 0; + for (auto& gpu_output_buffer : gpu_output_buffers) { + for (auto& buffer_memory_pair : gpu_output_buffer) { + auto& pb_memory = buffer_memory_pair.first; + void* pointer = buffer_memory_pair.second; + bool cuda_used = false; + + if (pb_memory->MemoryType() == TRITONSERVER_MEMORY_CPU) { + GUARDED_RESPOND_IF_ERROR( + responses, response_index, + CopyBuffer( + "Failed to copy the output tensor to buffer.", + TRITONSERVER_MEMORY_CPU, 0, TRITONSERVER_MEMORY_CPU, 0, + pb_memory->ByteSize(), pb_memory->DataPtr(), pointer, + CudaStream(), &cuda_used)); + cuda_copy |= cuda_used; + } else if ( + (pb_memory->MemoryType() == TRITONSERVER_MEMORY_GPU) && + pb_memory->UseCUDASharedPool() && + (pb_memory->DataPtr() != pointer)) { + // If the data pointer from pb_memory is not the same as the + // pointer, it means that the Triton-provided buffer is not used + // during tensor transfer. Instead, an intermediate buffer that uses + // CUDA shared memory pool is used. In this case, we need to copy + // the data from the intermediate buffer back to the Triton-provided + // buffer. + GUARDED_RESPOND_IF_ERROR( + responses, response_index, + CopyBuffer( + "Failed to copy the output tensor to buffer.", + TRITONSERVER_MEMORY_GPU, pb_memory->MemoryTypeId(), + TRITONSERVER_MEMORY_GPU, pb_memory->MemoryTypeId(), + pb_memory->ByteSize(), pb_memory->DataPtr(), pointer, + CudaStream(), &cuda_used)); + cuda_copy |= cuda_used; + } + } + response_index++; +#ifdef TRITON_ENABLE_GPU + if (cuda_copy) { + cudaStreamSynchronize(stream_); + } +#endif // TRITON_ENABLE_GPU + } + } + + for (uint32_t r = 0; r < request_count; ++r) { + if (requires_deferred_callback[r]) { + shm_responses[r]->DeferredSendCallback(); + } + } + } + return nullptr; // success } @@ -1401,16 +1731,12 @@ ModelInstanceState::~ModelInstanceState() if (Stub()->IsHealthy()) { // Wait for all the pending tasks to finish. thread_pool_->wait(); - // Push a dummy message to signal the thread to terminate. - Stub()->ParentMessageQueue()->Push(DUMMY_MESSAGE); - queue_monitor_.join(); } // Terminate stub first to allow any last messages to be received by the back // end before deallocating the queue memory Stub()->TerminateStub(); TerminateMonitor(); Stub()->ClearQueues(); - received_message_.reset(); Stub().reset(); } diff --git a/src/python_be.h b/src/python_be.h index 59660fc4..c98e1284 100644 --- a/src/python_be.h +++ b/src/python_be.h @@ -287,9 +287,6 @@ class ModelInstanceState : public BackendModelInstance { std::thread stub_to_parent_queue_monitor_; bool stub_to_parent_thread_; - // Queue monitor thread - std::thread queue_monitor_; - bool queue_monitor_thread_; std::mutex mu_; std::condition_variable cv_; std::unique_ptr received_message_; @@ -361,6 +358,24 @@ class ModelInstanceState : public BackendModelInstance { AllocatedSharedMemory& request_batch, std::shared_ptr>& responses); + void SendMessageAndReceiveResponse( + bi::managed_external_buffer::handle_t message, + bi::managed_external_buffer::handle_t& response, + std::shared_ptr>& responses, + TRITONBACKEND_Request** requests, const uint32_t request_count); + + void RespondErrorToAllRequests( + const char* message, + std::shared_ptr>& responses, + TRITONBACKEND_Request** requests, const uint32_t request_count); + + // void SendMessageToStub(bi::managed_external_buffer::handle_t message); + TRITONSERVER_Error* SendMessageToStub( + bi::managed_external_buffer::handle_t message); + + // Checks whether the stub process is live + bool IsStubProcessAlive(); + // Model instance stub std::unique_ptr& Stub() { return model_instance_stub_; } diff --git a/src/response_sender.cc b/src/response_sender.cc index 0a88fb6b..ef3b09dd 100644 --- a/src/response_sender.cc +++ b/src/response_sender.cc @@ -74,7 +74,7 @@ ResponseSender::~ResponseSender() void ResponseSender::UpdateStateAndCounters( - const std::shared_ptr& response, const uint32_t flags) + InferResponse* response, const uint32_t flags) { if (is_decoupled_ == nullptr) { // TODO: Can a model access the response sender on a BLS infer request? @@ -106,6 +106,7 @@ ResponseSender::UpdateStateAndCounters( } if (flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + response_factory_deleted_.exchange(true); closed_ = true; } number_of_response_sent_++; @@ -123,7 +124,7 @@ ResponseSender::Send( py::gil_scoped_release release; CheckResponseSenderArguments(infer_response, flags); - UpdateStateAndCounters(infer_response, flags); + UpdateStateAndCounters(infer_response.get(), flags); if (infer_response) { infer_response->PruneOutputTensors(requested_output_names_); } @@ -172,7 +173,11 @@ ResponseSender::Send( { bi::scoped_lock guard{send_message_payload->mu}; - stub->SendIPCMessage(ipc_message); + // The server will destruct the response factory if the final flag is set. + if (flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + response_factory_deleted_.exchange(true); + } + stub->SendIPCUtilsMessage(ipc_message); while (!send_message_payload->is_stub_turn) { send_message_payload->cv.wait(guard); } @@ -246,10 +251,6 @@ ResponseSender::Send( "An error occurred while sending a response."); } } - - if (flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL) { - DeleteResponseFactory(); - } } bool @@ -258,11 +259,19 @@ ResponseSender::IsCancelled() return pb_cancel_->IsCancelled(); } +bool +ResponseSender::IsClosed() +{ + std::lock_guard lk(mu_); + return closed_; +} + void ResponseSender::Close() { std::lock_guard lk(mu_); closed_ = true; + response_factory_deleted_.exchange(true); } void diff --git a/src/response_sender.h b/src/response_sender.h index 69f416c2..a696f9eb 100644 --- a/src/response_sender.h +++ b/src/response_sender.h @@ -43,16 +43,17 @@ class ResponseSender { const std::set& requested_output_names, std::unique_ptr& shm_pool, const std::shared_ptr& pb_cancel); + intptr_t ResponseFactory() { return response_factory_address_; } ~ResponseSender(); void Send(std::shared_ptr response, const uint32_t flags); bool IsCancelled(); + void UpdateStateAndCounters(InferResponse* response, const uint32_t flags); // Can be useful at stopping the model from sending any more responses. void Close(); + bool IsClosed(); private: - void UpdateStateAndCounters( - const std::shared_ptr& response, const uint32_t flags); void DeleteResponseFactory(); intptr_t request_address_;