Skip to content

Commit d282012

Browse files
committed
initial functionality and working example with image classification
1 parent 4df53f9 commit d282012

File tree

14 files changed

+635
-112
lines changed

14 files changed

+635
-112
lines changed

src/deepsparse/v2/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from .pipeline import *
1817
from .operators import *
18+
from .pipeline import *
1919
from .routers import *
2020
from .schedulers import *
2121
from .utils import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
from .postprocess_operator import *
17+
from .preprocess_operator import *
18+
19+
20+
from .pipeline import * # isort:skip
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, Optional, Tuple, Union
16+
17+
from deepsparse.utils import model_to_path
18+
from deepsparse.v2.image_classification import (
19+
ImageClassificationPostProcess,
20+
ImageClassificationPreProcess,
21+
)
22+
from deepsparse.v2.operators.engine_operator import EngineOperator
23+
from deepsparse.v2.pipeline import Pipeline
24+
from deepsparse.v2.routers.router import LinearRouter
25+
from deepsparse.v2.schedulers.scheduler import OperatorScheduler
26+
27+
28+
__all__ = ["ImageClassificationPipeline"]
29+
30+
31+
class ImageClassificationPipeline(Pipeline):
32+
def __init__(
33+
self,
34+
model_path: str,
35+
class_names: Union[None, str, Dict[str, str]] = None,
36+
image_size: Optional[Tuple[int]] = None,
37+
top_k: int = 1,
38+
engine_kwargs: Dict = None,
39+
):
40+
model_path = model_to_path(model_path)
41+
42+
if not engine_kwargs:
43+
engine_kwargs = {}
44+
elif engine_kwargs.get("model_path") != model_path:
45+
# TODO: swap to use logger
46+
print(f"Updating engine_kwargs to use {model_path}")
47+
48+
engine_kwargs["model_path"] = model_path
49+
50+
preproces = ImageClassificationPreProcess(
51+
model_path=model_path, image_size=image_size
52+
)
53+
postprocess = ImageClassificationPostProcess(
54+
top_k=top_k, class_names=class_names
55+
)
56+
57+
engine = EngineOperator(**engine_kwargs)
58+
59+
ops = [preproces, engine, postprocess]
60+
router = LinearRouter(end_route=len(ops))
61+
scheduler = [OperatorScheduler()]
62+
super().__init__(ops=ops, router=router, schedulers=scheduler)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from typing import Any, Dict, List, Optional, Union
17+
18+
import numpy
19+
from pydantic import BaseModel, Field
20+
21+
from deepsparse.v2.operators import Operator
22+
from deepsparse.v2.utils import Context
23+
24+
25+
class ImageClassificationOutput(BaseModel):
26+
"""
27+
Output model for image classification
28+
"""
29+
30+
labels: List[Union[int, str, List[int], List[str]]] = Field(
31+
description="List of labels, one for each prediction"
32+
)
33+
scores: List[Union[float, List[float]]] = Field(
34+
description="List of scores, one for each prediction"
35+
)
36+
37+
38+
__all__ = ["ImageClassificationPostProcess"]
39+
40+
41+
class ImageClassificationPostProcess(Operator):
42+
input_schema = None
43+
output_schema = ImageClassificationOutput
44+
45+
def __init__(
46+
self, top_k: int = 1, class_names: Union[None, str, Dict[str, str]] = None
47+
):
48+
self.top_k = top_k
49+
if isinstance(class_names, str) and class_names.endswith(".json"):
50+
self._class_names = json.load(open(class_names))
51+
elif isinstance(class_names, dict):
52+
self._class_names = class_names
53+
else:
54+
self._class_names = None
55+
56+
def run(self, inp: Any, context: Optional[Context]) -> Dict:
57+
labels, scores = [], []
58+
59+
for prediction_batch in inp[0]:
60+
label = (-prediction_batch).argsort()[: self.top_k]
61+
score = prediction_batch[label]
62+
labels.append(label)
63+
scores.append(score.tolist())
64+
65+
if self._class_names is not None:
66+
labels = numpy.vectorize(self._class_names.__getitem__)(labels)
67+
labels = labels.tolist()
68+
69+
if isinstance(labels[0], numpy.ndarray):
70+
labels = [label.tolist() for label in labels]
71+
72+
if len(labels) == 1:
73+
labels = labels[0]
74+
scores = scores[0]
75+
76+
return {"scores": scores, "labels": labels}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, List, Optional, Tuple
16+
17+
import numpy
18+
import onnx
19+
from PIL import Image
20+
from torchvision import transforms
21+
22+
from deepsparse.image_classification.constants import (
23+
IMAGENET_RGB_MEANS,
24+
IMAGENET_RGB_STDS,
25+
)
26+
from deepsparse.pipelines.computer_vision import ComputerVisionSchema
27+
from deepsparse.v2.operators import Operator
28+
from deepsparse.v2.utils import Context
29+
30+
31+
class ImageClassificationInput(ComputerVisionSchema):
32+
"""
33+
Input model for image classification
34+
"""
35+
36+
37+
__all__ = ["ImageClassificationPreProcess"]
38+
39+
40+
class ImageClassificationPreProcess(Operator):
41+
42+
input_schema = ImageClassificationInput
43+
output_schema = None
44+
45+
def __init__(self, model_path: str, image_size: Optional[Tuple[int]] = None):
46+
self.model_path = model_path
47+
self._image_size = image_size or self._infer_image_size()
48+
non_rand_resize_scale = 256.0 / 224.0 # standard used
49+
self._pre_normalization_transforms = transforms.Compose(
50+
[
51+
transforms.Resize(
52+
tuple(
53+
[
54+
round(non_rand_resize_scale * size)
55+
for size in self._image_size
56+
]
57+
)
58+
),
59+
transforms.CenterCrop(self._image_size),
60+
]
61+
)
62+
63+
def run(self, inp: ImageClassificationInput, context: Optional[Context]) -> Any:
64+
"""
65+
Pre-Process the Inputs for DeepSparse Engine
66+
67+
:param inputs: input model
68+
:return: list of preprocessed numpy arrays
69+
"""
70+
71+
if isinstance(inp.images, numpy.ndarray):
72+
image_batch = inp.images
73+
else:
74+
if isinstance(inp.images, str):
75+
inp.images = [inp.images]
76+
77+
image_batch = list(map(self._preprocess_image, inp.images))
78+
79+
# build batch
80+
image_batch = numpy.stack(image_batch, axis=0)
81+
82+
original_dtype = image_batch.dtype
83+
image_batch = numpy.ascontiguousarray(image_batch, dtype=numpy.float32)
84+
85+
if original_dtype == numpy.uint8:
86+
image_batch /= 255
87+
# normalize entire batch
88+
image_batch -= numpy.asarray(IMAGENET_RGB_MEANS).reshape((-1, 3, 1, 1))
89+
image_batch /= numpy.asarray(IMAGENET_RGB_STDS).reshape((-1, 3, 1, 1))
90+
91+
return [image_batch]
92+
93+
def _preprocess_image(self, image) -> numpy.ndarray:
94+
if isinstance(image, List):
95+
# image given as raw list
96+
image = numpy.asarray(image)
97+
if image.dtype == numpy.float32:
98+
# image is already processed, append and continue
99+
return image
100+
# assume raw image input
101+
# put image in PIL format for torchvision processing
102+
image = image.astype(numpy.uint8)
103+
if image.shape[0] < image.shape[-1]:
104+
# put channel last
105+
image = numpy.einsum("cwh->whc", image)
106+
image = Image.fromarray(image)
107+
elif isinstance(image, str):
108+
# load image from string filepath
109+
image = Image.open(image).convert("RGB")
110+
elif isinstance(image, numpy.ndarray):
111+
image = image.astype(numpy.uint8)
112+
if image.shape[0] < image.shape[-1]:
113+
# put channel last
114+
image = numpy.einsum("cwh->whc", image)
115+
image = Image.fromarray(image)
116+
117+
if not isinstance(image, Image.Image):
118+
raise ValueError(
119+
f"inputs to {self.__class__.__name__} must be a string image "
120+
"file path(s), a list representing a raw image, "
121+
"PIL.Image.Image object(s), or a numpy array representing"
122+
f"the entire pre-processed batch. Found {type(image)}"
123+
)
124+
125+
# apply resize and center crop
126+
image = self._pre_normalization_transforms(image)
127+
image_numpy = numpy.array(image)
128+
image.close()
129+
130+
# make channel first dimension
131+
image_numpy = image_numpy.transpose(2, 0, 1)
132+
return image_numpy
133+
134+
def _infer_image_size(self) -> Tuple[int, ...]:
135+
"""
136+
Infer and return the expected shape of the input tensor
137+
138+
:return: The expected shape of the input tensor from onnx graph
139+
"""
140+
onnx_model = onnx.load(self.model_path)
141+
input_tensor = onnx_model.graph.input[0]
142+
return (
143+
input_tensor.type.tensor_type.shape.dim[2].dim_value,
144+
input_tensor.type.tensor_type.shape.dim[3].dim_value,
145+
)

0 commit comments

Comments
 (0)