Skip to content

Commit 745ecc3

Browse files
authored
[Text Generation] Enable internal kv cache if CPU architecture is avx512 (#1122)
* initial implementation * initial commit * refactor the warning
1 parent ad998df commit 745ecc3

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

Diff for: src/deepsparse/transformers/pipelines/text_generation.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import warnings
1617
from typing import List, Optional, Tuple, Type, Union
1718

1819
import numpy
1920
from pydantic import BaseModel, Field
2021

2122
from deepsparse import Pipeline
23+
from deepsparse.cpu import cpu_avx512_compatible
2224
from deepsparse.pipeline import DEEPSPARSE_ENGINE
2325
from deepsparse.transformers.engines import NLDecoderEngine
2426
from deepsparse.transformers.pipelines import TransformersPipeline
@@ -115,9 +117,17 @@ def __init__(
115117
# TODO: Set this to 64 once we modify the OPT injection logic
116118
prompt_processing_sequence_length: int = 128,
117119
force_max_tokens: bool = False,
118-
use_deepsparse_cache: bool = False,
120+
use_deepsparse_cache: bool = True,
119121
**kwargs,
120122
):
123+
if not cpu_avx512_compatible() and kwargs["engine_type"] == DEEPSPARSE_ENGINE:
124+
warnings.warn(
125+
"AVX512 support not detected, disabling internal management "
126+
"of KV cache which may affect performance. To enable full "
127+
"performance, deploy on an AVX512-compatible system."
128+
)
129+
use_deepsparse_cache = False
130+
121131
if use_deepsparse_cache:
122132
if kwargs["engine_type"] != DEEPSPARSE_ENGINE:
123133
raise ValueError(
@@ -126,10 +136,6 @@ def __init__(
126136
f"is {kwargs['engine_type']}. "
127137
f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}"
128138
)
129-
raise NotImplementedError(
130-
"The deepsparse kv cache is not yet "
131-
"supported for text generation pipelines"
132-
)
133139

134140
super().__init__(
135141
**kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True

0 commit comments

Comments
 (0)