Skip to content

Commit c73a92f

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[BE][CI] bump ruff to 0.9.2: multiline assert statements (pytorch#144546)
Reference: https://docs.astral.sh/ruff/formatter/black/#assert-statements > Unlike Black, Ruff prefers breaking the message over breaking the assertion, similar to how both Ruff and Black prefer breaking the assignment value over breaking the assignment target: > > ```python > # Input > assert ( > len(policy_types) >= priority + num_duplicates > ), f"This tests needs at least {priority+num_duplicates} many types." > > > # Black > assert ( > len(policy_types) >= priority + num_duplicates > ), f"This tests needs at least {priority+num_duplicates} many types." > > # Ruff > assert len(policy_types) >= priority + num_duplicates, ( > f"This tests needs at least {priority + num_duplicates} many types." > ) > ``` Pull Request resolved: pytorch#144546 Approved by: https://github.com/malfet
1 parent f0d0042 commit c73a92f

File tree

84 files changed

+634
-622
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+634
-622
lines changed

.github/scripts/label_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def gh_get_labels(org: str, repo: str) -> list[str]:
6363
update_labels(labels, info)
6464

6565
last_page = get_last_page_num_from_header(header)
66-
assert (
67-
last_page > 0
68-
), "Error reading header info to determine total number of pages of labels"
66+
assert last_page > 0, (
67+
"Error reading header info to determine total number of pages of labels"
68+
)
6969
for page_number in range(2, last_page + 1): # skip page 1
7070
_, info = request_for_labels(prefix + f"&page={page_number}")
7171
update_labels(labels, info)

.lintrunner.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1476,7 +1476,7 @@ init_command = [
14761476
'black==23.12.1',
14771477
'usort==1.0.8.post1',
14781478
'isort==5.13.2',
1479-
'ruff==0.8.4', # sync with RUFF
1479+
'ruff==0.9.2', # sync with RUFF
14801480
]
14811481
is_formatter = true
14821482

@@ -1561,7 +1561,7 @@ init_command = [
15611561
'python3',
15621562
'tools/linter/adapters/pip_init.py',
15631563
'--dry-run={{DRYRUN}}',
1564-
'ruff==0.8.4', # sync with PYFMT
1564+
'ruff==0.9.2', # sync with PYFMT
15651565
]
15661566
is_formatter = true
15671567

benchmarks/dynamo/cachebench.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ def _run_torchbench_model(
101101
torchbench_file = os.path.join(
102102
os.path.dirname(cur_file), BENCHMARK_FILE[cmd_args.benchmark]
103103
)
104-
assert os.path.exists(
105-
torchbench_file
106-
), f"Torchbench does not exist at {torchbench_file}"
104+
assert os.path.exists(torchbench_file), (
105+
f"Torchbench does not exist at {torchbench_file}"
106+
)
107107

108108
dynamic = cmd_args.dynamic
109109
dynamic_args = ["--dynamic-shapes", "--dynamic-batch-only"] if dynamic else []

benchmarks/dynamo/common.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -1006,9 +1006,9 @@ def latency_experiment_summary(suite_name, args, model, timings, **kwargs):
10061006
row,
10071007
)
10081008
c_headers, c_data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
1009-
assert (
1010-
output_filename.find(".csv") > 0
1011-
), f"expected output_filename to be a .csv, but got {output_filename}"
1009+
assert output_filename.find(".csv") > 0, (
1010+
f"expected output_filename to be a .csv, but got {output_filename}"
1011+
)
10121012
write_outputs(
10131013
output_filename[:-4] + "_compilation_metrics.csv",
10141014
first_headers + c_headers,
@@ -1182,9 +1182,9 @@ def maybe_mark_profile(*args, **kwargs):
11821182
row,
11831183
)
11841184
c_headers, c_data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
1185-
assert (
1186-
output_filename.find(".csv") > 0
1187-
), f"expected output_filename to be a .csv, but got {output_filename}"
1185+
assert output_filename.find(".csv") > 0, (
1186+
f"expected output_filename to be a .csv, but got {output_filename}"
1187+
)
11881188
write_outputs(
11891189
output_filename[:-4] + "_compilation_metrics.csv",
11901190
first_headers + c_headers,
@@ -1997,16 +1997,16 @@ def get_fsdp_auto_wrap_policy(self, model_name: str):
19971997
def deepcopy_and_maybe_parallelize(self, model):
19981998
model = self.deepcopy_model(model)
19991999
if self.args.ddp:
2000-
assert (
2001-
torch.distributed.is_available()
2002-
), "Can't use DDP without a distributed enabled build"
2000+
assert torch.distributed.is_available(), (
2001+
"Can't use DDP without a distributed enabled build"
2002+
)
20032003
from torch.nn.parallel import DistributedDataParallel as DDP
20042004

20052005
model = DDP(model, find_unused_parameters=True)
20062006
elif self.args.fsdp:
2007-
assert (
2008-
torch.distributed.is_available()
2009-
), "Can't use FSDP without a distributed enabled build"
2007+
assert torch.distributed.is_available(), (
2008+
"Can't use FSDP without a distributed enabled build"
2009+
)
20102010
from torch.distributed.fsdp import (
20112011
FullyShardedDataParallel as FSDP,
20122012
MixedPrecision,
@@ -2375,9 +2375,9 @@ def run_performance_test_non_alternate(
23752375
self, name, model, example_inputs, optimize_ctx, experiment, tag=None
23762376
):
23772377
"Run performance test in non-alternately."
2378-
assert (
2379-
experiment.func is latency_experiment
2380-
), "Must run with latency_experiment."
2378+
assert experiment.func is latency_experiment, (
2379+
"Must run with latency_experiment."
2380+
)
23812381

23822382
def warmup(fn, model, example_inputs, mode, niters=10):
23832383
peak_mem = 0

benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ def fn():
8181
torch._dynamo.reset()
8282
torch._inductor.metrics.reset()
8383
triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
84-
assert (
85-
torch._inductor.metrics.generated_kernel_count == 1
86-
), "codegen #kernel != 1"
84+
assert torch._inductor.metrics.generated_kernel_count == 1, (
85+
"codegen #kernel != 1"
86+
)
8787
row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)])
8888

