Skip to content

Commit 7dbe618

Browse files
feat: Add multimodal embedding field in LlmRequest (#3855)
* Add a new param to LlmRequest and Request to natively support mm Signed-off-by: Kate Cheng <[email protected]> * update comment Signed-off-by: Kate Cheng <[email protected]> * Update tests to match the new LlmRequest constructor parameters Signed-off-by: Kate Cheng <[email protected]> * Modify unitTest and modify mm_embeding's dict name in llama4 Signed-off-by: Kate Cheng <[email protected]> * Fix based on comments Signed-off-by: Kate Cheng <[email protected]> * Fix comment Signed-off-by: Kate Cheng <[email protected]> * Fix LlmRequest initialization in kvCacheManagerTest Signed-off-by: Kate Cheng <[email protected]> * Clean up code for promt_tuning_config Signed-off-by: Kate Cheng <[email protected]> * Clean up prompt_tuning_config in GenerationRequest Signed-off-by: Kate Cheng <[email protected]> --------- Signed-off-by: Kate Cheng <[email protected]> Co-authored-by: Haohang Huang <[email protected]>
1 parent 1e317c9 commit 7dbe618

File tree

26 files changed

+290
-164
lines changed

26 files changed

+290
-164
lines changed

benchmarks/cpp/disaggServerBenchmark.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const
535535
std::nullopt, // embeddingBias
536536
std::nullopt, // speculativeDecoding
537537
std::nullopt, // pTuning
538+
std::nullopt, // multimodalEmbedding
538539
std::nullopt, // mRopeConfig
539540
loraConfig, // loraConfig
540541
lookaheadConfig, // lookaheadConfig

benchmarks/cpp/gptManagerBenchmark.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,7 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
828828
std::nullopt, // embeddingBias
829829
std::nullopt, // speculativeDecoding
830830
std::nullopt, // pTuning
831+
std::nullopt, // multimodalEmbedding
831832
std::nullopt, // mRopeConfig
832833
loraConfig, // loraConfig
833834
lookaheadConfig, // lookaheadConfig

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

+26-15
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,15 @@ class GenericLlmRequest
9595
using RequestPtr = std::shared_ptr<GenericLlmRequest>;
9696
using MillisecondsType = std::chrono::milliseconds;
9797

98-
// 45 parameters, 52 items in initialization list
98+
// 46 parameters, 53 items in initialization list
9999
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const& inputTokens,
100100
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
101101
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
102102
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
103103
std::optional<std::shared_ptr<std::vector<SizeType32>>> positionIds = std::nullopt,
104104
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
105105
std::optional<SizeType32> promptVocabSize = std::nullopt,
106+
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
106107
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
107108
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
108109
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
@@ -147,6 +148,7 @@ class GenericLlmRequest
147148
, mPositionIds(std::move(positionIds))
148149
, mPromptEmbeddingTable(std::move(promptEmbeddingTable))
149150
, mPromptVocabSize(promptVocabSize)
151+
, mMultimodalEmbedding(std::move(multimodalEmbedding))
150152
, mMropeRotaryCosSin(std::move(mropeRotaryCosSin))
151153
, mMropePositionDeltas(mropePositionDeltas)
152154
, mLoraTaskId(loraTaskId)
@@ -854,6 +856,11 @@ class GenericLlmRequest
854856
return mPromptVocabSize;
855857
}
856858

