Skip to content

VLM: Model Tracing Guide #1030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 369 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
369 commits
Select commit Hold shift + click to select a range
3830696
preliminary data pipeline
kylesayrs Nov 26, 2024
1ecaa39
WIP
kylesayrs Nov 26, 2024
9aa9679
delete unnecessary files
kylesayrs Nov 26, 2024
7e6fe17
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Nov 26, 2024
034c0b1
Merge branch 'kylesayrs/gptq-hooks' into kylesayrs/gptq-partition
kylesayrs Nov 26, 2024
a62617c
clean up CustomDataset
kylesayrs Nov 28, 2024
57b5e02
chchchchanges
kylesayrs Nov 29, 2024
fa317fd
wip: use rename to processor, going through tests
kylesayrs Dec 2, 2024
f3f5875
remove labels from calibration dataset rather than assuming that all …
kylesayrs Dec 2, 2024
58c3afe
cleanup
kylesayrs Dec 2, 2024
72aecfc
cleanup, etc
kylesayrs Dec 2, 2024
77217fb
Merge remote-tracking branch 'origin' into kylesayrs/cleanup-custom-d…
kylesayrs Dec 2, 2024
4461a3e
fix typehinting
kylesayrs Dec 2, 2024
fb33001
add typechecking imports
kylesayrs Dec 2, 2024
bf4744a
remove sparseml utilities
kylesayrs Dec 3, 2024
62ae31d
Merge branch 'kylesayrs/remove-sparseml-utilities' into kylesayrs/cle…
kylesayrs Dec 3, 2024
7e516c1
use in model_load
kylesayrs Dec 3, 2024
d69106e
Merge branch 'main' into kylesayrs/calculate_offload_default_gpus
kylesayrs Dec 3, 2024
9e33641
remove use of RECIPE FILE NAME
kylesayrs Dec 3, 2024
58c0fba
rename to RECIPE_FILE_NAME, avoid circular import
kylesayrs Dec 3, 2024
b28aaae
Merge branch 'kylesayrs/remove-sparseml-utilities' into kylesayrs/cle…
kylesayrs Dec 3, 2024
8d13013
image dataset collation
kylesayrs Dec 3, 2024
17cf9f3
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 3, 2024
163ee8f
cleanup, do not handle case where processor is None
kylesayrs Dec 3, 2024
1180b34
remove qa ignore
kylesayrs Dec 3, 2024
ad20ae7
Merge branch 'kylesayrs/remove-sparseml-utilities' into kylesayrs/cle…
kylesayrs Dec 3, 2024
c431958
add documentation
kylesayrs Dec 3, 2024
b48d55d
add data collator arg
kylesayrs Dec 3, 2024
2d201e0
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 3, 2024
0ed5c2c
use default factor
kylesayrs Dec 3, 2024
ca61e90
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 3, 2024
41dd463
wip mllama
kylesayrs Dec 4, 2024
8527e0e
cleanup
kylesayrs Dec 4, 2024
0a8a03f
merge-implement hessian offloading
kylesayrs Dec 4, 2024
fc044e2
better concrete arg handling
kylesayrs Dec 4, 2024
4576712
validate flickr
kylesayrs Dec 4, 2024
5276c58
discover bug, tests and multimodal working
kylesayrs Dec 4, 2024
dffcbc3
dataset split fallbacks
kylesayrs Dec 4, 2024
b3cb229
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 4, 2024
779c9a2
Merge branch 'kylesayrs/dataset-split-fallbacks' into kylesayrs/clean…
kylesayrs Dec 4, 2024
85e3f59
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 4, 2024
e9f150d
move typing
kylesayrs Dec 4, 2024
d061567
cleanup, depreciate remove_columns argument
kylesayrs Dec 4, 2024
55a31ca
silently assign tokenizer to processor
kylesayrs Dec 5, 2024
c14e40e
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 5, 2024
1aba16d
replace tokenizer with processor
kylesayrs Dec 5, 2024
135e459
Merge branch 'kylesayrs/processor-replaces-tokenizer' into kylesayrs/…
kylesayrs Dec 5, 2024
dde2fa7
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 5, 2024
89bda30
defer data collator changes
kylesayrs Dec 5, 2024
0fa4102
reduce warnings
kylesayrs Dec 5, 2024
bc505bf
typehinting, add not-implemented error
kylesayrs Dec 5, 2024
c91ba77
remove todos
kylesayrs Dec 5, 2024
e916936
Delete mllama.py
kylesayrs Dec 5, 2024
0a573a1
update dataset manager api in tests
kylesayrs Dec 5, 2024
853c0a8
typehinting, add not-implemented error
kylesayrs Dec 5, 2024
234ef79
remove todos
kylesayrs Dec 5, 2024
8972dd5
update dataset manager api in tests
kylesayrs Dec 5, 2024
acb1a18
Delete examples/multimodal_vision/qwen_vl2.py
kylesayrs Dec 5, 2024
56b5d12
Delete examples/multimodal_vision/mllama.py
kylesayrs Dec 5, 2024
57c293e
WIP: add pixtral
kylesayrs Dec 5, 2024
537c5ab
pixtral working
kylesayrs Dec 5, 2024
15b3508
move to data pipeline
kylesayrs Dec 6, 2024
42b5fc0
disable_hf_hook context
kylesayrs Dec 6, 2024
bc33e8e
woof
kylesayrs Dec 6, 2024
ca72bbb
change desc
kylesayrs Dec 6, 2024
293640a
fix docstring
kylesayrs Dec 6, 2024
17b3a70
rely on compressed tensors, support offloading
kylesayrs Dec 6, 2024
5e185f2
sequential targets
kylesayrs Dec 6, 2024
4d82180
support match_layers_params
kylesayrs Dec 6, 2024
6a1b2c2
make _update_size private and inferred
kylesayrs Dec 6, 2024
f9ab6fc
make a module
kylesayrs Dec 6, 2024
0dc74dd
fallback
kylesayrs Dec 6, 2024
9e07188
implement basic pipeline
kylesayrs Dec 6, 2024
ed099ef
balance between gpus
kylesayrs Dec 6, 2024
4bbbc49
add proper ignore list
kylesayrs Dec 6, 2024
ae74f45
treat offloaded modules as leaves, treat ignore as sequential target
kylesayrs Dec 7, 2024
31eeb8c
redisable piecewise for vision datasets
kylesayrs Dec 7, 2024
1b24090
implement pipeline fallback
kylesayrs Dec 9, 2024
d97ef2b
Merge remote-tracking branch 'origin' into kylesayrs/processor-replac…
kylesayrs Dec 9, 2024
e87e019
remove subbatch event
kylesayrs Dec 9, 2024
d5c08fb
input device inference
kylesayrs Dec 9, 2024
39ed8ca
do not disable hf hook during tracing
kylesayrs Dec 9, 2024
47ca742
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Dec 9, 2024
c1f5cb2
Merge remote-tracking branch 'origin' into kylesayrs/cleanup-custom-d…
kylesayrs Dec 9, 2024
4711e9f
remove import
kylesayrs Dec 9, 2024
e468197
use find_nodes
kylesayrs Dec 9, 2024
f8591ca
rename piecewise to sequential
kylesayrs Dec 9, 2024
cea02d2
add docstring
kylesayrs Dec 9, 2024
f1f6c0f
begin sequential pipeline testing
kylesayrs Dec 9, 2024
3b0b49f
remove todos, add tests for sequential pipeline
kylesayrs Dec 10, 2024
2c035b3
move function placement
kylesayrs Dec 10, 2024
b93868d
slight partition algorithm change
kylesayrs Dec 10, 2024
146e4be
revert llama3 example
kylesayrs Dec 10, 2024
0e4d8f3
Merge branch 'main' into kylesayrs/dataset-split-fallbacks
kylesayrs Dec 10, 2024
b8e867d
Merge branch 'main' into kylesayrs/processor-replaces-tokenizer
kylesayrs Dec 10, 2024
ccb007f
remove test, fix default in order to fix tests
kylesayrs Dec 10, 2024
e1055b0
bump memory requirements
kylesayrs Dec 11, 2024
70421ed
fix memory and offloading issues
kylesayrs Dec 12, 2024
b102bf5
add missing cache file
kylesayrs Dec 12, 2024
229d3ae
make mllama tracable
kylesayrs Dec 12, 2024
4e0b118
write using comprehesion
kylesayrs Dec 12, 2024
7dc4d2a
fix hessian requirements
kylesayrs Dec 12, 2024
377b2a4
implement offloading for tuple
kylesayrs Dec 12, 2024
adb1627
add save
kylesayrs Dec 12, 2024
ab3fc81
change num samples
kylesayrs Dec 12, 2024
1bf683e
implement intermediates offloading for dataclasses
kylesayrs Dec 12, 2024
8918917
Merge branch 'main' into kylesayrs/processor-replaces-tokenizer
kylesayrs Dec 12, 2024
b75fe15
wrap ignore but do not treat as sequential target
kylesayrs Dec 13, 2024
aa4a23d
tracable pixtral/mistral
kylesayrs Dec 13, 2024
aa532b5
remove double saving
kylesayrs Dec 13, 2024
19e4f97
revert dampening frac
kylesayrs Dec 13, 2024
f95b77f
do not cache model outputs to save memory
kylesayrs Dec 13, 2024
2d890db
fix dataclass case, add tests
kylesayrs Dec 13, 2024
7e69b9d
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Dec 13, 2024
4a22032
Remove docstring
kylesayrs Dec 13, 2024
8d72269
Merge branch 'main' into kylesayrs/processor-replaces-tokenizer
kylesayrs Dec 14, 2024
a71352a
move IntermediatesCache location
kylesayrs Dec 14, 2024
2d249a2
add fake_sequential
kylesayrs Dec 14, 2024
995cb2d
rename fake_sequential to layer_sequential
kylesayrs Dec 14, 2024
e4bca34
pipeline inference
kylesayrs Dec 14, 2024
4a046a5
update docstrings
kylesayrs Dec 14, 2024
f24a2af
fix last layer bug
kylesayrs Dec 14, 2024
691bac4
better inference
kylesayrs Dec 14, 2024
1e15d3e
even better inference
kylesayrs Dec 14, 2024
a4744d9
do now throw warning for calibration with training
kylesayrs Dec 16, 2024
9617e53
add information about how to silence warning
kylesayrs Dec 16, 2024
3b4cac1
nice
kylesayrs Dec 16, 2024
f53a3dd
remove unnecessary warning silencing
kylesayrs Dec 16, 2024
f45d0fa
Merge branch 'kylesayrs/processor-replaces-tokenizer', remote-trackin…
kylesayrs Dec 16, 2024
70a2811
Merge branch 'kylesayrs/dataset-split-fallbacks' into kylesayrs/gptq-…
kylesayrs Dec 16, 2024
fd151e4
add unmerged thing
kylesayrs Dec 16, 2024
d1d42de
fix deleted columns
kylesayrs Dec 16, 2024
92151a1
handle dataset dict case
kylesayrs Dec 17, 2024
4c049db
support torch.nn.Conv2d, silently ignore embeddings
kylesayrs Dec 17, 2024
7667998
handle columns better
kylesayrs Dec 17, 2024
f0eb640
fix tokenizer args
kylesayrs Dec 18, 2024
af86f45
filter_tokenizer_args
kylesayrs Dec 18, 2024
5567a90
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Dec 18, 2024
0438e17
Merge remote-tracking branch 'origin' into kylesayrs/cleanup-custom-d…
kylesayrs Dec 18, 2024
9b61145
update docstring
kylesayrs Dec 18, 2024
2f65d01
remove unused util
kylesayrs Dec 18, 2024
338d1cb
remove debug
kylesayrs Dec 18, 2024
f4fa9c3
more tests
kylesayrs Dec 18, 2024
6bd1721
Merge remote-tracking branch 'origin' into kylesayrs/cleanup-custom-d…
kylesayrs Dec 18, 2024
e757e61
remove duplicate file
kylesayrs Dec 18, 2024
bdfa3d4
better help texts
kylesayrs Dec 18, 2024
cd9dd21
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 18, 2024
f674579
Merge branch 'kylesayrs/calculate_offload_default_gpus' into kylesayr…
kylesayrs Dec 18, 2024
f1e1335
remove future notes, todos
kylesayrs Dec 19, 2024
e59c2e7
remove skipping patching
kylesayrs Dec 19, 2024
4932ec5
remove skipping for none args
kylesayrs Dec 19, 2024
6b7c11f
revert data split fallbacks
kylesayrs Dec 19, 2024
601cb0e
rvert data split fallbacks
kylesayrs Dec 19, 2024
4123636
propagate oom errors, separate data collators
kylesayrs Dec 19, 2024
c1e66e8
apply style, ignore visual on qwen
kylesayrs Dec 19, 2024
dc14e95
remove qwen while unsupported
kylesayrs Dec 19, 2024
47249c5
remove smoothquant while unsupported
kylesayrs Dec 19, 2024
de40a84
clean up examples
kylesayrs Dec 19, 2024
56ca97c
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Dec 19, 2024
7f6e8cd
handle non-fast tokenizers
kylesayrs Dec 20, 2024
1c8afe4
handle non-fast tokenizers
kylesayrs Dec 20, 2024
3a9816c
address nits, add logging
kylesayrs Dec 20, 2024
7be0c88
add back copyrights
kylesayrs Dec 20, 2024
bedbf8c
correctly update helptext
kylesayrs Dec 20, 2024
7c54bed
Merge remote-tracking branch 'origin' into kylesayrs/cleanup-custom-d…
kylesayrs Dec 20, 2024
d27dad3
Merge branch 'main' into kylesayrs/cleanup-custom-dataset
dsikka Dec 20, 2024
42f7892
do not remove prompt key
kylesayrs Dec 20, 2024
4139628
add no copyright to hf files
kylesayrs Dec 20, 2024
15fa27d
remove prompt key
kylesayrs Dec 23, 2024
ae16da3
do not process tokenized datasets, including adding labels
kylesayrs Dec 23, 2024
9a08725
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 23, 2024
1eb7f83
Merge branch 'main' into kylesayrs/cleanup-custom-dataset
dsikka Dec 23, 2024
c3a663a
rename classes so the saved config is the original class
kylesayrs Dec 23, 2024
0d484bf
Merge branch 'main' into kylesayrs/cleanup-custom-dataset
dsikka Dec 23, 2024
ddb6fc3
Merge remote-tracking branch 'origin/kylesayrs/cleanup-custom-dataset…
kylesayrs Dec 23, 2024
e71f4e5
remove default chat template
kylesayrs Dec 23, 2024
966b96b
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 23, 2024
0195fab
support llava-1.5 via installing metadata
kylesayrs Dec 23, 2024
148e617
account for models which improperly do not override the abstract methods
kylesayrs Dec 27, 2024
5ae2300
Merge branch 'kylesayrs/patch-mal-models' into kylesayrs/gptq-partition
kylesayrs Dec 27, 2024
e5dd582
add ChatGLMForConditionalGeneration
kylesayrs Dec 27, 2024
5303df2
list of unfixable errors
kylesayrs Dec 27, 2024
aa16223
Merge branch 'main' into kylesayrs/gptq-partition
dsikka Dec 28, 2024
5124e24
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Dec 29, 2024
14cbc97
add glm license, style
kylesayrs Dec 29, 2024
4ac9018
Merge branch 'main' into kylesayrs/gptq-partition
dsikka Jan 1, 2025
ff470b3
add suggestion to use offload_hessians
kylesayrs Jan 2, 2025
c1c3eaa
update names and comments
kylesayrs Jan 2, 2025
e5af728
change tqdm description, add comment
kylesayrs Jan 2, 2025
8fd93a7
add no vllm copyright to glm
kylesayrs Jan 2, 2025
8e5f693
update comments, remove unnecessary default values
kylesayrs Jan 2, 2025
0499bb1
Merge branch 'main' into kylesayrs/gptq-partition
dsikka Jan 2, 2025
c12c9f0
use text kwarg
kylesayrs Jan 2, 2025
ba360d7
WIP: provide tracing script
kylesayrs Jan 2, 2025
7ba6f60
rename examples to have _example suffix
kylesayrs Jan 3, 2025
c582b00
remove hardcoded value
kylesayrs Jan 3, 2025
435cf0d
update all list
kylesayrs Jan 3, 2025
0d25307
update examples to use w4a16
kylesayrs Jan 3, 2025
9abdea8
llava: clarify changes, undo style changes
kylesayrs Jan 3, 2025
3dca7b3
glm comments, fix isort
kylesayrs Jan 3, 2025
f416674
correct typo 'tracable'
kylesayrs Jan 3, 2025
71faee7
mllama: remove unnecessary definitions
kylesayrs Jan 3, 2025
557467b
add keyboard interrupts to list of unfixable errors
kylesayrs Jan 3, 2025
e158b9b
mistral: remove unnecessary definitions
kylesayrs Jan 3, 2025
dfadc11
remove propagate_error argument
kylesayrs Jan 3, 2025
d146771
pipeline docstrings
kylesayrs Jan 4, 2025
bb77a44
add gptq lifecycle docstring
kylesayrs Jan 4, 2025
14f5d88
layer sequential helpers docstrings
kylesayrs Jan 4, 2025
fde309a
update comments
kylesayrs Jan 4, 2025
e6a8fa8
sequential helpers docstrings
kylesayrs Jan 4, 2025
954cd4e
more docstrings
kylesayrs Jan 4, 2025
00309e9
IntermediatesCache docstrings
kylesayrs Jan 4, 2025
57e8f21
free hessians on finalize
kylesayrs Jan 4, 2025
378afb3
remove unnecessary examples
kylesayrs Jan 4, 2025
83b81be
make diff closer to original implementation
kylesayrs Jan 4, 2025
b6c0a50
Merge branch 'main' into kylesayrs/gptq-partition
kylesayrs Jan 4, 2025
5363d40
use original mask padding function
kylesayrs Jan 4, 2025
ae89688
reduce diff
kylesayrs Jan 4, 2025
1af401f
Merge branch 'kylesayrs/gptq-partition' into kylesayrs/traceability-r…
kylesayrs Jan 4, 2025
8913155
Merge remote-tracking branch 'origin' into kylesayrs/traceability-readme
kylesayrs Jan 9, 2025
d1f9352
merge dreggs
kylesayrs Jan 9, 2025
3230f88
fix link
kylesayrs Jan 9, 2025
f62dadd
sequential targets and ignore
kylesayrs Jan 9, 2025
5a92be0
guide roadmapping
kylesayrs Jan 9, 2025
d906bc5
Defining your own Traceable Model Definitions
kylesayrs Jan 9, 2025
eedfc5a
fix links
kylesayrs Jan 9, 2025
7040bdf
WIP
kylesayrs Jan 9, 2025
0161feb
WIP
kylesayrs Jan 9, 2025
ea46517
add argparse
kylesayrs Jan 10, 2025
76e6078
WIP: more progress
kylesayrs Jan 10, 2025
adadbec
add attempt_trace entrypoint
kylesayrs Jan 10, 2025
eef15b4
general readability, typos
kylesayrs Jan 10, 2025
407f325
first draft readme
kylesayrs Jan 10, 2025
50301b7
fix link
kylesayrs Jan 10, 2025
0e3e8bd
Merge branch 'main' into kylesayrs/traceability-readme
kylesayrs Jan 13, 2025
feeb67e
partial derivatives are not alphanumeric
kylesayrs Jan 13, 2025
d6441f5
rename attempt_trace to trace
kylesayrs Jan 13, 2025
3bd3ca7
Merge remote-tracking branch 'origin' into kylesayrs/traceability-readme
kylesayrs Jan 13, 2025
bb7ca2e
Merge branch 'main' into kylesayrs/traceability-readme
kylesayrs Jan 14, 2025
08fad5d
rename to guide, link to guide in warning
kylesayrs Jan 15, 2025
5f23f52
typos
kylesayrs Jan 19, 2025
6c71263
typo
kylesayrs Jan 19, 2025
08f9f79
add summary
kylesayrs Jan 19, 2025
9e6ceb8
Update src/llmcompressor/pipelines/sequential/README.md
kylesayrs Jan 20, 2025
7536b7d
Merge branch 'main' into kylesayrs/traceability-readme
kylesayrs Jan 20, 2025
32dd0e3
Merge branch 'main' into kylesayrs/traceability-readme
dsikka Jan 23, 2025
68586cb
Merge branch 'main' into kylesayrs/traceability-readme
dsikka Jan 23, 2025
f3d9162
use modality kwarg
kylesayrs Jan 23, 2025
5547e98
add image descriptions, fix typos
kylesayrs Jan 23, 2025
c8659ef
remove mention of sgpt until those changes land
kylesayrs Jan 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"llmcompressor.transformers.text_generation.finetune=llmcompressor.transformers.finetune.text_generation:train", # noqa 501
"llmcompressor.transformers.text_generation.eval=llmcompressor.transformers.finetune.text_generation:eval", # noqa 501
"llmcompressor.transformers.text_generation.oneshot=llmcompressor.transformers.finetune.text_generation:oneshot", # noqa 501
"llmcompressor.trace=llmcompressor.transformers.tracing.debug:main",
]
},
python_requires=">=3.8",
Expand Down
6 changes: 5 additions & 1 deletion src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,11 @@ def on_initialize(self, state: State, **kwargs) -> bool:

