Skip to content

Commit eeb33d7

Browse files
authored
Merge pull request fortran-lang#132 from jvdp1/mask_mean_dev
mask_mean_dev: addition of a mask in mean API
2 parents ecfbfb0 + 4dfb3b0 commit eeb33d7

File tree

4 files changed

+196
-21
lines changed

4 files changed

+196
-21
lines changed

src/stdlib_experimental_stats.fypp

+54-4
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,92 @@ module stdlib_experimental_stats
1515

1616
#:for k1, t1 in REAL_KINDS_TYPES
1717
#:for rank in RANKS
18-
module function mean_${rank}$_all_${k1}$_${k1}$(x) result(res)
18+
module function mean_${rank}$_all_${k1}$_${k1}$(x, mask) result(res)
1919
${t1}$, intent(in) :: x${ranksuffix(rank)}$
20+
logical, intent(in), optional :: mask
2021
${t1}$ :: res
2122
end function mean_${rank}$_all_${k1}$_${k1}$
2223
#:endfor
2324
#:endfor
2425

2526
#:for k1, t1 in INT_KINDS_TYPES
2627
#:for rank in RANKS
27-
module function mean_${rank}$_all_${k1}$_dp(x) result(res)
28+
module function mean_${rank}$_all_${k1}$_dp(x, mask) result(res)
2829
${t1}$, intent(in) :: x${ranksuffix(rank)}$
30+
logical, intent(in), optional :: mask
2931
real(dp) :: res
3032
end function mean_${rank}$_all_${k1}$_dp
3133
#:endfor
3234
#:endfor
3335

3436
#:for k1, t1 in REAL_KINDS_TYPES
3537
#:for rank in RANKS
36-
module function mean_${rank}$_${k1}$_${k1}$(x, dim) result(res)
38+
module function mean_${rank}$_${k1}$_${k1}$(x, dim, mask) result(res)
3739
${t1}$, intent(in) :: x${ranksuffix(rank)}$
3840
integer, intent(in) :: dim
41+
logical, intent(in), optional :: mask
3942
${t1}$ :: res${reduced_shape('x', rank, 'dim')}$
4043
end function mean_${rank}$_${k1}$_${k1}$
4144
#:endfor
4245
#:endfor
4346

4447
#:for k1, t1 in INT_KINDS_TYPES
4548
#:for rank in RANKS
46-
module function mean_${rank}$_${k1}$_dp(x, dim) result(res)
49+
module function mean_${rank}$_${k1}$_dp(x, dim, mask) result(res)
4750
${t1}$, intent(in) :: x${ranksuffix(rank)}$
4851
integer, intent(in) :: dim
52+
logical, intent(in), optional :: mask
4953
real(dp) :: res${reduced_shape('x', rank, 'dim')}$
5054
end function mean_${rank}$_${k1}$_dp
5155
#:endfor
5256
#:endfor
5357

58+
59+
#:for k1, t1 in REAL_KINDS_TYPES
60+
#:for rank in RANKS
61+
module function mean_${rank}$_mask_all_${k1}$_${k1}$(x, mask) result(res)
62+
${t1}$, intent(in) :: x${ranksuffix(rank)}$
63+
logical, intent(in) :: mask${ranksuffix(rank)}$
64+
${t1}$ :: res
65+
end function mean_${rank}$_mask_all_${k1}$_${k1}$
66+
#:endfor
67+
#:endfor
68+
69+
70+
#:for k1, t1 in INT_KINDS_TYPES
71+
#:for rank in RANKS
72+
module function mean_${rank}$_mask_all_${k1}$_dp(x, mask) result(res)
73+
${t1}$, intent(in) :: x${ranksuffix(rank)}$
74+
logical, intent(in) :: mask${ranksuffix(rank)}$
75+
real(dp) :: res
76+
end function mean_${rank}$_mask_all_${k1}$_dp
77+
#:endfor
78+
#:endfor
79+
80+
81+
#:for k1, t1 in REAL_KINDS_TYPES
82+
#:for rank in RANKS
83+
module function mean_${rank}$_mask_${k1}$_${k1}$(x, dim, mask) result(res)
84+
${t1}$, intent(in) :: x${ranksuffix(rank)}$
85+
integer, intent(in) :: dim
86+
logical, intent(in) :: mask${ranksuffix(rank)}$
87+
${t1}$ :: res${reduced_shape('x', rank, 'dim')}$
88+
end function mean_${rank}$_mask_${k1}$_${k1}$
89+
#:endfor
90+
#:endfor
91+
92+
93+
#:for k1, t1 in INT_KINDS_TYPES
94+
#:for rank in RANKS
95+
module function mean_${rank}$_mask_${k1}$_dp(x, dim, mask) result(res)
96+
${t1}$, intent(in) :: x${ranksuffix(rank)}$
97+
integer, intent(in) :: dim
98+
logical, intent(in) :: mask${ranksuffix(rank)}$
99+
real(dp) :: res${reduced_shape('x', rank, 'dim')}$
100+
end function mean_${rank}$_mask_${k1}$_dp
101+
#:endfor
102+
#:endfor
103+
54104
end interface mean
55105