8989
p.add_row(row)

benchmarks/dynamo/microbenchmarks/operator_inp_utils.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ def __init__(self, json_file_path):
265265
def get_inputs_for_operator(
266266
self, operator, dtype=None, device="cuda"
267267
) -> Generator[tuple[Iterable[Any], dict[str, Any]], None, None]:
268-
assert (
269-
str(operator) in self.operator_db
270-
), f"Could not find {operator}, must provide overload"
268+
assert str(operator) in self.operator_db, (
269+
f"Could not find {operator}, must provide overload"
270+
)
271271

272272
if "embedding" in str(operator):
273273
log.warning("Embedding inputs NYI, input data cannot be randomized")
@@ -302,9 +302,9 @@ def get_all_ops(self):
302302
yield op
303303

304304
def get_call_frequency(self, op):
305-
assert (
306-
str(op) in self.operator_db
307-
), f"Could not find {op}, must provide overload"
305+
assert str(op) in self.operator_db, (
306+
f"Could not find {op}, must provide overload"
307+
)
308308

309309
count = 0
310310
for counter in self.operator_db[str(op)].values():

benchmarks/dynamo/runner.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -711,9 +711,9 @@ def clean_batch_sizes(self, frames):
711711
for idx, (batch_a, batch_b) in enumerate(
712712
zip(batch_sizes, frame_batch_sizes)
713713
):
714-
assert (
715-
batch_a == batch_b or batch_a == 0 or batch_b == 0
716-
), f"a={batch_a}, b={batch_b}"
714+
assert batch_a == batch_b or batch_a == 0 or batch_b == 0, (
715+
f"a={batch_a}, b={batch_b}"
716+
)
717717
batch_sizes[idx] = max(batch_a, batch_b)
718718
for frame in frames:
719719
frame["batch_size"] = batch_sizes

benchmarks/functional_autograd_benchmark/torchaudio_models.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -538,21 +538,21 @@ def forward(
538538
query.size(-1),
539539
)
540540
q, k, v = self.in_proj_container(query, key, value)
541-
assert (
542-
q.size(-1) % self.nhead == 0
543-
), "query's embed_dim must be divisible by the number of heads"
541+
assert q.size(-1) % self.nhead == 0, (
542+
"query's embed_dim must be divisible by the number of heads"
543+
)
544544
head_dim = q.size(-1) // self.nhead
545545
q = q.reshape(tgt_len, bsz * self.nhead, head_dim)
546546