except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
warnings.warn(
f"Failed to trace {model_name} with inputs {input_names}. For more "
"information on tracing with the sequential pipeline, see "
"`src/llmcompressor/transformers/tracing/GUIDE.md`"
)
if isinstance(exception, unfixable_errors):
raise exception

Expand Down
6 changes: 6 additions & 0 deletions src/llmcompressor/pipelines/sequential/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Sequential Pipeline #
The sequential pipeline is a data pipeline, primarily used for compressing models with the
[GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py).

If, when using this pipeline, you encounter a `torch.fx.proxy.TraceError`, see the
[Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md).
441 changes: 441 additions & 0 deletions src/llmcompressor/transformers/tracing/GUIDE.md

Large diffs are not rendered by default.

5,319 changes: 5,319 additions & 0 deletions src/llmcompressor/transformers/tracing/assets/Llama_3.2-Vision.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
136 changes: 136 additions & 0 deletions src/llmcompressor/transformers/tracing/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import List, Type, Union, Optional, Dict

import argparse

import torch
import transformers
from transformers import AutoProcessor, PreTrainedModel

from llmcompressor.transformers import tracing
from llmcompressor.utils.pytorch.module import get_no_split_params
from llmcompressor.pipelines.sequential.helpers import trace_subgraphs
from llmcompressor.transformers import DataTrainingArguments, TextGenerationDataset


