Skip to content

Commit 2ccc0b3

Browse files
committed
Fixes memory leaks in autodiff linear map context allocation builtins
Fixes swiftlang#67323
1 parent 62fce8a commit 2ccc0b3

13 files changed

+245
-102
lines changed

include/swift/AST/Builtins.def

+2-2
Original file line numberDiff line numberDiff line change
@@ -985,13 +985,13 @@ BUILTIN_MISC_OPERATION_WITH_SILGEN(CreateAsyncTaskInGroup,
985985
/// is a pure value and therefore we can consider it as readnone).
986986
BUILTIN_MISC_OPERATION_WITH_SILGEN(GlobalStringTablePointer, "globalStringTablePointer", "n", Special)
987987

988-
// autoDiffCreateLinearMapContext: (Builtin.Word) -> Builtin.NativeObject
988+
// autoDiffCreateLinearMapContext: (T.Type) -> Builtin.NativeObject
989989
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffCreateLinearMapContext, "autoDiffCreateLinearMapContext", "", Special)
990990

991991
// autoDiffProjectTopLevelSubcontext: (Builtin.NativeObject) -> Builtin.RawPointer
992992
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffProjectTopLevelSubcontext, "autoDiffProjectTopLevelSubcontext", "n", Special)
993993

994-
// autoDiffAllocateSubcontext: (Builtin.NativeObject, Builtin.Word) -> Builtin.RawPointer
994+
// autoDiffAllocateSubcontext: (Builtin.NativeObject, T.Type) -> Builtin.RawPointer
995995
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffAllocateSubcontext, "autoDiffAllocateSubcontext", "", Special)
996996

997997
/// Build a Builtin.Executor value from an "ordinary" serial executor

include/swift/Runtime/RuntimeFunctions.def

+4-4
Original file line numberDiff line numberDiff line change
@@ -2273,12 +2273,12 @@ FUNCTION(TaskGroupDestroy,
22732273
ATTRS(NoUnwind),
22742274
EFFECT(Concurrency))
22752275

2276-
// AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext(size_t);
2276+
// AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext(const Metadata *);
22772277
FUNCTION(AutoDiffCreateLinearMapContext,
22782278
swift_autoDiffCreateLinearMapContext, SwiftCC,
22792279
DifferentiationAvailability,
22802280
RETURNS(RefCountedPtrTy),
2281-
ARGS(SizeTy),
2281+
ARGS(TypeMetadataPtrTy),
22822282
ATTRS(NoUnwind, ArgMemOnly),
22832283
EFFECT(AutoDiff))
22842284

@@ -2291,12 +2291,12 @@ FUNCTION(AutoDiffProjectTopLevelSubcontext,
22912291
ATTRS(NoUnwind, ArgMemOnly),
22922292
EFFECT(AutoDiff))
22932293

2294-
// void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, size_t);
2294+
// void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, const Metadata *);
22952295
FUNCTION(AutoDiffAllocateSubcontext,
22962296
swift_autoDiffAllocateSubcontext, SwiftCC,
22972297
DifferentiationAvailability,
22982298
RETURNS(Int8PtrTy),
2299-
ARGS(RefCountedPtrTy, SizeTy),
2299+
ARGS(RefCountedPtrTy, TypeMetadataPtrTy),
23002300
ATTRS(NoUnwind, ArgMemOnly),
23012301
EFFECT(AutoDiff))
23022302

