@@ -5,15 +5,15 @@ use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
5
5
use rustc_hir:: lang_items;
6
6
use rustc_middle:: hir;
7
7
use rustc_middle:: ich:: StableHashingContext ;
8
- use rustc_middle:: mir:: interpret:: Scalar ;
8
+ use rustc_middle:: mir:: interpret:: { ConstValue , Scalar } ;
9
9
use rustc_middle:: mir:: {
10
10
self , traversal, BasicBlock , BasicBlockData , CoverageData , Operand , Place , SourceInfo ,
11
11
StatementKind , Terminator , TerminatorKind , START_BLOCK ,
12
12
} ;
13
13
use rustc_middle:: ty;
14
14
use rustc_middle:: ty:: query:: Providers ;
15
- use rustc_middle:: ty:: FnDef ;
16
15
use rustc_middle:: ty:: TyCtxt ;
16
+ use rustc_middle:: ty:: { ConstKind , FnDef } ;
17
17
use rustc_span:: def_id:: DefId ;
18
18
use rustc_span:: Span ;
19
19
@@ -26,16 +26,36 @@ pub struct InstrumentCoverage;
26
26
pub ( crate ) fn provide ( providers : & mut Providers < ' _ > ) {
27
27
providers. coverage_data = |tcx, def_id| {
28
28
let mir_body = tcx. optimized_mir ( def_id) ;
29
+ // FIXME(richkadel): The current implementation assumes the MIR for the given DefId
30
+ // represents a single function. Validate and/or correct if inlining and/or monomorphization
31
+ // invalidates these assumptions.
29
32
let count_code_region_fn =
30
33
tcx. require_lang_item ( lang_items:: CountCodeRegionFnLangItem , None ) ;
31
34
let mut num_counters: u32 = 0 ;
35
+ // The `num_counters` argument to `llvm.instrprof.increment` is the number of injected
36
+ // counters, with each counter having an index from `0..num_counters-1`. MIR optimization
37
+ // may split and duplicate some BasicBlock sequences. Simply counting the calls may not
38
+ // not work; but computing the num_counters by adding `1` to the highest index (for a given
39
+ // instrumented function) is valid.
32
40
for ( _, data) in traversal:: preorder ( mir_body) {
33
41
if let Some ( terminator) = & data. terminator {
34
- if let TerminatorKind :: Call { func : Operand :: Constant ( func) , .. } = & terminator. kind
42
+ if let TerminatorKind :: Call { func : Operand :: Constant ( func) , args, .. } =
43
+ & terminator. kind
35
44
{
36
45
if let FnDef ( called_fn_def_id, _) = func. literal . ty . kind {
37
46
if called_fn_def_id == count_code_region_fn {
38
- num_counters += 1 ;
47
+ if let Operand :: Constant ( constant) =
48
+ args. get ( 0 ) . expect ( "count_code_region has at least one arg" )
49
+ {
50
+ if let ConstKind :: Value ( ConstValue :: Scalar ( value) ) =
51
+ constant. literal . val
52
+ {
53
+ let index = value
54
+ . to_u32 ( )
55
+ . expect ( "count_code_region index at arg0 is u32" ) ;
56
+ num_counters = std:: cmp:: max ( num_counters, index + 1 ) ;
57
+ }
58
+ }
39
59
}
40
60
}
41
61
}
0 commit comments