Skip to content

Commit 5d967a5

Browse files
authored
Permit inactive register arguments on an indirect call (rust-lang#830)
1 parent cf7b24e commit 5d967a5

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12062,16 +12062,18 @@ class AdjointGenerator
1206212062
Value *diffeadd = Builder2.CreateExtractValue(diffes, {structidx});
1206312063
++structidx;
1206412064

12065-
size_t size = 1;
12066-
if (orig->getArgOperand(i)->getType()->isSized())
12067-
size =
12068-
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
12069-
orig->getArgOperand(i)->getType()) +
12070-
7) /
12071-
8;
12065+
if (!gutils->isConstantValue(orig->getArgOperand(i))) {
12066+
size_t size = 1;
12067+
if (orig->getArgOperand(i)->getType()->isSized())
12068+
size = (gutils->newFunc->getParent()
12069+
->getDataLayout()
12070+
.getTypeSizeInBits(orig->getArgOperand(i)->getType()) +
12071+
7) /
12072+
8;
1207212073

12073-
addToDiffe(orig->getArgOperand(i), diffeadd, Builder2,
12074-
TR.addingType(size, orig->getArgOperand(i)));
12074+
addToDiffe(orig->getArgOperand(i), diffeadd, Builder2,
12075+
TR.addingType(size, orig->getArgOperand(i)));
12076+
}
1207512077
}
1207612078
}
1207712079

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s
2+
3+
declare dso_local double @__enzyme_autodiff(...)
4+
5+
define double @square(double (double, double)* %add, double %x) {
6+
entry:
7+
%mul = call double %add(double %x, double 1.0000000e+00)
8+
ret double %mul
9+
}
10+
11+
define double @dsquare(double (double, double)* %add, double (double, double)* %dadd, double %x) local_unnamed_addr {
12+
entry:
13+
%call = tail call double (...) @__enzyme_autodiff(i8* bitcast (double (double (double, double)*, double)* @square to i8*),
14+
metadata !"enzyme_dup", double (double, double)* %add, double (double, double)* %dadd, double %x)
15+
ret double %call
16+
}
17+
18+
; CHECK: define internal { double } @diffesquare(double (double, double)* %add, double (double, double)* %"add'", double %x, double %differeturn)
19+
; CHECK-NEXT: entry:
20+
; CHECK-NEXT: %0 = bitcast double (double, double)* %add to i8*
21+
; CHECK-NEXT: %1 = bitcast double (double, double)* %"add'" to i8*
22+
; CHECK-NEXT: %2 = icmp eq i8* %0, %1
23+
; CHECK-NEXT: br i1 %2, label %error.i, label %__enzyme_runtimeinactiveerr.exit
24+
25+
; CHECK: error.i: ; preds = %entry
26+
; CHECK-NEXT: %3 = call i32 @puts(i8* getelementptr inbounds ([79 x i8], [79 x i8]* @.str, i32 0, i32 0))
27+
; CHECK-NEXT: call void @exit(i32 1)
28+
; CHECK-NEXT: unreachable
29+
30+
; CHECK: __enzyme_runtimeinactiveerr.exit: ; preds = %entry
31+
; CHECK-NEXT: %4 = bitcast double (double, double)* %"add'" to { i8*, double } (double, double)**
32+
; CHECK-NEXT: %5 = load { i8*, double } (double, double)*, { i8*, double } (double, double)** %4
33+
; CHECK-NEXT: %mul_augmented = call { i8*, double } %5(double %x, double 1.000000e+00)
34+
; CHECK-NEXT: %subcache = extractvalue { i8*, double } %mul_augmented, 0
35+
; CHECK-NEXT: %6 = bitcast double (double, double)* %"add'" to { double, double } (double, double, double, i8*)**
36+
; CHECK-NEXT: %7 = getelementptr { double, double } (double, double, double, i8*)*, { double, double } (double, double, double, i8*)** %6, i64 1
37+
; CHECK-NEXT: %8 = load { double, double } (double, double, double, i8*)*, { double, double } (double, double, double, i8*)** %7
38+
; CHECK-NEXT: %9 = call { double, double } %8(double %x, double 1.000000e+00, double %differeturn, i8* %subcache)
39+
; CHECK-NEXT: %10 = extractvalue { double, double } %9, 0
40+
; CHECK-NEXT: %11 = insertvalue { double } undef, double %10, 0
41+
; CHECK-NEXT: ret { double } %11
42+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)