Skip to content

Commit dbd9a83

Browse files
authored
feat: Integrate GPUDirect Storage (GDS) into Executor API (#3582)
* feat: Integrate GPUDirect Storage (GDS) into Executor API Squash of several dev commits Signed-off-by: Dom Brown <[email protected]>
1 parent 90a28b9 commit dbd9a83

23 files changed

+410
-82
lines changed

benchmarks/cpp/bertBenchmark.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,17 @@ std::string engineFilename(
7474
}
7575

7676
void benchmarkBert(std::string const& modelName, std::filesystem::path const& dataPath,
77-
std::vector<int> const& batchSizes, std::vector<int> const& inLens, std::vector<float> const& gpuWeightsPercents,
78-
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration)
77+
std::vector<int> const& batchSizes, std::vector<int> const& inLens, bool useGpuDirectStorage,
78+
std::vector<float> const& gpuWeightsPercents, std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp,
79+
int numRuns, int duration)
7980
{
8081
auto const worldConfig = WorldConfig::mpi();
8182
auto const enginePath = dataPath / engineFilename(dataPath, worldConfig, modelName);
8283

8384
for (float gpuWeightsPercent : gpuWeightsPercents)
8485
{
85-
auto rt = std::make_shared<TllmRuntime>(RawEngine(enginePath), logger.get(), gpuWeightsPercent);
86+
auto rt = std::make_shared<TllmRuntime>(
87+
RawEngine(enginePath), logger.get(), useGpuDirectStorage, gpuWeightsPercent);
8688
rt->addContext(0);
8789
for (auto inLen : inLens)
8890
{
@@ -174,6 +176,8 @@ int main(int argc, char* argv[])
174176
"by \";\", "
175177
"example: \"0.0;0.5;1.0\".",
176178
cxxopts::value<std::string>()->default_value("1.0"));
179+
options.add_options()("use_gpu_direct_storage", "Enable GPUDirect Storage (GDS) for loading engine.",
180+
cxxopts::value<bool>()->default_value("false"));
177181

178182
auto result = options.parse(argc, argv);
179183

@@ -258,8 +262,8 @@ int main(int argc, char* argv[])
258262
try
259263
{
260264
benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens,
261-
gpuWeightsPercents, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(),
262-
result["duration"].as<int>());
265+
result["use_gpu_direct_storage"].as<bool>(), gpuWeightsPercents, logger, result["warm_up"].as<int>(),
266+
result["num_runs"].as<int>(), result["duration"].as<int>());
263267
}
264268
catch (std::exception const& e)
265269
{

cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h

+12-9
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ class TrtGptModelOptionalParams
4141
std::optional<std::vector<SizeType32>> deviceIds = std::nullopt, bool normalizeLogProbs = true,
4242
bool enableChunkedContext = true,
4343
PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{},
44-
executor::DecodingConfig decodingConfig = executor::DecodingConfig{}, float gpuWeightsPercent = 1,
45-
std::optional<SizeType32> maxBeamWidth = std::nullopt, std::optional<SizeType32> maxBatchSize = std::nullopt,
46-
std::optional<SizeType32> maxNumTokens = std::nullopt,
44+
executor::DecodingConfig decodingConfig = executor::DecodingConfig{}, bool useGpuDirectStorage = false,
45+
float gpuWeightsPercent = 1, std::optional<SizeType32> maxBeamWidth = std::nullopt,
46+
std::optional<SizeType32> maxBatchSize = std::nullopt, std::optional<SizeType32> maxNumTokens = std::nullopt,
4747
executor::SchedulerConfig schedulerConfig = executor::SchedulerConfig{},
4848
executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig
4949
= executor::ExtendedRuntimePerfKnobConfig{},
@@ -61,6 +61,7 @@ class TrtGptModelOptionalParams
6161
, enableChunkedContext{enableChunkedContext}
6262
, peftCacheManagerConfig(peftCacheManagerConfig)
6363
, decodingConfig(std::move(decodingConfig))
64+
, useGpuDirectStorage(useGpuDirectStorage)
6465
, gpuWeightsPercent(gpuWeightsPercent)
6566
, maxBeamWidth(maxBeamWidth)
6667
, maxBatchSize(maxBatchSize)
@@ -87,12 +88,12 @@ class TrtGptModelOptionalParams
8788
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
8889
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
8990
executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{}),
90-
executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(),
91-
executorConfig.getMaxNumTokens(), executorConfig.getSchedulerConfig(),
92-
executorConfig.getExtendedRuntimePerfKnobConfig(), executorConfig.getDebugConfig(),
93-
executorConfig.getMaxSeqIdleMicroseconds(), executorConfig.getSpecDecConfig(),
94-
executorConfig.getGuidedDecodingConfig(), isLeaderInOrchMode, executorConfig.getAdditionalModelOutputs(),
95-
executorConfig.getGatherGenerationLogits())
91+
executorConfig.getUseGpuDirectStorage(), executorConfig.getGpuWeightsPercent(),
92+
executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(), executorConfig.getMaxNumTokens(),
93+
executorConfig.getSchedulerConfig(), executorConfig.getExtendedRuntimePerfKnobConfig(),
94+
executorConfig.getDebugConfig(), executorConfig.getMaxSeqIdleMicroseconds(),
95+
executorConfig.getSpecDecConfig(), executorConfig.getGuidedDecodingConfig(), isLeaderInOrchMode,
96+
executorConfig.getAdditionalModelOutputs(), executorConfig.getGatherGenerationLogits())
9697
{
9798
}
9899

