Skip to content

Commit ce3ab30

Browse files
committed
working dupvonly for fwd mode
1 parent 2831ce7 commit ce3ab30

File tree

4 files changed

+33
-19
lines changed

4 files changed

+33
-19
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ pub enum DiffActivity {
5656
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
5757
/// with it. Drop the code which updates the original input/output for maximum performance.
5858
DualOnly,
59+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60+
/// with it. Drop the code which updates the original input/output for maximum performance.
61+
/// It expects the shadow argument to be `width` times larger than the original input/output.
62+
DualvOnly,
5963
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
6064
Duplicated,
6165
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
@@ -133,6 +137,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
133137
activity == DiffActivity::Dual
134138
|| activity == DiffActivity::Dualv
135139
|| activity == DiffActivity::DualOnly
140+
|| activity == DiffActivity::DualvOnly
136141
|| activity == DiffActivity::Const
137142
}
138143
DiffMode::Reverse => {
@@ -155,7 +160,7 @@ pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
155160
if matches!(activity, Const) {
156161
return true;
157162
}
158-
if matches!(activity, Dual | DualOnly | Dualv) {
163+
if matches!(activity, Dual | DualOnly | Dualv | DualvOnly) {
159164
return true;
160165
}
161166
// FIXME(ZuseZ4) We should make this more robust to also
@@ -172,7 +177,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
172177
DiffMode::Error => false,
173178
DiffMode::Source => false,
174179
DiffMode::Forward => {
175-
matches!(activity, Dual | DualOnly | Dualv | Const)
180+
matches!(activity, Dual | DualOnly | Dualv | DualvOnly | Const)
176181
}
177182
DiffMode::Reverse => {
178183
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
@@ -190,6 +195,7 @@ impl Display for DiffActivity {
190195
DiffActivity::Dual => write!(f, "Dual"),
191196
DiffActivity::Dualv => write!(f, "Dualv"),
192197
DiffActivity::DualOnly => write!(f, "DualOnly"),
198+
DiffActivity::DualvOnly => write!(f, "DualvOnly"),
193199
DiffActivity::Duplicated => write!(f, "Duplicated"),
194200
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
195201
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
@@ -222,6 +228,7 @@ impl FromStr for DiffActivity {
222228
"Dual" => Ok(DiffActivity::Dual),
223229
"Dualv" => Ok(DiffActivity::Dualv),
224230
"DualOnly" => Ok(DiffActivity::DualOnly),
231+
"DualvOnly" => Ok(DiffActivity::DualvOnly),
225232
"Duplicated" => Ok(DiffActivity::Duplicated),
226233
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
227234
_ => Err(()),

compiler/rustc_builtin_macros/src/autodiff.rs

+13-7
Original file line numberDiff line numberDiff line change
@@ -770,12 +770,18 @@ mod llvm_enzyme {
770770
d_inputs.push(shadow_arg.clone());
771771
}
772772
}
773-
DiffActivity::Dual | DiffActivity::DualOnly | DiffActivity::Dualv => {
774-
let iterations = if matches!(activity, DiffActivity::Dualv) {
775-
1
776-
} else {
777-
x.width
778-
};
773+
DiffActivity::Dual
774+
| DiffActivity::DualOnly
775+
| DiffActivity::Dualv
776+
| DiffActivity::DualvOnly => {
777+
// the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
778+
// Enzyme to not expect N arguments, but one argument (which is instead larger).
779+
let iterations =
780+
if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
781+
1
782+
} else {
783+
x.width
784+
};
779785
for i in 0..iterations {
780786
let mut shadow_arg = arg.clone();
781787
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
@@ -879,7 +885,7 @@ mod llvm_enzyme {
879885
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
880886
d_decl.output = FnRetTy::Ty(ty);
881887
}
882-
if let DiffActivity::DualOnly = x.ret_activity {
888+
if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
883889
// No need to change the return type,
884890
// we will just return the shadow in place of the primal return.
885891
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+8-10
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ fn has_sret(fnc: &Value) -> bool {
5151
// using iterators and peek()?
5252
fn match_args_from_caller_to_enzyme<'ll>(
5353
cx: &SimpleCx<'ll>,
54-
builder: &SBuilder<'ll,'ll>,
54+
builder: &SBuilder<'ll, 'll>,
5555
width: u32,
5656
args: &mut Vec<&'ll llvm::Value>,
5757
inputs: &[DiffActivity],
@@ -81,6 +81,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
8181
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
8282
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
8383
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
84+
let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();
8485

8586
while activity_pos < inputs.len() {
8687
let diff_activity = inputs[activity_pos as usize];
@@ -94,6 +95,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
9495
DiffActivity::Dual => (enzyme_dup, true),
9596
DiffActivity::Dualv => (enzyme_dupv, true),
9697
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
98+
DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
9799
DiffActivity::Duplicated => (enzyme_dup, true),
98100
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
99101
DiffActivity::FakeActivitySize => (enzyme_const, false),
@@ -106,10 +108,9 @@ fn match_args_from_caller_to_enzyme<'ll>(
106108
// T=f32 => 4 bytes
107109
// n_elems is the next integer.
108110
// Now we multiply `4 * next_outer_arg` to get the stride.
109-
//let mul = builder
110-
// .build_mul(cx.get_const_i64(4), next_outer_arg)
111-
// .unwrap();
112-
let mul = unsafe {llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)};
111+
let mul = unsafe {
112+
llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)
113+
};
113114
args.push(mul);
114115
}
115116
args.push(outer_arg);
@@ -140,11 +141,8 @@ fn match_args_from_caller_to_enzyme<'ll>(
140141
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
141142
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
142143

143-
let iterations = if matches!(diff_activity, DiffActivity::Dualv) {
144-
1
145-
} else {
146-
width as usize
147-
};
144+
let iterations =
145+
if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
148146

149147
for i in 0..iterations {
150148
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];

compiler/rustc_monomorphize/src/partitioning/autodiff.rs

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
4040
new_activities.push(activity);
4141
new_positions.push(i + 1);
4242
}
43+
// Now we need to figure out the size of each slice element in memory.
44+
// Can we actually do that here?
45+
4346
continue;
4447
}
4548
}

0 commit comments

Comments
 (0)