859+
[[nodiscard]] std::optional<TensorPtr> getMultimodalEmbedding() const
860+
{
861+
return mMultimodalEmbedding;
862+
}
863+
857864
[[nodiscard]] std::optional<TensorPtr> getMropeRotaryCosSin() const
858865
{
859866
return mMropeRotaryCosSin;
@@ -1818,6 +1825,7 @@ class GenericLlmRequest
18181825

18191826
std::optional<TensorPtr> mPromptEmbeddingTable{std::nullopt};
18201827
std::optional<SizeType32> mPromptVocabSize{std::nullopt};
1828+
std::optional<TensorPtr> mMultimodalEmbedding{std::nullopt};
18211829
std::optional<TensorPtr> mMropeRotaryCosSin{std::nullopt};
18221830
std::optional<SizeType32> mMropePositionDeltas{std::nullopt};
18231831

@@ -2076,14 +2084,15 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
20762084
using TokenExtraIdType = Base::TokenExtraIdType;
20772085
using VecTokenExtraIds = Base::VecTokenExtraIds;
20782086

2079-
// 45 parameters, 45 parameters in Base class constructor
2087+
// 46 parameters, 46 parameters in Base class constructor
20802088
LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
20812089
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
20822090
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
20832091
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
20842092
std::optional<std::shared_ptr<std::vector<SizeType32>>> positionIds = std::nullopt,
20852093
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
20862094
std::optional<SizeType32> promptVocabSize = std::nullopt,
2095+
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
20872096
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
20882097
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
20892098
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
@@ -2111,26 +2120,27 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
21112120
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
21122121
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
21132122
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
2114-
std::move(promptEmbeddingTable), promptVocabSize, std::move(mropeRotaryCosSin), mropePositionDeltas,
2115-
loraTaskId, std::move(loraWeights), std::move(loraConfig), std::move(lookaheadConfig),
2116-
std::move(kvCacheRetentionConfig), returnLogProbs, returnContextLogits, returnGenerationLogits,
2117-
std::move(draftTokens), std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),
2118-
applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority,
2119-
std::move(encoderInputFeatures), std::move(encoderOutputLength), std::move(crossAttentionMask),
2120-
llmRequestType, std::move(inputTokenExtraIds), numReturnSequences, std::move(eagleConfig),
2121-
std::move(skipCrossAttnBlocks), returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid,
2122-
allottedTimeMs, contextPhaseParams)
2123+
std::move(promptEmbeddingTable), promptVocabSize, std::move(multimodalEmbedding),
2124+
std::move(mropeRotaryCosSin), mropePositionDeltas, loraTaskId, std::move(loraWeights),
2125+
std::move(loraConfig), std::move(lookaheadConfig), std::move(kvCacheRetentionConfig), returnLogProbs,
2126+
returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits),
2127+
excludeInputFromOutput, std::move(logitsPostProcessor), applyLogitsPostProcessorBatched,
2128+
std::move(encoderInputTokens), returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures),
2129+
std::move(encoderOutputLength), std::move(crossAttentionMask), llmRequestType,
2130+
std::move(inputTokenExtraIds), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
2131+
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
21232132
{
21242133
}
21252134

2126-
// 45 parameters, 45 parameters in Base class constructor
2135+
// 46 parameters, 46 parameters in Base class constructor
21272136
LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector<TokenIdType> inputTokens,
21282137
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
21292138
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
21302139
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
21312140
std::optional<std::vector<SizeType32>> positionIds = std::nullopt,
21322141
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
21332142
std::optional<SizeType32> promptVocabSize = std::nullopt,
2143+
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
21342144
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
21352145
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
21362146
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
@@ -2159,9 +2169,10 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
21592169
std::move(stopWordsList),
21602170
positionIds.has_value() ? std::make_shared<std::vector<SizeType32>>(std::move(positionIds.value()))
21612171
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
2162-
std::move(promptEmbeddingTable), promptVocabSize, std::move(mropeRotaryCosSin), mropePositionDeltas,
2163-
loraTaskId, std::move(loraWeights), std::move(loraConfig), lookaheadConfig,
2164-
std::move(kvCacheRetentionConfig), returnLogProbs, returnContextLogits, returnGenerationLogits,
2172+
std::move(promptEmbeddingTable), promptVocabSize, std::move(multimodalEmbedding),
2173+
std::move(mropeRotaryCosSin), mropePositionDeltas, loraTaskId, std::move(loraWeights),
2174+
std::move(loraConfig), lookaheadConfig, std::move(kvCacheRetentionConfig), returnLogProbs,
2175+
returnContextLogits, returnGenerationLogits,
21652176
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
21662177
: std::make_shared<VecTokens>(),
21672178
std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),

