Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit 7c04f76

Browse files
bfineranKSGulindhuangnmdhuang
authored and
Benjamin
committed
Update base to transformers v4.30.2 (#81)
* Add recipe_name to default file names * Upgrade to transformers release V4.30.2 (#62) * Update trainer and model flows to accommodate sparseml Disable FP16 on QAT start (#12) * Override LRScheduler when using LRModifiers * Disable FP16 on QAT start * keep wrapped scaler object for training after disabling Using QATMatMul in DistilBERT model class (#41) Removed double quantization of output of context layer. (#45) Fix DataParallel validation forward signatures (#47) * Fix: DataParallel validation forward signatures * Update: generalize forward_fn selection Best model after epoch (#46) fix sclaer check for non fp16 mode in trainer (#38) Mobilebert QAT (#55) * Remove duplicate quantization of vocabulary. enable a QATWrapper for non-parameterized matmuls in BERT self attention (#9) * Utils and auxillary changes update Zoo stub loading for SparseZoo 1.1 refactor (#54) add flag to signal NM integration is active (#32) Add recipe_name to file names * Fix errors introduced in manual cherry-pick upgrade Co-authored-by: Benjamin Fineran <[email protected]> * update build versions for NM fork pypi push (#74) * fix nightly package name (#75) * add make build command (#76) * add GHA workflow files to build nightly and release packages (#77) * add GHA workflow files to build nightly and release packages * fix name --------- Co-authored-by: dhuang <[email protected]> * bump up version to 1.6.0 (#79) Co-authored-by: dhuang <[email protected]> --------- Co-authored-by: Konstantin <[email protected]> Co-authored-by: Konstantin Gulin <[email protected]> Co-authored-by: dhuangnm <[email protected]> Co-authored-by: dhuang <[email protected]>
1 parent acc394c commit 7c04f76

File tree

14 files changed

+290
-26
lines changed

14 files changed

+290
-26
lines changed

.github/workflows/build-nightly.yml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: build-nightly
2+
run-name: ${{ github.workflow }} is to create nightly wheel file for pypi
3+
on:
4+
push:
5+
branches:
6+
- 'main'
7+
schedule:
8+
- cron: '0 0 * * *'
9+
jobs:
10+
build-nightly:
11+
runs-on: ubuntu-22.04
12+
permissions:
13+
id-token: write
14+
contents: read
15+
steps:
16+
- uses: aws-actions/configure-aws-credentials@v2
17+
with:
18+
role-to-assume: arn:aws:iam::498127099666:role/WebIdentity-nm-github-actions-only
19+
aws-region: us-east-1
20+
- uses: actions/checkout@v3
21+
- run: |
22+
pwd
23+
sudo apt-get install python3-pip
24+
pip3 --version
25+
sudo pip3 install virtualenv
26+
virtualenv venv
27+
source venv/bin/activate
28+
pip install -e .
29+
make -B build
30+
deactivate
31+
ls dist/
32+
aws s3 ls s3://nm-github-actions/
33+
aws s3 cp dist/*nightly*.whl s3://nm-github-actions/transformers/

.github/workflows/build-release.yml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: build-release
2+
run-name: ${{ github.workflow }} is to create release wheel file for pypi
3+
on:
4+
push:
5+
branches:
6+
- 'release/[0-9]+.[0-9]+'
7+
8+
jobs:
9+
build-release:
10+
runs-on: ubuntu-22.04
11+
permissions:
12+
id-token: write
13+
contents: read
14+
steps:
15+
- uses: aws-actions/configure-aws-credentials@v2
16+
with:
17+
role-to-assume: arn:aws:iam::498127099666:role/WebIdentity-nm-github-actions-only
18+
aws-region: us-east-1
19+
- uses: actions/checkout@v3
20+
- run: |
21+
pwd
22+
sudo apt-get install python3-pip
23+
pip3 --version
24+
sudo pip3 install virtualenv
25+
virtualenv venv
26+
source venv/bin/activate
27+
pip install -e .
28+
sed -i 's/is_release = False/is_release = True/g' src/transformers/version.py
29+
make -B build
30+
deactivate
31+
ls dist/
32+
aws s3 ls s3://nm-github-actions/
33+
aws s3 cp dist/*.whl s3://nm-github-actions/transformers/

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
1+
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples build
22

33
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
44
export PYTHONPATH = src
@@ -119,3 +119,7 @@ build-release:
119119
python setup.py bdist_wheel
120120
python setup.py sdist
121121
python utils/check_build.py
122+
123+
# neuralmagic: creates wheel file
124+
build:
125+
python3 setup.py sdist bdist_wheel

setup.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,17 +423,22 @@ def run(self):
423423
deps["tqdm"], # progress bars in model download and training scripts
424424
]
425425

426+
# default variable to be overwritten by the version.py file
427+
version = "unknown"
428+
# load and overwrite version and release info from version.py
429+
exec(open(os.path.join("src", "transformers", "version.py")).read())
430+
426431
setup(
427-
name="transformers",
428-
version="4.34.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
432+
name="nm-transformers" if is_release else "nm-transformers-nightly",
433+
version=version, # major.minor.patch to match NM repos, fourth entry is either transformers base version or nightly date
429434
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
430435
author_email="[email protected]",
431436
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
432437
long_description=open("README.md", "r", encoding="utf-8").read(),
433438
long_description_content_type="text/markdown",
434439
keywords="NLP vision speech deep learning transformer pytorch tensorflow jax BERT GPT-2 Wav2Vec2 ViT",
435440
license="Apache 2.0 License",
436-
url="https://github.com/huggingface/transformers",
441+
url="https://github.com/neuralmagic/transformers",
437442
package_dir={"": "src"},
438443
packages=find_packages("src"),
439444
include_package_data=True,

src/transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
1919
# in the namespace without actually importing anything (and especially none of the backends).
2020

21-
__version__ = "4.34.1"
21+
from .version import *
2222

2323
from typing import TYPE_CHECKING
2424

src/transformers/hf_argparser.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,16 @@
2323
from pathlib import Path
2424
from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints
2525

26+
import os
2627
import yaml
2728

29+
from sparsezoo import Model
30+
31+
from .utils.logging import get_logger
32+
33+
34+
logger = get_logger(__name__)
35+
2836

2937
DataClass = NewType("DataClass", Any)
3038
DataClassType = NewType("DataClassType", Any)
@@ -341,12 +349,17 @@ def parse_args_into_dataclasses(
341349
# additional namespace.
342350
outputs.append(namespace)
343351
if return_remaining_strings:
344-
return (*outputs, remaining_args)
352+
return tuple(
353+
*[_download_dataclass_zoo_stub_files(output) for output in outputs],
354+
remaining_args,
355+
)
345356
else:
346357
if remaining_args:
347358
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
348359

349-
return (*outputs,)
360+
return tuple(
361+
[_download_dataclass_zoo_stub_files(output) for output in outputs]
362+
)
350363

351364
def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
352365
"""
@@ -374,7 +387,9 @@ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tu
374387
outputs.append(obj)
375388
if not allow_extra_keys and unused_keys:
376389
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
377-
return tuple(outputs)
390+
return tuple(
391+
[_download_dataclass_zoo_stub_files(output) for output in outputs]
392+
)
378393

379394
def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
380395
"""
@@ -417,3 +432,28 @@ def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tup
417432
"""
418433
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
419434
return tuple(outputs)
435+
436+
def _download_dataclass_zoo_stub_files(data_class: DataClass):
437+
for name, val in data_class.__dict__.items():
438+
if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"):
439+
continue
440+
441+
logger.info(f"Downloading framework files for SparseZoo stub: {val}")
442+
443+
zoo_model = Model(val)
444+
framework_file_paths = [file.path for file in zoo_model.training.default.files]
445+
assert framework_file_paths, "Unable to download any framework files for SparseZoo stub {val}"
446+
framework_file_names = [os.path.basename(path) for path in framework_file_paths]
447+
if "pytorch_model.bin" not in framework_file_names or ("config.json" not in framework_file_names):
448+
raise RuntimeError(
449+
"Unable to find 'pytorch_model.bin' and 'config.json' in framework "
450+
f"files downloaded from {val}. Found {framework_file_names}. Check "
451+
"if the given stub is for a transformers repo model"
452+
)
453+
framework_dir_path = Path(framework_file_paths[0]).parent.absolute()
454+
455+
logger.info(f"Overwriting argument {name} to downloaded {framework_dir_path}")
456+
457+
data_class.__dict__[name] = str(framework_dir_path)
458+
459+
return data_class

src/transformers/models/bert/modeling_bert.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,22 @@ def forward(
241241
return embeddings
242242

243243

244+
class QATMatMul(nn.Module):
245+
def __init__(self):
246+
super().__init__()
247+
248+
# behaves like normal torch.matmul unless a SparseML QuantizationModifier
249+
# is initialized
250+
self.wrap_qat = True
251+
self.qat_wrapper_kwargs = {
252+
"num_inputs": 2,
253+
"input_qconfigs": ["asymmetric", "symmetric"],
254+
}
255+
256+
def forward(self, a: torch.Tensor, b: torch.Tensor):
257+
return torch.matmul(a, b)
258+
259+
244260
class BertSelfAttention(nn.Module):
245261
def __init__(self, config, position_embedding_type=None):
246262
super().__init__()
@@ -258,6 +274,11 @@ def __init__(self, config, position_embedding_type=None):
258274
self.key = nn.Linear(config.hidden_size, self.all_head_size)
259275
self.value = nn.Linear(config.hidden_size, self.all_head_size)
260276

277+
# non-parameterized matmuls will behave as normal torch.matmul ops unless
278+
# Quantization-Aware-Training is invoked
279+
self.attention_scores_matmul = QATMatMul()
280+
self.context_layer_matmul = QATMatMul()
281+
261282
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
262283
self.position_embedding_type = position_embedding_type or getattr(
263284
config, "position_embedding_type", "absolute"
@@ -322,7 +343,7 @@ def forward(
322343
past_key_value = (key_layer, value_layer)
323344

324345
# Take the dot product between "query" and "key" to get the raw attention scores.
325-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
346+
attention_scores = self.attention_scores_matmul(query_layer, key_layer.transpose(-1, -2))
326347

327348
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
328349
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
@@ -362,7 +383,7 @@ def forward(
362383
if head_mask is not None:
363384
attention_probs = attention_probs * head_mask
364385

365-
context_layer = torch.matmul(attention_probs, value_layer)
386+
context_layer = self.context_layer_matmul(attention_probs, value_layer)
366387

367388
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
368389
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)

src/transformers/models/distilbert/modeling_distilbert.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,38 @@ def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
8888
out.detach_()
8989

9090

91+
class QATAttentionScores(nn.Module):
92+
def __init__(self):
93+
super().__init__()
94+
95+
# behaves like normal torch.matmul unless a SparseML QuantizationModifier
96+
# is initialized
97+
self.wrap_qat = True
98+
self.qat_wrapper_kwargs = {
99+
"num_inputs": 2,
100+
"input_qconfigs": ["asymmetric", "symmetric"],
101+
}
102+
103+
def forward(self, a: torch.Tensor, b: torch.Tensor):
104+
return torch.matmul(a, b)
105+
106+
class QATContextLayer(nn.Module):
107+
def __init__(self):
108+
super().__init__()
109+
110+
# behaves like normal torch.matmul unless a SparseML QuantizationModifier
111+
# is initialized
112+
self.wrap_qat = True
113+
self.qat_wrapper_kwargs = {
114+
"num_inputs": 2,
115+
"num_outputs": 0,
116+
"input_qconfigs": ["asymmetric", "symmetric"],
117+
}
118+
119+
def forward(self, a: torch.Tensor, b: torch.Tensor):
120+
return torch.matmul(a, b)
121+
122+
91123
class Embeddings(nn.Module):
92124
def __init__(self, config: PretrainedConfig):
93125
super().__init__()
@@ -159,6 +191,11 @@ def __init__(self, config: PretrainedConfig):
159191
self.pruned_heads: Set[int] = set()
160192
self.attention_head_size = self.dim // self.n_heads
161193

194+
# non-parameterized matmuls will behave as normal torch.matmul ops unless
195+
# Quantization-Aware-Training is invoked
196+
self.attention_scores_matmul = QATAttentionScores()
197+
self.context_layer_matmul = QATContextLayer()
198+
162199
def prune_heads(self, heads: List[int]):
163200
if len(heads) == 0:
164201
return
@@ -217,7 +254,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
217254
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
218255

219256
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
220-
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
257+
scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
221258
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
222259
scores = scores.masked_fill(
223260
mask, torch.tensor(torch.finfo(scores.dtype).min)
@@ -230,7 +267,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
230267
if head_mask is not None:
231268
weights = weights * head_mask
232269

233-
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
270+
context = self.context_layer_matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
234271
context = unshape(context) # (bs, q_length, dim)
235272
context = self.out_lin(context) # (bs, q_length, dim)
236273

@@ -688,7 +725,6 @@ def forward(
688725
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
689726
"""
690727
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
691-
692728
dlbrt_output = self.distilbert(
693729
input_ids=input_ids,
694730
attention_mask=attention_mask,

src/transformers/models/mobilebert/modeling_mobilebert.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,23 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
169169

170170
NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}
171171

172+
class QATEmbeddingTransformation(nn.Module):
173+
def __init__(self, embedded_input_size, hidden_size):
174+
super().__init__()
175+
176+
# Behaves like normal Linear module unless a SparseML QuantizationModifier
177+
# is initialized.
178+
# When initialized, does not quantize inputs.
179+
# Only weights are quantized (inputs come quantized from embeddings)
180+
self.linear = nn.Linear(embedded_input_size, hidden_size)
181+
self.wrap_qat = True
182+
self.qat_wrapper_kwargs = {
183+
"num_inputs": 0,
184+
"num_outputs": 1,
185+
}
186+
187+
def forward(self, x: torch.Tensor):
188+
return self.linear(x)
172189

173190
class MobileBertEmbeddings(nn.Module):
174191
"""Construct the embeddings from word, position and token_type embeddings."""
@@ -185,7 +202,7 @@ def __init__(self, config):
185202

186203
embed_dim_multiplier = 3 if self.trigram_input else 1
187204
embedded_input_size = self.embedding_size * embed_dim_multiplier
188-
self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
205+
self.embedding_transformation = QATEmbeddingTransformation(embedded_input_size, config.hidden_size)
189206

190207
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
191208
self.dropout = nn.Dropout(config.hidden_dropout_prob)

0 commit comments

Comments
 (0)