Skip to content

Commit 6377f1e

Browse files
authored
[VLM] Multimodal Data Collator (#1087)
## Purpose ## * Move data collators to example script as per @mgoin's suggestion #1032 (comment) ## Changes ## * Remove data collator definitions in LC * Add data collators in examples with a comment indicating that this is for multimodal inputs ## Testing ## Ran all multimodal vision models * Qwen2 * Pixtral * Mllama * Llava * Phi3_vision --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 138bdaa commit 6377f1e

File tree

6 files changed

+48
-69
lines changed

6 files changed

+48
-69
lines changed

examples/multimodal_vision/llava_example.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import requests
2+
import torch
23
from PIL import Image
34
from transformers import AutoProcessor
45

56
from llmcompressor.modifiers.quantization import GPTQModifier
67
from llmcompressor.transformers import oneshot
78
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration
8-
from llmcompressor.transformers.utils.data_collator import llava_data_collator
99

1010
# Load model.
1111
model_id = "llava-hf/llava-1.5-7b-hf"
@@ -20,6 +20,13 @@
2020
NUM_CALIBRATION_SAMPLES = 512
2121
MAX_SEQUENCE_LENGTH = 2048
2222

23+
24+
# Define a oneshot data collator for multimodal inputs.
25+
def data_collator(batch):
26+
assert len(batch) == 1
27+
return {key: torch.tensor(value) for key, value in batch[0].items()}
28+
29+
2330
# Recipe
2431
recipe = [
2532
GPTQModifier(
@@ -40,7 +47,7 @@
4047
max_seq_length=MAX_SEQUENCE_LENGTH,
4148
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
4249
trust_remote_code_model=True,
43-
data_collator=llava_data_collator,
50+
data_collator=data_collator,
4451
)
4552

4653
# Confirm generations of the quantized model look sane.

examples/multimodal_vision/mllama_example.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import requests
2+
import torch
23
from PIL import Image
34
from transformers import AutoProcessor
45

56
from llmcompressor.modifiers.quantization import GPTQModifier
67
from llmcompressor.transformers import oneshot
78
from llmcompressor.transformers.tracing import TraceableMllamaForConditionalGeneration
8-
from llmcompressor.transformers.utils.data_collator import mllama_data_collator
99

1010
# Load model.
1111
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@@ -20,6 +20,13 @@
2020
NUM_CALIBRATION_SAMPLES = 512
2121
MAX_SEQUENCE_LENGTH = 2048
2222

23+
24+
# Define a oneshot data collator for multimodal inputs.
25+
def data_collator(batch):
26+
assert len(batch) == 1
27+
return {key: torch.tensor(value) for key, value in batch[0].items()}
28+
29+
2330
# Recipe
2431
recipe = [
2532
GPTQModifier(
@@ -39,7 +46,7 @@
3946
max_seq_length=MAX_SEQUENCE_LENGTH,
4047
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
4148
trust_remote_code_model=True,
42-
data_collator=mllama_data_collator,
49+
data_collator=data_collator,
4350
)
4451

4552
# Confirm generations of the quantized model look sane.

examples/multimodal_vision/phi3_vision_example.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import torch
12
from datasets import load_dataset
23
from transformers import AutoModelForCausalLM, AutoProcessor
34

45
from llmcompressor.modifiers.quantization import GPTQModifier
56
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
67
from llmcompressor.transformers import oneshot
7-
from llmcompressor.transformers.utils.data_collator import phi3_vision_data_collator
88

99
# Load model.
1010
model_id = "microsoft/Phi-3-vision-128k-instruct"
@@ -60,6 +60,12 @@ def tokenize(sample):
6060
ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names)
6161

6262

63+
# Define a oneshot data collator for multimodal inputs.
64+
def data_collator(batch):
65+
assert len(batch) == 1
66+
return {key: torch.tensor(value) for key, value in batch[0].items()}
67+
68+
6369
# Recipe
6470
recipe = [
6571
SmoothQuantModifier(smoothing_strength=0.8),
@@ -79,7 +85,7 @@ def tokenize(sample):
7985
max_seq_length=MAX_SEQUENCE_LENGTH,
8086
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
8187
trust_remote_code_model=True,
82-
data_collator=phi3_vision_data_collator,
88+
data_collator=data_collator,
8389
)
8490

8591
# Confirm generations of the quantized model look sane.

examples/multimodal_vision/pixtral_example.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import requests
2+
import torch
23
from PIL import Image
34
from transformers import AutoProcessor
45

56
from llmcompressor.modifiers.quantization import GPTQModifier
67
from llmcompressor.transformers import oneshot
78
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration
8-
from llmcompressor.transformers.utils.data_collator import pixtral_data_collator
99

1010
# Load model.
1111
model_id = "mgoin/pixtral-12b"
@@ -20,6 +20,17 @@
2020
NUM_CALIBRATION_SAMPLES = 512
2121
MAX_SEQUENCE_LENGTH = 2048
2222

23+
24+
# Define a oneshot data collator for multimodal inputs.
25+
def data_collator(batch):
26+
assert len(batch) == 1
27+
return {
28+
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
29+
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
30+
"pixel_values": torch.tensor(batch[0]["pixel_values"])[0],
31+
}
32+
33+
2334
# Recipe
2435
recipe = [
2536
GPTQModifier(
@@ -40,7 +51,7 @@
4051
max_seq_length=MAX_SEQUENCE_LENGTH,
4152
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
4253
trust_remote_code_model=True,
43-
data_collator=pixtral_data_collator,
54+
data_collator=data_collator,
4455
)
4556

4657
# Confirm generations of the quantized model look sane.

examples/multimodal_vision/qwen2_vl_example.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import base64
22
from io import BytesIO
33

4+
import torch
45
from datasets import load_dataset
56
from qwen_vl_utils import process_vision_info
67
from transformers import AutoProcessor
78

89
from llmcompressor.modifiers.quantization import GPTQModifier
910
from llmcompressor.transformers import oneshot
1011
from llmcompressor.transformers.tracing import TraceableQwen2VLForConditionalGeneration
11-
from llmcompressor.transformers.utils.data_collator import qwen2_vl_data_collator
1212

1313
# Load model.
1414
model_id = "Qwen/Qwen2-VL-2B-Instruct"
@@ -65,6 +65,13 @@ def preprocess_and_tokenize(example):
6565

6666
ds = ds.map(preprocess_and_tokenize, remove_columns=ds["calibration"].column_names)
6767

68+
69+
# Define a oneshot data collator for multimodal inputs.
70+
def data_collator(batch):
71+
assert len(batch) == 1
72+
return {key: torch.tensor(value) for key, value in batch[0].items()}
73+
74+
6875
# Recipe
6976
recipe = [
7077
GPTQModifier(
@@ -84,7 +91,7 @@ def preprocess_and_tokenize(example):
8491
max_seq_length=MAX_SEQUENCE_LENGTH,
8592
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
8693
trust_remote_code_model=True,
87-
data_collator=qwen2_vl_data_collator,
94+
data_collator=data_collator,
8895
)
8996

9097
# Confirm generations of the quantized model look sane.

src/llmcompressor/transformers/utils/data_collator.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)