56106
end module stdlib_experimental_stats

src/stdlib_experimental_stats.md

+12-6
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,42 @@
88

99
### Description
1010

11-
Returns the mean of all the elements of `array`, or of the elements of `array` along dimension `dim` if provided.
11+
Returns the mean of all the elements of `array`, or of the elements of `array` along dimension `dim` if provided, and if the corresponding element in `mask` is `true`.
1212

1313
### Syntax
1414

15-
`result = mean(array)`
15+
`result = mean(array [, mask])`
1616

17-
`result = mean(array, dim)`
17+
`result = mean(array, dim [, mask])`
1818

1919
### Arguments
2020

2121
`array`: Shall be an array of type `integer`, or `real`.
2222

2323
`dim`: Shall be a scalar of type `integer` with a value in the range from 1 to n, where n is the rank of `array`.
2424

25+
`mask` (optional): Shall be of type `logical` and either by a scalar or an array of the same shape as `array`.
26+
2527
### Return value
2628

2729
If `array` is of type `real`, the result is of the same type as `array`.
2830
If `array` is of type `integer`, the result is of type `double precision`.
2931

3032
If `dim` is absent, a scalar with the mean of all elements in `array` is returned. Otherwise, an array of rank n-1, where n equals the rank of `array`, and a shape similar to that of `array` with dimension `dim` dropped is returned.
3133

34+
If `mask` is specified, the result is the mean of all elements of `array` corresponding to `true` elements of `mask`. If every element of `mask` is `false`, the result is IEEE `NaN`.
35+
3236
### Example
3337

3438
```fortran
3539
program demo_mean
3640
use stdlib_experimental_stats, only: mean
3741
implicit none
3842
real :: x(1:6) = [ 1., 2., 3., 4., 5., 6. ]
39-
print *, mean(x) !returns 21.
40-
print *, mean( reshape(x, [ 2, 3 ] )) !returns 21.
41-
print *, mean( reshape(x, [ 2, 3 ] ), 1) !returns [ 3., 7., 11. ]
43+
print *, mean(x) !returns 3.5
44+
print *, mean( reshape(x, [ 2, 3 ] )) !returns 3.5
45+
print *, mean( reshape(x, [ 2, 3 ] ), 1) !returns [ 1.5, 3.5, 5.5 ]
46+
print *, mean( reshape(x, [ 2, 3 ] ), 1,&
47+
reshape(x, [ 2, 3 ] ) > 3.) !returns [ NaN, 4.0, 5.5 ]
4248
end program demo_mean
4349
```

src/stdlib_experimental_stats_mean.fypp

+97-5
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,25 @@
55

66
submodule (stdlib_experimental_stats) stdlib_experimental_stats_mean
77

8+
use, intrinsic:: ieee_arithmetic, only: ieee_value, ieee_quiet_nan
89
use stdlib_experimental_error, only: error_stop
10+
use stdlib_experimental_optval, only: optval
911
implicit none
1012

1113
contains
1214

1315
#:for k1, t1 in REAL_KINDS_TYPES
1416
#:for rank in RANKS
15-
module function mean_${rank}$_all_${k1}$_${k1}$(x) result(res)
17+
module function mean_${rank}$_all_${k1}$_${k1}$(x, mask) result(res)
1618
${t1}$, intent(in) :: x${ranksuffix(rank)}$
19+
logical, intent(in), optional :: mask
1720
${t1}$ :: res
1821

22+
if (.not.optval(mask, .true.)) then
23+
res = ieee_value(res, ieee_quiet_nan)
24+
return
25+
end if
26+
1927
res = sum(x) / real(size(x, kind = int64), ${k1}$)
2028

2129
end function mean_${rank}$_all_${k1}$_${k1}$
@@ -25,10 +33,16 @@ contains
2533

2634
#:for k1, t1 in INT_KINDS_TYPES
2735
#:for rank in RANKS
28-
module function mean_${rank}$_all_${k1}$_dp(x) result(res)
36+
module function mean_${rank}$_all_${k1}$_dp(x, mask) result(res)
2937
${t1}$, intent(in) :: x${ranksuffix(rank)}$
38+
logical, intent(in), optional :: mask
3039
real(dp) :: res
3140