def parse_args():
parser = argparse.ArgumentParser(description="Trace a model into subgraphs")
parser.add_argument("--model_id", type=str, required=True, help="The stub of the model to load") # noqa: E501
parser.add_argument("--model_class", type=str, required=True, help="The class name of the model") # noqa: E501
parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501
parser.add_argument("--ignore", type=str, nargs="*", default=[], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501
parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501
return parser.parse_args()


def trace(
model_id: str,
model_class: Type[PreTrainedModel],
sequential_targets: Optional[Union[List[str], str]] = None,
ignore: Union[List[str], str] = [],
modality: str = "text",
):
"""
Debug traceability by tracing a pre-trained model into subgraphs

:param model_id: stub of the model to load
:param model_class: class constructor of the pre-trained model. Can use either
HF transformers classes or `Traceable` classes defined by LLM Compressor
:param sequential_targets: targets for sequential tracing, defaults to automatic
inference
:param ignore: patterns to ignore during tracing
:param modality: data modality for dummy tracing data, defaults to 'text'

Example usage from CLI
llmcompressor.trace \
--model_id Qwen/Qwen2-VL-2B-Instruct \
--model_class Qwen2VLForConditionalGeneration \
--sequential_targets Qwen2VLDecoderLayer \
--ignore "lm_head" "re:visual.*" \
--modality text
"""
# Load model
model = model_class.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print("Loaded model")

# Prepare sample data
data_args = DataTrainingArguments(**get_dataset_kwargs(modality))
dataset = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split=data_args.splits["calibration"],
processor=processor,
)(add_labels=False)
sample_input = next(iter(dataset))
sample_input = {k: torch.tensor(v) for k, v in sample_input.items()}
print("Loaded sample data")

