12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import datetime
15
- from typing import Optional
15
+ from typing import List , Optional
16
16
17
17
import numpy
18
18
@@ -54,6 +54,33 @@ def _create_generated_text_output(
54
54
finished = False ,
55
55
)
56
56
57
+ def _generate_streamed_text_from_past_tokens (
58
+ self , generated_tokens : numpy .ndarray , past_tokens_queue : List [int ]
59
+ ) -> str :
60
+ """
61
+ An auxiliary method that helps to properly generate the streamed text.
62
+ Some models like llama2 and mistral are using LlamaTokenizer which is
63
+ based on SentencePiece tokenizer. This specific tokenizer doesn't seem
64
+ to output appropriate prefix spaces when decoding token by token.
65
+ One can make it work if the previously generated tokens are included.
66
+ This allows the tokenizer to figure out that the appropriate spaces
67
+ from last n consecutive tokens.
68
+
69
+ :param generated_tokens: the generated tokens from the engine
70
+ :param past_tokens_queue: the queue of last n tokens (n is the
71
+ original prompt length in tokens)
72
+ :return: the generated string
73
+ """
74
+ string_from_n_tokens = self .tokenizer .decode (
75
+ past_tokens_queue , skip_special_tokens = True
76
+ )
77
+ past_tokens_queue .append (generated_tokens [0 ])
78
+ string_from_n_plus_1_tokens = self .tokenizer .decode (
79
+ past_tokens_queue , skip_special_tokens = True
80
+ )
81
+ past_tokens_queue .pop (0 )
82
+ return [string_from_n_plus_1_tokens [len (string_from_n_tokens ) :]]
83
+
57
84
def run (
58
85
self ,
59
86
generated_tokens : numpy .ndarray ,
@@ -64,9 +91,24 @@ def run(
64
91
):
65
92
generation_config = inference_state .current_state .get ("generation_config" )
66
93
generated_logits = generated_logits if generation_config .output_scores else None
67
- sequences = self .tokenizer .batch_decode (
68
- generated_tokens , skip_special_tokens = True
69
- )
94
+
95
+ import transformers
96
+
97
+ # Fix for LLAMA-specific models when running streaming
98
+ # TODO: make streaming a conditional input to this operator. using inference
99
+ # state is a quick fix.
100
+ if isinstance (
101
+ self .tokenizer ,
102
+ (transformers .LlamaTokenizer , transformers .LlamaTokenizerFast ),
103
+ ) and inference_state .current_state .get ("streaming" ):
104
+ past_tokens_queue = inference_state .current_state .get ("past_tokens_queue" )
105
+ sequences = self ._generate_streamed_text_from_past_tokens (
106
+ generated_tokens , past_tokens_queue
107
+ )
108
+ else :
109
+ sequences = self .tokenizer .batch_decode (
110
+ generated_tokens , skip_special_tokens = True
111
+ )
70
112
71
113
try :
72
114
finished_reason = [f [- 1 ] for f in finished_reason ]
0 commit comments