Skip to content

Autodiff batching2 #139351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,32 @@ pub enum DiffActivity {
/// with it.
Dual,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. It expects the shadow argument to be `width` times larger than the original
/// input/output.
Dualv,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. Drop the code which updates the original input/output for maximum performance.
DualOnly,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. Drop the code which updates the original input/output for maximum performance.
/// It expects the shadow argument to be `width` times larger than the original input/output.
DualvOnly,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
Duplicated,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
/// Drop the code which updates the original input for maximum performance.
DuplicatedOnly,
/// All Integers must be Const, but these are used to mark the integer which represents the
/// length of a slice/vec. This is used for safety checks on slices.
FakeActivitySize,
/// The integer (if given) specifies the size of the slice element in bytes.
FakeActivitySize(Option<u32>),
}

impl DiffActivity {
pub fn is_dual_or_const(&self) -> bool {
use DiffActivity::*;
matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)
}
}
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
Expand Down Expand Up @@ -131,11 +147,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
match mode {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward => {
activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Const
}
DiffMode::Forward => activity.is_dual_or_const(),
DiffMode::Reverse => {
activity == DiffActivity::Const
|| activity == DiffActivity::Active
Expand All @@ -153,10 +165,8 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
use DiffActivity::*;
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
if matches!(activity, Const) {
return true;
}
if matches!(activity, Dual | DualOnly) {
// Dual variants also support all types.
if activity.is_dual_or_const() {
return true;
}
// FIXME(ZuseZ4) We should make this more robust to also
Expand All @@ -172,9 +182,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
return match mode {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward => {
matches!(activity, Dual | DualOnly | Const)
}
DiffMode::Forward => activity.is_dual_or_const(),
DiffMode::Reverse => {
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
}
Expand All @@ -189,10 +197,12 @@ impl Display for DiffActivity {
DiffActivity::Active => write!(f, "Active"),
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
DiffActivity::Dual => write!(f, "Dual"),
DiffActivity::Dualv => write!(f, "Dualv"),
DiffActivity::DualOnly => write!(f, "DualOnly"),
DiffActivity::DualvOnly => write!(f, "DualvOnly"),
DiffActivity::Duplicated => write!(f, "Duplicated"),
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),
}
}
}
Expand Down Expand Up @@ -220,7 +230,9 @@ impl FromStr for DiffActivity {
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
"Const" => Ok(DiffActivity::Const),
"Dual" => Ok(DiffActivity::Dual),
"Dualv" => Ok(DiffActivity::Dualv),
"DualOnly" => Ok(DiffActivity::DualOnly),
"DualvOnly" => Ok(DiffActivity::DualvOnly),
"Duplicated" => Ok(DiffActivity::Duplicated),
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
_ => Err(()),
Expand Down
23 changes: 17 additions & 6 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,19 @@ mod llvm_enzyme {
d_inputs.push(shadow_arg.clone());
}
}
DiffActivity::Dual | DiffActivity::DualOnly => {
for i in 0..x.width {
DiffActivity::Dual
| DiffActivity::DualOnly
| DiffActivity::Dualv
| DiffActivity::DualvOnly => {
// the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
// Enzyme to not expect N arguments, but one argument (which is instead larger).
let iterations =
if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
1
} else {
x.width
};
for i in 0..iterations {
let mut shadow_arg = arg.clone();
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
ident.name
Expand All @@ -823,7 +834,7 @@ mod llvm_enzyme {
DiffActivity::Const => {
// Nothing to do here.
}
DiffActivity::None | DiffActivity::FakeActivitySize => {
DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
panic!("Should not happen");
}
}
Expand Down Expand Up @@ -887,8 +898,8 @@ mod llvm_enzyme {
}
};

if let DiffActivity::Dual = x.ret_activity {
let kind = if x.width == 1 {
if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
// Dual can only be used for f32/f64 ret.
// In that case we return now a tuple with two floats.
TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
Expand All @@ -903,7 +914,7 @@ mod llvm_enzyme {
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
d_decl.output = FnRetTy::Ty(ty);
}
if let DiffActivity::DualOnly = x.ret_activity {
if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
// No need to change the return type,
// we will just return the shadow in place of the primal return.
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
/// Empty string, to be used where LLVM expects an instruction name, indicating
/// that the instruction is to be left unnamed (i.e. numbered, in textual IR).
// FIXME(eddyb) pass `&CStr` directly to FFI once it's a thin pointer.
const UNNAMED: *const c_char = c"".as_ptr();
pub(crate) const UNNAMED: *const c_char = c"".as_ptr();

impl<'ll, CX: Borrow<SCx<'ll>>> BackendTypes for GenericBuilder<'_, 'll, CX> {
type Value = <GenericCx<'ll, CX> as BackendTypes>::Value;
Expand Down
38 changes: 33 additions & 5 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rustc_middle::bug;
use tracing::{debug, trace};

use crate::back::write::llvm_err;
use crate::builder::SBuilder;
use crate::builder::{SBuilder, UNNAMED};
use crate::context::SimpleCx;
use crate::declare::declare_simple_fn;
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
Expand Down Expand Up @@ -51,6 +51,7 @@ fn has_sret(fnc: &Value) -> bool {
// using iterators and peek()?
fn match_args_from_caller_to_enzyme<'ll>(
cx: &SimpleCx<'ll>,
builder: &SBuilder<'ll, 'll>,
width: u32,
args: &mut Vec<&'ll llvm::Value>,
inputs: &[DiffActivity],
Expand Down Expand Up @@ -78,7 +79,9 @@ fn match_args_from_caller_to_enzyme<'ll>(
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();

while activity_pos < inputs.len() {
let diff_activity = inputs[activity_pos as usize];
Expand All @@ -90,13 +93,34 @@ fn match_args_from_caller_to_enzyme<'ll>(
DiffActivity::Active => (enzyme_out, false),
DiffActivity::ActiveOnly => (enzyme_out, false),
DiffActivity::Dual => (enzyme_dup, true),
DiffActivity::Dualv => (enzyme_dupv, true),
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
DiffActivity::Duplicated => (enzyme_dup, true),
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
DiffActivity::FakeActivitySize => (enzyme_const, false),
DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
};
let outer_arg = outer_args[outer_pos];
args.push(cx.get_metadata_value(activity));
if matches!(diff_activity, DiffActivity::Dualv) {
let next_outer_arg = outer_args[outer_pos + 1];
let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
DiffActivity::FakeActivitySize(Some(s)) => s.into(),
_ => bug!("incorrect Dualv handling recognized."),
};
// stride: sizeof(T) * n_elems.
// n_elems is the next integer.
// Now we multiply `4 * next_outer_arg` to get the stride.
let mul = unsafe {
llvm::LLVMBuildMul(
builder.llbuilder,
cx.get_const_i64(elem_bytes_size),
next_outer_arg,
UNNAMED,
)
};
args.push(mul);
}
args.push(outer_arg);
if duplicated {
// We know that duplicated args by construction have a following argument,
Expand All @@ -114,7 +138,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
} else {
let next_activity = inputs[activity_pos + 1];
// We analyze the MIR types and add this dummy activity if we visit a slice.
next_activity == DiffActivity::FakeActivitySize
matches!(next_activity, DiffActivity::FakeActivitySize(_))
}
};
if slice {
Expand All @@ -125,7 +149,10 @@ fn match_args_from_caller_to_enzyme<'ll>(
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);

for i in 0..(width as usize) {
let iterations =
if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };

for i in 0..iterations {
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
Expand All @@ -136,7 +163,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
}
args.push(cx.get_metadata_value(enzyme_const));
args.push(next_outer_arg);
outer_pos += 2 + 2 * width as usize;
outer_pos += 2 + 2 * iterations;
activity_pos += 2;
} else {
// A duplicated pointer will have the following two outer_fn arguments:
Expand Down Expand Up @@ -360,6 +387,7 @@ fn generate_enzyme_call<'ll>(
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
match_args_from_caller_to_enzyme(
&cx,
&builder,
attrs.width,
&mut args,
&attrs.input_activity,
Expand Down
32 changes: 30 additions & 2 deletions compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
use rustc_hir::def_id::LOCAL_CRATE;
use rustc_middle::bug;
use rustc_middle::mir::mono::MonoItem;
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
use tracing::{debug, trace};

Expand All @@ -22,23 +22,51 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
for (i, ty) in sig.inputs().iter().enumerate() {
if let Some(inner_ty) = ty.builtin_deref(true) {
if inner_ty.is_slice() {
// Now we need to figure out the size of each slice element in memory to allow
// safety checks and usability improvements in the backend.
let sty = match inner_ty.builtin_index() {
Some(sty) => sty,
None => {
panic!("slice element type unknown");
}
};
let pci = PseudoCanonicalInput {
typing_env: TypingEnv::fully_monomorphized(),
value: sty,
};

let layout = tcx.layout_of(pci);
let elem_size = match layout {
Ok(layout) => layout.size,
Err(_) => {
bug!("autodiff failed to compute slice element size");
}
};
let elem_size: u32 = elem_size.bytes() as u32;

// We know that the length will be passed as extra arg.
if !da.is_empty() {
// We are looking at a slice. The length of that slice will become an
// extra integer on llvm level. Integers are always const.
// However, if the slice get's duplicated, we want to know to later check the
// size. So we mark the new size argument as FakeActivitySize.
// There is one FakeActivitySize per slice, so for convenience we store the
// slice element size in bytes in it. We will use the size in the backend.
let activity = match da[i] {
DiffActivity::DualOnly
| DiffActivity::Dual
| DiffActivity::Dualv
| DiffActivity::DuplicatedOnly
| DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
| DiffActivity::Duplicated => {
DiffActivity::FakeActivitySize(Some(elem_size))
}
DiffActivity::Const => DiffActivity::Const,
_ => bug!("unexpected activity for ptr/ref"),
};
new_activities.push(activity);
new_positions.push(i + 1);
}

continue;
}
}
Expand Down
Loading
Loading