lib/AST/Builtins.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -1608,8 +1608,9 @@ static ValueDecl *getBuildComplexEqualitySerialExecutorRef(ASTContext &ctx,
16081608

16091609
static ValueDecl *getAutoDiffCreateLinearMapContext(ASTContext &ctx,
16101610
Identifier id) {
1611-
return getBuiltinFunction(
1612-
id, {BuiltinIntegerType::getWordType(ctx)}, ctx.TheNativeObjectType);
1611+
return getBuiltinFunction(ctx, id, _thin, _generics(_unrestricted),
1612+
_parameters(_metatype(_typeparam(0))),
1613+
_nativeObject);
16131614
}
16141615

16151616
static ValueDecl *getAutoDiffProjectTopLevelSubcontext(ASTContext &ctx,
@@ -1621,8 +1622,8 @@ static ValueDecl *getAutoDiffProjectTopLevelSubcontext(ASTContext &ctx,
16211622
static ValueDecl *getAutoDiffAllocateSubcontext(ASTContext &ctx,
16221623
Identifier id) {
16231624
return getBuiltinFunction(
1624-
id, {ctx.TheNativeObjectType, BuiltinIntegerType::getWordType(ctx)},
1625-
ctx.TheRawPointerType);
1625+
ctx, id, _thin, _generics(_unrestricted),
1626+
_parameters(_nativeObject, _metatype(_typeparam(0))), _rawPointer);
16261627
}
16271628

16281629
static ValueDecl *getPoundAssert(ASTContext &Context, Identifier Id) {

lib/IRGen/GenBuiltin.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -1308,8 +1308,8 @@ void irgen::emitBuiltinCall(IRGenFunction &IGF, const BuiltinInfo &Builtin,
13081308
}
13091309

13101310
if (Builtin.ID == BuiltinValueKind::AutoDiffCreateLinearMapContext) {
1311-
auto topLevelSubcontextSize = args.claimNext();
1312-
out.add(emitAutoDiffCreateLinearMapContext(IGF, topLevelSubcontextSize)
1311+
auto topLevelSubcontextMetaType = args.claimNext();
1312+
out.add(emitAutoDiffCreateLinearMapContext(IGF, topLevelSubcontextMetaType)
13131313
.getAddress());
13141314
return;
13151315
}
@@ -1325,9 +1325,10 @@ void irgen::emitBuiltinCall(IRGenFunction &IGF, const BuiltinInfo &Builtin,
13251325
if (Builtin.ID == BuiltinValueKind::AutoDiffAllocateSubcontext) {
13261326
Address allocatorAddr(args.claimNext(), IGF.IGM.RefCountedStructTy,
13271327
IGF.IGM.getPointerAlignment());
1328-
auto size = args.claimNext();
1328+
auto subcontextMetatype = args.claimNext();
13291329
out.add(
1330-
emitAutoDiffAllocateSubcontext(IGF, allocatorAddr, size).getAddress());
1330+
emitAutoDiffAllocateSubcontext(IGF, allocatorAddr, subcontextMetatype)
1331+
.getAddress());
13311332
return;
13321333
}
13331334

lib/IRGen/GenCall.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -5480,10 +5480,12 @@ IRGenFunction::getFunctionPointerForResumeIntrinsic(llvm::Value *resume) {
54805480
}
54815481

54825482
Address irgen::emitAutoDiffCreateLinearMapContext(
5483-
IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize) {
5483+
IRGenFunction &IGF, llvm::Value *topLevelSubcontextMetatype) {
5484+
topLevelSubcontextMetatype = IGF.Builder.CreateBitCast(
5485+
topLevelSubcontextMetatype, IGF.IGM.TypeMetadataPtrTy);
54845486
auto *call = IGF.Builder.CreateCall(
54855487
IGF.IGM.getAutoDiffCreateLinearMapContextFunctionPointer(),
5486-
{topLevelSubcontextSize});
5488+
{topLevelSubcontextMetatype});
54875489
call->setDoesNotThrow();
54885490
call->setCallingConv(IGF.IGM.SwiftCC);
54895491
return Address(call, IGF.IGM.RefCountedStructTy,
@@ -5500,11 +5502,14 @@ Address irgen::emitAutoDiffProjectTopLevelSubcontext(
55005502
return Address(call, IGF.IGM.Int8Ty, IGF.IGM.getPointerAlignment());
55015503
}
55025504

5503-
Address irgen::emitAutoDiffAllocateSubcontext(
5504-
IRGenFunction &IGF, Address context, llvm::Value *size) {
5505+
Address irgen::emitAutoDiffAllocateSubcontext(IRGenFunction &IGF,
5506+
Address context,
5507+
llvm::Value *subcontextMetatype) {
5508+
subcontextMetatype =
5509+
IGF.Builder.CreateBitCast(subcontextMetatype, IGF.IGM.TypeMetadataPtrTy);
55055510
auto *call = IGF.Builder.CreateCall(
55065511
IGF.IGM.getAutoDiffAllocateSubcontextFunctionPointer(),
5507-
{context.getAddress(), size});
5512+
{context.getAddress(), subcontextMetatype});
55085513
call->setDoesNotThrow();
55095514
call->setCallingConv(IGF.IGM.SwiftCC);
55105515
return Address(call, IGF.IGM.Int8Ty, IGF.IGM.getPointerAlignment());

lib/IRGen/GenCall.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,13 @@ namespace irgen {
261261
CanSILFunctionType fnType, Explosion &result,
262262
Explosion &error);
263263

264-
Address emitAutoDiffCreateLinearMapContext(
265-
IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize);
264+
Address
265+
emitAutoDiffCreateLinearMapContext(IRGenFunction &IGF,
266+
llvm::Value *topLevelSubcontextMetatype);
266267
Address emitAutoDiffProjectTopLevelSubcontext(
267268
IRGenFunction &IGF, Address context);
268-
Address emitAutoDiffAllocateSubcontext(
269-
IRGenFunction &IGF, Address context, llvm::Value *size);
269+
Address emitAutoDiffAllocateSubcontext(IRGenFunction &IGF, Address context,
270+
llvm::Value *subcontextMetatype);
270271

271272
FunctionPointer getFunctionPointerForDispatchCall(IRGenModule &IGM,
272273
const FunctionPointer &fn);

lib/SILOptimizer/Differentiation/VJPCloner.cpp

+22-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
#define DEBUG_TYPE "differentiation"
1919

20+
#include "swift/AST/Types.h"
21+
2022
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
2123
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2224
#include "swift/SILOptimizer/Differentiation/ADContext.h"
@@ -118,15 +120,21 @@ class VJPCloner::Implementation final
118120
auto pullbackTupleType =
119121
remapASTType(pullbackInfo.getLinearMapTupleType(returnBB)->getCanonicalType());
120122
Builder.setInsertionPoint(vjp->getEntryBlock());
121-
auto topLevelSubcontextSize = emitMemoryLayoutSize(
122-
Builder, original->getLocation(), pullbackTupleType);
123+
124+
auto pbTupleMetatypeType =
125+
CanMetatypeType::get(pullbackTupleType, MetatypeRepresentation::Thick);
126+
auto pbTupleMetatypeSILType =
127+
SILType::getPrimitiveObjectType(pbTupleMetatypeType);
128+
auto pbTupleMetatype =
129+
Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType);
130+
123131
// Create an context.
124132
pullbackContextValue = Builder.createBuiltin(
125133
original->getLocation(),
126134
getASTContext().getIdentifier(
127135
getBuiltinName(BuiltinValueKind::AutoDiffCreateLinearMapContext)),
128-
SILType::getNativeObjectType(getASTContext()),
129-
SubstitutionMap(), {topLevelSubcontextSize});
136+
SILType::getNativeObjectType(getASTContext()), SubstitutionMap(),
137+
{pbTupleMetatype});
130138
borrowedPullbackContextValue = Builder.createBeginBorrow(
131139
original->getLocation(), pullbackContextValue);
132140
LLVM_DEBUG(getADDebugStream()
@@ -1067,14 +1075,21 @@ EnumInst *VJPCloner::Implementation::buildPredecessorEnumValue(
10671075
assert(enumEltType == rawPtrType);
10681076
auto pbTupleType =
10691077
remapASTType(pullbackInfo.getLinearMapTupleType(predBB)->getCanonicalType());
1070-
SILValue pbTupleSize =
1071-
emitMemoryLayoutSize(Builder, loc, pbTupleType);
1078+
1079+
auto pbTupleMetatypeType =
1080+
CanMetatypeType::get(pbTupleType, MetatypeRepresentation::Thick);
1081+
auto pbTupleMetatypeSILType =
1082+
SILType::getPrimitiveObjectType(pbTupleMetatypeType);
1083+
auto pbTupleMetatype =
1084+
Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType);
1085+
10721086
auto rawBufferValue = builder.createBuiltin(
10731087
loc,
10741088
getASTContext().getIdentifier(
10751089
getBuiltinName(BuiltinValueKind::AutoDiffAllocateSubcontext)),
10761090
rawPtrType, SubstitutionMap(),
1077-
{borrowedPullbackContextValue, pbTupleSize});
1091+
{borrowedPullbackContextValue, pbTupleMetatype});
1092+
10781093
auto typedBufferValue =
10791094
builder.createPointerToAddress(
10801095
loc, rawBufferValue, pbTupleVal->getType().getAddressType(),

stdlib/public/runtime/AutoDiffSupport.cpp

+47-23
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,21 @@
1313
#include "AutoDiffSupport.h"
1414
#include "swift/ABI/Metadata.h"
1515
#include "swift/Runtime/HeapObject.h"
16-
16+
#include "llvm/ADT/SmallVector.h"
1717
#include <new>
1818

1919
using namespace swift;
2020
using namespace llvm;
2121

2222
SWIFT_CC(swift)
2323
static void destroyLinearMapContext(SWIFT_CONTEXT HeapObject *obj) {
24-
static_cast<AutoDiffLinearMapContext *>(obj)->~AutoDiffLinearMapContext();
24+
auto *linearMapContext = static_cast<AutoDiffLinearMapContext *>(obj);
25+
26+
for (auto *heapObjectPtr : linearMapContext->getAllocatedHeapObjects()) {
27+
swift_release(heapObjectPtr);
28+
}
29+
30+
linearMapContext->~AutoDiffLinearMapContext();
2531
free(obj);
2632
}
2733

@@ -43,36 +49,54 @@ static FullMetadata<HeapMetadata> linearMapContextHeapMetadata = {
4349
}
4450
};
4551

46-
AutoDiffLinearMapContext::AutoDiffLinearMapContext()
52+
AutoDiffLinearMapContext::AutoDiffLinearMapContext(
53+
OpaqueValue *const topLevelLinearMapContextProjection)
4754
: HeapObject(&linearMapContextHeapMetadata) {
55+
this->topLevelLinearMapContextProjection = topLevelLinearMapContextProjection;
4856
}
4957

50-
void *AutoDiffLinearMapContext::projectTopLevelSubcontext() const {
51-
auto offset = alignTo(
52-
sizeof(AutoDiffLinearMapContext), alignof(AutoDiffLinearMapContext));
53-
return const_cast<uint8_t *>(
54-
reinterpret_cast<const uint8_t *>(this) + offset);
55-
}
58+
AutoDiffLinearMapContext *swift::swift_autoDiffCreateLinearMapContext(
59+
const Metadata *topLevelLinearMapContextMetadata) {
60+
// Linear map context metadata must have non-null value witnesses
61+
assert(topLevelLinearMapContextMetadata->getValueWitnesses());
5662

57-
void *AutoDiffLinearMapContext::allocate(size_t size) {
58-
return allocator.Allocate(size, alignof(AutoDiffLinearMapContext));
59-
}
63+
// Allocate a box for the top-level linear map context
64+
auto [topLevelContextHeapObjectPtr, toplevelContextProjection] =
65+
swift_allocBox(topLevelLinearMapContextMetadata);
6066

61-
AutoDiffLinearMapContext *swift::swift_autoDiffCreateLinearMapContext(
62-
size_t topLevelLinearMapStructSize) {
63-
auto allocationSize = alignTo(
64-
sizeof(AutoDiffLinearMapContext), alignof(AutoDiffLinearMapContext))
65-
+ topLevelLinearMapStructSize;
66-
auto *buffer = (AutoDiffLinearMapContext *)malloc(allocationSize);
67-
return ::new (buffer) AutoDiffLinearMapContext;
67+
// Create a linear map context object that stores the projection
68+
// for the top level context
69+
auto linearMapContext =
70+
new AutoDiffLinearMapContext(toplevelContextProjection);
71+
72+
// Stash away the `HeapObject` pointer for the allocated context
73+
// for proper "release" during clean up.
74+
linearMapContext->storeAllocatedHeapObjectPtr(topLevelContextHeapObjectPtr);
75+
76+
// Return the newly created linear map context object
77+
return linearMapContext;
6878
}
6979

7080
void *swift::swift_autoDiffProjectTopLevelSubcontext(
71-
AutoDiffLinearMapContext *allocator) {
72-
return allocator->projectTopLevelSubcontext();
81+
AutoDiffLinearMapContext *linearMapContext) {
82+
return static_cast<void *>(
83+
linearMapContext->getTopLevelLinearMapContextProjection());
7384
}
7485

7586
void *swift::swift_autoDiffAllocateSubcontext(
76-
AutoDiffLinearMapContext *allocator, size_t size) {
77-
return allocator->allocate(size);
87+
AutoDiffLinearMapContext *linearMapContext,
88+
const Metadata *linearMapSubcontextMetadata) {
89+
// Linear map context metadata must have non-null value witnesses
90+
assert(linearMapSubcontextMetadata->getValueWitnesses());
91+
92+
// Allocate a box for the linear map subcontext
93+
auto [subcontextHeapObjectPtr, subcontextProjection] =
94+
swift_allocBox(linearMapSubcontextMetadata);
95+
96+
// Stash away the `HeapObject` pointer for the allocated context
97+
// for proper "release" during clean up.
98+
linearMapContext->storeAllocatedHeapObjectPtr(subcontextHeapObjectPtr);
99+
100+
// Return the subcontext projection
101+
return static_cast<void *>(subcontextProjection);
78102
}

stdlib/public/runtime/AutoDiffSupport.h

+46-19
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,71 @@
1313
#ifndef SWIFT_RUNTIME_AUTODIFF_SUPPORT_H
1414
#define SWIFT_RUNTIME_AUTODIFF_SUPPORT_H
1515

16-
#include "swift/Runtime/HeapObject.h"
1716
#include "swift/Runtime/Config.h"
18-
#include "llvm/Support/Allocator.h"
17+
#include "swift/Runtime/HeapObject.h"
18+
#include "llvm/ADT/ArrayRef.h"
19+
#include "llvm/ADT/SmallVector.h"
1920

2021
namespace swift {
21-
2222
/// A data structure responsible for efficiently allocating closure contexts for
2323
/// linear maps such as pullbacks, including recursive branching trace enum
2424
/// case payloads.
2525
class AutoDiffLinearMapContext : public HeapObject {
2626
private:
27-
/// The underlying allocator.
28-
// TODO: Use a custom allocator so that the initial slab can be
29-
// tail-allocated.
30-
llvm::BumpPtrAllocator allocator;
27+
// TODO: Commenting out BumpPtrAllocator temporarily
28+
// until we move away from the interim solution of allocating
29+
// boxes for linear map contexts/subcontexts.
30+
//
31+
// /// The underlying allocator.
32+
// // TODO: Use a custom allocator so that the initial slab can be
33+
// // tail-allocated.
34+
// llvm::BumpPtrAllocator allocator;
35+
36+
/// A projection/pointer to the memory storing the
37+
/// top-level linear map context object.
38+
OpaqueValue *topLevelLinearMapContextProjection;
39+
40+
/// Storage for `HeapObject` pointers to the linear map
41+
/// context and subcontexts allocated for derivatives with
42+
/// loops.
43+
llvm::SmallVector<HeapObject *, 4> allocatedHeapObjects;
3144

3245
public:
3346
/// Creates a linear map context.
34-
AutoDiffLinearMapContext();
35-
/// Returns the address of the tail-allocated top-level subcontext.
36-
void *projectTopLevelSubcontext() const;
37-
/// Allocates memory for a new subcontext.
38-
void *allocate(size_t size);
47+
AutoDiffLinearMapContext(OpaqueValue *const);
48+
49+
// TODO: Commenting out BumpPtrAllocator temporarily
50+
// until we move away from the interim solution of allocating
51+
// boxes for linear map contexts/subcontexts.
52+
//
53+
// llvm::BumpPtrAllocator& getAllocator() const {
54+
// return const_cast<llvm::BumpPtrAllocator&>(this->allocator);
55+
// }
56+
57+
OpaqueValue *getTopLevelLinearMapContextProjection() const {
58+
return this->topLevelLinearMapContextProjection;
59+
}
60+
61+
llvm::ArrayRef<HeapObject *> getAllocatedHeapObjects() const {
62+
return this->allocatedHeapObjects;
63+
}
64+
65+
void storeAllocatedHeapObjectPtr(HeapObject *allocatedHeapObjectPtr) {
66+
this->allocatedHeapObjects.push_back(allocatedHeapObjectPtr);
67+
}
3968
};
4069

4170
/// Creates a linear map context with a tail-allocated top-level subcontext.
4271
SWIFT_RUNTIME_EXPORT SWIFT_CC(swift)
43-
AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext(
44-
size_t topLevelSubcontextSize);
72+
AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext(
73+
const Metadata *topLevelLinearMapContextMetadata);
4574

4675
/// Returns the address of the tail-allocated top-level subcontext.
4776
SWIFT_RUNTIME_EXPORT SWIFT_CC(swift)
4877
void *swift_autoDiffProjectTopLevelSubcontext(AutoDiffLinearMapContext *);
4978

5079
/// Allocates memory for a new subcontext.
51-
SWIFT_RUNTIME_EXPORT SWIFT_CC(swift)
52-
void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, size_t size);
53-
54-
}
55-
80+
SWIFT_RUNTIME_EXPORT SWIFT_CC(swift) void *swift_autoDiffAllocateSubcontext(
81+
AutoDiffLinearMapContext *, const Metadata *linearMapSubcontextMetadata);
82+
} // namespace swift
5683
#endif /* SWIFT_RUNTIME_AUTODIFF_SUPPORT_H */

0 commit comments

Comments
 (0)