Skip to content

Commit 75f86e6

Browse files
committed
fix LooseTypes flag and PrintMod behaviour, add debug helper
1 parent e643f59 commit 75f86e6

File tree

6 files changed

+68
-21
lines changed

6 files changed

+68
-21
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

+22-18
Original file line numberDiff line numberDiff line change
@@ -584,12 +584,10 @@ fn thin_lto(
584584
}
585585
}
586586

587-
fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
587+
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
588588
for &val in ad {
589+
// We intentionally don't use a wildcard, to not forget handling anything new.
589590
match val {
590-
config::AutoDiff::PrintModBefore => {
591-
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
592-
}
593591
config::AutoDiff::PrintPerf => {
594592
llvm::set_print_perf(true);
595593
}
@@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
603601
llvm::set_inline(true);
604602
}
605603
config::AutoDiff::LooseTypes => {
606-
llvm::set_loose_types(false);
604+
llvm::set_loose_types(true);
607605
}
608606
config::AutoDiff::PrintSteps => {
609607
llvm::set_print(true);
610608
}
611-
// We handle this below
609+
// We handle this in the PassWrapper.cpp
610+
config::AutoDiff::PrintPasses => {}
611+
// We handle this in the PassWrapper.cpp
612+
config::AutoDiff::PrintModBefore => {}
613+
// We handle this in the PassWrapper.cpp
612614
config::AutoDiff::PrintModAfter => {}
613-
// We handle this below
615+
// We handle this in the PassWrapper.cpp
614616
config::AutoDiff::PrintModFinal => {}
615617
// This is required and already checked
616618
config::AutoDiff::Enable => {}
619+
// We handle this below
620+
config::AutoDiff::NoPostopt => {}
617621
}
618622
}
619623
// This helps with handling enums for now.
@@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
647651
// We then run the llvm_optimize function a second time, to optimize the code which we generated
648652
// in the enzyme differentiation pass.
649653
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
650-
let stage =
651-
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
654+
let stage = if thin {
655+
write::AutodiffStage::PreAD
656+
} else {
657+
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
658+
};
652659

653660
if enable_ad {
654-
enable_autodiff_settings(&config.autodiff, module);
661+
enable_autodiff_settings(&config.autodiff);
655662
}
656663

657664
unsafe {
658665
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
659666
}
660667

