Skip to content

Refactor engine.py to pull out some common functionality #1035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 151 additions & 77 deletions src/deepsparse/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from deepsparse.benchmark import BenchmarkResults
from deepsparse.utils import (
generate_random_inputs,
get_output_names,
model_to_path,
override_onnx_input_shapes,
)
Expand All @@ -53,6 +54,8 @@
"Scheduler",
"Context",
"MultiModelEngine",
"KVCacheEngine",
"BaseEngine",
]

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -152,7 +155,95 @@ def _validate_scheduler(scheduler: Union[None, str, Scheduler]) -> Scheduler:
return scheduler


class Engine(object):
class Context(object):
"""
Contexts can be used to run multiple instances of the MultiModelEngine with the same
scheduler. This allows one scheduler to manage the resources of the system
effectively, keeping engines that are running different models from fighting over system
resources.

:param num_cores: The number of physical cores to run the model on. If more
cores are requested than are available on a single socket, the engine
will try to distribute them evenly across as few sockets as possible.
:param num_streams: The max number of requests the model can handle
concurrently.
"""

def __init__(
self,
num_cores: int = None,
num_streams: int = None,
):
self._num_cores = _validate_num_cores(num_cores)
self._scheduler = Scheduler.from_str("elastic")
self._deepsparse_context = LIB.deepsparse_context(
self._num_cores,
_validate_num_streams(num_streams, self._num_cores),
self._scheduler.value,
)
# num_streams can be adjusted by how we map optimially to the hardware topology,
# so let's use the context as the source of truth to be transparent
self._num_streams = self._deepsparse_context.num_streams()

@property
def value(self):
return self._deepsparse_context

@property
def num_cores(self):
return self._num_cores

@property
def num_streams(self):
return self._num_streams

@property
def scheduler(self):
return self._scheduler

def __repr__(self) -> str:
return f"Context(num_cores={self.num_cores}, num_streams={self.num_streams}, scheduler={self.scheduler})"


class BaseEngine(object):
def construct(
self,
model: Union[str, "Model", "File"],
batch_size: int = 1,
num_cores: int = None,
num_streams: int = None,
scheduler: Scheduler = None,
input_shapes: List[List[int]] = None,
):
_analytics.send_event("python__engine__init")
self._model_path = model_to_path(model)
self._batch_size = _validate_batch_size(batch_size)
self._num_cores = _validate_num_cores(num_cores)
self._num_streams = _validate_num_streams(num_streams, self._num_cores)
self._scheduler = _validate_scheduler(scheduler)
self._input_shapes = input_shapes
self._cpu_avx_type = AVX_TYPE
self._cpu_vnni = VNNI

def construct_with_context(
self,
model: Union[str, "Model", "File"],
batch_size: int,
context: Context,
input_shapes: List[List[int]] = None,
):
_analytics.send_event("python__engine__init")
self._model_path = model_to_path(model)
self._batch_size = _validate_batch_size(batch_size)
self._num_cores = context.num_cores
self._num_streams = context.num_streams
self._scheduler = _validate_scheduler(context.scheduler)
self._input_shapes = input_shapes
self._cpu_avx_type = AVX_TYPE
self._cpu_vnni = VNNI


