Skip to content

Commit e48d9db

Browse files
kylesayrsdsikkamgoin
authored
VLM: Model Tracing Guide (#1030)
## Purpose ## This guide explains the concepts of tracing as they relate to LLM Compressor and how to modify your model to support recipes which require using the Sequential Pipeline. Through reading this guide, you will learn 1. Why tracing is required when compressing with recipes involving the Sequential Pipeline and modifiers such as GPTQModifier 2. How to determine if your model is traceable for your dataset 3. How to modify your model definition to be traceable ## Prerequisites ## * #1031 ## Changes ## * Add a model tracing guide `src/llmcompressor/transformers/tracing/README.md` with pictures * Add a readme for the sequential pipeline which points to the Tracing Guide `src/llmcompressor/pipelines/sequential/README.md` * Add a debug script to help users debug their models for traceability `src/llmcompressor/transformers/tracing/debug.py` * Add the `llm-compressor.attempt_trace` entrypoint for ease of use * Swap the order of arguments in `llava_example.py` and and `pixtral_example.py` to match the order of arguments on the modifier ## Testing ## Use the `llmcompressor.attempt_trace` debug script ```bash llmcompressor.attempt_trace \ --model_id llava-hf/llava-1.5-7b-hf --model_class TraceableLlavaForConditionalGeneration --sequential-targets LlamaDecoderLayer --ignore "re:.*lm_head" "re:vision_tower.*" "re:multi_modal_projector.*" --multimodal_data ``` ## Stretch ## It might be nice if this tracing debugger tool also printed the model graph to an svg --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent 6377f1e commit e48d9db

File tree

10 files changed

+5908
-1
lines changed

10 files changed

+5908
-1
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
"llmcompressor.transformers.text_generation.finetune=llmcompressor.transformers.finetune.text_generation:train", # noqa 501
9595
"llmcompressor.transformers.text_generation.eval=llmcompressor.transformers.finetune.text_generation:eval", # noqa 501
9696
"llmcompressor.transformers.text_generation.oneshot=llmcompressor.transformers.finetune.text_generation:oneshot", # noqa 501
97+
"llmcompressor.trace=llmcompressor.transformers.tracing.debug:main",
9798
]
9899
},
99100
python_requires=">=3.8",

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,11 @@ def on_initialize(self, state: State, **kwargs) -> bool:
244244

245245
except Exception as exception:
246246
if isinstance(exception, torch.fx.proxy.TraceError):
247-
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
247+
warnings.warn(
248+
f"Failed to trace {model_name} with inputs {input_names}. For more "
249+
"information on tracing with the sequential pipeline, see "
250+
"`src/llmcompressor/transformers/tracing/GUIDE.md`"
251+
)
248252
if isinstance(exception, unfixable_errors):
249253
raise exception
250254

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Sequential Pipeline #
2+
The sequential pipeline is a data pipeline, primarily used for compressing models with the
3+
[GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py).
4+
5+
If, when using this pipeline, you encounter a `torch.fx.proxy.TraceError`, see the
6+
[Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md).

src/llmcompressor/transformers/tracing/GUIDE.md

Lines changed: 441 additions & 0 deletions
Large diffs are not rendered by default.

src/llmcompressor/transformers/tracing/assets/Llama_3.2-Vision.svg

