@@ -50,16 +50,32 @@ pub enum DiffActivity {
50
50
/// with it.
51
51
Dual ,
52
52
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
53
+ /// with it. It expects the shadow argument to be `width` times larger than the original
54
+ /// input/output.
55
+ Dualv ,
56
+ /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
53
57
/// with it. Drop the code which updates the original input/output for maximum performance.
54
58
DualOnly ,
59
+ /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60
+ /// with it. Drop the code which updates the original input/output for maximum performance.
61
+ /// It expects the shadow argument to be `width` times larger than the original input/output.
62
+ DualvOnly ,
55
63
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
56
64
Duplicated ,
57
65
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
58
66
/// Drop the code which updates the original input for maximum performance.
59
67
DuplicatedOnly ,
60
68
/// All Integers must be Const, but these are used to mark the integer which represents the
61
69
/// length of a slice/vec. This is used for safety checks on slices.
62
- FakeActivitySize ,
70
+ /// The integer (if given) specifies the size of the slice element in bytes.
71
+ FakeActivitySize ( Option < u32 > ) ,
72
+ }
73
+
74
+ impl DiffActivity {
75
+ pub fn is_dual_or_const ( & self ) -> bool {
76
+ use DiffActivity :: * ;
77
+ matches ! ( self , |Dual | DualOnly | Dualv | DualvOnly | Const )
78
+ }
63
79
}
64
80
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
65
81
#[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
@@ -131,11 +147,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
131
147
match mode {
132
148
DiffMode :: Error => false ,
133
149
DiffMode :: Source => false ,
134
- DiffMode :: Forward => {
135
- activity == DiffActivity :: Dual
136
- || activity == DiffActivity :: DualOnly
137
- || activity == DiffActivity :: Const
138
- }
150
+ DiffMode :: Forward => activity. is_dual_or_const ( ) ,
139
151
DiffMode :: Reverse => {
140
152
activity == DiffActivity :: Const
141
153
|| activity == DiffActivity :: Active
@@ -153,10 +165,8 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
153
165
pub fn valid_ty_for_activity ( ty : & P < Ty > , activity : DiffActivity ) -> bool {
154
166
use DiffActivity :: * ;
155
167
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
156
- if matches ! ( activity, Const ) {
157
- return true ;
158
- }
159
- if matches ! ( activity, Dual | DualOnly ) {
168
+ // Dual variants also support all types.
169
+ if activity. is_dual_or_const ( ) {
160
170
return true ;
161
171
}
162
172
// FIXME(ZuseZ4) We should make this more robust to also
@@ -172,9 +182,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
172
182
return match mode {
173
183
DiffMode :: Error => false ,
174
184
DiffMode :: Source => false ,
175
- DiffMode :: Forward => {
176
- matches ! ( activity, Dual | DualOnly | Const )
177
- }
185
+ DiffMode :: Forward => activity. is_dual_or_const ( ) ,
178
186
DiffMode :: Reverse => {
179
187
matches ! ( activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const )
180
188
}
@@ -189,10 +197,12 @@ impl Display for DiffActivity {
189
197
DiffActivity :: Active => write ! ( f, "Active" ) ,
190
198
DiffActivity :: ActiveOnly => write ! ( f, "ActiveOnly" ) ,
191
199
DiffActivity :: Dual => write ! ( f, "Dual" ) ,
200
+ DiffActivity :: Dualv => write ! ( f, "Dualv" ) ,
192
201
DiffActivity :: DualOnly => write ! ( f, "DualOnly" ) ,
202
+ DiffActivity :: DualvOnly => write ! ( f, "DualvOnly" ) ,
193
203
DiffActivity :: Duplicated => write ! ( f, "Duplicated" ) ,
194
204
DiffActivity :: DuplicatedOnly => write ! ( f, "DuplicatedOnly" ) ,
195
- DiffActivity :: FakeActivitySize => write ! ( f, "FakeActivitySize" ) ,
205
+ DiffActivity :: FakeActivitySize ( s ) => write ! ( f, "FakeActivitySize({:?})" , s ) ,
196
206
}
197
207
}
198
208
}
@@ -220,7 +230,9 @@ impl FromStr for DiffActivity {
220
230
"ActiveOnly" => Ok ( DiffActivity :: ActiveOnly ) ,
221
231
"Const" => Ok ( DiffActivity :: Const ) ,
222
232
"Dual" => Ok ( DiffActivity :: Dual ) ,
233
+ "Dualv" => Ok ( DiffActivity :: Dualv ) ,
223
234
"DualOnly" => Ok ( DiffActivity :: DualOnly ) ,
235
+ "DualvOnly" => Ok ( DiffActivity :: DualvOnly ) ,
224
236
"Duplicated" => Ok ( DiffActivity :: Duplicated ) ,
225
237
"DuplicatedOnly" => Ok ( DiffActivity :: DuplicatedOnly ) ,
226
238
_ => Err ( ( ) ) ,
0 commit comments