Skip to content

Commit 9876cc7

Browse files
authored
more inputs support for LLM exporter (microsoft#19005)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 07d3aed commit 9876cc7

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

onnxruntime/python/tools/transformers/large_model_exporter.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -224,24 +224,35 @@ def fetch_onnx_inputs_outputs_name(
224224
if not num_of_past_key:
225225
num_of_past_key = model.config.num_hidden_layers
226226

227-
onnx_inp_names = ("input_ids", "attention_mask")
227+
# filter out constant inputs
228+
onnx_inp_names = tuple(
229+
[torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)]
230+
)
231+
assert (
232+
"input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names
233+
), "input_ids and attention_mask must be existed in inputs"
228234
onnx_out_names = ("logits",)
229235
onnx_dynamic_axes = {
230236
"input_ids": {0: "batch_size", 1: "seq_len"},
231237
"attention_mask": {0: "batch_size", 1: "seq_len"},
232238
}
239+
# add dyanmic dimensions for the unkonw inputs
240+
for idx, name in enumerate(onnx_inp_names):
241+
if name not in onnx_dynamic_axes:
242+
unknown_dims = {i: f"{idx}__unknown_dims__{i}" for i in range(onnx_inputs[idx].dim())}
243+
onnx_dynamic_axes[name] = unknown_dims
233244
if input_with_past:
234245
for i in range(num_of_past_key):
235-
onnx_inp_names += (f"present_key.{i}",)
236-
onnx_inp_names += (f"present_values.{i}",)
246+
onnx_inp_names += (f"past_key_values.{i}.key",)
247+
onnx_inp_names += (f"past_key_values.{i}.value",)
237248

238249
onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis
239250
onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis
240251

241252
if with_past or input_with_past:
242253
for i in range(num_of_past_key):
243-
onnx_out_names += (f"past_key.{i}",)
244-
onnx_out_names += (f"past_values.{i}",)
254+
onnx_out_names += (f"present.{i}.key",)
255+
onnx_out_names += (f"present.{i}.value",)
245256
onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis
246257
onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis
247258

0 commit comments

Comments
 (0)