Skip to content

Commit 639c9f7

Browse files
dhuangnmdsikkaDipika Sikkadbogunowiczdhuang
authored
[TextGeneration] Fix llama tokenizer (#1635) (#1636)
* [TextGeneration] Fix llama tokenizer (#1635) * add llama tokenizer fix * fix generated string * only run for streaming * add TODO --------- Co-authored-by: Dipika Sikka <[email protected]> * Retire `flaky` in favour of `pytest-rerunfailures` (#1628) * pick up another fix and bump up version to 1.7.1 --------- Co-authored-by: Dipika Sikka <[email protected]> Co-authored-by: Dipika Sikka <[email protected]> Co-authored-by: dbogunowicz <[email protected]> Co-authored-by: dhuang <[email protected]>
1 parent 5fc5f73 commit 639c9f7

File tree

6 files changed

+53
-11
lines changed

6 files changed

+53
-11
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _parse_requirements_file(file_path):
9999
"black==22.12.0",
100100
"flake8>=3.8.3",
101101
"isort>=5.7.0",
102-
"flaky~=3.7.0",
102+
"pytest-rerunfailures>=13.0",
103103
"ndjson>=0.3.1",
104104
"wheel>=0.36.2",
105105
"pytest>=6.0.0",

src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def run(
101101
else [],
102102
"finished_reason": [],
103103
"token_generator": token_generator,
104+
"past_tokens_queue": copy.copy(tokens),
104105
}
105106

106107
if kv_cache is None:

src/deepsparse/transformers/pipelines/text_generation/process_outputs.py

+46-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import datetime
15-
from typing import Optional
15+
from typing import List, Optional
1616

1717
import numpy
1818

@@ -54,6 +54,33 @@ def _create_generated_text_output(
5454
finished=False,
5555
)
5656

57+
def _generate_streamed_text_from_past_tokens(
58+
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
59+
) -> str:
60+
"""
61+
An auxiliary method that helps to properly generate the streamed text.
62+
Some models like llama2 and mistral are using LlamaTokenizer which is
63+
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
64+
to output appropriate prefix spaces when decoding token by token.
65+
One can make it work if the previously generated tokens are included.
66+
This allows the tokenizer to figure out that the appropriate spaces
67+
from last n consecutive tokens.
68+
69+
:param generated_tokens: the generated tokens from the engine
70+
:param past_tokens_queue: the queue of last n tokens (n is the
71+
original prompt length in tokens)
72+
:return: the generated string
73+
"""
74+
string_from_n_tokens = self.tokenizer.decode(
75+
past_tokens_queue, skip_special_tokens=True
76+
)
77+
past_tokens_queue.append(generated_tokens[0])
78+
string_from_n_plus_1_tokens = self.tokenizer.decode(
79+
past_tokens_queue, skip_special_tokens=True
80+
)
81+
past_tokens_queue.pop(0)
82+
return [string_from_n_plus_1_tokens[len(string_from_n_tokens) :]]
83+
5784
def run(
5885
self,
5986
generated_tokens: numpy.ndarray,
@@ -64,9 +91,24 @@ def run(
6491
):
6592
generation_config = inference_state.current_state.get("generation_config")
6693
generated_logits = generated_logits if generation_config.output_scores else None
67-
sequences = self.tokenizer.batch_decode(
68-
generated_tokens, skip_special_tokens=True
69-
)
94+
95+
import transformers
96+
97+
# Fix for LLAMA-specific models when running streaming
98+
# TODO: make streaming a conditional input to this operator. using inference
99+
# state is a quick fix.
100+
if isinstance(
101+
self.tokenizer,
102+
(transformers.LlamaTokenizer, transformers.LlamaTokenizerFast),
103+
) and inference_state.current_state.get("streaming"):
104+
past_tokens_queue = inference_state.current_state.get("past_tokens_queue")
105+
sequences = self._generate_streamed_text_from_past_tokens(
106+
generated_tokens, past_tokens_queue
107+
)
108+
else:
109+
sequences = self.tokenizer.batch_decode(
110+
generated_tokens, skip_special_tokens=True
111+
)
70112

71113
try:
72114
finished_reason = [f[-1] for f in finished_reason]

src/deepsparse/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from deepsparse.generated_version import is_enterprise, is_release, splash, version
4040
except Exception:
4141
# otherwise, fall back to version info in this file
42-
version = "1.7.0"
42+
version = "1.7.1"
4343
is_release = False
4444
is_enterprise = False
4545
splash = (

tests/deepsparse/pipelines/test_pipeline.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from concurrent.futures import ThreadPoolExecutor
1717
from unittest import mock
1818

19-
import flaky
2019
import pytest
2120
from deepsparse.legacy.base_pipeline import BasePipeline
2221

@@ -125,7 +124,7 @@ def test_pipeline_executor_num_workers():
125124
assert executor._max_workers >= 1
126125

127126

128-
@flaky.flaky(max_runs=2, min_passes=1)
127+
@pytest.mark.flaky(reruns=2, min_passes=1)
129128
@mock_engine(rng_seed=0)
130129
def test_pipeline_call_is_async(engine_mock):
131130
# attempts to verify that pipeline calls to engine are async

tests/server/test_legacy_loggers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import Counter
1717
from unittest import mock
1818

19+
import pytest
1920
from deepsparse.legacy.loggers import PythonLogger
2021
from deepsparse.legacy.loggers.config import (
2122
PipelineSystemLoggingConfig,
@@ -30,7 +31,6 @@
3031
from deepsparse.server.deepsparse_server import DeepsparseServer
3132
from deepsparse.server.helpers import server_logger_from_config
3233
from fastapi.testclient import TestClient
33-
from flaky import flaky
3434
from tests.deepsparse.legacy.loggers.helpers import fetch_leaf_logger
3535
from tests.helpers import find_free_port
3636
from tests.test_data.server_test_data import SAMPLE_LOGS_DICT
@@ -106,7 +106,7 @@ def test_data_logging_from_predefined():
106106
assert log == expected_log
107107

108108

109-
@flaky(max_runs=4, min_passes=3)
109+
@pytest.mark.flaky(reruns=4, min_passes=3)
110110
def test_logging_only_system_info():
111111
server_config = ServerConfig(
112112
endpoints=[EndpointConfig(task=task, name=name, model=stub)],
@@ -195,7 +195,7 @@ def test_multiple_targets_logging():
195195
)
196196

197197

198-
@flaky(max_runs=3, min_passes=2)
198+
@pytest.mark.flaky(reruns=3, min_passes=2)
199199
def test_function_metric_with_target_loggers():
200200
server_config = ServerConfig(
201201
endpoints=[

0 commit comments

Comments
 (0)