Skip to content

Commit a4b311d

Browse files
authored
Update numpy version to latest (#1094)
* Update numpy version to latest * Update problem line to use a helper that takes care of numpy raising ValueErrors when trying to convert ragged sequences
1 parent 79c4ada commit a4b311d

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _parse_requirements_file(file_path):
8585

8686

8787
_deps = [
88-
"numpy>=1.16.3,<=1.21.6",
88+
"numpy>=1.16.3",
8989
"onnx>=1.5.0,<1.15.0",
9090
"pydantic>=1.8.2",
9191
"requests>=2.0.0",

Diff for: src/deepsparse/transformers/pipelines/question_answering.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def process_inputs(
250250
span_engine_inputs = []
251251
span_extra_info = []
252252
num_spans = len(tokenized_example["input_ids"])
253+
253254
for span in range(num_spans):
254255
span_input = [
255256
numpy.array(tokenized_example[key][span])
@@ -259,7 +260,7 @@ def process_inputs(
259260

260261
span_extra_info.append(
261262
{
262-
key: numpy.array(tokenized_example[key][span])
263+
key: _convert_to_numpy_array(tokenized_example[key][span])
263264
for key in tokenized_example.keys()
264265
if key not in self.onnx_input_names
265266
}
@@ -583,3 +584,18 @@ def _save_predictions(self, all_predictions, all_nbest_json, scores_diff_json):
583584
mode = "a" if os.path.exists(null_odds_file) else "w"
584585
with open(null_odds_file, mode) as writer:
585586
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
587+
588+
589+
def _convert_to_numpy_array(array):
590+
"""
591+
Convert a list to a numpy array. If the list contains non-numeric values,
592+
convert to an array of objects.
593+
594+
:param array: The list to convert
595+
:return: The converted numpy array
596+
"""
597+
try:
598+
return numpy.array(array)
599+
except ValueError:
600+
# If the array contains non-numeric values, convert to an array of objects
601+
return numpy.array(array, dtype=object)

0 commit comments

Comments
 (0)