Skip to content

Commit 569060f

Browse files
[flake8-pytest-style] Rewrite references to .exception (PT027) (#15680)
1 parent 15394a8 commit 569060f

File tree

5 files changed

+303
-52
lines changed

5 files changed

+303
-52
lines changed

crates/ruff_linter/resources/test/fixtures/flake8_pytest_style/PT027_1.py

+29
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,32 @@ def test_pytest_raises(self):
1010
def test_errors(self):
1111
with self.assertRaises(ValueError):
1212
raise ValueError
13+
14+
def test_rewrite_references(self):
15+
with self.assertRaises(ValueError) as e:
16+
raise ValueError
17+
18+
print(e.foo)
19+
print(e.exception)
20+
21+
def test_rewrite_references_multiple_items(self):
22+
with self.assertRaises(ValueError) as e1, \
23+
self.assertRaises(ValueError) as e2:
24+
raise ValueError
25+
26+
print(e1.foo)
27+
print(e1.exception)
28+
29+
print(e2.foo)
30+
print(e2.exception)
31+
32+
def test_rewrite_references_multiple_items_nested(self):
33+
with self.assertRaises(ValueError) as e1, \
34+
foo(self.assertRaises(ValueError)) as e2:
35+
raise ValueError
36+
37+
print(e1.foo)
38+
print(e1.exception)
39+
40+
print(e2.foo)
41+
print(e2.exception)

crates/ruff_linter/src/checkers/ast/analyze/bindings.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use ruff_text_size::Ranged;
44
use crate::checkers::ast::Checker;
55
use crate::codes::Rule;
66
use crate::rules::{
7-
flake8_import_conventions, flake8_pyi, flake8_type_checking, pyflakes, pylint, ruff,
7+
flake8_import_conventions, flake8_pyi, flake8_pytest_style, flake8_type_checking, pyflakes,
8+
pylint, ruff,
89
};
910

1011
/// Run lint rules over the [`Binding`]s.
@@ -20,6 +21,7 @@ pub(crate) fn bindings(checker: &mut Checker) {
2021
Rule::UnusedVariable,
2122
Rule::UnquotedTypeAlias,
2223
Rule::UsedDummyVariable,
24+
Rule::PytestUnittestRaisesAssertion,
2325
]) {
2426
return;
2527
}
@@ -100,5 +102,12 @@ pub(crate) fn bindings(checker: &mut Checker) {
100102
checker.diagnostics.push(diagnostic);
101103
}
102104
}
105+
if checker.enabled(Rule::PytestUnittestRaisesAssertion) {
106+
if let Some(diagnostic) =
107+
flake8_pytest_style::rules::unittest_raises_assertion_binding(checker, binding)
108+
{
109+
checker.diagnostics.push(diagnostic);
110+
}
111+
}
103112
}
104113
}

crates/ruff_linter/src/checkers/ast/analyze/expression.rs