@@ -106,6 +107,8 @@ class TrtGptModelOptionalParams
106107
bool enableChunkedContext;
107108
PeftCacheManagerConfig peftCacheManagerConfig;
108109
executor::DecodingConfig decodingConfig;
110+
// Use GDS to load the engines?
111+
bool useGpuDirectStorage;
109112
// Percentage of weights on the gpu at runtime
110113
float gpuWeightsPercent;
111114
std::optional<SizeType32> maxBeamWidth;

cpp/include/tensorrt_llm/executor/executor.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -1400,8 +1400,8 @@ class ExecutorConfig
14001400
std::optional<ParallelConfig> parallelConfig = std::nullopt,
14011401
std::optional<PeftCacheConfig> const& peftCacheConfig = std::nullopt,
14021402
std::optional<LogitsPostProcessorConfig> logitsPostProcessorConfig = std::nullopt,
1403-
std::optional<DecodingConfig> decodingConfig = std::nullopt, float gpuWeightsPercent = 1,
1404-
std::optional<SizeType32> maxQueueSize = std::nullopt,
1403+
std::optional<DecodingConfig> decodingConfig = std::nullopt, bool useGpuDirectStorage = false,
1404+
float gpuWeightsPercent = 1, std::optional<SizeType32> maxQueueSize = std::nullopt,
14051405
ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig = ExtendedRuntimePerfKnobConfig(),
14061406
std::optional<DebugConfig> debugConfig = std::nullopt, SizeType32 recvPollPeriodMs = 0,
14071407
uint64_t maxSeqIdleMicroseconds = kDefaultMaxSeqIdleMicroseconds,
@@ -1429,6 +1429,7 @@ class ExecutorConfig
14291429
[[nodiscard]] std::optional<PeftCacheConfig> getPeftCacheConfig() const;
14301430
[[nodiscard]] std::optional<LogitsPostProcessorConfig> getLogitsPostProcessorConfig() const;
14311431
[[nodiscard]] std::optional<DecodingConfig> getDecodingConfig() const;
1432+
[[nodiscard]] bool getUseGpuDirectStorage() const;
14321433
[[nodiscard]] float getGpuWeightsPercent() const;
14331434
[[nodiscard]] std::optional<SizeType32> getMaxQueueSize() const;
14341435
[[nodiscard]] ExtendedRuntimePerfKnobConfig getExtendedRuntimePerfKnobConfig() const;
@@ -1455,6 +1456,7 @@ class ExecutorConfig
14551456
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
14561457
void setLogitsPostProcessorConfig(LogitsPostProcessorConfig const& logitsPostProcessorConfig);
14571458
void setDecodingConfig(DecodingConfig const& decodingConfig);
1459+
void setUseGpuDirectStorage(bool const& useGpuDirectStorage);
14581460
void setGpuWeightsPercent(float const& gpuWeightsPercent);
14591461
void setMaxQueueSize(std::optional<SizeType32> const& maxQueueSize);
14601462
void setExtendedRuntimePerfKnobConfig(ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig);
@@ -1510,6 +1512,9 @@ class ExecutorConfig
15101512
/// @brief Decoding configuration.
15111513
std::optional<DecodingConfig> mDecodingConfig;
15121514

1515+
/// @brief Enable/disable use of GPU Direct Storage (GDS) to load engines.
1516+
bool mUseGpuDirectStorage;
1517+
15131518
/// @brief GPU weights percent for weight streaming.
15141519
float mGpuWeightsPercent;
15151520

cpp/include/tensorrt_llm/runtime/gptSession.h