661-
if cfg!(llvm_enzyme) && enable_ad {
662-
// This is the post-autodiff IR, mainly used for testing and educational purposes.
663-
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
664-
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
665-
}
666-
668+
if cfg!(llvm_enzyme) && enable_ad && !thin {
667669
let opt_stage = llvm::OptStage::FatLTO;
668670
let stage = write::AutodiffStage::PostAD;
669-
unsafe {
670-
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
671+
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
672+
unsafe {
673+
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
674+
}
671675
}
672676

673677
// This is the final IR, so people should be able to inspect the optimized autodiff output,

compiler/rustc_codegen_llvm/src/back/write.rs

+6
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,9 @@ pub(crate) unsafe fn llvm_optimize(
565565

566566
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
567567
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
568+
let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
569+
let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
570+
let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses);
568571
let unroll_loops;
569572
let vectorize_slp;
570573
let vectorize_loop;
@@ -663,6 +666,9 @@ pub(crate) unsafe fn llvm_optimize(
663666
config.no_builtins,
664667
config.emit_lifetime_markers,
665668
run_enzyme,
669+
print_before_enzyme,
670+
print_after_enzyme,
671+
print_passes,
666672
sanitizer_options.as_ref(),
667673
pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
668674
pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+3
Original file line numberDiff line numberDiff line change
@@ -2454,6 +2454,9 @@ unsafe extern "C" {
24542454
DisableSimplifyLibCalls: bool,
24552455
EmitLifetimeMarkers: bool,
24562456
RunEnzyme: bool,
2457+
PrintBeforeEnzyme: bool,
2458+
PrintAfterEnzyme: bool,
2459+
PrintPasses: bool,
24572460
SanitizerOptions: Option<&SanitizerOptions>,
24582461
PGOGenPath: *const c_char,
24592462
PGOUsePath: *const c_char,

compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

+28-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/IR/LegacyPassManager.h"
1515
#include "llvm/IR/PassManager.h"
1616
#include "llvm/IR/Verifier.h"
17+
#include "llvm/IRPrinter/IRPrintingPasses.h"
1718
#include "llvm/LTO/LTO.h"
1819
#include "llvm/MC/MCSubtargetInfo.h"
1920
#include "llvm/MC/TargetRegistry.h"
@@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
703704
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
704705
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
705706
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
706-
bool EmitLifetimeMarkers, bool RunEnzyme,
707+
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
708+
bool PrintAfterEnzyme, bool PrintPasses,
707709
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
708710
const char *PGOUsePath, bool InstrumentCoverage,
709711
const char *InstrProfileOutput, const char *PGOSampleUsePath,
@@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10481050
// now load "-enzyme" pass:
10491051
#ifdef ENZYME
10501052
if (RunEnzyme) {
1051-
registerEnzymeAndPassPipeline(PB, true);
1053+
1054+
if (PrintBeforeEnzyme) {
1055+
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
1056+
std::string Banner = "Module before EnzymeNewPM";
1057+
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
1058+
}
1059+
1060+
registerEnzymeAndPassPipeline(PB, false);
10521061
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
10531062
std::string ErrMsg = toString(std::move(Err));
10541063
LLVMRustSetLastError(ErrMsg.c_str());
10551064
return LLVMRustResult::Failure;
10561065
}
1066+
1067+
if (PrintAfterEnzyme) {
1068+
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
1069+
std::string Banner = "Module after EnzymeNewPM";
1070+
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
1071+
}
10571072
}
10581073
#endif
1074+
if (PrintPasses) {
1075+
// Print all passes from the PM:
1076+
std::string Pipeline;
1077+
raw_string_ostream SOS(Pipeline);
1078+
MPM.printPipeline(SOS, [&PIC](StringRef ClassName) {
1079+
auto PassName = PIC.getPassNameForClassName(ClassName);
1080+
return PassName.empty() ? ClassName : PassName;
1081+
});
1082+
outs() << Pipeline;
1083+
outs() << "\n";
1084+
}
10591085

10601086
// Upgrade all calls to old intrinsics first.
10611087
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)

compiler/rustc_session/src/config.rs

+4
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ pub enum AutoDiff {
244244
/// Print the module after running autodiff and optimizations.
245245
PrintModFinal,
246246

247+
/// Print all passes scheduled by LLVM
248+
PrintPasses,
249+
/// Disable extra opt run after running autodiff
250+
NoPostopt,
247251
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
248252
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
249253
LooseTypes,

compiler/rustc_session/src/options.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ mod desc {
711711
pub(crate) const parse_list: &str = "a space-separated list of strings";
712712
pub(crate) const parse_list_with_polarity: &str =
713713
"a comma-separated list of strings, with elements beginning with + or -";
714-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
714+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
715715
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
716716
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
717717
pub(crate) const parse_number: &str = "a number";
@@ -1360,6 +1360,8 @@ pub mod parse {
13601360
"PrintModBefore" => AutoDiff::PrintModBefore,
13611361
"PrintModAfter" => AutoDiff::PrintModAfter,
13621362
"PrintModFinal" => AutoDiff::PrintModFinal,
1363+
"NoPostopt" => AutoDiff::NoPostopt,
1364+
"PrintPasses" => AutoDiff::PrintPasses,
13631365
"LooseTypes" => AutoDiff::LooseTypes,
13641366
"Inline" => AutoDiff::Inline,
13651367
_ => {
@@ -2095,6 +2097,8 @@ options! {
20952097
`=PrintModBefore`
20962098
`=PrintModAfter`
20972099
`=PrintModFinal`
2100+
`=PrintPasses`,
2101+
`=NoPostopt`
20982102
`=LooseTypes`
20992103
`=Inline`
21002104
Multiple options can be combined with commas."),

0 commit comments

Comments
 (0)