+2-10
Original file line numberDiff line numberDiff line change
@@ -940,18 +940,10 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) {
940940
flake8_pytest_style::rules::parametrize(checker, call);
941941
}
942942
if checker.enabled(Rule::PytestUnittestAssertion) {
943-
if let Some(diagnostic) = flake8_pytest_style::rules::unittest_assertion(
944-
checker, expr, func, args, keywords,
945-
) {
946-
checker.diagnostics.push(diagnostic);
947-
}
943+
flake8_pytest_style::rules::unittest_assertion(checker, expr, func, args, keywords);
948944
}
949945
if checker.enabled(Rule::PytestUnittestRaisesAssertion) {
950-
if let Some(diagnostic) =
951-
flake8_pytest_style::rules::unittest_raises_assertion(checker, call)
952-
{
953-
checker.diagnostics.push(diagnostic);
954-
}
946+
flake8_pytest_style::rules::unittest_raises_assertion_call(checker, call);
955947
}
956948
if checker.enabled(Rule::SubprocessPopenPreexecFn) {
957949
pylint::rules::subprocess_popen_preexec_fn(checker, call);

crates/ruff_linter/src/rules/flake8_pytest_style/rules/assertion.rs

+136-40
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::borrow::Cow;
2+
use std::iter;
23

34
use anyhow::Result;
45
use anyhow::{bail, Context};
@@ -13,10 +14,11 @@ use ruff_python_ast::helpers::Truthiness;
1314
use ruff_python_ast::parenthesize::parenthesized_range;
1415
use ruff_python_ast::visitor::Visitor;
1516
use ruff_python_ast::{
16-
self as ast, Arguments, BoolOp, ExceptHandler, Expr, Keyword, Stmt, UnaryOp,
17+
self as ast, AnyNodeRef, Arguments, BoolOp, ExceptHandler, Expr, Keyword, Stmt, UnaryOp,
1718
};
1819
use ruff_python_ast::{visitor, whitespace};
1920
use ruff_python_codegen::Stylist;
21+
use ruff_python_semantic::{Binding, BindingKind};
2022
use ruff_source_file::LineRanges;
2123
use ruff_text_size::Ranged;
2224

@@ -266,47 +268,48 @@ fn check_assert_in_except(name: &str, body: &[Stmt]) -> Vec<Diagnostic> {
266268

267269
/// PT009
268270
pub(crate) fn unittest_assertion(
269-
checker: &Checker,
271+
checker: &mut Checker,
270272
expr: &Expr,
271273
func: &Expr,
272274
args: &[Expr],
273275
keywords: &[Keyword],
274-
) -> Option<Diagnostic> {
275-
match func {
276-
Expr::Attribute(ast::ExprAttribute { attr, .. }) => {
277-
if let Ok(unittest_assert) = UnittestAssert::try_from(attr.as_str()) {
278-
let mut diagnostic = Diagnostic::new(
279-
PytestUnittestAssertion {
280-
assertion: unittest_assert.to_string(),
281-
},
282-
func.range(),
283-
);
284-
// We're converting an expression to a statement, so avoid applying the fix if
285-
// the assertion is part of a larger expression.
286-
if checker.semantic().current_statement().is_expr_stmt()
287-
&& checker.semantic().current_expression_parent().is_none()
288-
&& !checker.comment_ranges().intersects(expr.range())
289-
{
290-
if let Ok(stmt) = unittest_assert.generate_assert(args, keywords) {
291-
diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement(
292-
checker.generator().stmt(&stmt),
293-
parenthesized_range(
294-
expr.into(),
295-
checker.semantic().current_statement().into(),
296-
checker.comment_ranges(),
297-
checker.locator().contents(),
298-
)
299-
.unwrap_or(expr.range()),
300-
)));
301-
}
302-
}
303-
Some(diagnostic)
304-
} else {
305-
None
306-
}
276+
) {
277+
let Expr::Attribute(ast::ExprAttribute { attr, .. }) = func else {
278+
return;
279+
};
280+
281+
let Ok(unittest_assert) = UnittestAssert::try_from(attr.as_str()) else {
282+
return;
283+
};
284+
285+
let mut diagnostic = Diagnostic::new(
286+
PytestUnittestAssertion {
287+
assertion: unittest_assert.to_string(),
288+
},
289+
func.range(),
290+
);
291+
292+
// We're converting an expression to a statement, so avoid applying the fix if
293+
// the assertion is part of a larger expression.
294+
if checker.semantic().current_statement().is_expr_stmt()
295+
&& checker.semantic().current_expression_parent().is_none()
296+
&& !checker.comment_ranges().intersects(expr.range())
297+
{
298+
if let Ok(stmt) = unittest_assert.generate_assert(args, keywords) {
299+
diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement(
300+
checker.generator().stmt(&stmt),
301+
parenthesized_range(
302+
expr.into(),
303+
checker.semantic().current_statement().into(),
304+
checker.comment_ranges(),
305+
checker.locator().contents(),
306+
)
307+
.unwrap_or(expr.range()),
308+
)));
307309
}
308-
_ => None,
309310
}
311+
312+
checker.diagnostics.push(diagnostic);
310313
}
311314

312315
/// ## What it does
@@ -364,9 +367,96 @@ impl Violation for PytestUnittestRaisesAssertion {
364367
}
365368

