Skip to content

Commit b319fa3

Browse files
committed
initial implementation
1 parent 8a26435 commit b319fa3

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

src/deepsparse/transformers/pipelines/text_generation.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import warnings
1617
from typing import List, Optional, Tuple, Type, Union
1718

1819
import numpy
1920
from pydantic import BaseModel, Field
2021

2122
from deepsparse import Pipeline
23+
from deepsparse.cpu import cpu_avx512_compatible
2224
from deepsparse.pipeline import DEEPSPARSE_ENGINE
2325
from deepsparse.transformers.engines import NLDecoderEngine
2426
from deepsparse.transformers.pipelines import TransformersPipeline
@@ -115,9 +117,19 @@ def __init__(
115117
# TODO: Set this to 64 once we modify the OPT injection logic
116118
prompt_processing_sequence_length: int = 128,
117119
force_max_tokens: bool = False,
118-
use_deepsparse_cache: bool = False,
120+
use_deepsparse_cache: bool = True,
119121
**kwargs,
120122
):
123+
print(cpu_avx512_compatible())
124+
if not cpu_avx512_compatible() and kwargs["engine_type"] == DEEPSPARSE_ENGINE:
125+
warnings.warn(
126+
"Detected CPU is not AVX512 compatible. "
127+
"The kv cache management will not be supported "
128+
"by the optimized engine. The user may experience "
129+
"non optimal performance."
130+
)
131+
use_deepsparse_cache = False
132+
121133
if use_deepsparse_cache:
122134
if kwargs["engine_type"] != DEEPSPARSE_ENGINE:
123135
raise ValueError(
@@ -126,10 +138,6 @@ def __init__(
126138
f"is {kwargs['engine_type']}. "
127139
f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}"
128140
)
129-
raise NotImplementedError(
130-
"The deepsparse kv cache is not yet "
131-
"supported for text generation pipelines"
132-
)
133141

134142
super().__init__(
135143
**kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True

0 commit comments

Comments
 (0)