-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
Copy pathgptDecoder.h
149 lines (118 loc) · 5.52 KB
/
gptDecoder.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/request.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include <NvInferRuntime.h>
#include <curand_kernel.h>
#include <memory>
namespace tensorrt_llm
{
namespace layers
{
// Forward declaration
template <typename T>
class DynamicDecodeLayer;
} // namespace layers
namespace runtime
{
class SpeculativeDecodingModule;
class DecodingLayerWorkspace;
class IGptDecoder
{
public:
using TensorPtr = runtime::ITensor::SharedPtr;
using TensorConstPtr = runtime::ITensor::SharedConstPtr;
virtual ~IGptDecoder() = default;
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots,
std::optional<DecodingOutput> const& output = std::nullopt,
std::optional<std::vector<decoder_batch::Request> const> const& requests = std::nullopt)
= 0;
virtual void forwardAsync(DecodingOutput& output, DecodingInput const& input) = 0;
virtual void forwardSync(DecodingOutput& output, DecodingInput const& input) = 0;
virtual SamplingConfig const& getSamplingConfig() = 0;
virtual void disableLookahead(
std::optional<SamplingConfig> const& samplingConfig, SizeType32 batchSize, TensorConstPtr batchSlots)
= 0;
static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream,
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule = nullptr);
};
template <typename T>
class GptDecoder : public virtual IGptDecoder
{
public:
using CudaStreamPtr = BufferManager::CudaStreamPtr;
using TensorPtr = std::shared_ptr<ITensor>;
GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream,
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule = nullptr);
void setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots,
std::optional<DecodingOutput> const& output = std::nullopt,
std::optional<std::vector<decoder_batch::Request> const> const& requests = std::nullopt) override;
void forwardAsync(DecodingOutput& output, DecodingInput const& input) override;
void forwardSync(DecodingOutput& output, DecodingInput const& input) override;
SamplingConfig const& getSamplingConfig() override
{
return mSamplingConfig;
}
void disableLookahead(
std::optional<SamplingConfig> const& samplingConfig, SizeType32 batchSize, TensorConstPtr batchSlots) override;
private:
std::shared_ptr<BufferManager> mManager;
std::shared_ptr<tensorrt_llm::layers::DynamicDecodeLayer<T>> mDynamicDecodeLayer;
std::shared_ptr<tensorrt_llm::runtime::DecodingLayerWorkspace> mDecodingLayerWorkspace;
SamplingConfig mSamplingConfig;
size_t mMaxBatchSize;
size_t mVocabSize;
size_t mVocabSizePadded;
executor::DecodingMode mDecodingMode;
};
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream,
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule)
{
switch (dtype)
{
case nvinfer1::DataType::kFLOAT:
return std::make_unique<GptDecoder<float>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
maxSequenceLength, stream, speculativeDecodingModule);
case nvinfer1::DataType::kHALF:
return std::make_unique<GptDecoder<half>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
maxSequenceLength, stream, speculativeDecodingModule);
default:
TLLM_THROW("Unsupported decoder data type: %d. Use either kFLOAT or kHALF.", static_cast<int>(dtype));
return nullptr;
}
}
/// @brief Helper function to produce batch slots [0, 1, ..., batchSize - 1] for paths that do not explicitly provide
/// batch slots to the decoder.
inline runtime::ITensor::SharedConstPtr getDefaultBatchSlots(runtime::SizeType32 batchSize)
{
auto defaultBatchSlots = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({batchSize}), runtime::TRTDataType<runtime::SizeType32>::value);
auto range = runtime::BufferRange<runtime::SizeType32>(*defaultBatchSlots);
std::iota(range.begin(), range.end(), 0);
return defaultBatchSlots;
}
} // namespace runtime
} // namespace tensorrt_llm