41+
if (.not.optval(mask, .true.)) then
42+
res = ieee_value(res, ieee_quiet_nan)
43+
return
44+
end if
45+
3246
res = sum(real(x, dp)) / real(size(x, kind = int64), dp)
3347

3448
end function mean_${rank}$_all_${k1}$_dp
@@ -38,11 +52,17 @@ contains
3852

3953
#:for k1, t1 in REAL_KINDS_TYPES
4054
#:for rank in RANKS
41-
module function mean_${rank}$_${k1}$_${k1}$(x, dim) result(res)
55+
module function mean_${rank}$_${k1}$_${k1}$(x, dim, mask) result(res)
4256
${t1}$, intent(in) :: x${ranksuffix(rank)}$
4357
integer, intent(in) :: dim
58+
logical, intent(in), optional :: mask
4459
${t1}$ :: res${reduced_shape('x', rank, 'dim')}$
4560

61+
if (.not.optval(mask, .true.)) then
62+
res = ieee_value(res, ieee_quiet_nan)
63+
return
64+
end if
65+
4666
if (dim >= 1 .and. dim <= ${rank}$) then
4767
res = sum(x, dim) / real(size(x, dim), ${k1}$)
4868
else
@@ -56,13 +76,19 @@ contains
5676

5777
#:for k1, t1 in INT_KINDS_TYPES
5878
#:for rank in RANKS
59-
module function mean_${rank}$_${k1}$_dp(x, dim) result(res)
79+
module function mean_${rank}$_${k1}$_dp(x, dim, mask) result(res)
6080
${t1}$, intent(in) :: x${ranksuffix(rank)}$
6181
integer, intent(in) :: dim
82+
logical, intent(in), optional :: mask
6283
real(dp) :: res${reduced_shape('x', rank, 'dim')}$
6384

85+
if (.not.optval(mask, .true.)) then
86+
res = ieee_value(res, ieee_quiet_nan)
87+
return
88+
end if
89+
6490
if (dim >= 1 .and. dim <= ${rank}$) then
65-
res = sum(x, dim) / real(size(x, dim), dp)
91+
res = sum(real(x, dp), dim) / real(size(x, dim), dp)
6692
else
6793
call error_stop("ERROR (mean): wrong dimension")
6894
end if
@@ -71,4 +97,70 @@ contains
7197
#:endfor
7298
#:endfor
7399

100+
101+
#:for k1, t1 in REAL_KINDS_TYPES
102+
#:for rank in RANKS
103+
module function mean_${rank}$_mask_all_${k1}$_${k1}$(x, mask) result(res)
104+
${t1}$, intent(in) :: x${ranksuffix(rank)}$
105+
logical, intent(in) :: mask${ranksuffix(rank)}$
106+
${t1}$ :: res
107+
108+
res = sum(x, mask) / real(count(mask, kind = int64), ${k1}$)
109+
110+
end function mean_${rank}$_mask_all_${k1}$_${k1}$
111+
#:endfor
112+
#:endfor
113+
114+
115+
#:for k1, t1 in INT_KINDS_TYPES
116+
#:for rank in RANKS
117+
module function mean_${rank}$_mask_all_${k1}$_dp(x, mask) result(res)
118+
${t1}$, intent(in) :: x${ranksuffix(rank)}$
119+
logical, intent(in) :: mask${ranksuffix(rank)}$
120+
real(dp) :: res
121+
122+
res = sum(real(x, dp), mask) / real(count(mask, kind = int64), dp)
123+
124+
end function mean_${rank}$_mask_all_${k1}$_dp
125+
#:endfor
126+
#:endfor
127+
128+
129+
#:for k1, t1 in REAL_KINDS_TYPES
130+
#:for rank in RANKS
131+
module function mean_${rank}$_mask_${k1}$_${k1}$(x, dim, mask) result(res)
132+
${t1}$, intent(in) :: x${ranksuffix(rank)}$
133+
integer, intent(in) :: dim
134+
logical, intent(in) :: mask${ranksuffix(rank)}$
135+
${t1}$ :: res${reduced_shape('x', rank, 'dim')}$
136+
137+
if (dim >= 1 .and. dim <= ${rank}$) then
138+
res = sum(x, dim, mask) / real(count(mask, dim), ${k1}$)
139+
else
140+
call error_stop("ERROR (mean): wrong dimension")
141+
end if
142+
143+
end function mean_${rank}$_mask_${k1}$_${k1}$
144+
#:endfor
145+
#:endfor
146+
147+
148+
#:for k1, t1 in INT_KINDS_TYPES
149+
#:for rank in RANKS
150+
module function mean_${rank}$_mask_${k1}$_dp(x, dim, mask) result(res)
151+
${t1}$, intent(in) :: x${ranksuffix(rank)}$
152+
integer, intent(in) :: dim
153+
logical, intent(in) :: mask${ranksuffix(rank)}$
154+
real(dp) :: res${reduced_shape('x', rank, 'dim')}$
155+
156+
if (dim >= 1 .and. dim <= ${rank}$) then
157+
res = sum(real(x, dp), dim, mask) / real(count(mask, dim), dp)
158+
else
159+
call error_stop("ERROR (mean): wrong dimension")
160+
end if
161+
162+
end function mean_${rank}$_mask_${k1}$_dp
163+
#:endfor
164+
#:endfor
165+
74166
end submodule