# infer sequential targets
if sequential_targets is None:
sequential_targets = get_no_split_params(model)
if isinstance(sequential_targets, str):
sequential_targets = [sequential_targets]

# infer ignore
if isinstance(ignore, str):
ignore = [ignore]

# Attempt trace
print(
"\nAttempting trace\n"
f" model_id={model_id}\n"
f" model_class={model_class.__name__}\n"
f" dataset={data_args.dataset}\n"
f" split={dataset.split}\n"
f" inputs={sample_input.keys()}\n"
f" sequential_targets={sequential_targets}\n"
f" ignore={ignore}\n"
)
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
print(f"Successfully traced model into {len(subgraphs)} subgraphs!\n")


def get_model_class(model_class: str) -> Type[PreTrainedModel]:
model_cls = getattr(tracing, model_class, getattr(transformers, model_class, None))
if model_cls is None:
raise ValueError(f"Could not import model class {model_class}")

return model_cls


def get_dataset_kwargs(modality: str) -> Dict[str, str]:
dataset_kwargs = {
"text": {
"dataset": "ultrachat-200k",
"splits": {"calibration": "test_sft[:1]"},
},
"vision": {
"dataset": "flickr",
"splits": {"calibration": "test[:1]"},
},
}

if modality not in dataset_kwargs:
raise ValueError(f"Modality must be one of {list(dataset_kwargs.keys())}")

return dataset_kwargs[modality]


def main():
args = parse_args()

trace(
model_id=args.model_id,
model_class=get_model_class(args.model_class),
sequential_targets=args.sequential_targets,
ignore=args.ignore,
modality=args.modality,
)


if __name__ == "__main__":
main()
Loading