Skip to content

Commit 67be84b

Browse files
Guang Yangfacebook-github-bot
Guang Yang
authored andcommitted
Script to export 🤗 models (#4723)
Summary: bypass-github-export-checks [Done] ~~Require PR [Make StaticCache configurable at model construct time](huggingface/transformers#32830) in order to export, lower and run the 🤗 model OOTB.~~ [Done] ~~Require huggingface/transformers#33303 or huggingface/transformers#33287 to be merged to 🤗 `transformers` to resolve the export issue introduced by huggingface/transformers#32543 ----------- Now we can take the integration point from 🤗 `transformers` to lower compatible models to ExecuTorch OOTB. - This PR creates a simple script with recipe of XNNPACK. - This PR also created a secret `EXECUTORCH_HT_TOKEN` to allow download checkpoints in the CI - This PR connects the 🤗 "Export to ExecuTorch" e2e workflow to ExecuTorch CI ### Instructions to run the demo: 1. Run the export_hf_model.py to lower gemma-2b to ExecuTorch: ``` python -m extension.export_util.export_hf_model -hfm "google/gemma-2b" # The model is exported statical dims with static KV cache ``` 2. Run the tokenizer.py to generate the binary format for ExecuTorch runtime: ``` python -m extension.llm.tokenizer.tokenizer -t <path_to_downloaded_gemma_checkpoint_dir>/tokenizer.model -o tokenizer.bin ``` 3. Build llm runner by following this guide [step 4](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#step-4-run-on-your-computer-to-validate) 4. Run the lowered model ``` cmake-out/examples/models/llama2/llama_main --model_path=gemma.pte --tokenizer_path=tokenizer.bin --prompt="My name is" ``` OOTB output and perf ``` I 00:00:00.003110 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003360 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003380 executorch:cpuinfo_utils.cpp:158] Number of efficient cores 4 I 00:00:00.003384 executorch:main.cpp:65] Resetting threadpool with num threads = 6 I 00:00:00.014716 executorch:runner.cpp:51] Creating LLaMa runner: model_path=gemma.pte, tokenizer_path=tokenizer_gemma.bin I 00:00:03.065359 executorch:runner.cpp:66] Reading metadata from model I 00:00:03.065391 executorch:metadata_util.h:43] get_n_bos: 1 I 00:00:03.065396 executorch:metadata_util.h:43] get_n_eos: 1 I 00:00:03.065399 executorch:metadata_util.h:43] get_max_seq_len: 123 I 00:00:03.065402 executorch:metadata_util.h:43] use_kv_cache: 1 I 00:00:03.065404 executorch:metadata_util.h:41] The model does not contain use_sdpa_with_kv_cache method, using default value 0 I 00:00:03.065405 executorch:metadata_util.h:43] use_sdpa_with_kv_cache: 0 I 00:00:03.065407 executorch:metadata_util.h:41] The model does not contain append_eos_to_prompt method, using default value 0 I 00:00:03.065409 executorch:metadata_util.h:43] append_eos_to_prompt: 0 I 00:00:03.065411 executorch:metadata_util.h:41] The model does not contain enable_dynamic_shape method, using default value 0 I 00:00:03.065412 executorch:metadata_util.h:43] enable_dynamic_shape: 0 I 00:00:03.130388 executorch:metadata_util.h:43] get_vocab_size: 256000 I 00:00:03.130405 executorch:metadata_util.h:43] get_bos_id: 2 I 00:00:03.130408 executorch:metadata_util.h:43] get_eos_id: 1 My name is Melle. I am a 20 year old girl from Belgium. I am living in the southern part of Belgium. I am 165 cm tall and I weigh 45kg. I like to play sports like swimming, running and playing tennis. I am very interested in music and I like to listen to classical music. I like to sing and I can play the piano. I would like to go to the USA because I like to travel a lot. I am looking for a boy from the USA who is between 18 and 25 years old. I PyTorchObserver {"prompt_tokens":4,"generated_tokens":118,"model_load_start_ms":1723685715497,"model_load_end_ms":1723685718612,"inference_start_ms":1723685718612,"inference_end_ms":1723685732965,"prompt_eval_end_ms":1723685719087,"first_token_ms":1723685719087,"aggregate_sampling_time_ms":182,"SCALING_FACTOR_UNITS_PER_SECOND":1000} I 00:00:17.482472 executorch:stats.h:70] Prompt Tokens: 4 Generated Tokens: 118 I 00:00:17.482475 executorch:stats.h:76] Model Load Time: 3.115000 (seconds) I 00:00:17.482481 executorch:stats.h:86] Total inference time: 14.353000 (seconds) Rate: 8.221278 (tokens/second) I 00:00:17.482483 executorch:stats.h:94] Prompt evaluation: 0.475000 (seconds) Rate: 8.421053 (tokens/second) I 00:00:17.482485 executorch:stats.h:105] Generated 118 tokens: 13.878000 (seconds) Rate: 8.502666 (tokens/second) I 00:00:17.482486 executorch:stats.h:113] Time to first generated token: 0.475000 (seconds) I 00:00:17.482488 executorch:stats.h:120] Sampling time over 122 tokens: 0.182000 (seconds) ``` Pull Request resolved: #4723 Reviewed By: huydhn, kirklandsign Differential Revision: D62543933 Pulled By: guangy10 fbshipit-source-id: 00401a39ba03d7383e4b284d25c8fc62a6695b34
1 parent 2001b3c commit 67be84b

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed

.github/workflows/trunk.yml

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,93 @@ jobs:
351351
PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_model.sh "${MODEL_NAME}" "${BUILD_TOOL}" "${BACKEND}"
352352
echo "::endgroup::"
353353
done
354+
355+
test-huggingface-transformers:
356+
name: test-huggingface-transformers
357+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
358+
secrets: inherit
359+
strategy:
360+
matrix:
361+
hf_model_repo: [google/gemma-2b]
362+
fail-fast: false
363+
with:
364+
secrets-env: EXECUTORCH_HF_TOKEN
365+
runner: linux.12xlarge
366+
docker-image: executorch-ubuntu-22.04-clang12
367+
submodules: 'true'
368+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
369+
timeout: 90
370+
script: |
371+
echo "::group::Set up ExecuTorch"
372+
# The generic Linux job chooses to use base env, not the one setup by the image
373+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
374+
conda activate "${CONDA_ENV}"
375+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh cmake
376+
377+
echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a"
378+
rm -rf cmake-out
379+
cmake \
380+
-DCMAKE_INSTALL_PREFIX=cmake-out \
381+
-DCMAKE_BUILD_TYPE=Release \
382+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
383+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
384+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
385+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
386+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
387+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
388+
-DEXECUTORCH_BUILD_XNNPACK=ON \
389+
-DPYTHON_EXECUTABLE=python \
390+
-Bcmake-out .
391+
cmake --build cmake-out -j9 --target install --config Release
392+
393+
echo "Build llama runner"
394+
dir="examples/models/llama2"
395+
cmake \
396+
-DCMAKE_INSTALL_PREFIX=cmake-out \
397+
-DCMAKE_BUILD_TYPE=Release \
398+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
399+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
400+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
401+
-DEXECUTORCH_BUILD_XNNPACK=ON \
402+
-DPYTHON_EXECUTABLE=python \
403+
-Bcmake-out/${dir} \
404+
${dir}
405+
cmake --build cmake-out/${dir} -j9 --config Release
406+
echo "::endgroup::"
407+
408+
echo "::group::Set up HuggingFace Dependencies"
409+
pip install -U "huggingface_hub[cli]"
410+
huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
411+
pip install accelerate sentencepiece
412+
# TODO(guangyang): Switch to use released transformers library after all required patches are included
413+
pip install "git+https://github.com/huggingface/transformers.git@6cc4dfe3f1e8d421c6d6351388e06e9b123cbfe1"
414+
pip list
415+
echo "::endgroup::"
416+
417+
echo "::group::Export to ExecuTorch"
418+
TOKENIZER_FILE=tokenizer.model
419+
TOKENIZER_BIN_FILE=tokenizer.bin
420+
ET_MODEL_NAME=et_model
421+
# Fetch the file using a Python one-liner
422+
DOWNLOADED_TOKENIZER_FILE_PATH=$(python -c "
423+
from huggingface_hub import hf_hub_download
424+
# Download the file from the Hugging Face Hub
425+
downloaded_path = hf_hub_download(
426+
repo_id='${{ matrix.hf_model_repo }}',
427+
filename='${TOKENIZER_FILE}'
428+
)
429+
print(downloaded_path)
430+
")
431+
if [ -f "$DOWNLOADED_TOKENIZER_FILE_PATH" ]; then
432+
echo "${TOKENIZER_FILE} downloaded successfully at: $DOWNLOADED_TOKENIZER_FILE_PATH"
433+
python -m extension.llm.tokenizer.tokenizer -t $DOWNLOADED_TOKENIZER_FILE_PATH -o ./${TOKENIZER_BIN_FILE}
434+
ls ./tokenizer.bin
435+
else
436+
echo "Failed to download ${TOKENIZER_FILE} from ${{ matrix.hf_model_repo }}."
437+
exit 1
438+
fi
439+
440+
python -m extension.export_util.export_hf_model -hfm=${{ matrix.hf_model_repo }} -o ${ET_MODEL_NAME}
441+
442+
cmake-out/examples/models/llama2/llama_main --model_path=${ET_MODEL_NAME}.pte --tokenizer_path=${TOKENIZER_BIN_FILE} --prompt="My name is"
443+
echo "::endgroup::"
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import os
9+
10+
import torch
11+
import torch.export._trace
12+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13+
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
14+
from torch.nn.attention import SDPBackend
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
from transformers.generation.configuration_utils import GenerationConfig
17+
from transformers.integrations.executorch import convert_and_export_with_cache
18+
from transformers.modeling_utils import PreTrainedModel
19+
20+
21+
def main() -> None:
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument(
24+
"-hfm",
25+
"--hf_model_repo",
26+
required=True,
27+
default=None,
28+
help="a valid huggingface model repo name",
29+
)
30+
parser.add_argument(
31+
"-o",
32+
"--output_name",
33+
required=False,
34+
default=None,
35+
help="output name of the exported model",
36+
)
37+
38+
args = parser.parse_args()
39+
40+
# Configs to HF model
41+
device = "cpu"
42+
dtype = torch.float32
43+
batch_size = 1
44+
max_length = 123
45+
cache_implementation = "static"
46+
attn_implementation = "sdpa"
47+
48+
# Load and configure a HF model
49+
model = AutoModelForCausalLM.from_pretrained(
50+
args.hf_model_repo,
51+
attn_implementation=attn_implementation,
52+
device_map=device,
53+
torch_dtype=dtype,
54+
generation_config=GenerationConfig(
55+
use_cache=True,
56+
cache_implementation=cache_implementation,
57+
max_length=max_length,
58+
cache_config={
59+
"batch_size": batch_size,
60+
"max_cache_len": max_length,
61+
},
62+
),
63+
)
64+
print(f"{model.config}")
65+
print(f"{model.generation_config}")
66+
67+
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
68+
input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"]
69+
cache_position = torch.tensor([0], dtype=torch.long)
70+
71+
def _get_constant_methods(model: PreTrainedModel):
72+
return {
73+
"get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6,
74+
"get_bos_id": model.config.bos_token_id,
75+
"get_eos_id": model.config.eos_token_id,
76+
"get_head_dim": model.config.hidden_size / model.config.num_attention_heads,
77+
"get_max_batch_size": model.generation_config.cache_config.batch_size,
78+
"get_max_seq_len": model.generation_config.cache_config.max_cache_len,
79+
"get_n_bos": 1,
80+
"get_n_eos": 1,
81+
"get_n_kv_heads": model.config.num_key_value_heads,
82+
"get_n_layers": model.config.num_hidden_layers,
83+
"get_vocab_size": model.config.vocab_size,
84+
"use_kv_cache": model.generation_config.use_cache,
85+
}
86+
87+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
88+
89+
exported_prog = convert_and_export_with_cache(model, input_ids, cache_position)
90+
prog = (
91+
to_edge(
92+
exported_prog,
93+
compile_config=EdgeCompileConfig(
94+
_check_ir_validity=False,
95+
_skip_dim_order=True,
96+
),
97+
constant_methods=_get_constant_methods(model),
98+
)
99+
.to_backend(XnnpackPartitioner())
100+
.to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True))
101+
)
102+
out_name = args.output_name if args.output_name else model.config.model_type
103+
filename = os.path.join("./", f"{out_name}.pte")
104+
with open(filename, "wb") as f:
105+
prog.write_to_file(f)
106+
print(f"Saved exported program to {filename}")
107+
108+
109+
if __name__ == "__main__":
110+
main()

0 commit comments

Comments
 (0)