src/tests/stats/test_mean.f90

+33-6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ program test_mean
3737
call assert( sum( abs( mean(d,2) - sum(d,2)/real(size(d,2), dp) )) < dptol)
3838

3939

40+
! check mask = .false.
41+
42+
call assert( isnan(mean(d, .false.)))
43+
call assert( any(isnan(mean(d, 1, .false.))))
44+
call assert( any(isnan(mean(d, 2, .false.))))
45+
46+
! check mask of the same shape as input
47+
call assert( abs(mean(d, d > 0) - sum(d, d > 0)/real(count(d > 0), dp)) < dptol)
48+
call assert( sum(abs(mean(d, 1, d > 0) - sum(d, 1, d > 0)/real(count(d > 0, 1), dp))) < dptol)
49+
call assert( sum(abs(mean(d, 2, d > 0) - sum(d, 2, d > 0)/real(count(d > 0, 2), dp))) < dptol)
50+
4051
!int32
4152
call loadtxt("array3.dat", d)
4253

@@ -56,8 +67,8 @@ program test_mean
5667
!dp rank 3
5768
allocate(d3(size(d,1),size(d,2),3))
5869
d3(:,:,1)=d;
59-
d3(:,:,2)=d*1.5_dp;
60-
d3(:,:,3)=d*4._dp;
70+
d3(:,:,2)=d*1.5;
71+
d3(:,:,3)=d*4;
6172

6273
call assert( abs(mean(d3) - sum(d3)/real(size(d3), dp)) < dptol)
6374
call assert( sum( abs( mean(d3,1) - sum(d3,1)/real(size(d3,1), dp) )) < dptol)
@@ -67,16 +78,32 @@ program test_mean
6778

6879
!dp rank 4
6980
allocate(d4(size(d,1),size(d,2),3,9))
70-
d4 = 1.
81+
d4 = -1
7182
d4(:,:,1,1)=d;
72-
d4(:,:,2,1)=d*1.5_dp;
73-
d4(:,:,3,1)=d*4._dp;
74-
d4(:,:,3,9)=d*4._dp;
83+
d4(:,:,2,1)=d*1.5;
84+
d4(:,:,3,1)=d*4;
85+
d4(:,:,3,9)=d*4;
7586

7687
call assert( abs(mean(d4) - sum(d4)/real(size(d4), dp)) < dptol)
7788
call assert( sum( abs( mean(d4,1) - sum(d4,1)/real(size(d4,1), dp) )) < dptol)
7889
call assert( sum( abs( mean(d4,2) - sum(d4,2)/real(size(d4,2), dp) )) < dptol)
7990
call assert( sum( abs( mean(d4,3) - sum(d4,3)/real(size(d4,3), dp) )) < dptol)
8091
call assert( sum( abs( mean(d4,4) - sum(d4,4)/real(size(d4,4), dp) )) < dptol)
8192

93+
! check mask = .false.
94+
95+
call assert( isnan(mean(d4, .false.)))
96+
call assert( any(isnan(mean(d4, 1, .false.))))
97+
call assert( any(isnan(mean(d4, 2, .false.))))
98+
call assert( any(isnan(mean(d4, 3, .false.))))
99+
call assert( any(isnan(mean(d4, 4, .false.))))
100+
101+
102+
! check mask of the same shape as input
103+
call assert( abs(mean(d4, d4 > 0) - sum(d4, d4 > 0)/real(count(d4 > 0), dp)) < dptol)
104+
call assert( any(isnan(mean(d4, 1, d4 > 0))) )
105+
call assert( any(isnan(mean(d4, 2, d4 > 0))) )
106+
call assert( any(isnan(mean(d4, 3, d4 > 0))) )
107+
call assert( sum(abs(mean(d4, 4, d4 > 0) - sum(d4, 4, d4 > 0)/real(count(d4 > 0, 4), dp))) < dptol)
108+
82109
end program

0 commit comments

Comments
 (0)