366369
/// PT027
367-
pub(crate) fn unittest_raises_assertion(
370+
pub(crate) fn unittest_raises_assertion_call(checker: &mut Checker, call: &ast::ExprCall) {
371+
// Bindings in `with` statements are handled by `unittest_raises_assertion_bindings`.
372+
if let Stmt::With(ast::StmtWith { items, .. }) = checker.semantic().current_statement() {
373+
let call_ref = AnyNodeRef::from(call);
374+
375+
if items.iter().any(|item| {
376+
AnyNodeRef::from(&item.context_expr).ptr_eq(call_ref) && item.optional_vars.is_some()
377+
}) {
378+
return;
379+
}
380+
}
381+
382+
if let Some(diagnostic) = unittest_raises_assertion(call, vec![], checker) {
383+
checker.diagnostics.push(diagnostic);
384+
}
385+
}
386+
387+
/// PT027
388+
pub(crate) fn unittest_raises_assertion_binding(
368389
checker: &Checker,
390+
binding: &Binding,
391+
) -> Option<Diagnostic> {
392+
if !matches!(binding.kind, BindingKind::WithItemVar) {
393+
return None;
394+
}
395+
396+
let semantic = checker.semantic();
397+
398+
let Stmt::With(with) = binding.statement(semantic)? else {
399+
return None;
400+
};
401+
402+
let Expr::Call(call) = corresponding_context_expr(binding, with)? else {
403+
return None;
404+
};
405+
406+
let mut edits = vec![];
407+
408+
// Rewrite all references to `.exception` to `.value`:
409+
// ```py
410+
// # Before
411+
// with self.assertRaises(Exception) as e:
412+
// ...
413+
// print(e.exception)
414+
//
415+
// # After
416+
// with pytest.raises(Exception) as e:
417+
// ...
418+
// print(e.value)
419+
// ```
420+
for reference_id in binding.references() {
421+
let reference = semantic.reference(reference_id);
422+
let node_id = reference.expression_id()?;
423+
424+
let mut ancestors = semantic.expressions(node_id).skip(1);
425+
426+
let Expr::Attribute(ast::ExprAttribute { attr, .. }) = ancestors.next()? else {
427+
continue;
428+
};
429+
430+
if attr.as_str() == "exception" {
431+
edits.push(Edit::range_replacement("value".to_string(), attr.range));
432+
}
433+
}
434+
435+
unittest_raises_assertion(call, edits, checker)
436+
}
437+
438+
fn corresponding_context_expr<'a>(binding: &Binding, with: &'a ast::StmtWith) -> Option<&'a Expr> {
439+
with.items.iter().find_map(|item| {
440+
let Some(optional_var) = &item.optional_vars else {
441+
return None;
442+
};
443+
444+
let Expr::Name(name) = optional_var.as_ref() else {
445+
return None;
446+
};
447+
448+
if name.range == binding.range {
449+
Some(&item.context_expr)
450+
} else {
451+
None
452+
}
453+
})
454+
}
455+
456+
fn unittest_raises_assertion(
369457
call: &ast::ExprCall,
458+
extra_edits: Vec<Edit>,
459+
checker: &Checker,
370460
) -> Option<Diagnostic> {
371461
let Expr::Attribute(ast::ExprAttribute { attr, .. }) = call.func.as_ref() else {
372462
return None;
@@ -385,19 +475,25 @@ pub(crate) fn unittest_raises_assertion(
385475
},
386476
call.func.range(),
387477
);
478+
388479
if !checker
389480
.comment_ranges()
390481
.has_comments(call, checker.source())
391482
{
392483
if let Some(args) = to_pytest_raises_args(checker, attr.as_str(), &call.arguments) {
393484
diagnostic.try_set_fix(|| {
394-
let (import_edit, binding) = checker.importer().get_or_import_symbol(
485+
let (import_pytest_raises, binding) = checker.importer().get_or_import_symbol(
395486
&ImportRequest::import("pytest", "raises"),
396487
call.func.start(),
397488
checker.semantic(),
398489
)?;
399-
let edit = Edit::range_replacement(format!("{binding}({args})"), call.range());
400-
Ok(Fix::unsafe_edits(import_edit, [edit]))
490+
let replace_call =
491+
Edit::range_replacement(format!("{binding}({args})"), call.range());
492+
493+
Ok(Fix::unsafe_edits(
494+
import_pytest_raises,
495+
iter::once(replace_call).chain(extra_edits),
496+
))
401497
});
402498
}
403499
}

0 commit comments

Comments
 (0)