|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -from typing import Dict, List, Optional |
| 15 | +from typing import Any, Dict, List, Optional |
| 16 | + |
| 17 | +import numpy |
| 18 | +from pydantic import BaseModel |
16 | 19 |
|
17 | 20 | from deepsparse import (
|
18 | 21 | BaseLogger,
|
|
24 | 27 | from deepsparse.server.config import EndpointConfig, ServerConfig
|
25 | 28 |
|
26 | 29 |
|
27 |
| -__all__ = ["server_logger_from_config"] |
| 30 | +__all__ = ["server_logger_from_config", "prep_outputs_for_serialization"] |
28 | 31 |
|
29 | 32 |
|
30 | 33 | def server_logger_from_config(config: ServerConfig) -> BaseLogger:
|
@@ -58,6 +61,26 @@ def server_logger_from_config(config: ServerConfig) -> BaseLogger:
|
58 | 61 | )
|
59 | 62 |
|
60 | 63 |
|
| 64 | +def prep_outputs_for_serialization(pipeline_outputs: Any): |
| 65 | + """ |
| 66 | + Prepares a pipeline output for JSON serialization by converting any numpy array |
| 67 | + field to a list. For large numpy arrays, this operation will take a while to run. |
| 68 | +
|
| 69 | + :param pipeline_outputs: output data to clean |
| 70 | + :return: cleaned pipeline_outputs |
| 71 | + """ |
| 72 | + if isinstance(pipeline_outputs, BaseModel): |
| 73 | + for field_name in pipeline_outputs.__fields__.keys(): |
| 74 | + field_value = getattr(pipeline_outputs, field_name) |
| 75 | + if isinstance(field_value, numpy.ndarray): |
| 76 | + # numpy arrays aren't JSON serializable |
| 77 | + setattr(pipeline_outputs, field_name, field_value.tolist()) |
| 78 | + elif isinstance(pipeline_outputs, numpy.ndarray): |
| 79 | + pipeline_outputs = pipeline_outputs.tolist() |
| 80 | + |
| 81 | + return pipeline_outputs |
| 82 | + |
| 83 | + |
61 | 84 | def _extract_system_logging_from_endpoints(
|
62 | 85 | endpoints: List[EndpointConfig],
|
63 | 86 | ) -> Dict[str, SystemLoggingGroup]:
|
|
0 commit comments