@@ -41,9 +41,9 @@ class TrtGptModelOptionalParams
41
41
std::optional<std::vector<SizeType32>> deviceIds = std::nullopt, bool normalizeLogProbs = true ,
42
42
bool enableChunkedContext = true ,
43
43
PeftCacheManagerConfig const & peftCacheManagerConfig = PeftCacheManagerConfig{},
44
- executor::DecodingConfig decodingConfig = executor::DecodingConfig{}, float gpuWeightsPercent = 1 ,
45
- std::optional<SizeType32> maxBeamWidth = std::nullopt , std::optional<SizeType32> maxBatchSize = std::nullopt,
46
- std::optional<SizeType32> maxNumTokens = std::nullopt,
44
+ executor::DecodingConfig decodingConfig = executor::DecodingConfig{}, bool useGpuDirectStorage = false ,
45
+ float gpuWeightsPercent = 1 , std::optional<SizeType32> maxBeamWidth = std::nullopt,
46
+ std::optional<SizeType32> maxBatchSize = std::nullopt, std::optional<SizeType32> maxNumTokens = std::nullopt,
47
47
executor::SchedulerConfig schedulerConfig = executor::SchedulerConfig{},
48
48
executor::ExtendedRuntimePerfKnobConfig const & extendedRuntimePerfKnobConfig
49
49
= executor::ExtendedRuntimePerfKnobConfig{},
@@ -61,6 +61,7 @@ class TrtGptModelOptionalParams
61
61
, enableChunkedContext{enableChunkedContext}
62
62
, peftCacheManagerConfig(peftCacheManagerConfig)
63
63
, decodingConfig(std::move(decodingConfig))
64
+ , useGpuDirectStorage(useGpuDirectStorage)
64
65
, gpuWeightsPercent(gpuWeightsPercent)
65
66
, maxBeamWidth(maxBeamWidth)
66
67
, maxBatchSize(maxBatchSize)
@@ -87,12 +88,12 @@ class TrtGptModelOptionalParams
87
88
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
88
89
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
89
90
executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{}),
90
- executorConfig.getGpuWeightsPercent (), executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize (),
91
- executorConfig.getMaxNumTokens (), executorConfig.getSchedulerConfig (),
92
- executorConfig.getExtendedRuntimePerfKnobConfig (), executorConfig.getDebugConfig (),
93
- executorConfig.getMaxSeqIdleMicroseconds (), executorConfig.getSpecDecConfig (),
94
- executorConfig.getGuidedDecodingConfig (), isLeaderInOrchMode, executorConfig.getAdditionalModelOutputs() ,
95
- executorConfig.getGatherGenerationLogits())
91
+ executorConfig.getUseGpuDirectStorage (), executorConfig.getGpuWeightsPercent (),
92
+ executorConfig.getMaxBeamWidth (), executorConfig.getMaxBatchSize(), executorConfig.getMaxNumTokens (),
93
+ executorConfig.getSchedulerConfig (), executorConfig.getExtendedRuntimePerfKnobConfig (),
94
+ executorConfig.getDebugConfig (), executorConfig.getMaxSeqIdleMicroseconds (),
95
+ executorConfig.getSpecDecConfig (), executorConfig.getGuidedDecodingConfig(), isLeaderInOrchMode ,
96
+ executorConfig.getAdditionalModelOutputs(), executorConfig. getGatherGenerationLogits())
96
97
{
97
98
}
98
99
@@ -106,6 +107,8 @@ class TrtGptModelOptionalParams
106
107
bool enableChunkedContext;
107
108
PeftCacheManagerConfig peftCacheManagerConfig;
108
109
executor::DecodingConfig decodingConfig;
110
+ // Use GDS to load the engines?
111
+ bool useGpuDirectStorage;
109
112
// Percentage of weights on the gpu at runtime
110
113
float gpuWeightsPercent;
111
114
std::optional<SizeType32> maxBeamWidth;
0 commit comments