cpp/include/tensorrt_llm/executor/executor.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,8 @@ class Request
609609
/// @param embeddingBias The embedding bias tensor. Expected shape is [vocab_size]
610610
/// @param externalDraftTokensConfig The speculative decoding with external draft tokens configuration
611611
/// @param pTuningConfig The prompt tuning configuration
612+
/// @param multimodalEmbedding The multimodal embedding tensor. Expected shape is [num_multimodal_tokens,
613+
/// hidden_dim]
612614
/// @param mRopeConfig The mrope configuration
613615
/// @param loraConfig The LoRA configuration
614616
/// @param lookaheadConfig The lookahead speculative decoding configuration
@@ -646,7 +648,8 @@ class Request
646648
std::optional<Tensor> embeddingBias = std::nullopt,
647649
std::optional<ExternalDraftTokensConfig> externalDraftTokensConfig = std::nullopt,
648650
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
649-
std::optional<MropeConfig> mRopeConfig = std::nullopt, std::optional<LoraConfig> loraConfig = std::nullopt,
651+
std::optional<Tensor> multimodalEmbedding = std::nullopt, std::optional<MropeConfig> mRopeConfig = std::nullopt,
652+
std::optional<LoraConfig> loraConfig = std::nullopt,
650653
std::optional<LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
651654
std::optional<KvCacheRetentionConfig> kvCacheRetentionConfig = std::nullopt,
652655
std::optional<std::string> logitsPostProcessorName = std::nullopt,
@@ -688,6 +691,7 @@ class Request
688691
[[nodiscard]] std::optional<Tensor> getEmbeddingBias() const;
689692
[[nodiscard]] std::optional<ExternalDraftTokensConfig> getExternalDraftTokensConfig() const;
690693
[[nodiscard]] std::optional<PromptTuningConfig> getPromptTuningConfig() const;
694+
[[nodiscard]] std::optional<Tensor> getMultimodalEmbedding() const;
691695
[[nodiscard]] std::optional<MropeConfig> getMropeConfig() const;
692696
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
693697
[[nodiscard]] std::optional<LookaheadDecodingConfig> getLookaheadConfig() const;
@@ -722,6 +726,7 @@ class Request
722726
void setEmbeddingBias(Tensor const& embeddingBias);
723727
void setExternalDraftTokensConfig(ExternalDraftTokensConfig const& externalDraftTokensConfig);
724728
void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig);
729+
void setMultimodalEmbedding(Tensor const& multimodalEmbedding);
725730
void setMropeConfig(MropeConfig const& mRopeConfig);
726731
void setLoraConfig(LoraConfig const& loraConfig);
727732
void setLookaheadConfig(LookaheadDecodingConfig const& lookaheadConfig);

cpp/tensorrt_llm/executor/request.cpp

+19-9
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525

2626
namespace tensorrt_llm::executor
2727
{
28-
// 34 parameters
28+
// 35 parameters
2929
Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, SamplingConfig const& samplingConfig,
3030
OutputConfig const& outputConfig, std::optional<SizeType32> const& endId, std::optional<SizeType32> const& padId,
3131
std::optional<std::vector<SizeType32>> positionIds, std::optional<std::list<VecTokens>> badWords,
3232
std::optional<std::list<VecTokens>> stopWords, std::optional<Tensor> embeddingBias,
3333
std::optional<ExternalDraftTokensConfig> externalDraftTokensConfig, std::optional<PromptTuningConfig> pTuningConfig,
34-
std::optional<MropeConfig> mRopeConfig, std::optional<LoraConfig> loraConfig,
35-
std::optional<LookaheadDecodingConfig> lookaheadConfig,
34+
std::optional<Tensor> multimodalEmbedding, std::optional<MropeConfig> mRopeConfig,
35+
std::optional<LoraConfig> loraConfig, std::optional<LookaheadDecodingConfig> lookaheadConfig,
3636
std::optional<KvCacheRetentionConfig> kvCacheRetentionConfig, std::optional<std::string> logitsPostProcessorName,
3737
std::optional<LogitsPostProcessor> logitslogitsPostProcessor, std::optional<VecTokens> encoderInputTokenIds,
3838
std::optional<IdType> clientId, bool returnAllGeneratedTokens, float priority, RequestType type,
@@ -43,12 +43,12 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
4343
std::optional<MillisecondsType> allottedTimeMs)
4444
: mImpl(std::make_unique<Impl>(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId,
4545
padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
46-
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(mRopeConfig), std::move(loraConfig),
47-
lookaheadConfig, std::move(kvCacheRetentionConfig), std::move(logitsPostProcessorName),
48-
std::move(logitslogitsPostProcessor), std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens,
49-
priority, type, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength,
50-
crossAttentionMask, numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams),
51-
languageAdapterUid, allottedTimeMs))
46+
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalEmbedding),
47+
std::move(mRopeConfig), std::move(loraConfig), lookaheadConfig, std::move(kvCacheRetentionConfig),
48+
std::move(logitsPostProcessorName), std::move(logitslogitsPostProcessor), std::move(encoderInputTokenIds),
49+
clientId, returnAllGeneratedTokens, priority, type, std::move(contextPhaseParams),
50+
std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask, numReturnSequences, eagleConfig,
51+
skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs))
5252
{
5353
}
5454

@@ -143,6 +143,11 @@ std::optional<PromptTuningConfig> Request::getPromptTuningConfig() const
143143
return mImpl->getPromptTuningConfig();
144144
}
145145

146+
std::optional<Tensor> Request::getMultimodalEmbedding() const
147+
{
148+
return mImpl->getMultimodalEmbedding();
149+
}
150+
146151
std::optional<MropeConfig> Request::getMropeConfig() const
147152
{
148153
return mImpl->getMropeConfig();
@@ -306,6 +311,11 @@ void Request::setPromptTuningConfig(PromptTuningConfig const& pTuningConfig)
306311
return mImpl->setPromptTuningConfig(pTuningConfig);
307312
}
308313

