Skip to content

Commit 3dab78c

Browse files
authored
Fix struct float virtual fn (rust-lang#842)
1 parent 72b7ad0 commit 3dab78c

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3796,7 +3796,10 @@ Constant *GradientUtils::GetOrCreateShadowFunction(
37963796
std::vector<DIFFE_TYPE> types;
37973797
for (auto &a : fn->args()) {
37983798
uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy();
3799-
type_args.Arguments.insert(std::pair<Argument *, TypeTree>(&a, {}));
3799+
TypeTree TT;
3800+
if (a.getType()->isFPOrFPVectorTy())
3801+
TT.insert({-1}, ConcreteType(a.getType()->getScalarType()));
3802+
type_args.Arguments.insert(std::pair<Argument *, TypeTree>(&a, TT));
38003803
type_args.KnownValues.insert(
38013804
std::pair<Argument *, std::set<int64_t>>(&a, {}));
38023805
DIFFE_TYPE typ;
@@ -3818,11 +3821,26 @@ Constant *GradientUtils::GetOrCreateShadowFunction(
38183821
mode != DerivativeMode::ForwardMode
38193822
? DIFFE_TYPE::OUT_DIFF
38203823
: DIFFE_TYPE::DUP_ARG;
3824+
38213825
if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() ||
38223826
(fn->getReturnType()->isIntegerTy() &&
38233827
cast<IntegerType>(fn->getReturnType())->getBitWidth() < 16))
38243828
retType = DIFFE_TYPE::CONSTANT;
38253829

3830+
if (mode != DerivativeMode::ForwardMode && retType == DIFFE_TYPE::DUP_ARG) {
3831+
if (auto ST = dyn_cast<StructType>(fn->getReturnType())) {
3832+
size_t numflt = 0;
3833+
3834+
for (unsigned i = 0; i < ST->getNumElements(); ++i) {
3835+
auto midTy = ST->getElementType(i);
3836+
if (midTy->isFPOrFPVectorTy())
3837+
numflt++;
3838+
}
3839+
if (numflt == ST->getNumElements())
3840+
retType = DIFFE_TYPE::OUT_DIFF;
3841+
}
3842+
}
3843+
38263844
switch (mode) {
38273845
case DerivativeMode::ForwardMode: {
38283846
Constant *newf = Logic.CreateForwardDiff(
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; ModuleID = 'ptr12.ll'
4+
source_filename = "ld-temp.o"
5+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
6+
target triple = "x86_64-unknown-linux-gnu"
7+
8+
define { double } @foo2(double %i13) {
9+
bb:
10+
%i151 = insertvalue { double } undef, double %i13, 0
11+
ret { double } %i151
12+
}
13+
14+
declare i8* @_Z17__enzyme_virtualreversePv(...)
15+
16+
define void @caller() {
17+
%1 = tail call i8* (...) @_Z17__enzyme_virtualreversePv({ double } (double)* nonnull @foo2)
18+
ret void
19+
}
20+
21+
; CHECK: define internal { i8*, { double } } @augmented_foo2(double %i13)
22+
; CHECK-NEXT: bb:
23+
; CHECK-NEXT: %0 = alloca { i8*, { double } }
24+
; CHECK-NEXT: %1 = getelementptr inbounds { i8*, { double } }, { i8*, { double } }* %0, i32 0, i32 0
25+
; CHECK-NEXT: store i8* null, i8** %1
26+
; CHECK-NEXT: %i151 = insertvalue { double } undef, double %i13, 0
27+
; CHECK-NEXT: %2 = getelementptr inbounds { i8*, { double } }, { i8*, { double } }* %0, i32 0, i32 1
28+
; CHECK-NEXT: store { double } %i151, { double }* %2
29+
; CHECK-NEXT: %3 = load { i8*, { double } }, { i8*, { double } }* %0
30+
; CHECK-NEXT: ret { i8*, { double } } %3
31+
; CHECK-NEXT: }
32+
33+
; CHECK: define internal { double } @diffefoo2(double %i13, { double } %differeturn, i8* %tapeArg)
34+
; CHECK-NEXT: bb:
35+
; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg)
36+
; CHECK-NEXT: ret { double } %differeturn
37+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)