547-
assert (
548-
k.size(-1) % self.nhead == 0
549-
), "key's embed_dim must be divisible by the number of heads"
547+
assert k.size(-1) % self.nhead == 0, (
548+
"key's embed_dim must be divisible by the number of heads"
549+
)
550550
head_dim = k.size(-1) // self.nhead
551551
k = k.reshape(src_len, bsz * self.nhead, head_dim)
552552

553-
assert (
554-
v.size(-1) % self.nhead == 0
555-
), "value's embed_dim must be divisible by the number of heads"
553+
assert v.size(-1) % self.nhead == 0, (
554+
"value's embed_dim must be divisible by the number of heads"
555+
)
556556
head_dim = v.size(-1) // self.nhead
557557
v = v.reshape(src_len, bsz * self.nhead, head_dim)
558558

@@ -629,9 +629,9 @@ def forward(
629629
attn_mask = torch.nn.functional.pad(_attn_mask, [0, 1])
630630

631631
tgt_len, head_dim = query.size(-3), query.size(-1)
632-
assert (
633-
query.size(-1) == key.size(-1) == value.size(-1)
634-
), "The feature dim of query, key, value must be equal."
632+
assert query.size(-1) == key.size(-1) == value.size(-1), (
633+
"The feature dim of query, key, value must be equal."
634+
)
635635
assert key.size() == value.size(), "Shape of key, value must match"
636636
src_len = key.size(-3)
637637
batch_heads = max(query.size(-2), key.size(-2))

benchmarks/functional_autograd_benchmark/torchvision_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -884,9 +884,9 @@ def __init__(
884884
self.cost_class = cost_class
885885
self.cost_bbox = cost_bbox
886886
self.cost_giou = cost_giou
887-
assert (
888-
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
889-
), "all costs cant be 0"
887+
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
888+
"all costs cant be 0"
889+
)
890890

891891
@torch.no_grad()
892892
def forward(self, outputs, targets):

benchmarks/gpt_fast/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def from_name(cls, name: str):
5151
# take longer name (as it have more symbols matched)
5252
if len(config) > 1:
5353
config.sort(key=len, reverse=True)
54-
assert len(config[0]) != len(
55-
config[1]
56-
), name # make sure only one 'best' match
54+
assert len(config[0]) != len(config[1]), (
55+
name
56+
) # make sure only one 'best' match
5757

5858
return cls(**transformer_configs[config[0]])
5959

benchmarks/instruction_counts/core/expand.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def _generate_torchscript_file(model_src: str, name: str) -> Optional[str]:
8080

8181
# And again, the type checker has no way of knowing that this line is valid.
8282
jit_model = module.jit_model # type: ignore[attr-defined]
83-
assert isinstance(
84-
jit_model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)
85-
), f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}"
83+
assert isinstance(jit_model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)), (
84+
f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}"
85+
)
8686
jit_model.save(artifact_path) # type: ignore[call-arg]
8787

8888
# Cleanup now that we have the actual serialized model.

benchmarks/operator_benchmark/benchmark_core.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,9 @@ def split(s):
276276
if c in open_to_close.keys():
277277
curr_brackets.append(c)
278278
elif c in open_to_close.values():
279-
assert (
280-
curr_brackets and open_to_close[curr_brackets[-1]] == c
281-
), "ERROR: not able to parse the string!"
279+
assert curr_brackets and open_to_close[curr_brackets[-1]] == c, (
280+
"ERROR: not able to parse the string!"
281+
)
282282
curr_brackets.pop()
283283
elif c == "," and (not curr_brackets):
284284
break_idxs.append(i)

benchmarks/sparse/triton_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44

55
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
6-
assert (
7-
sparsity <= 1.0 and sparsity >= 0.0
8-
), "sparsity should be a value between 0 and 1"
6+
assert sparsity <= 1.0 and sparsity >= 0.0, (
7+
"sparsity should be a value between 0 and 1"
8+
)
99
assert M % blocksize[0] == 0
1010
assert N % blocksize[1] == 0
1111
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]

benchmarks/transformer/attention_bias_benchmarks.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def __init__(self, num_heads, embed_dim, device=None, dtype=None):
8484

8585
self.head_dim = embed_dim // num_heads
8686
self.embed_dim = embed_dim
87-
assert (
88-
self.head_dim * num_heads == self.embed_dim
89-
), "embed_dim must be divisible by num_heads"
87+
assert self.head_dim * num_heads == self.embed_dim, (
88+
"embed_dim must be divisible by num_heads"
89+
)
9090