314+
void Request::setMultimodalEmbedding(Tensor const& multimodalEmbedding)
315+
{
316+
return mImpl->setMultimodalEmbedding(multimodalEmbedding);
317+
}
318+
309319
void Request::setMropeConfig(MropeConfig const& mRopeConfig)
310320
{
311321
return mImpl->setMropeConfig(mRopeConfig);

cpp/tensorrt_llm/executor/requestImpl.h

+17-3
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@ class Request::Impl
3232
{
3333

3434
public:
35-
// 34 parameters, 34 items in initialization list
35+
// 35 parameters, 35 items in initialization list
3636
Impl(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming, SamplingConfig const& samplingConfig,
3737
OutputConfig outputConfig, std::optional<TokenIdType> const& endId, std::optional<TokenIdType> const& padId,
3838
std::optional<std::vector<SizeType32>> positionIds, std::optional<std::list<VecTokens>> badWords,
3939
std::optional<std::list<VecTokens>> stopWords, std::optional<Tensor> embeddingBias,
4040
std::optional<ExternalDraftTokensConfig> externalDraftTokensConfig,
41-
std::optional<PromptTuningConfig> pTuningConfig, std::optional<MropeConfig> mRopeConfig,
42-
std::optional<LoraConfig> loraConfig, std::optional<LookaheadDecodingConfig> lookaheadConfig,
41+
std::optional<PromptTuningConfig> pTuningConfig, std::optional<Tensor> multimodalEmbedding,
42+
std::optional<MropeConfig> mRopeConfig, std::optional<LoraConfig> loraConfig,
43+
std::optional<LookaheadDecodingConfig> lookaheadConfig,
4344
std::optional<KvCacheRetentionConfig> kvCacheRetentionConfig,
4445
std::optional<std::string> logitsPostProcessorName, std::optional<LogitsPostProcessor> logitsPostProcessor,
4546
std::optional<VecTokens> encoderInputTokenIds, std::optional<IdType> clientId, bool returnAllGeneratedTokens,
@@ -61,6 +62,7 @@ class Request::Impl
6162
, mEmbeddingBias(checkEmbeddingBias(std::move(embeddingBias)))
6263
, mExternalDraftTokensConfig(std::move(externalDraftTokensConfig))
6364
, mPTuningConfig(std::move(pTuningConfig))
65+
, mMultimodalEmbedding(std::move(multimodalEmbedding))
6466
, mMropeConfig(std::move(mRopeConfig))
6567
, mLoraConfig(std::move(loraConfig))
6668
, mLookaheadConfig(lookaheadConfig)
@@ -175,6 +177,11 @@ class Request::Impl
175177
return mPTuningConfig;
176178
}
177179

180+
[[nodiscard]] std::optional<Tensor> getMultimodalEmbedding() const
181+
{
182+
return mMultimodalEmbedding;
183+
}
184+
178185
[[nodiscard]] std::optional<MropeConfig> getMropeConfig() const
179186
{
180187
return mMropeConfig;
@@ -338,6 +345,11 @@ class Request::Impl
338345
mPTuningConfig = pTuningConfig;
339346
}
340347

348+
void setMultimodalEmbedding(Tensor const& multimodalEmbedding)
349+
{
350+
mMultimodalEmbedding = multimodalEmbedding;
351+
}
352+
341353
void setMropeConfig(MropeConfig const& mRopeConfig)
342354
{
343355
mMropeConfig = mRopeConfig;
@@ -498,6 +510,7 @@ class Request::Impl
498510
lambda(mEmbeddingBias);
499511
lambda(mExternalDraftTokensConfig);
500512
lambda(mPTuningConfig);
513+
lambda(mMultimodalEmbedding);
501514
lambda(mMropeConfig);
502515
lambda(mLoraConfig);
503516
lambda(mLookaheadConfig);
@@ -533,6 +546,7 @@ class Request::Impl
533546
std::optional<Tensor> mEmbeddingBias;
534547
std::optional<ExternalDraftTokensConfig> mExternalDraftTokensConfig;
535548
std::optional<PromptTuningConfig> mPTuningConfig;
549+
std::optional<Tensor> mMultimodalEmbedding;
536550
std::optional<MropeConfig> mMropeConfig;
537551
std::optional<LoraConfig> mLoraConfig;
538552
std::optional<LookaheadDecodingConfig> mLookaheadConfig;

0 commit comments

Comments
 (0)