1
1
use std:: borrow:: Cow ;
2
+ use std:: iter;
2
3
3
4
use anyhow:: Result ;
4
5
use anyhow:: { bail, Context } ;
@@ -13,10 +14,11 @@ use ruff_python_ast::helpers::Truthiness;
13
14
use ruff_python_ast:: parenthesize:: parenthesized_range;
14
15
use ruff_python_ast:: visitor:: Visitor ;
15
16
use ruff_python_ast:: {
16
- self as ast, Arguments , BoolOp , ExceptHandler , Expr , Keyword , Stmt , UnaryOp ,
17
+ self as ast, AnyNodeRef , Arguments , BoolOp , ExceptHandler , Expr , Keyword , Stmt , UnaryOp ,
17
18
} ;
18
19
use ruff_python_ast:: { visitor, whitespace} ;
19
20
use ruff_python_codegen:: Stylist ;
21
+ use ruff_python_semantic:: { Binding , BindingKind } ;
20
22
use ruff_source_file:: LineRanges ;
21
23
use ruff_text_size:: Ranged ;
22
24
@@ -266,47 +268,48 @@ fn check_assert_in_except(name: &str, body: &[Stmt]) -> Vec<Diagnostic> {
266
268
267
269
/// PT009
268
270
pub ( crate ) fn unittest_assertion (
269
- checker : & Checker ,
271
+ checker : & mut Checker ,
270
272
expr : & Expr ,
271
273
func : & Expr ,
272
274
args : & [ Expr ] ,
273
275
keywords : & [ Keyword ] ,
274
- ) -> Option < Diagnostic > {
275
- match func {
276
- Expr :: Attribute ( ast :: ExprAttribute { attr , .. } ) => {
277
- if let Ok ( unittest_assert ) = UnittestAssert :: try_from ( attr . as_str ( ) ) {
278
- let mut diagnostic = Diagnostic :: new (
279
- PytestUnittestAssertion {
280
- assertion : unittest_assert . to_string ( ) ,
281
- } ,
282
- func . range ( ) ,
283
- ) ;
284
- // We're converting an expression to a statement, so avoid applying the fix if
285
- // the assertion is part of a larger expression.
286
- if checker . semantic ( ) . current_statement ( ) . is_expr_stmt ( )
287
- && checker . semantic ( ) . current_expression_parent ( ) . is_none ( )
288
- && !checker . comment_ranges ( ) . intersects ( expr . range ( ) )
289
- {
290
- if let Ok ( stmt ) = unittest_assert . generate_assert ( args , keywords ) {
291
- diagnostic . set_fix ( Fix :: unsafe_edit ( Edit :: range_replacement (
292
- checker. generator ( ) . stmt ( & stmt ) ,
293
- parenthesized_range (
294
- expr. into ( ) ,
295
- checker . semantic ( ) . current_statement ( ) . into ( ) ,
296
- checker . comment_ranges ( ) ,
297
- checker . locator ( ) . contents ( ) ,
298
- )
299
- . unwrap_or ( expr . range ( ) ) ,
300
- ) ) ) ;
301
- }
302
- }
303
- Some ( diagnostic )
304
- } else {
305
- None
306
- }
276
+ ) {
277
+ let Expr :: Attribute ( ast :: ExprAttribute { attr , .. } ) = func else {
278
+ return ;
279
+ } ;
280
+
281
+ let Ok ( unittest_assert ) = UnittestAssert :: try_from ( attr . as_str ( ) ) else {
282
+ return ;
283
+ } ;
284
+
285
+ let mut diagnostic = Diagnostic :: new (
286
+ PytestUnittestAssertion {
287
+ assertion : unittest_assert . to_string ( ) ,
288
+ } ,
289
+ func . range ( ) ,
290
+ ) ;
291
+
292
+ // We're converting an expression to a statement, so avoid applying the fix if
293
+ // the assertion is part of a larger expression.
294
+ if checker. semantic ( ) . current_statement ( ) . is_expr_stmt ( )
295
+ && checker . semantic ( ) . current_expression_parent ( ) . is_none ( )
296
+ && !checker . comment_ranges ( ) . intersects ( expr. range ( ) )
297
+ {
298
+ if let Ok ( stmt ) = unittest_assert . generate_assert ( args , keywords ) {
299
+ diagnostic . set_fix ( Fix :: unsafe_edit ( Edit :: range_replacement (
300
+ checker . generator ( ) . stmt ( & stmt ) ,
301
+ parenthesized_range (
302
+ expr . into ( ) ,
303
+ checker . semantic ( ) . current_statement ( ) . into ( ) ,
304
+ checker . comment_ranges ( ) ,
305
+ checker . locator ( ) . contents ( ) ,
306
+ )
307
+ . unwrap_or ( expr . range ( ) ) ,
308
+ ) ) ) ;
307
309
}
308
- _ => None ,
309
310
}
311
+
312
+ checker. diagnostics . push ( diagnostic) ;
310
313
}
311
314
312
315
/// ## What it does
@@ -364,9 +367,96 @@ impl Violation for PytestUnittestRaisesAssertion {
364
367
}
365
368
366
369
/// PT027
367
- pub ( crate ) fn unittest_raises_assertion (
370
+ pub ( crate ) fn unittest_raises_assertion_call ( checker : & mut Checker , call : & ast:: ExprCall ) {
371
+ // Bindings in `with` statements are handled by `unittest_raises_assertion_bindings`.
372
+ if let Stmt :: With ( ast:: StmtWith { items, .. } ) = checker. semantic ( ) . current_statement ( ) {
373
+ let call_ref = AnyNodeRef :: from ( call) ;
374
+
375
+ if items. iter ( ) . any ( |item| {
376
+ AnyNodeRef :: from ( & item. context_expr ) . ptr_eq ( call_ref) && item. optional_vars . is_some ( )
377
+ } ) {
378
+ return ;
379
+ }
380
+ }
381
+
382
+ if let Some ( diagnostic) = unittest_raises_assertion ( call, vec ! [ ] , checker) {
383
+ checker. diagnostics . push ( diagnostic) ;
384
+ }
385
+ }
386
+
387
+ /// PT027
388
+ pub ( crate ) fn unittest_raises_assertion_binding (
368
389
checker : & Checker ,
390
+ binding : & Binding ,
391
+ ) -> Option < Diagnostic > {
392
+ if !matches ! ( binding. kind, BindingKind :: WithItemVar ) {
393
+ return None ;
394
+ }
395
+
396
+ let semantic = checker. semantic ( ) ;
397
+
398
+ let Stmt :: With ( with) = binding. statement ( semantic) ? else {
399
+ return None ;
400
+ } ;
401
+
402
+ let Expr :: Call ( call) = corresponding_context_expr ( binding, with) ? else {
403
+ return None ;
404
+ } ;
405
+
406
+ let mut edits = vec ! [ ] ;
407
+
408
+ // Rewrite all references to `.exception` to `.value`:
409
+ // ```py
410
+ // # Before
411
+ // with self.assertRaises(Exception) as e:
412
+ // ...
413
+ // print(e.exception)
414
+ //
415
+ // # After
416
+ // with pytest.raises(Exception) as e:
417
+ // ...
418
+ // print(e.value)
419
+ // ```
420
+ for reference_id in binding. references ( ) {
421
+ let reference = semantic. reference ( reference_id) ;
422
+ let node_id = reference. expression_id ( ) ?;
423
+
424
+ let mut ancestors = semantic. expressions ( node_id) . skip ( 1 ) ;
425
+
426
+ let Expr :: Attribute ( ast:: ExprAttribute { attr, .. } ) = ancestors. next ( ) ? else {
427
+ continue ;
428
+ } ;
429
+
430
+ if attr. as_str ( ) == "exception" {
431
+ edits. push ( Edit :: range_replacement ( "value" . to_string ( ) , attr. range ) ) ;
432
+ }
433
+ }
434
+
435
+ unittest_raises_assertion ( call, edits, checker)
436
+ }
437
+
438
+ fn corresponding_context_expr < ' a > ( binding : & Binding , with : & ' a ast:: StmtWith ) -> Option < & ' a Expr > {
439
+ with. items . iter ( ) . find_map ( |item| {
440
+ let Some ( optional_var) = & item. optional_vars else {
441
+ return None ;
442
+ } ;
443
+
444
+ let Expr :: Name ( name) = optional_var. as_ref ( ) else {
445
+ return None ;
446
+ } ;
447
+
448
+ if name. range == binding. range {
449
+ Some ( & item. context_expr )
450
+ } else {
451
+ None
452
+ }
453
+ } )
454
+ }
455
+
456
+ fn unittest_raises_assertion (
369
457
call : & ast:: ExprCall ,
458
+ extra_edits : Vec < Edit > ,
459
+ checker : & Checker ,
370
460
) -> Option < Diagnostic > {
371
461
let Expr :: Attribute ( ast:: ExprAttribute { attr, .. } ) = call. func . as_ref ( ) else {
372
462
return None ;
@@ -385,19 +475,25 @@ pub(crate) fn unittest_raises_assertion(
385
475
} ,
386
476
call. func . range ( ) ,
387
477
) ;
478
+
388
479
if !checker
389
480
. comment_ranges ( )
390
481
. has_comments ( call, checker. source ( ) )
391
482
{
392
483
if let Some ( args) = to_pytest_raises_args ( checker, attr. as_str ( ) , & call. arguments ) {
393
484
diagnostic. try_set_fix ( || {
394
- let ( import_edit , binding) = checker. importer ( ) . get_or_import_symbol (
485
+ let ( import_pytest_raises , binding) = checker. importer ( ) . get_or_import_symbol (
395
486
& ImportRequest :: import ( "pytest" , "raises" ) ,
396
487
call. func . start ( ) ,
397
488
checker. semantic ( ) ,
398
489
) ?;
399
- let edit = Edit :: range_replacement ( format ! ( "{binding}({args})" ) , call. range ( ) ) ;
400
- Ok ( Fix :: unsafe_edits ( import_edit, [ edit] ) )
490
+ let replace_call =
491
+ Edit :: range_replacement ( format ! ( "{binding}({args})" ) , call. range ( ) ) ;
492
+
493
+ Ok ( Fix :: unsafe_edits (
494
+ import_pytest_raises,
495
+ iter:: once ( replace_call) . chain ( extra_edits) ,
496
+ ) )
401
497
} ) ;
402
498
}
403
499
}
0 commit comments