Skip to content

Commit 2cf112a

Browse files
author
Sara Adkins
authored
Bug Fix: Make Numpy Array Outputs JSON Serializable for Server (#1192)
* make numpy arrays json serializable * quality * Note runtime of tolist() * quality * fix case where output isn't a BaseModel * quality
1 parent 6a5560d commit 2cf112a

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

src/deepsparse/server/helpers.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, List, Optional
15+
from typing import Any, Dict, List, Optional
16+
17+
import numpy
18+
from pydantic import BaseModel
1619

1720
from deepsparse import (
1821
BaseLogger,
@@ -24,7 +27,7 @@
2427
from deepsparse.server.config import EndpointConfig, ServerConfig
2528

2629

27-
__all__ = ["server_logger_from_config"]
30+
__all__ = ["server_logger_from_config", "prep_outputs_for_serialization"]
2831

2932

3033
def server_logger_from_config(config: ServerConfig) -> BaseLogger:
@@ -58,6 +61,26 @@ def server_logger_from_config(config: ServerConfig) -> BaseLogger:
5861
)
5962

6063

64+
def prep_outputs_for_serialization(pipeline_outputs: Any):
65+
"""
66+
Prepares a pipeline output for JSON serialization by converting any numpy array
67+
field to a list. For large numpy arrays, this operation will take a while to run.
68+
69+
:param pipeline_outputs: output data to clean
70+
:return: cleaned pipeline_outputs
71+
"""
72+
if isinstance(pipeline_outputs, BaseModel):
73+
for field_name in pipeline_outputs.__fields__.keys():
74+
field_value = getattr(pipeline_outputs, field_name)
75+
if isinstance(field_value, numpy.ndarray):
76+
# numpy arrays aren't JSON serializable
77+
setattr(pipeline_outputs, field_name, field_value.tolist())
78+
elif isinstance(pipeline_outputs, numpy.ndarray):
79+
pipeline_outputs = pipeline_outputs.tolist()
80+
81+
return pipeline_outputs
82+
83+
6184
def _extract_system_logging_from_endpoints(
6285
endpoints: List[EndpointConfig],
6386
) -> Dict[str, SystemLoggingGroup]:

src/deepsparse/server/server.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
SystemLoggingConfig,
3535
)
3636
from deepsparse.server.config_hot_reloading import start_config_watcher
37-
from deepsparse.server.helpers import server_logger_from_config
37+
from deepsparse.server.helpers import (
38+
prep_outputs_for_serialization,
39+
server_logger_from_config,
40+
)
3841
from deepsparse.server.system_logging import (
3942
SystemLoggingMiddleware,
4043
log_system_information,
@@ -261,6 +264,7 @@ def _predict(request: pipeline.input_schema):
261264
server_logger=server_logger,
262265
system_logging_config=system_logging_config,
263266
)
267+
pipeline_outputs = prep_outputs_for_serialization(pipeline_outputs)
264268
return pipeline_outputs
265269

266270
def _predict_from_files(request: List[UploadFile]):

0 commit comments

Comments
 (0)