1
1
use crate :: { bf16, f16} ;
2
2
3
- use rand:: { distributions :: Distribution , Rng } ;
3
+ use rand:: { distr :: Distribution , Rng } ;
4
4
use rand_distr:: uniform:: UniformFloat ;
5
5
6
6
macro_rules! impl_distribution_via_f32 {
@@ -13,13 +13,13 @@ macro_rules! impl_distribution_via_f32 {
13
13
} ;
14
14
}
15
15
16
- impl_distribution_via_f32 ! ( f16, rand_distr:: Standard ) ;
16
+ impl_distribution_via_f32 ! ( f16, rand_distr:: StandardUniform ) ;
17
17
impl_distribution_via_f32 ! ( f16, rand_distr:: StandardNormal ) ;
18
18
impl_distribution_via_f32 ! ( f16, rand_distr:: Exp1 ) ;
19
19
impl_distribution_via_f32 ! ( f16, rand_distr:: Open01 ) ;
20
20
impl_distribution_via_f32 ! ( f16, rand_distr:: OpenClosed01 ) ;
21
21
22
- impl_distribution_via_f32 ! ( bf16, rand_distr:: Standard ) ;
22
+ impl_distribution_via_f32 ! ( bf16, rand_distr:: StandardUniform ) ;
23
23
impl_distribution_via_f32 ! ( bf16, rand_distr:: StandardNormal ) ;
24
24
impl_distribution_via_f32 ! ( bf16, rand_distr:: Exp1 ) ;
25
25
impl_distribution_via_f32 ! ( bf16, rand_distr:: Open01 ) ;
@@ -34,25 +34,25 @@ impl rand_distr::uniform::SampleUniform for f16 {
34
34
35
35
impl rand_distr:: uniform:: UniformSampler for Float16Sampler {
36
36
type X = f16 ;
37
- fn new < B1 , B2 > ( low : B1 , high : B2 ) -> Self
37
+ fn new < B1 , B2 > ( low : B1 , high : B2 ) -> Result < Self , rand_distr :: uniform :: Error >
38
38
where
39
39
B1 : rand_distr:: uniform:: SampleBorrow < Self :: X > + Sized ,
40
40
B2 : rand_distr:: uniform:: SampleBorrow < Self :: X > + Sized ,
41
41
{
42
- Self ( UniformFloat :: new (
42
+ Ok ( Self ( UniformFloat :: new (
43
43
low. borrow ( ) . to_f32 ( ) ,
44
44
high. borrow ( ) . to_f32 ( ) ,
45
- ) )
45
+ ) ? ) )
46
46
}
47
- fn new_inclusive < B1 , B2 > ( low : B1 , high : B2 ) -> Self
47
+ fn new_inclusive < B1 , B2 > ( low : B1 , high : B2 ) -> Result < Self , rand_distr :: uniform :: Error >
48
48
where
49
49
B1 : rand_distr:: uniform:: SampleBorrow < Self :: X > + Sized ,
50
50
B2 : rand_distr:: uniform:: SampleBorrow < Self :: X > + Sized ,
51
51
{
52
- Self ( UniformFloat :: new_inclusive (
52
+ Ok ( Self ( UniformFloat :: new_inclusive (
53
53
low. borrow ( ) . to_f32 ( ) ,
54
54
high. borrow ( ) . to_f32 ( ) ,
55
- ) )
55
+ ) ? ) )
56
56
}
57
57
fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Self :: X {
58
58
f16:: from_f32 ( self . 0 . sample ( rng) )
@@ -68,25 +68,25 @@ impl rand_distr::uniform::SampleUniform for bf16 {
68
68
69
69
impl rand_distr:: uniform:: UniformSampler for BFloat16Sampler {
70
70
type X = bf16 ;
71
- fn new < B1 , B2 > ( low : B1 , high : B2 ) -> Self
71
+ fn new < B1 , B2 > ( low : B1 , high : B2 ) -> Result < Self , rand_distr :: uniform :: Error >
72
72
where
73
73
B1 : rand_distr:: uniform:: SampleBorrow < Self :: X > + Sized ,
74
74
B2 : rand_distr:: uniform:: SampleBorrow < Self :: X > + Sized ,
75
75
{
76
- Self ( UniformFloat :: new (
76
+ Ok ( Self ( UniformFloat :: new (
77
77
low. borrow ( ) . to_f32 ( ) ,
78
78
high. borrow ( ) . to_f32 ( ) ,
79
- ) )
79
+ ) ? ) )
80
80
}
81
- fn new_inclusive < B1 , B2 > ( low : B1 , high : B2 ) -> Self
81
+ fn new_inclusive < B1 , B2 > ( low : B1 , high : B2 ) -> Result < Self , rand_distr :: uniform :: Error >
82
82
where
83
83
B1 : rand_distr:: uniform:: SampleBorrow < Self :: X > + Sized ,
84
84
B2 : rand_distr:: uniform:: SampleBorrow < Self :: X > + Sized ,
85
85
{
86
- Self ( UniformFloat :: new_inclusive (
86
+ Ok ( Self ( UniformFloat :: new_inclusive (
87
87
low. borrow ( ) . to_f32 ( ) ,
88
88
high. borrow ( ) . to_f32 ( ) ,
89
- ) )
89
+ ) ? ) )
90
90
}
91
91
fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Self :: X {
92
92
bf16:: from_f32 ( self . 0 . sample ( rng) )
@@ -98,26 +98,26 @@ mod tests {
98
98
use super :: * ;
99
99
100
100
#[ allow( unused_imports) ]
101
- use rand:: { thread_rng , Rng } ;
102
- use rand_distr:: { Standard , StandardNormal , Uniform } ;
101
+ use rand:: { rng , Rng } ;
102
+ use rand_distr:: { StandardNormal , StandardUniform , Uniform } ;
103
103
104
104
#[ test]
105
105
fn test_sample_f16 ( ) {
106
- let mut rng = thread_rng ( ) ;
107
- let _: f16 = rng. sample ( Standard ) ;
106
+ let mut rng = rng ( ) ;
107
+ let _: f16 = rng. sample ( StandardUniform ) ;
108
108
let _: f16 = rng. sample ( StandardNormal ) ;
109
- let _: f16 = rng. sample ( Uniform :: new ( f16:: from_f32 ( 0.0 ) , f16:: from_f32 ( 1.0 ) ) ) ;
109
+ let _: f16 = rng. sample ( Uniform :: new ( f16:: from_f32 ( 0.0 ) , f16:: from_f32 ( 1.0 ) ) . unwrap ( ) ) ;
110
110
#[ cfg( feature = "num-traits" ) ]
111
111
let _: f16 =
112
112
rng. sample ( rand_distr:: Normal :: new ( f16:: from_f32 ( 0.0 ) , f16:: from_f32 ( 1.0 ) ) . unwrap ( ) ) ;
113
113
}
114
114
115
115
#[ test]
116
116
fn test_sample_bf16 ( ) {
117
- let mut rng = thread_rng ( ) ;
118
- let _: bf16 = rng. sample ( Standard ) ;
117
+ let mut rng = rng ( ) ;
118
+ let _: bf16 = rng. sample ( StandardUniform ) ;
119
119
let _: bf16 = rng. sample ( StandardNormal ) ;
120
- let _: bf16 = rng. sample ( Uniform :: new ( bf16:: from_f32 ( 0.0 ) , bf16:: from_f32 ( 1.0 ) ) ) ;
120
+ let _: bf16 = rng. sample ( Uniform :: new ( bf16:: from_f32 ( 0.0 ) , bf16:: from_f32 ( 1.0 ) ) . unwrap ( ) ) ;
121
121
#[ cfg( feature = "num-traits" ) ]
122
122
let _: bf16 =
123
123
rng. sample ( rand_distr:: Normal :: new ( bf16:: from_f32 ( 0.0 ) , bf16:: from_f32 ( 1.0 ) ) . unwrap ( ) ) ;
0 commit comments