Skip to content

Commit 659bf5b

Browse files
authored
use original powi function (rust-lang#471)
1 parent 5e347a1 commit 659bf5b

File tree

2 files changed

+158
-8
lines changed

2 files changed

+158
-8
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3254,9 +3254,22 @@ class AdjointGenerator
32543254
orig_ops[1]->getType()
32553255
#endif
32563256
};
3257-
Function *PowF = Intrinsic::getDeclaration(M, Intrinsic::powi, tys);
3258-
auto cal = cast<CallInst>(Builder2.CreateCall(PowF, args));
3259-
cal->setCallingConv(PowF->getCallingConv());
3257+
auto &CI = cast<CallInst>(I);
3258+
#if LLVM_VERSION_MAJOR >= 11
3259+
auto *PowF = CI.getCalledOperand();
3260+
#else
3261+
auto *PowF = CI.getCalledValue();
3262+
#endif
3263+
if (!PowF)
3264+
PowF = Intrinsic::getDeclaration(M, Intrinsic::powi, tys);
3265+
3266+
auto FT = FunctionType::get(
3267+
I.getType(), {orig_ops[0]->getType(), orig_ops[1]->getType()},
3268+
false);
3269+
auto cal = cast<CallInst>(Builder2.CreateCall(FT, PowF, args));
3270+
if (auto F = dyn_cast<Function>(PowF))
3271+
cal->setCallingConv(F->getCallingConv());
3272+
32603273
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
32613274
Value *dif0 = Builder2.CreateFMul(
32623275
Builder2.CreateFMul(vdiff, cal),
@@ -3268,7 +3281,19 @@ class AdjointGenerator
32683281
}
32693282
case Intrinsic::pow: {
32703283
Type *tys[] = {orig_ops[0]->getType()};
3271-
Function *PowF = Intrinsic::getDeclaration(M, Intrinsic::pow, tys);
3284+
auto &CI = cast<CallInst>(I);
3285+
#if LLVM_VERSION_MAJOR >= 11
3286+
auto *PowF = CI.getCalledOperand();
3287+
#else
3288+
auto *PowF = CI.getCalledValue();
3289+
#endif
3290+
if (!PowF)
3291+
PowF = Intrinsic::getDeclaration(M, Intrinsic::pow, tys);
3292+
3293+
auto FT = FunctionType::get(
3294+
I.getType(), {orig_ops[0]->getType(), orig_ops[1]->getType()},
3295+
false);
3296+
32723297
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
32733298

32743299
Value *op0 = gutils->getNewFromOriginal(orig_ops[0]);
@@ -3284,8 +3309,10 @@ class AdjointGenerator
32843309
lookup(op0, Builder2),
32853310
Builder2.CreateFSub(lookup(op1, Builder2),
32863311
ConstantFP::get(I.getType(), 1.0))};
3287-
auto cal = cast<CallInst>(Builder2.CreateCall(PowF, args));
3288-
cal->setCallingConv(PowF->getCallingConv());
3312+
auto cal = cast<CallInst>(Builder2.CreateCall(FT, PowF, args));
3313+
if (auto F = dyn_cast<Function>(PowF))
3314+
cal->setCallingConv(F->getCallingConv());
3315+
32893316
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
32903317

32913318
Value *dif0 = Builder2.CreateFMul(Builder2.CreateFMul(vdiff, cal),
@@ -3301,8 +3328,10 @@ class AdjointGenerator
33013328
lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2),
33023329
lookup(gutils->getNewFromOriginal(orig_ops[1]), Builder2)};
33033330

