Skip to content

Commit 0d4c45c

Browse files
Add Onnx Config for ImageGPT (#19868)
* add Onnx Config for ImageGPT * add generate_dummy_inputs for onnx config * add TYPE_CHECKING clause * Update doc for generate_dummy_inputs Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
1 parent 9b1dcba commit 0d4c45c

File tree

5 files changed

+69
-2
lines changed

5 files changed

+69
-2
lines changed

docs/source/en/serialization.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ Ready-made configurations include the following architectures:
7474
- GPT-J
7575
- GroupViT
7676
- I-BERT
77+
- ImageGPT
7778
- LayoutLM
7879
- LayoutLMv3
7980
- LeViT

src/transformers/models/imagegpt/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
2222

2323

24-
_import_structure = {"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig"]}
24+
_import_structure = {
25+
"configuration_imagegpt": ["IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ImageGPTConfig", "ImageGPTOnnxConfig"]
26+
}
2527

2628
try:
2729
if not is_vision_available():
@@ -48,7 +50,7 @@
4850

4951

5052
if TYPE_CHECKING:
51-
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig
53+
from .configuration_imagegpt import IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP, ImageGPTConfig, ImageGPTOnnxConfig
5254

5355
try:
5456
if not is_vision_available():

src/transformers/models/imagegpt/configuration_imagegpt.py

+60
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,17 @@
1414
# limitations under the License.
1515
""" OpenAI ImageGPT configuration"""
1616

17+
from collections import OrderedDict
18+
from typing import TYPE_CHECKING, Any, Mapping, Optional
19+
1720
from ...configuration_utils import PretrainedConfig
21+
from ...onnx import OnnxConfig
1822
from ...utils import logging
1923

2024

25+
if TYPE_CHECKING:
26+
from ... import FeatureExtractionMixin, TensorType
27+
2128
logger = logging.get_logger(__name__)
2229

2330
IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
@@ -140,3 +147,56 @@ def __init__(
140147
self.tie_word_embeddings = tie_word_embeddings
141148

142149
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
150+
151+
152+
class ImageGPTOnnxConfig(OnnxConfig):
153+
@property
154+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
155+
return OrderedDict(
156+
[
157+
("input_ids", {0: "batch", 1: "sequence"}),
158+
]
159+
)
160+
161+
def generate_dummy_inputs(
162+
self,
163+
preprocessor: "FeatureExtractionMixin",
164+
batch_size: int = 1,
165+
seq_length: int = -1,
166+
is_pair: bool = False,
167+
framework: Optional["TensorType"] = None,
168+
num_channels: int = 3,
169+
image_width: int = 32,
170+
image_height: int = 32,
171+
) -> Mapping[str, Any]:
172+
"""
173+
Generate inputs to provide to the ONNX exporter for the specific framework
174+
175+
Args:
176+
preprocessor ([`PreTrainedTokenizerBase`] or [`FeatureExtractionMixin`]):
177+
The preprocessor associated with this model configuration.
178+
batch_size (`int`, *optional*, defaults to -1):
179+
The batch size to export the model for (-1 means dynamic axis).
180+
num_choices (`int`, *optional*, defaults to -1):
181+
The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
182+
seq_length (`int`, *optional*, defaults to -1):
183+
The sequence length to export the model for (-1 means dynamic axis).
184+
is_pair (`bool`, *optional*, defaults to `False`):
185+
Indicate if the input is a pair (sentence 1, sentence 2)
186+
framework (`TensorType`, *optional*, defaults to `None`):
187+
The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for.
188+
num_channels (`int`, *optional*, defaults to 3):
189+
The number of channels of the generated images.
190+
image_width (`int`, *optional*, defaults to 40):
191+
The width of the generated images.
192+
image_height (`int`, *optional*, defaults to 40):
193+
The height of the generated images.
194+
195+
Returns:
196+
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
197+
"""
198+
199+
input_image = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
200+
inputs = dict(preprocessor(input_image, framework))
201+
202+
return inputs

src/transformers/onnx/features.py

+3
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ class FeaturesManager:
341341
"question-answering",
342342
onnx_config_cls="models.ibert.IBertOnnxConfig",
343343
),
344+
"imagegpt": supported_features_mapping(
345+
"default", "image-classification", onnx_config_cls="models.imagegpt.ImageGPTOnnxConfig"
346+
),
344347
"layoutlm": supported_features_mapping(
345348
"default",
346349
"masked-lm",

tests/onnx/test_onnx_v2.py

+1
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def test_values_override(self):
193193
("detr", "facebook/detr-resnet-50"),
194194
("distilbert", "distilbert-base-cased"),
195195
("electra", "google/electra-base-generator"),
196+
("imagegpt", "openai/imagegpt-small"),
196197
("resnet", "microsoft/resnet-50"),
197198
("roberta", "roberta-base"),
198199
("roformer", "junnyu/roformer_chinese_base"),

0 commit comments

Comments
 (0)