+3
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ class [[deprecated("Use the executor API instead.")]] GptSession
9999
SizeType32 maxBeamWidth;
100100
// The length of the longest input sequence
101101
SizeType32 maxSequenceLength;
102+
// Enable/disable GPUDirectStorage
103+
// Not supported by GptSession so hard-coded as false
104+
bool useGpuDirectStorage{false};
102105
// Percentage of weights on the gpu at runtime
103106
float gpuWeightsPercent;
104107
// Whether the session will use a different decoder per request.

cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldC
4545
, mWorldConfig{worldConfig}
4646
, mDevice{runtime::utils::initDevice(worldConfig)}
4747
, mLogger{logger ? std::move(logger) : std::make_shared<TllmLogger>()}
48-
, mRuntime{std::make_shared<TllmRuntime>(rawEngine, mLogger.get(), optionalParams.gpuWeightsPercent)}
48+
, mRuntime{std::make_shared<TllmRuntime>(
49+
rawEngine, mLogger.get(), optionalParams.useGpuDirectStorage, optionalParams.gpuWeightsPercent)}
4950
, mMicroBatchId(0)
5051
, mCopyBufferManager{std::make_shared<CudaStream>()}
5152
{

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
138138
, mDebugConfig{optionalParams.debugConfig}
139139
, mAdditionalModelOutputs{optionalParams.additionalModelOutputs}
140140
, mLogger{logger ? std::move(logger) : std::make_shared<TllmLogger>()}
141-
, mRuntime{std::make_shared<TllmRuntime>(
142-
rawEngine, mLogger.get(), optionalParams.gpuWeightsPercent, modelConfig.useShapeInference())}
141+
, mRuntime{std::make_shared<TllmRuntime>(rawEngine, mLogger.get(), optionalParams.useGpuDirectStorage,
142+
optionalParams.gpuWeightsPercent, modelConfig.useShapeInference())}
143143
, mCopyBufferManager{std::make_shared<CudaStream>()}
144144
, mCtxGenFusion(ctxGenFusion)
145145
, mOperatingBeamWidth{getMaxBeamWidth()}

cpp/tensorrt_llm/executor/executorConfig.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule
2828
std::optional<SizeType32> maxNumTokens, std::optional<ParallelConfig> parallelConfig,
2929
std::optional<PeftCacheConfig> const& peftCacheConfig,
3030
std::optional<LogitsPostProcessorConfig> logitsPostProcessorConfig, std::optional<DecodingConfig> decodingConfig,
31-
float gpuWeightPercent, std::optional<SizeType32> maxQueueSize,
31+
bool useGpuDirectStorage, float gpuWeightPercent, std::optional<SizeType32> maxQueueSize,
3232
ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig, std::optional<DebugConfig> debugConfig,
3333
SizeType32 recvPollPeriodMs, uint64_t maxSeqIdleMicroseconds,
3434
std::optional<SpeculativeDecodingConfig> specDecConfig, std::optional<GuidedDecodingConfig> guidedDecodingConfig,
@@ -48,6 +48,7 @@ ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedule
4848
, mPeftCacheConfig(peftCacheConfig)
4949
, mLogitsPostProcessorConfig(std::move(logitsPostProcessorConfig))
5050
, mDecodingConfig(std::move(decodingConfig))
51+
, mUseGpuDirectStorage((useGpuDirectStorage))
5152
, mGpuWeightsPercent(gpuWeightPercent)
5253
, mMaxQueueSize(maxQueueSize)
5354
, mExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig)
@@ -146,6 +147,11 @@ std::optional<DecodingConfig> ExecutorConfig::getDecodingConfig() const
146147
return mDecodingConfig;
147148
}
148149

150+
bool ExecutorConfig::getUseGpuDirectStorage() const
151+
{
152+
return mUseGpuDirectStorage;
153+
}
154+
149155
float ExecutorConfig::getGpuWeightsPercent() const
150156
{
151157
return mGpuWeightsPercent;
@@ -276,6 +282,11 @@ void ExecutorConfig::setDecodingConfig(DecodingConfig const& decodingConfig)
276282
mDecodingConfig = decodingConfig;
277283
}
278284

