Skip to content

Trtllm backend improvements #3231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/llamacpp/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ struct Args {
#[clap(default_value = "3000", long, short, env)]
port: u16,

#[clap(default_value = "9000", long, short, env)]
#[clap(default_value = "9000", long, env)]
prometheus_port: u16,

/// Enable JSON output format.
Expand Down
9 changes: 8 additions & 1 deletion backends/trtllm/csrc/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@ namespace huggingface::tgi::backends::trtllm {
static_cast<tle::SizeType32>(g_params.max_new_tokens),
true,
(tle::SamplingConfig) s_params,
tle::OutputConfig{ /* returnLogProbs= */ true},
tle::OutputConfig{
/* returnLogProbs= */ true,
false,
false,
false,
false,
/* returnPerfMetrics=*/ true,
},
std::nullopt,
std::nullopt,
std::nullopt,
Expand Down
62 changes: 52 additions & 10 deletions backends/trtllm/csrc/ffi.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef TGI_BACKEND_TRTLLM_FFI
#define TGI_BACKEND_TRTLLM_FFI

#include <chrono>
#include <exception>
#include <memory>
#include <thread>

Expand All @@ -17,7 +19,7 @@ namespace rust::behavior {
template<typename Try, typename Fail>
static void trycatch(Try &&func, Fail &&fail) noexcept try {
func();
} catch (tensorrt_llm::common::TllmException &e) {
} catch (const std::exception &e) {
fail(e.what());
}
}
Expand All @@ -42,22 +44,46 @@ namespace huggingface::tgi::backends::trtllm {
return finish_reason_t::kEND_ID;
case tle::FinishReason::kLENGTH:
return finish_reason_t::kLENGTH;
case tle::FinishReason::kTIMED_OUT:
return finish_reason_t::kTIMED_OUT;
case tle::FinishReason::kCANCELLED:
return finish_reason_t::kCANCELLED;
default:
std::unreachable();
}
}

static auto as_generation_step = [](const tle::Response &r) {
static auto as_generation_step = [](const tle::Response &r, const std::chrono::time_point<std::chrono::steady_clock> created) {
const auto reqId = r.getRequestId();
if (!r.hasError()) [[likely]] {
const auto result = r.getResult();
const auto logits = result.logProbs.value()[0];
std::optional<uint32_t> token_id = std::nullopt;
if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) {
token_id = static_cast<uint32_t>(result.outputTokenIds[0][0]);
}

std::optional<float> log_prob = std::nullopt;
if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) {
log_prob = result.logProbs.value()[0].back();
}

std::optional<int64_t> first_scheduled_time_ns = std::nullopt;
if (result.requestPerfMetrics) {
const auto &t = result.requestPerfMetrics->timingMetrics;
const auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(t.firstScheduledTime - created).count();
first_scheduled_time_ns = static_cast<int64_t>(ns);
}

return generation_step_t{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
logits.back(),
token_id.value_or(0),
log_prob.value_or(0.0),
first_scheduled_time_ns.value_or(0),
result.isFinal,
as_finish_reason_t(result.finishReasons[0]),
token_id.has_value(),
log_prob.has_value(),
first_scheduled_time_ns.has_value(),
false,
std::string()
};
Expand All @@ -66,8 +92,12 @@ namespace huggingface::tgi::backends::trtllm {
reqId,
0,
0.0,
0,
true,
finish_reason_t::kNOT_FINISHED,
false,
false,
false,
true,
std::move(r.getErrorMsg())
};
Expand All @@ -79,9 +109,16 @@ namespace huggingface::tgi::backends::trtllm {
private:
backend_t inner_;

// m_created_time is a reference point to convert time from c++ time_point
// to rust Instant.
std::chrono::time_point<std::chrono::steady_clock> m_created_time;


public:
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
: inner_(engine_folder, executor_worker_path) {}
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path, const std::chrono::time_point<std::chrono::steady_clock>& created_time)
: inner_(engine_folder, executor_worker_path),
m_created_time {created_time}
{}

size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }

Expand Down Expand Up @@ -121,13 +158,16 @@ namespace huggingface::tgi::backends::trtllm {

SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());

auto f = [this](const tle::Response &r){
return as_generation_step(r, m_created_time);
};
// Transform tle::Response to generation_step_t
#ifdef __cpp_lib_ranges_to_container
auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to<std::vector>();
auto steps = responses | std::views::transform(f) | std::ranges::to<std::vector>();
#else
auto steps = std::vector<generation_step_t>();
steps.reserve(responses.size());
std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
std::transform(responses.begin(), responses.end(), std::back_inserter(steps), f);
#endif
return std::make_unique<std::vector<generation_step_t>>(steps);

Expand Down Expand Up @@ -179,12 +219,14 @@ namespace huggingface::tgi::backends::trtllm {

std::unique_ptr<tensorrt_llm_backend_t>
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
const auto created_time = std::chrono::steady_clock::now();
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
return std::make_unique<tensorrt_llm_backend_t>(
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
std::filesystem::path::format::auto_format),
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
std::filesystem::path::format::auto_format)
std::filesystem::path::format::auto_format),
created_time
);
}
}
Expand Down
4 changes: 4 additions & 0 deletions backends/trtllm/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ pub enum TensorRtLlmBackendError {
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
#[error("config.json doesn't exist in engine folder {0}")]
ConfigNotFound(PathBuf),
#[error("generation_config.json doesn't exist in engine folder {0}")]
GenerationConfigNotFound(PathBuf),
}
14 changes: 14 additions & 0 deletions backends/trtllm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ mod ffi {
/// The request finished because the maximum number of tokens was reached.
#[cxx_name = "kLENGTH"]
MaxLength = 3u8,

#[cxx_name = "kTIMED_OUT"]
/// The request finished because it got timed out (via the mAllotedTime parameter)
TimedOut = 4u8,

#[cxx_name = "kCANCELLED"]
/// The request was cancelled by calling cancelRequest.
Cancelled = 5u8,
}

/// Struct used as shared type between rust and C++ to represent the result
Expand All @@ -34,8 +42,14 @@ mod ffi {
request_id: u64,
token_id: u32,
log_prob: f32,

/// The time of first schedule since the creation of the backend
first_scheduled_time_ns: i64,
is_final: bool,
finish_reason: FinishReason,
token_id_valid: bool,
log_prob_valid: bool,
first_scheduled_time_ns_valid: bool,
has_error: bool,
error_msg: String,
}
Expand Down
Loading