From 8251a58cf42a326812cc26e11bd3ec0cf21bce71 Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:51:48 +0000 Subject: [PATCH 1/2] open source 7a776af38eccd9c94ccc23ff069959f2f629745e --- all_models/gpt/postprocessing/config.pbtxt | 1 + .../postprocessing/config.pbtxt | 1 + .../tensorrt_llm/1/model.py | 20 +++++++++++------ .../tensorrt_llm/config.pbtxt | 7 ++++++ build.sh | 1 + dockerfile/Dockerfile.triton.trt_llm_backend | 22 +++++++++++++++++-- .../client/inflight_batcher_llm_client.py | 14 ++++++++++-- .../src/model_instance_state.cc | 5 ++++- inflight_batcher_llm/src/utils.cc | 18 ++++++++++++--- inflight_batcher_llm/src/utils.h | 1 + tensorrt_llm | 2 +- tools/version.txt | 2 +- 12 files changed, 77 insertions(+), 17 deletions(-) diff --git a/all_models/gpt/postprocessing/config.pbtxt b/all_models/gpt/postprocessing/config.pbtxt index 432acbab..04b56cab 100755 --- a/all_models/gpt/postprocessing/config.pbtxt +++ b/all_models/gpt/postprocessing/config.pbtxt @@ -1,6 +1,7 @@ name: "postprocessing" backend: "python" max_batch_size: 1024 +dynamic_batching {} input [ { name: "TOKENS_BATCH" diff --git a/all_models/inflight_batcher_llm/postprocessing/config.pbtxt b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt index a1c2eb20..df875db4 100644 --- a/all_models/inflight_batcher_llm/postprocessing/config.pbtxt +++ b/all_models/inflight_batcher_llm/postprocessing/config.pbtxt @@ -27,6 +27,7 @@ name: "postprocessing" backend: "python" max_batch_size: ${triton_max_batch_size} +dynamic_batching {} input [ { name: "TOKENS_BATCH" diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py index ef327fe4..1e0a84b6 100644 --- a/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py +++ b/all_models/inflight_batcher_llm/tensorrt_llm/1/model.py @@ -175,10 +175,7 @@ def get_sampling_config_from_request(request, batch_size=1, batch_index=0): return trtllm.SamplingConfig(**kwargs) -def get_output_config_from_request(request, - exclude_input_from_output, - batch_size=1, - batch_index=0): +def get_output_config_from_request(request, batch_size=1, batch_index=0): kwargs = {} kwargs["return_log_probs"] = get_input_scalar_by_name( request, 'return_log_probs', batch_size, batch_index) @@ -186,7 +183,6 @@ def get_output_config_from_request(request, request, 'return_context_logits', batch_size, batch_index) kwargs["return_generation_logits"] = get_input_scalar_by_name( request, 'return_generation_logits', batch_size, batch_index) - kwargs["exclude_input_from_output"] = exclude_input_from_output kwargs = {k: v for k, v in kwargs.items() if v is not None} return trtllm.OutputConfig(**kwargs) @@ -312,8 +308,18 @@ def convert_request(request, exclude_input_from_output, decoupled): sampling_config = get_sampling_config_from_request( request, batch_size, batch_index) - output_config = get_output_config_from_request( - request, exclude_input_from_output, batch_size, batch_index) + output_config = get_output_config_from_request(request, batch_size, + batch_index) + req_exclude_input_from_output = get_input_scalar_by_name( + request, 'exclude_input_in_output', batch_size, batch_index) + if req_exclude_input_from_output is None: + # if request doesn't specify exclude_input_from_output, try to use the parameter + output_config.exclude_input_from_output = ( + exclude_input_from_output + if exclude_input_from_output is not None else false) + else: + output_config.exclude_input_from_output = req_exclude_input_from_output + external_draft_tokens_config = get_external_draft_tokens_config_from_request( request, batch_size, batch_index) prompt_tuning_config = get_prompt_tuning_config_from_request( diff --git a/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt index 6aa88088..2f14106d 100644 --- a/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt +++ b/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt @@ -253,6 +253,13 @@ input [ reshape: { shape: [ ] } optional: true }, + { + name: "exclude_input_in_output" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, { name: "stop" data_type: TYPE_BOOL diff --git a/build.sh b/build.sh index 43d99f29..d99f274e 100755 --- a/build.sh +++ b/build.sh @@ -42,6 +42,7 @@ PYTHON_BACKEND_REPO_TAG=${PYTHON_BACKEND_REPO_TAG:-r24.08} --filesystem=gcs --filesystem=s3 --filesystem=azure_storage \ --endpoint=http --endpoint=grpc --endpoint=sagemaker --endpoint=vertex-ai \ --backend=ensemble --enable-gpu --no-container-pull \ + --repoagent=checksum --cache=local --cache=redis \ --image=base,${TRTLLM_BASE_IMAGE} \ --backend=tensorrtllm:${TENSORRTLLM_BACKEND_REPO_TAG} \ --backend=python:${PYTHON_BACKEND_REPO_TAG} diff --git a/dockerfile/Dockerfile.triton.trt_llm_backend b/dockerfile/Dockerfile.triton.trt_llm_backend index 88a8fa57..113777b3 100644 --- a/dockerfile/Dockerfile.triton.trt_llm_backend +++ b/dockerfile/Dockerfile.triton.trt_llm_backend @@ -8,6 +8,20 @@ ARG RELEASE_URL_TRT_ARM=https://developer.nvidia.com/downloads/compute/machine-l FROM ${PYTORCH_IMAGE} as pytorch_image FROM ${BASE_IMAGE} as install_dependencies +ARG CCACHE_REMOTE_STORAGE +ARG CCACHE_URL +ENV CCACHE_DEBUG=1 + +RUN if [ -n "${CCACHE_REMOTE_STORAGE}" ] ; then \ + curl -k -L ${CCACHE_URL} -o ccache.tar.gz ; \ + tar -xzf ccache.tar.gz -C /usr/local --strip-components=1 ; \ + rm ccache.tar.gz ; \ + ccache --set-config=remote_only=true ; \ + ccache --set-config=remote_storage=${CCACHE_REMOTE_STORAGE} ; \ + ccache --set-config=log_file=/tmp/ccache.log ; \ + ccache -p ; \ + fi + # Copy PyTorch package from PyTorch image COPY --from=pytorch_image /usr/local/lib/lib* /usr/local/lib/ COPY --from=pytorch_image /usr/local/lib/python3.10/dist-packages/torch /usr/local/lib/python3.10/dist-packages/torch @@ -20,7 +34,6 @@ RUN apt-get update -q=2 && \ apt-get install -y --no-install-recommends \ python3-dev \ python3-pip \ - ccache \ git-lfs && \ # Remove previous TRT installation apt-get remove -y tensorrt* libnvinfer* && \ @@ -76,7 +89,12 @@ RUN pip3 install --no-cache-dir polygraphy==0.49.9 mpi4py==3.1.5 cmake==3.30.2 COPY scripts scripts COPY tensorrt_llm tensorrt_llm -RUN cd tensorrt_llm && python3 scripts/build_wheel.py --trt_root="${TRT_ROOT}" --clean +RUN cd tensorrt_llm && \ + if [ -n "${CCACHE_REMOTE_STORAGE}" ] ; then \ + python3 scripts/build_wheel.py --trt_root="${TRT_ROOT}" --clean --use_ccache ; \ + else \ + python3 scripts/build_wheel.py --trt_root="${TRT_ROOT}" --clean ; \ + fi # Final stage to build the TRT-LLM container FROM ${BASE_IMAGE} as final_stage diff --git a/inflight_batcher_llm/client/inflight_batcher_llm_client.py b/inflight_batcher_llm/client/inflight_batcher_llm_client.py index 82ff8b11..faad635c 100755 --- a/inflight_batcher_llm/client/inflight_batcher_llm_client.py +++ b/inflight_batcher_llm/client/inflight_batcher_llm_client.py @@ -123,7 +123,8 @@ def prepare_inputs(input_ids_data, input_lengths_data, request_output_len_data, lora_weights_data, lora_config_data, return_log_probs_data, top_k_data, top_p_data, draft_ids_data, return_context_logits_data, return_generation_logits_data, - decoder_input_ids_data, prompt_table_extra_id_data): + decoder_input_ids_data, prompt_table_extra_id_data, + exclude_input_in_output): inputs = [ prepare_tensor("input_ids", input_ids_data), prepare_tensor("input_lengths", input_lengths_data), @@ -185,6 +186,10 @@ def prepare_inputs(input_ids_data, input_lengths_data, request_output_len_data, prepare_tensor("prompt_table_extra_ids", prompt_table_extra_id_data), ] + if exclude_input_in_output is not None: + inputs += [ + prepare_tensor("exclude_input_in_output", exclude_input_in_output), + ] return inputs @@ -665,6 +670,11 @@ def callback(user_data, result, error): if decoder_input_ids is not None: decoder_input_ids_data = np.array(decoder_input_ids, dtype=np.int32) + exclude_input_in_output = None + if FLAGS.exclude_input_in_output: + exclude_input_in_output = np.array([[FLAGS.exclude_input_in_output]], + dtype=bool) + if not FLAGS.vocab_size and tokenizer: FLAGS.vocab_size = tokenizer.vocab_size prompt_table_extra_id_data = None @@ -690,7 +700,7 @@ def callback(user_data, result, error): lora_config_data, return_log_probs_data, top_k_data, top_p_data, draft_ids_data, return_context_logits_data, return_generation_logits_data, decoder_input_ids_data, - prompt_table_extra_id_data) + prompt_table_extra_id_data, exclude_input_in_output) if FLAGS.requested_outputs: # Must have at least output_ids in requested outputs diff --git a/inflight_batcher_llm/src/model_instance_state.cc b/inflight_batcher_llm/src/model_instance_state.cc index 40e480cf..93635d4a 100644 --- a/inflight_batcher_llm/src/model_instance_state.cc +++ b/inflight_batcher_llm/src/model_instance_state.cc @@ -211,6 +211,9 @@ executor::ParallelConfig ModelInstanceState::getParallelConfigFromParams() if (useOrchestratorMode && std::atoi(useOrchestratorMode) != 0) { parallelConfig.setCommunicationMode(executor::CommunicationMode::kORCHESTRATOR); + + tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE); + auto const workerExecutablePath = model_state_->GetExecutorWorkerPath(); auto const spawnProcessesEnvVar = std::getenv("TRTLLM_ORCHESTRATOR_SPAWN_PROCESSES"); auto const spawnProcesses = !spawnProcessesEnvVar || std::atoi(spawnProcessesEnvVar); @@ -978,7 +981,7 @@ std::tuple ModelInstanceStat { size_t contextPhaseParamsSize = executor::Serialization::serializedSize(response.getResult().contextPhaseParams.value()); - std::vector contextPhaseParamsShape{1, contextPhaseParamsSize}; + std::vector contextPhaseParamsShape{1, static_cast(contextPhaseParamsSize)}; TRITONSERVER_DataType contextPhaseParamsType = TRITONSERVER_TYPE_UINT8; auto contextPhaseParamsBuffer = utils::getResponseBuffer(tritonResponse, contextPhaseParamsShape, contextPhaseParamsType, OutputFieldsNames::contextPhaseParams); diff --git a/inflight_batcher_llm/src/utils.cc b/inflight_batcher_llm/src/utils.cc index f7430e15..49eb5e02 100644 --- a/inflight_batcher_llm/src/utils.cc +++ b/inflight_batcher_llm/src/utils.cc @@ -535,7 +535,6 @@ executor::OutputConfig getOutputConfigFromTensors(InputTensors const& inputsTens bool returnContextLogits{false}; extractSingleton(inputsTensors, InputFieldsNames::returnContextLogits, returnContextLogits); - // Note that currently excludeInputFromOutput is set from the backend parameters. return executor::OutputConfig(returnLogProbs, returnContextLogits, returnGenerationLogits); } @@ -628,7 +627,7 @@ std::optional getLoraConfigFromTensors(InputTensors const& } std::vector createRequestsFromInputTensors(std::vector const& inputsTensors, - bool excludeInputFromOutput, bool isDecoupled, bool streaming, executor::ModelType modelType, + bool paramExcludeInputFromOutput, bool isDecoupled, bool streaming, executor::ModelType modelType, executor::RequestType requestType) { if (!isDecoupled && inputsTensors.size() > 1) @@ -644,7 +643,20 @@ std::vector createRequestsFromInputTensors(std::vector reqExcludeInputFromOutput{std::nullopt}; + extractOptionalSingleton( + inputTensors, InputFieldsNames::excludeInputFromOutput, reqExcludeInputFromOutput); + + // If specified in request, set from request + if (reqExcludeInputFromOutput != std::nullopt) + { + outConfig.excludeInputFromOutput = reqExcludeInputFromOutput.value(); + } + else // Set from parameter + { + outConfig.excludeInputFromOutput = paramExcludeInputFromOutput; + } executor::VecTokens inputTokens; if (!utils::extractVector(inputTensors, InputFieldsNames::inputTokens, inputTokens)) diff --git a/inflight_batcher_llm/src/utils.h b/inflight_batcher_llm/src/utils.h index 490f7a3d..019540cb 100644 --- a/inflight_batcher_llm/src/utils.h +++ b/inflight_batcher_llm/src/utils.h @@ -62,6 +62,7 @@ struct InputFieldsNames static constexpr char const* returnLogProbs = "return_log_probs"; static constexpr char const* returnGenerationLogits = "return_generation_logits"; static constexpr char const* returnContextLogits = "return_context_logits"; + static constexpr char const* excludeInputFromOutput = "exclude_input_in_output"; // SamplingConfig static constexpr char const* beamWidth = "beam_width"; diff --git a/tensorrt_llm b/tensorrt_llm index a65dba7a..57ea56bc 160000 --- a/tensorrt_llm +++ b/tensorrt_llm @@ -1 +1 @@ -Subproject commit a65dba7aaf7e2d8bb0120eea8f8f04deff145d6a +Subproject commit 57ea56bcd2ae075ecd72236ee486ab6667f535d3 diff --git a/tools/version.txt b/tools/version.txt index 5a0e272e..5ec4beb5 100644 --- a/tools/version.txt +++ b/tools/version.txt @@ -1 +1 @@ -9f42c546baf991a2e69cb605595b6484d5388709 +7a776af38eccd9c94ccc23ff069959f2f629745e From 9ffe0d7b7437b0bd1dd89374303bf6df527c8ad3 Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Tue, 24 Sep 2024 08:29:05 -0700 Subject: [PATCH 2/2] Update submodule --- tensorrt_llm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm b/tensorrt_llm index 57ea56bc..e1533727 160000 --- a/tensorrt_llm +++ b/tensorrt_llm @@ -1 +1 @@ -Subproject commit 57ea56bcd2ae075ecd72236ee486ab6667f535d3 +Subproject commit e15337275966ee9dea22aba8a6dba138354f6027