285+
void ExecutorConfig::setUseGpuDirectStorage(bool const& useGpuDirectStorage)
286+
{
287+
mUseGpuDirectStorage = useGpuDirectStorage;
288+
}
289+
279290
void ExecutorConfig::setGpuWeightsPercent(float const& gpuWeightsPercent)
280291
{
281292
mGpuWeightsPercent = gpuWeightsPercent;

cpp/tensorrt_llm/executor/serialization.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,7 @@ ExecutorConfig Serialization::deserializeExecutorConfig(std::istream& is)
978978
auto parallelConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getParallelConfig)>(is);
979979
auto peftCacheConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getPeftCacheConfig)>(is);
980980
auto decodingConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getDecodingConfig)>(is);
981+
auto useGpuDirectStorage = su::deserializeWithGetterType<decltype(&ExecutorConfig::getUseGpuDirectStorage)>(is);
981982
auto gpuWeightsPercent = su::deserializeWithGetterType<decltype(&ExecutorConfig::getGpuWeightsPercent)>(is);
982983
auto maxQueueSize = su::deserializeWithGetterType<decltype(&ExecutorConfig::getMaxQueueSize)>(is);
983984
auto extendedRuntimePerfKnobConfig
@@ -995,9 +996,9 @@ ExecutorConfig Serialization::deserializeExecutorConfig(std::istream& is)
995996

996997
return ExecutorConfig{maxBeamWidth, schedulerConfig, kvCacheConfig, enableChunkedContext, normalizeLogProbs,
997998
iterStatsMaxIterations, requestStatsMaxIterations, batchingType, maxBatchSize, maxNumTokens, parallelConfig,
998-
peftCacheConfig, std::nullopt, decodingConfig, gpuWeightsPercent, maxQueueSize, extendedRuntimePerfKnobConfig,
999-
debugConfig, recvPollPeriodMs, maxSeqIdleMicroseconds, specDecConfig, guidedDecodingConfig,
1000-
additionalModelOutputs, gatherGenerationLogits};
999+
peftCacheConfig, std::nullopt, decodingConfig, useGpuDirectStorage, gpuWeightsPercent, maxQueueSize,
1000+
extendedRuntimePerfKnobConfig, debugConfig, recvPollPeriodMs, maxSeqIdleMicroseconds, specDecConfig,
1001+
guidedDecodingConfig, additionalModelOutputs, gatherGenerationLogits};
10011002
}
10021003

10031004
size_t Serialization::serializedSize(ExecutorConfig const& executorConfig)
@@ -1020,6 +1021,7 @@ size_t Serialization::serializedSize(ExecutorConfig const& executorConfig)
10201021
totalSize += su::serializedSize(executorConfig.getParallelConfig());
10211022
totalSize += su::serializedSize(executorConfig.getPeftCacheConfig());
10221023
totalSize += su::serializedSize(executorConfig.getDecodingConfig());
1024+
totalSize += su::serializedSize(executorConfig.getUseGpuDirectStorage());
10231025
totalSize += su::serializedSize(executorConfig.getGpuWeightsPercent());
10241026
totalSize += su::serializedSize(executorConfig.getMaxQueueSize());
10251027
totalSize += su::serializedSize(executorConfig.getExtendedRuntimePerfKnobConfig());
@@ -1052,6 +1054,7 @@ void Serialization::serialize(ExecutorConfig const& executorConfig, std::ostream
10521054
su::serialize(executorConfig.getParallelConfig(), os);
10531055
su::serialize(executorConfig.getPeftCacheConfig(), os);
10541056
su::serialize(executorConfig.getDecodingConfig(), os);
1057+
su::serialize(executorConfig.getUseGpuDirectStorage(), os);
10551058
su::serialize(executorConfig.getGpuWeightsPercent(), os);
10561059
su::serialize(executorConfig.getMaxQueueSize(), os);
10571060
su::serialize(executorConfig.getExtendedRuntimePerfKnobConfig(), os);

cpp/tensorrt_llm/pybind/bindings.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
527527
.def_readwrite("enable_chunked_context", &tb::TrtGptModelOptionalParams::enableChunkedContext)
528528
.def_readwrite("normalize_log_probs", &tb::TrtGptModelOptionalParams::normalizeLogProbs)
529529
.def_readwrite("decoding_config", &tb::TrtGptModelOptionalParams::decodingConfig)
530+
.def_readwrite("use_gpu_direct_storage", &tb::TrtGptModelOptionalParams::useGpuDirectStorage)
530531
.def_readwrite("gpu_weights_percent", &tb::TrtGptModelOptionalParams::gpuWeightsPercent)
531532
.def_readwrite("max_beam_width", &tb::TrtGptModelOptionalParams::maxBeamWidth)
532533
.def_readwrite("scheduler_config", &tb::TrtGptModelOptionalParams::schedulerConfig)

0 commit comments

Comments
 (0)