Lines changed: 5319 additions & 0 deletions
Loading
Loading
Loading
Loading
Loading
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import List, Type, Union, Optional, Dict
2+
3+
import argparse
4+
5+
import torch
6+
import transformers
7+
from transformers import AutoProcessor, PreTrainedModel
8+
9+
from llmcompressor.transformers import tracing
10+
from llmcompressor.utils.pytorch.module import get_no_split_params
11+
from llmcompressor.pipelines.sequential.helpers import trace_subgraphs
12+
from llmcompressor.transformers import DataTrainingArguments, TextGenerationDataset
13+
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser(description="Trace a model into subgraphs")
17+
parser.add_argument("--model_id", type=str, required=True, help="The stub of the model to load") # noqa: E501
18+
parser.add_argument("--model_class", type=str, required=True, help="The class name of the model") # noqa: E501
19+
parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501
20+
parser.add_argument("--ignore", type=str, nargs="*", default=[], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501
21+
parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501
22+
return parser.parse_args()
23+
24+
25+
def trace(
26+
model_id: str,
27+
model_class: Type[PreTrainedModel],
28+
sequential_targets: Optional[Union[List[str], str]] = None,
29+
ignore: Union[List[str], str] = [],
30+
modality: str = "text",
31+
):
32+
"""
33+
Debug traceability by tracing a pre-trained model into subgraphs
34+
35+
:param model_id: stub of the model to load
36+
:param model_class: class constructor of the pre-trained model. Can use either
37+
HF transformers classes or `Traceable` classes defined by LLM Compressor
38+
:param sequential_targets: targets for sequential tracing, defaults to automatic
39+
inference
40+
:param ignore: patterns to ignore during tracing
41+
:param modality: data modality for dummy tracing data, defaults to 'text'
42+
43+
Example usage from CLI
44+
llmcompressor.trace \
45+
--model_id Qwen/Qwen2-VL-2B-Instruct \
46+
--model_class Qwen2VLForConditionalGeneration \
47+
--sequential_targets Qwen2VLDecoderLayer \
48+
--ignore "lm_head" "re:visual.*" \
49+
--modality text
50+
"""
51+
# Load model
52+
model = model_class.from_pretrained(
53+
model_id,
54+
device_map="auto",
55+
torch_dtype="auto",
56+
)
57+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
58+
print("Loaded model")
59+
60+
# Prepare sample data
61+
data_args = DataTrainingArguments(**get_dataset_kwargs(modality))
62+
dataset = TextGenerationDataset.load_from_registry(
63+
data_args.dataset,
64+
data_args=data_args,
65+
split=data_args.splits["calibration"],
66+
processor=processor,
67+
)(add_labels=False)
68+
sample_input = next(iter(dataset))
69+
sample_input = {k: torch.tensor(v) for k, v in sample_input.items()}
70+
print("Loaded sample data")
71+
72+
# infer sequential targets
73+
if sequential_targets is None:
74+
sequential_targets = get_no_split_params(model)
75+
if isinstance(sequential_targets, str):
76+
sequential_targets = [sequential_targets]
77+
78+
# infer ignore
79+
if isinstance(ignore, str):
80+
ignore = [ignore]
81+
82+
# Attempt trace
83+
print(
84+
"\nAttempting trace\n"
85+
f" model_id={model_id}\n"
86+
f" model_class={model_class.__name__}\n"
87+
f" dataset={data_args.dataset}\n"
88+
f" split={dataset.split}\n"
89+
f" inputs={sample_input.keys()}\n"
90+
f" sequential_targets={sequential_targets}\n"
91+
f" ignore={ignore}\n"
92+
)
93+
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
94+
print(f"Successfully traced model into {len(subgraphs)} subgraphs!\n")
95+
96+
97+
def get_model_class(model_class: str) -> Type[PreTrainedModel]:
98+
model_cls = getattr(tracing, model_class, getattr(transformers, model_class, None))
99+
if model_cls is None:
100+
raise ValueError(f"Could not import model class {model_class}")
101+
102+
return model_cls
103+
104+
105+
def get_dataset_kwargs(modality: str) -> Dict[str, str]:
106+
dataset_kwargs = {
107+
"text": {
108+
"dataset": "ultrachat-200k",
109+
"splits": {"calibration": "test_sft[:1]"},
110+
},
111+
"vision": {
112+
"dataset": "flickr",
113+
"splits": {"calibration": "test[:1]"},
114+
},
115+
}
116+
117+
if modality not in dataset_kwargs:
118+
raise ValueError(f"Modality must be one of {list(dataset_kwargs.keys())}")
119+
120+
return dataset_kwargs[modality]
121+
122+
123+
def main():
124+
args = parse_args()
125+
126+
trace(
127+
model_id=args.model_id,
128+
model_class=get_model_class(args.model_class),
129+
sequential_targets=args.sequential_targets,
130+
ignore=args.ignore,
131+
modality=args.modality,
132+
)
133+
134+
135+
if __name__ == "__main__":
136+
main()

0 commit comments

Comments
 (0)