class Engine(BaseEngine):
"""
Create a new DeepSparse Engine that compiles the given onnx file
for GPU class performance on commodity CPUs.
Expand Down Expand Up @@ -186,16 +277,10 @@ def __init__(
scheduler: Scheduler = None,
input_shapes: List[List[int]] = None,
):
_analytics.send_event("python__engine__init")
self._model_path = model_to_path(model)
self._batch_size = _validate_batch_size(batch_size)
self._num_cores = _validate_num_cores(num_cores)
self._scheduler = _validate_scheduler(scheduler)
self._input_shapes = input_shapes
self._cpu_avx_type = AVX_TYPE
self._cpu_vnni = VNNI
BaseEngine.construct(
self, model, batch_size, num_cores, num_streams, scheduler, input_shapes
)

num_streams = _validate_num_streams(num_streams, self._num_cores)
if self._input_shapes:
with override_onnx_input_shapes(
self._model_path, self._input_shapes
Expand All @@ -204,7 +289,7 @@ def __init__(
model_path,
self._batch_size,
self._num_cores,
num_streams,
self._num_streams,
self._scheduler.value,
None,
)
Expand All @@ -213,7 +298,7 @@ def __init__(
self._model_path,
self._batch_size,
self._num_cores,
num_streams,
self._num_streams,
self._scheduler.value,
None,
)
Expand Down Expand Up @@ -645,15 +730,10 @@ def __init__(
imposed_as: Optional[float] = None,
imposed_ks: Optional[float] = None,
):
self._model_path = model_to_path(model)
self._batch_size = _validate_batch_size(batch_size)
self._num_cores = _validate_num_cores(num_cores)
self._scheduler = _validate_scheduler(scheduler)
self._input_shapes = input_shapes
self._cpu_avx_type = AVX_TYPE
self._cpu_vnni = VNNI
BaseEngine.construct(
self, model, batch_size, num_cores, None, scheduler, input_shapes
)

num_streams = _validate_num_streams(None, self._num_cores)
if self._input_shapes:
with override_onnx_input_shapes(
self._model_path, self._input_shapes
Expand All @@ -662,7 +742,7 @@ def __init__(
model_path,
self._batch_size,
self._num_cores,
num_streams,
self._num_streams,
self._scheduler.value,
None,
"external",
Expand All @@ -677,7 +757,7 @@ def __init__(
self._model_path,
self._batch_size,
self._num_cores,
num_streams,
self._num_streams,
self._scheduler.value,
None,
"external",
Expand Down Expand Up @@ -712,53 +792,6 @@ def analyze(
return bench_info


class Context(object):
"""
Contexts can be used to run multiple instances of the MultiModelEngine with the same
scheduler. This allows one scheduler to manage the resources of the system
effectively, keeping engines that are running different models from fighting over system
resources.

:param num_cores: The number of physical cores to run the model on. If more
cores are requested than are available on a single socket, the engine
will try to distribute them evenly across as few sockets as possible.
:param num_streams: The max number of requests the model can handle
concurrently.
"""

def __init__(
self,
num_cores: int = None,
num_streams: int = None,
):
self._num_cores = _validate_num_cores(num_cores)
self._scheduler = Scheduler.from_str("elastic")
self._deepsparse_context = LIB.deepsparse_context(
self._num_cores,
_validate_num_streams(num_streams, self._num_cores),
self._scheduler.value,
)

@property
def value(self):
return self._deepsparse_context

@property
def num_cores(self):
return self._num_cores

@property
def num_streams(self):
return self._deepsparse_context.num_streams()

@property
def scheduler(self):
return self._scheduler

def __repr__(self) -> str:
return f"Context(num_cores={self.num_cores}, num_streams={self.num_streams}, scheduler={self.scheduler})"


class MultiModelEngine(Engine):
"""
The MultiModelEngine, together with the Context class, can be used to run multiple models
Expand All @@ -785,14 +818,9 @@ def __init__(
context: Context,
input_shapes: List[List[int]] = None,
):
self._model_path = model_to_path(model)
self._batch_size = _validate_batch_size(batch_size)
self._num_cores = context.num_cores
self._num_streams = context.num_streams
self._scheduler = _validate_scheduler(context.scheduler)
self._input_shapes = input_shapes
self._cpu_avx_type = AVX_TYPE
self._cpu_vnni = VNNI
BaseEngine.construct_with_context(
self, model, batch_size, context, input_shapes
)

if self._input_shapes:
with override_onnx_input_shapes(
Expand All @@ -817,6 +845,52 @@ def __init__(
)


class KVCacheEngine(Engine):
"""
Engine that can do kv caching.
"""

def __init__(
self,
model: Union[str, "Model", "File"],
batch_size: int = 1,
num_cores: int = None,
num_streams: int = None,
scheduler: Scheduler = None,
input_shapes: List[List[int]] = None,
kv_cache_bools: List[bool] = None,
prev_cache_length: int = 0,
):
BaseEngine.construct(
self, model, batch_size, num_cores, num_streams, scheduler, input_shapes
)

if kv_cache_bools is None:
# If no list was provided, then we assume all outputs except for the first are KV caches
# Note: In the future we can look at the names of outputs to be more sure
#
# Create a boolean list of every output of the model
output_names = get_output_names(self._model_path)
kv_cache_bools = [True for i in range(len(output_names))]
# Assume first input is logits and logits ought not to be cached
kv_cache_bools[0] = False

num_streams = _validate_num_streams(num_streams, self._num_cores)
if self._input_shapes:
raise NotImplementedError("Don't do this yet :)")
else:
self._eng_net = LIB.deepsparse_engine(
self._model_path,
self._batch_size,
self._num_cores,
num_streams,
self._scheduler.value,
None,
kv_cache_bools,
prev_cache_length,
)


def compile_model(
model: Union[str, "Model", "File"],
batch_size: int = 1,
Expand Down