3304-
cal = cast<CallInst>(Builder2.CreateCall(PowF, args));
3305-
cal->setCallingConv(PowF->getCallingConv());
3331+
cal = cast<CallInst>(Builder2.CreateCall(FT, PowF, args));
3332+
if (auto F = dyn_cast<Function>(PowF))
3333+
cal->setCallingConv(F->getCallingConv());
3334+
33063335
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
33073336
}
33083337

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
; RUN: if [ %llvmver -ge 13 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
2+
source_filename = "text"
3+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
4+
target triple = "x86_64-pc-linux-gnu"
5+
6+
define private fastcc double @julia___2797(double %0, i64 signext %1) unnamed_addr #0 !dbg !7 {
7+
top:
8+
switch i64 %1, label %L20 [
9+
i64 -1, label %L3
10+
i64 0, label %L7
11+
i64 1, label %L7.fold.split
12+
i64 2, label %L13
13+
i64 3, label %L17
14+
], !dbg !9
15+
16+
L3: ; preds = %top
17+
%2 = fdiv double 1.000000e+00, %0, !dbg !10
18+
ret double %2, !dbg !9
19+
20+
L7.fold.split: ; preds = %top
21+
br label %L7, !dbg !16
22+
23+
L7: ; preds = %top, %L7.fold.split
24+
%merge = phi double [ 1.000000e+00, %top ], [ %0, %L7.fold.split ]
25+
ret double %merge, !dbg !16
26+
27+
L13: ; preds = %top
28+
%3 = fmul double %0, %0, !dbg !17
29+
ret double %3, !dbg !19
30+
31+
L17: ; preds = %top
32+
%4 = fmul double %0, %0, !dbg !20
33+
%5 = fmul double %4, %0, !dbg !20
34+
ret double %5, !dbg !24
35+
36+
L20: ; preds = %top
37+
%6 = sitofp i64 %1 to double, !dbg !25
38+
%7 = call double @llvm.pow.f64(double %0, double %6), !dbg !27
39+
ret double %7, !dbg !27
40+
}
41+
42+
; Function Attrs: nofree nosync nounwind readnone speculatable willreturn
43+
declare double @llvm.pow.f64(double, double) #1
44+
45+
; Function Attrs: nounwind
46+
declare double @__enzyme_autodiff(double (double, i64)*, ...)
47+
48+
; Function Attrs: alwaysinline nosync readnone
49+
define double @julia_f_2794(double %0, i64 signext %1) local_unnamed_addr #2 !dbg !28 {
50+
entry:
51+
%2 = call fastcc double @julia___2797(double %0, i64 signext %1) #5, !dbg !29
52+
ret double %2
53+
}
54+
55+
define double @test_derivative(double %x, i64 %y) {
56+
entry:
57+
%0 = tail call double (double (double, i64)*, ...) @__enzyme_autodiff(double (double, i64)* nonnull @julia_f_2794, double %x, i64 %y)
58+
ret double %0
59+
}
60+
61+
; CHECK: define internal { double } @diffejulia_f_2794(double %0, i64 signext %1, double %differeturn) local_unnamed_addr #5 !dbg !35 {
62+
; CHECK-NEXT: entry:
63+
; CHECK-NEXT: %2 = sub i64 %1, 1
64+
; CHECK-NEXT: %3 = call fast fastcc double @julia___2797(double %0, i64 %2), !dbg !36
65+
; CHECK-NEXT: %4 = sitofp i64 %1 to double
66+
; CHECK-NEXT: %5 = fmul fast double %differeturn, %3
67+
; CHECK-NEXT: %6 = fmul fast double %5, %4
68+
; CHECK-NEXT: %7 = insertvalue { double } undef, double %6, 0
69+
; CHECK-NEXT: ret { double } %7
70+
; CHECK-NEXT: }
71+
72+
; Function Attrs: inaccessiblemem_or_argmemonly
73+
declare void @jl_gc_queue_root({} addrspace(10)*) #3
74+
75+
; Function Attrs: allocsize(1)
76+
declare noalias nonnull {} addrspace(10)* @jl_gc_pool_alloc(i8*, i32, i32) #4
77+
78+
; Function Attrs: allocsize(1)
79+
declare noalias nonnull {} addrspace(10)* @jl_gc_big_alloc(i8*, i64) #4
80+
81+
attributes #0 = { noinline nosync readnone "enzyme_math"="powi" "enzyme_shouldrecompute"="powi" "probe-stack"="inline-asm" }
82+
attributes #1 = { nofree nosync nounwind readnone speculatable willreturn }
83+
attributes #2 = { alwaysinline nosync readnone "probe-stack"="inline-asm" }
84+
attributes #3 = { inaccessiblemem_or_argmemonly }
85+
attributes #4 = { allocsize(1) }
86+
attributes #5 = { "probe-stack"="inline-asm" }
87+
88+
!llvm.module.flags = !{!0, !1}
89+
!llvm.dbg.cu = !{!2, !5}
90+
91+
!0 = !{i32 2, !"Dwarf Version", i32 4}
92+
!1 = !{i32 2, !"Debug Info Version", i32 3}
93+
!2 = distinct !DICompileUnit(language: DW_LANG_Julia, file: !3, producer: "julia", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly, enums: !4, nameTableKind: None)
94+
!3 = !DIFile(filename: "math.jl", directory: ".")
95+
!4 = !{}
96+
!5 = distinct !DICompileUnit(language: DW_LANG_Julia, file: !6, producer: "julia", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly, enums: !4, nameTableKind: None)
97+
!6 = !DIFile(filename: "REPL[3]", directory: ".")
98+
!7 = distinct !DISubprogram(name: "^", linkageName: "julia_^_2797", scope: null, file: !3, line: 922, type: !8, scopeLine: 922, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
99+
!8 = !DISubroutineType(types: !4)
100+
!9 = !DILocation(line: 923, scope: !7)
101+
!10 = !DILocation(line: 408, scope: !11, inlinedAt: !13)
102+
!11 = distinct !DISubprogram(name: "/;", linkageName: "/", scope: !12, file: !12, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
103+
!12 = !DIFile(filename: "float.jl", directory: ".")
104+
!13 = !DILocation(line: 243, scope: !14, inlinedAt: !9)
105+
!14 = distinct !DISubprogram(name: "inv;", linkageName: "inv", scope: !15, file: !15, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
106+
!15 = !DIFile(filename: "number.jl", directory: ".")
107+
!16 = !DILocation(line: 924, scope: !7)
108+
!17 = !DILocation(line: 405, scope: !18, inlinedAt: !19)
109+
!18 = distinct !DISubprogram(name: "*;", linkageName: "*", scope: !12, file: !12, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
110+
!19 = !DILocation(line: 926, scope: !7)
111+
!20 = !DILocation(line: 405, scope: !18, inlinedAt: !21)
112+
!21 = !DILocation(line: 655, scope: !22, inlinedAt: !24)
113+
!22 = distinct !DISubprogram(name: "*;", linkageName: "*", scope: !23, file: !23, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
114+
!23 = !DIFile(filename: "operators.jl", directory: ".")
115+
!24 = !DILocation(line: 927, scope: !7)
116+
!25 = !DILocation(line: 146, scope: !26, inlinedAt: !27)
117+
!26 = distinct !DISubprogram(name: "Float64;", linkageName: "Float64", scope: !12, file: !12, type: !8, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !4)
118+
!27 = !DILocation(line: 928, scope: !7)
119+
!28 = distinct !DISubprogram(name: "f", linkageName: "julia_f_2794", scope: null, file: !6, line: 1, type: !8, scopeLine: 1, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !5, retainedNodes: !4)
120+
!29 = !DILocation(line: 1, scope: !28, inlinedAt: !30)
121+
!30 = distinct !DILocation(line: 0, scope: !28)

0 commit comments

Comments
 (0)