Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

[Text Generation] Enable internal kv cache if CPU architecture is avx512 #1122

Merged
merged 5 commits into from
Jul 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

import logging
import warnings
from typing import List, Optional, Tuple, Type, Union

import numpy
from pydantic import BaseModel, Field

from deepsparse import Pipeline
from deepsparse.cpu import cpu_avx512_compatible
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
Expand Down Expand Up @@ -115,9 +117,17 @@ def __init__(
# TODO: Set this to 64 once we modify the OPT injection logic
prompt_processing_sequence_length: int = 128,
force_max_tokens: bool = False,
use_deepsparse_cache: bool = False,
use_deepsparse_cache: bool = True,
**kwargs,
):
if not cpu_avx512_compatible() and kwargs["engine_type"] == DEEPSPARSE_ENGINE:
warnings.warn(
"AVX512 support not detected, disabling internal management "
"of KV cache which may affect performance. To enable full "
"performance, deploy on an AVX512-compatible system."
)
use_deepsparse_cache = False

if use_deepsparse_cache:
if kwargs["engine_type"] != DEEPSPARSE_ENGINE:
raise ValueError(
Expand All @@ -126,10 +136,6 @@ def __init__(
f"is {kwargs['engine_type']}. "
f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}"
)
raise NotImplementedError(
"The deepsparse kv cache is not yet "
"supported for text generation pipelines"
)

super().__init__(
**kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True
Expand Down