9191
self.q_proj_weight = Parameter(
9292
torch.empty((embed_dim, embed_dim), **factory_kwargs)

benchmarks/transformer/score_mod.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ class ExperimentConfig:
4949
backends: list[str]
5050

5151
def __post_init__(self):
52-
assert (
53-
len(self.shape) == 6
54-
), "Shape must be of length 6" # [B, Hq, M, Hkv, N, D]
52+
assert len(self.shape) == 6, (
53+
"Shape must be of length 6"
54+
) # [B, Hq, M, Hkv, N, D]
5555

5656
def asdict(self):
5757
# Convert the dataclass instance to a dictionary

functorch/dim/reference.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,9 @@ def split(self, split_size_or_sections, dim=0):
625625
unbound.append(i)
626626

627627
if unbound:
628-
assert (
629-
total_bound_size <= size
630-
), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
628+
assert total_bound_size <= size, (
629+
f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
630+
)
631631
remaining_size = size - total_bound_size
632632
chunk_size = -(-remaining_size // len(unbound))
633633
for u in unbound:
@@ -636,9 +636,9 @@ def split(self, split_size_or_sections, dim=0):
636636
sizes[u] = sz
637637
remaining_size -= sz
638638
else:
639-
assert (
640-
total_bound_size == size
641-
), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
639+
assert total_bound_size == size, (
640+
f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
641+
)
642642
return tuple(
643643
t.index(dim, d)
644644
for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))

scripts/compile_tests/download_reports.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def subdir_path(config):
6262
for config in configs:
6363
required_jobs.extend(list(CONFIGS[config]))
6464
for job in required_jobs:
65-
assert (
66-
job in workflow_jobs
67-
), f"{job} not found, is the commit_sha correct? has the job finished running? The GitHub API may take a couple minutes to update."
65+
assert job in workflow_jobs, (
66+
f"{job} not found, is the commit_sha correct? has the job finished running? The GitHub API may take a couple minutes to update."
67+
)
6868

6969
# This page lists all artifacts.
7070
listings = requests.get(

scripts/export/update_schema.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
)
2424
args = parser.parse_args()
2525

26-
assert os.path.exists(
27-
args.prefix
28-
), f"Assuming path {args.prefix} is the root of pytorch directory, but it doesn't exist."
26+
assert os.path.exists(args.prefix), (
27+
f"Assuming path {args.prefix} is the root of pytorch directory, but it doesn't exist."
28+
)
2929

3030
commit = schema_check.update_schema()
3131

@@ -40,7 +40,9 @@
4040
f"Treespec version downgraded from {commit.base['TREESPEC_VERSION']} to {commit.result['TREESPEC_VERSION']}."
4141
)
4242
else:
43-
assert args.force_unsafe, "Existing schema yaml file not found, please use --force-unsafe to try again."
43+
assert args.force_unsafe, (
44+
"Existing schema yaml file not found, please use --force-unsafe to try again."
45+
)
4446

4547
next_version, reason = schema_check.check(commit, args.force_unsafe)
4648

test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,9 @@ def _test_model_numerically(
182182
baseline_param.grad, param.grad, atol=atol, rtol=rtol
183183
)
184184
else:
185-
assert (
186-
test_backward is False
187-
), "Calculating backward with multiple outputs is not supported yet."
185+
assert test_backward is False, (
186+
"Calculating backward with multiple outputs is not supported yet."
187+
)
188188
for baseline_elem, result_elem in zip(baseline_result, result):
189189
torch.testing.assert_close(
190190
baseline_elem, result_elem, atol=atol, rtol=rtol

test/onnx/exporter/test_hf_models_e2e.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,13 @@ def test_onnx_export_with_custom_axis_names_in_dynamic_shapes(self):
9898
self.assertEqual(dim.value, custom_name)
9999

100100

101-
def _prepare_llm_model_gptj_to_test() -> (
102-
tuple[
103-
torch.nn.Module,
104-
dict[str, Any],
105-
dict[str, dict[int, str]],
106-
list[str],
107-
list[str],
108-
]
109-
):
101+
def _prepare_llm_model_gptj_to_test() -> tuple[
102+
torch.nn.Module,
103+
dict[str, Any],
104+
dict[str, dict[int, str]],
105+
list[str],
106+
list[str],
107+
]:
110108
model = transformers.GPTJForCausalLM.from_pretrained(
111109
"hf-internal-testing/tiny-random-gptj"
112110
)

0 commit comments

Comments
 (0)