Skip to content

Commit 84f0186

Browse files
ArthurZuckersaurabhdash2512yonigozlan
authored
Add aya (#36521)
* initial commit * small fix * move stuff to image processing file * remove stuff in validate turn and fix return tensor * remove liquid stuff * in the process of addressing comments * changes to get the right tokenization * new __init__ works * fixing defulat std and mean * works * small testing scipt -- to be deleted before merge * remove redundant code * addressing comments * fix inits, add docs templates * refactor processor, switch to gotocr image processor * remove image proc from init * refactor to working llava-style architecture * Change AyaVisionModel to AyaVisionForConditionalGeneration * add tests * fixups * update doc * Adding logits_to_keep explicitly in ayavision forward to enable compatibility with cohere model * better variable names + remove code paths * Updates to aya_vision.md * address comments * adding copied from * make style and remove unused projector_hidden_act from config * sort init * include usage of fast image proc and proc on cuda in doc * update checkpoint iin test processor * update checkpoint in test processor 2 * remove test_model and update docstring * skip failing tests --------- Co-authored-by: Saurabh Dash <[email protected]> Co-authored-by: yonigozlan <[email protected]>
1 parent c0f8d05 commit 84f0186

17 files changed

+1928
-1
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,8 @@
874874
title: AltCLIP
875875
- local: model_doc/aria
876876
title: Aria
877+
- local: model_doc/aya_vision
878+
title: AyaVision
877879
- local: model_doc/blip
878880
title: BLIP
879881
- local: model_doc/blip-2
+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# AyaVision
18+
19+
## Overview
20+
21+
The Aya Vision 8B and 32B models is a state-of-the-art multilingual multimodal models developed by Cohere For AI. They build on the Aya Expanse recipe to handle both visual and textual information without compromising on the strong multilingual textual performance of the original model.
22+
23+
Aya Vision 8B combines the `Siglip2-so400-384-14` vision encoder with the Cohere CommandR-7B language model further post-trained with the Aya Expanse recipe, creating a powerful vision-language model capable of understanding images and generating text across 23 languages. Whereas, Aya Vision 32B uses Aya Expanse 32B as the language model.
24+
25+
Key features of Aya Vision include:
26+
- Multimodal capabilities in 23 languages
27+
- Strong text-only multilingual capabilities inherited from CommandR-7B post-trained with the Aya Expanse recipe and Aya Expanse 32B
28+
- High-quality visual understanding using the Siglip2-so400-384-14 vision encoder
29+
- Seamless integration of visual and textual information in 23 languages.
30+
31+
<!-- <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/aya_vision_architecture.webp"
32+
alt="drawing" width="600"/>
33+
34+
<small> Aya Vision architecture. </small> -->
35+
36+
Tips:
37+
38+
- Aya Vision is a multimodal model that takes images and text as input and produces text as output.
39+
- Images are represented using the `<image>` tag in the templated input.
40+
- For best results, use the `apply_chat_template` method of the processor to format your inputs correctly.
41+
- The model can process multiple images in a single conversation.
42+
- Aya Vision can understand and generate text in 23 languages, making it suitable for multilingual multimodal applications.
43+
44+
This model was contributed by [saurabhdash](https://huggingface.co/saurabhdash) and [yonigozlan](https://huggingface.co/yonigozlan).
45+
46+
47+
## Usage
48+
49+
Here's how to use Aya Vision for inference:
50+
51+
```python
52+
from transformers import AutoProcessor, AutoModelForImageTextToText
53+
import torch
54+
55+
model_id = "CohereForAI/aya-vision-8b"
56+
torch_device = "cuda:0"
57+
58+
# Use fast image processor
59+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
60+
model = AutoModelForImageTextToText.from_pretrained(
61+
model_id, device_map=torch_device, torch_dtype=torch.float16
62+
)
63+
64+
# Format message with the aya-vision chat template
65+
messages = [
66+
{"role": "user",
67+
"content": [
68+
{"type": "image", "url": "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium"},
69+
{"type": "text", "text": "चित्र में लिखा पाठ क्या कहता है?"},
70+
]},
71+
]
72+
73+
# Process image on CUDA
74+
inputs = processor.apply_chat_template(
75+
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", device=torch_device
76+
).to(model.device)
77+
78+
gen_tokens = model.generate(
79+
**inputs,
80+
max_new_tokens=300,
81+
do_sample=True,
82+
temperature=0.3,
83+
)
84+
85+
gen_text = print(processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True))
86+
```
87+
### Pipeline
88+
89+
```python
90+
from transformers import pipeline
91+
92+
pipe = pipeline(model="CohereForAI/aya-vision-8b", task="image-text-to-text", device_map="auto")
93+
94+
# Format message with the aya-vision chat template
95+
messages = [
96+
{"role": "user",
97+
"content": [
98+
{"type": "image", "url": "https://media.istockphoto.com/id/458012057/photo/istanbul-turkey.jpg?s=612x612&w=0&k=20&c=qogAOVvkpfUyqLUMr_XJQyq-HkACXyYUSZbKhBlPrxo="},
99+
{"type": "text", "text": "Bu resimde hangi anıt gösterilmektedir?"},
100+
]},
101+
]
102+
outputs = pipe(text=messages, max_new_tokens=300, return_full_text=False)
103+
104+
print(outputs)
105+
```
106+
107+
### Multiple Images and Batched Inputs
108+
109+
Aya Vision can process multiple images in a single conversation. Here's how to use it with multiple images:
110+
111+
```python
112+
from transformers import AutoProcessor, AutoModelForImageTextToText
113+
import torch
114+
115+
model_id = "CohereForAI/aya-vision-8b"
116+
117+
processor = AutoProcessor.from_pretrained(model_id)
118+
model = AutoModelForImageTextToText.from_pretrained(
119+
model_id, device_map="cuda:0", torch_dtype=torch.float16
120+
)
121+
122+
# Example with multiple images in a single message
123+
messages = [
124+
{
125+
"role": "user",
126+
"content": [
127+
{
128+
"type": "image",
129+
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
130+
},
131+
{
132+
"type": "image",
133+
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
134+
},
135+
{
136+
"type": "text",
137+
"text": "These images depict two different landmarks. Can you identify them?",
138+
},
139+
],
140+
},
141+
]
142+
143+
inputs = processor.apply_chat_template(
144+
messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
145+
).to(model.device)
146+
147+
gen_tokens = model.generate(
148+
**inputs,
149+
max_new_tokens=300,
150+
do_sample=True,
151+
temperature=0.3,
152+
)
153+
154+
gen_text = processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
155+
print(gen_text)
156+
```
157+
158+
For processing batched inputs (multiple conversations at once):
159+
160+
```python
161+
from transformers import AutoProcessor, AutoModelForImageTextToText
162+
import torch
163+
164+
model_id = "CohereForAI/aya-vision-8b"
165+
166+
processor = AutoProcessor.from_pretrained(model_id)
167+
model = AutoModelForImageTextToText.from_pretrained(
168+
model_id, device_map="cuda:0", torch_dtype=torch.float16
169+
)
170+
171+
# Prepare two different conversations
172+
batch_messages = [
173+
# First conversation with a single image
174+
[
175+
{
176+
"role": "user",
177+
"content": [
178+
{"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
179+
{"type": "text", "text": "Write a haiku for this image"},
180+
],
181+
},
182+
],
183+
# Second conversation with multiple images
184+
[
185+
{
186+
"role": "user",
187+
"content": [
188+
{
189+
"type": "image",
190+
"url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
191+
},
192+
{
193+
"type": "image",
194+
"url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
195+
},
196+
{
197+
"type": "text",
198+
"text": "These images depict two different landmarks. Can you identify them?",
199+
},
200+
],
201+
},
202+
],
203+
]
204+
205+
# Process each conversation separately and combine into a batch
206+
batch_inputs = processor.apply_chat_template(
207+
batch_messages,
208+
padding=True,
209+
add_generation_prompt=True,
210+
tokenize=True,
211+
return_dict=True,
212+
return_tensors="pt"
213+
).to(model.device)
214+
215+
# Generate responses for the batch
216+
batch_outputs = model.generate(
217+
**batch_inputs,
218+
max_new_tokens=300,
219+
do_sample=True,
220+
temperature=0.3,
221+
)
222+
223+
# Decode the generated responses
224+
for i, output in enumerate(batch_outputs):
225+
response = processor.tokenizer.decode(
226+
output[batch_inputs.input_ids.shape[1]:],
227+
skip_special_tokens=True
228+
)
229+
print(f"Response {i+1}:\n{response}\n")
230+
```
231+
232+
## AyaVisionProcessor
233+
234+
[[autodoc]] AyaVisionProcessor
235+
236+
## AyaVisionConfig
237+
238+
[[autodoc]] AyaVisionConfig
239+
240+
## AyaVisionForConditionalGeneration
241+
242+
[[autodoc]] AyaVisionForConditionalGeneration
243+
- forward

src/transformers/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@
194194
"AutoTokenizer",
195195
],
196196
"models.autoformer": ["AutoformerConfig"],
197+
"models.aya_vision": ["AyaVisionConfig", "AyaVisionProcessor"],
197198
"models.bamba": ["BambaConfig"],
198199
"models.bark": [
199200
"BarkCoarseConfig",
@@ -1600,6 +1601,7 @@
16001601
"AutoformerPreTrainedModel",
16011602
]
16021603
)
1604+
_import_structure["models.aya_vision"].extend(["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"])
16031605
_import_structure["models.bamba"].extend(
16041606
[
16051607
"BambaForCausalLM",
@@ -5320,6 +5322,10 @@
53205322
from .models.autoformer import (
53215323
AutoformerConfig,
53225324
)
5325+
from .models.aya_vision import (
5326+
AyaVisionConfig,
5327+
AyaVisionProcessor,
5328+
)
53235329
from .models.bamba import BambaConfig
53245330
from .models.bark import (
53255331
BarkCoarseConfig,
@@ -6765,6 +6771,7 @@
67656771
AutoformerModel,
67666772
AutoformerPreTrainedModel,
67676773
)
6774+
from .models.aya_vision import AyaVisionForConditionalGeneration, AyaVisionPreTrainedModel
67686775
from .models.bamba import BambaForCausalLM, BambaModel, BambaPreTrainedModel
67696776
from .models.bark import (
67706777
BarkCausalModel,

src/transformers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
audio_spectrogram_transformer,
2121
auto,
2222
autoformer,
23+
aya_vision,
2324
bamba,
2425
bark,
2526
bart,

src/transformers/models/auto/configuration_auto.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
("aria_text", "AriaTextConfig"),
4040
("audio-spectrogram-transformer", "ASTConfig"),
4141
("autoformer", "AutoformerConfig"),
42+
("aya_vision", "AyaVisionConfig"),
4243
("bamba", "BambaConfig"),
4344
("bark", "BarkConfig"),
4445
("bart", "BartConfig"),
@@ -359,6 +360,7 @@
359360
("aria_text", "AriaText"),
360361
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
361362
("autoformer", "Autoformer"),
363+
("aya_vision", "AyaVision"),
362364
("bamba", "Bamba"),
363365
("bark", "Bark"),
364366
("bart", "BART"),

src/transformers/models/auto/modeling_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@
818818
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
819819
[
820820
("aria", "AriaForConditionalGeneration"),
821+
("aya_vision", "AyaVisionForConditionalGeneration"),
821822
("blip", "BlipForConditionalGeneration"),
822823
("blip-2", "Blip2ForConditionalGeneration"),
823824
("chameleon", "ChameleonForConditionalGeneration"),

src/transformers/models/auto/processing_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
("align", "AlignProcessor"),
4949
("altclip", "AltCLIPProcessor"),
5050
("aria", "AriaProcessor"),
51+
("aya_vision", "AyaVisionProcessor"),
5152
("bark", "BarkProcessor"),
5253
("blip", "BlipProcessor"),
5354
("blip-2", "Blip2Processor"),

src/transformers/models/auto/tokenization_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
),
7070
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
7171
("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
72+
("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
7273
("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
7374
("bart", ("BartTokenizer", "BartTokenizerFast")),
7475
(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_aya_vision import *
22+
from .modeling_aya_vision import *
23+
from .processing_aya_vision import *
24+
else:
25+
import sys
26+
27+
_file = globals()["__file__"]
28+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

0 commit comments

Comments
 (0)