Skip to content

Commit 3127f31

Browse files
cccclaifacebook-github-bot
authored andcommitted
Fix mobile bert fine tune
Summary: As title, it's broken in #9643 Differential Revision: D72472098
1 parent 95d38c4 commit 3127f31

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

Diff for: examples/qualcomm/scripts/mobilebert_fine_tune.py

+27-14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
generate_htp_compiler_spec,
1818
generate_qnn_executorch_compiler_spec,
1919
skip_annotation,
20+
to_edge_transform_and_lower_to_qnn,
2021
)
2122
from executorch.examples.qualcomm.utils import (
2223
build_executorch_binary,
@@ -27,7 +28,7 @@
2728
setup_common_args_and_variables,
2829
SimpleADB,
2930
)
30-
from executorch.exir import to_edge
31+
from executorch.exir import ExecutorchBackendConfig, to_edge
3132
from transformers import BertTokenizer, MobileBertForSequenceClassification
3233

3334

@@ -273,30 +274,42 @@ def calibrator(gm):
273274

274275
quantizer = make_quantizer(quant_dtype=quant_dtype)
275276
backend_options = generate_htp_compiler_spec(quant_dtype is not None)
276-
partitioner = QnnPartitioner(
277-
generate_qnn_executorch_compiler_spec(
278-
soc_model=getattr(QcomChipset, args.model),
279-
backend_options=backend_options,
280-
),
281-
skip_node_id_set=skip_node_id_set,
282-
skip_node_op_set=skip_node_op_set,
277+
# partitioner = QnnPartitioner(
278+
# generate_qnn_executorch_compiler_spec(
279+
# soc_model=getattr(QcomChipset, args.model),
280+
# backend_options=backend_options,
281+
# ),
282+
# skip_node_id_set=skip_node_id_set,
283+
# skip_node_op_set=skip_node_op_set,
284+
# )
285+
backend_options = generate_htp_compiler_spec(
286+
use_fp16=False,
287+
)
288+
compile_spec = generate_qnn_executorch_compiler_spec(
289+
soc_model=QcomChipset.SM8550,
290+
backend_options=backend_options,
283291
)
284292
# skip embedding layer cause it's quantization sensitive
285293
graph_module, _ = skip_annotation(
286294
nn_module=model,
287295
quantizer=quantizer,
288-
partitioner=partitioner,
296+
compiler_specs=compile_spec,
289297
sample_input=inputs[0],
290298
calibration_cb=calibrator,
291299
fp_node_op_set={torch.ops.aten.embedding.default},
292300
)
293301
# lower all graph again, the skipped operators will be left in CPU
294-
exec_prog = to_edge(
295-
torch.export.export(graph_module, inputs[0], strict=True),
296-
).to_executorch()
297-
302+
# exec_prog = to_edge(
303+
# torch.export.export(graph_module, inputs[0], strict=True),
304+
# ).to_executorch()
305+
delegated_program = to_edge_transform_and_lower_to_qnn(
306+
graph_module, inputs[0], compile_spec
307+
)
308+
executorch_program = delegated_program.to_executorch(
309+
config=ExecutorchBackendConfig(extract_delegate_segments=True)
310+
)
298311
with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
299-
file.write(exec_prog.buffer)
312+
file.write(executorch_program.buffer)
300313

301314
if args.compile_only:
302315
return

0 commit comments

Comments
 (0)