Skip to content

[AutoDiff] Bump-pointer allocate pullback structs in loops. #34886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,9 @@ class ASTContext final {
/// Get the runtime availability of support for concurrency.
AvailabilityContext getConcurrencyAvailability();

/// Get the runtime availability of support for differentiation.
AvailabilityContext getDifferentiationAvailability();

/// Get the runtime availability of features introduced in the Swift 5.2
/// compiler for the target platform.
AvailabilityContext getSwift52Availability();
Expand Down
9 changes: 9 additions & 0 deletions include/swift/AST/Builtins.def
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,15 @@ BUILTIN_MISC_OPERATION_WITH_SILGEN(CreateAsyncTaskFuture,
/// is a pure value and therefore we can consider it as readnone).
BUILTIN_MISC_OPERATION_WITH_SILGEN(GlobalStringTablePointer, "globalStringTablePointer", "n", Special)

// autoDiffCreateLinearMapContext: (Builtin.Word) -> Builtin.NativeObject
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffCreateLinearMapContext, "autoDiffCreateLinearMapContext", "n", Special)

// autoDiffProjectTopLevelSubcontext: (Builtin.NativeObject) -> Builtin.RawPointer
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffProjectTopLevelSubcontext, "autoDiffProjectTopLevelSubcontext", "n", Special)

// autoDiffAllocateSubcontext: (Builtin.NativeObject, Builtin.Word) -> Builtin.RawPointer
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffAllocateSubcontext, "autoDiffAllocateSubcontext", "", Special)

#undef BUILTIN_MISC_OPERATION_WITH_SILGEN

#undef BUILTIN_MISC_OPERATION
Expand Down
24 changes: 24 additions & 0 deletions include/swift/Runtime/RuntimeFunctions.def
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,30 @@ FUNCTION(TaskCreateFutureFunc,
TaskContinuationFunctionPtrTy, SizeTy),
ATTRS(NoUnwind, ArgMemOnly))

// AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext(size_t);
FUNCTION(AutoDiffCreateLinearMapContext,
swift_autoDiffCreateLinearMapContext, SwiftCC,
DifferentiationAvailability,
RETURNS(RefCountedPtrTy),
ARGS(SizeTy),
ATTRS(NoUnwind, ArgMemOnly))

// void *swift_autoDiffProjectTopLevelSubcontext(AutoDiffLinearMapContext *);
FUNCTION(AutoDiffProjectTopLevelSubcontext,
swift_autoDiffProjectTopLevelSubcontext, SwiftCC,
DifferentiationAvailability,
RETURNS(Int8PtrTy),
ARGS(RefCountedPtrTy),
ATTRS(NoUnwind, ArgMemOnly))

// void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, size_t);
FUNCTION(AutoDiffAllocateSubcontext,
swift_autoDiffAllocateSubcontext, SwiftCC,
DifferentiationAvailability,
RETURNS(Int8PtrTy),
ARGS(RefCountedPtrTy, SizeTy),
ATTRS(NoUnwind, ArgMemOnly))

#undef RETURNS
#undef ARGS
#undef ATTRS
Expand Down
10 changes: 10 additions & 0 deletions include/swift/SILOptimizer/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ void extractAllElements(SILValue value, SILBuilder &builder,
void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
SILValue bufferAccess, SILLocation loc);

/// Emit a `Builtin.Word` value that represents the given type's memory layout
/// size.
SILValue emitMemoryLayoutSize(
SILBuilder &builder, SILLocation loc, CanType type);

/// Emit a projection of the top-level subcontext from the context object.
SILValue emitProjectTopLevelSubcontext(
SILBuilder &builder, SILLocation loc, SILValue context,
SILType subcontextType);

//===----------------------------------------------------------------------===//
// Utilities for looking up derivatives of functions
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 18 additions & 3 deletions include/swift/SILOptimizer/Differentiation/LinearMapInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class LinearMapInfo {
/// Activity info of the original function.
const DifferentiableActivityInfo &activityInfo;

/// The original function's loop info.
SILLoopInfo *loopInfo;

/// Differentiation indices of the function.
const SILAutoDiffIndices indices;

Expand All @@ -86,6 +89,9 @@ class LinearMapInfo {
/// Mapping from linear map structs to their branching trace enum fields.
llvm::DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields;

/// Blocks in a loop.
llvm::SmallSetVector<SILBasicBlock *, 4> blocksInLoop;

/// A synthesized file unit.
SynthesizedFileUnit &synthesizedFile;

Expand Down Expand Up @@ -144,7 +150,8 @@ class LinearMapInfo {
explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
SILFunction *original, SILFunction *derivative,
SILAutoDiffIndices indices,
const DifferentiableActivityInfo &activityInfo);
const DifferentiableActivityInfo &activityInfo,
SILLoopInfo *loopInfo);

/// Returns the linear map struct associated with the given original block.
StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const {
Expand Down Expand Up @@ -200,20 +207,28 @@ class LinearMapInfo {

/// Returns the branching trace enum field for the linear map struct of the
/// given original block.
VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) {
VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) const {
auto *linearMapStruct = getLinearMapStruct(origBB);
return linearMapStructEnumFields.lookup(linearMapStruct);
}

/// Finds the linear map declaration in the pullback struct for the given
/// `apply` instruction in the original function.
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) {
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) const {
assert(ai->getFunction() == original);
auto lookup = linearMapFieldMap.find(ai);
assert(lookup != linearMapFieldMap.end() &&
"No linear map field corresponding to the given `apply`");
return lookup->getSecond();
}

bool hasLoops() const {
return !blocksInLoop.empty();
}

ArrayRef<SILBasicBlock *> getBlocksInLoop() const {
return blocksInLoop.getArrayRef();
}
};

} // end namespace autodiff
Expand Down
2 changes: 2 additions & 0 deletions include/swift/SILOptimizer/Differentiation/VJPCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
#include "swift/SIL/LoopInfo.h"

namespace swift {
namespace autodiff {
Expand Down Expand Up @@ -52,6 +53,7 @@ class VJPCloner final {
const SILAutoDiffIndices getIndices() const;
DifferentiationInvoker getInvoker() const;
LinearMapInfo &getPullbackInfo() const;
SILLoopInfo *getLoopInfo() const;
const DifferentiableActivityInfo &getActivityInfo() const;

/// Performs VJP generation on the empty VJP function. Returns true if any
Expand Down
10 changes: 5 additions & 5 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,8 @@ static StringRef getPrivateDiscriminatorIfNecessary(const ValueDecl *decl) {

// Mangle non-local private declarations with a textual discriminator
// based on their enclosing file.
auto topLevelContext = decl->getDeclContext()->getModuleScopeContext();
auto fileUnit = cast<FileUnit>(topLevelContext);
auto topLevelSubcontext = decl->getDeclContext()->getModuleScopeContext();
auto fileUnit = cast<FileUnit>(topLevelSubcontext);

Identifier discriminator =
fileUnit->getDiscriminatorForPrivateValue(decl);
Expand Down Expand Up @@ -2900,17 +2900,17 @@ void ASTMangler::appendEntity(const ValueDecl *decl) {
void
ASTMangler::appendProtocolConformance(const ProtocolConformance *conformance) {
GenericSignature contextSig;
auto topLevelContext =
auto topLevelSubcontext =
conformance->getDeclContext()->getModuleScopeContext();
Mod = topLevelContext->getParentModule();
Mod = topLevelSubcontext->getParentModule();

auto conformingType = conformance->getType();
appendType(conformingType->getCanonicalType());

appendProtocolName(conformance->getProtocol());

bool needsModule = true;
if (auto *file = dyn_cast<FileUnit>(topLevelContext)) {
if (auto *file = dyn_cast<FileUnit>(topLevelSubcontext)) {
if (file->getKind() == FileUnitKind::ClangModule ||
file->getKind() == FileUnitKind::DWARFModule) {
if (conformance->getProtocol()->hasClangNode())
Expand Down
8 changes: 4 additions & 4 deletions lib/AST/ASTVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class Verifier : public ASTWalker {
typedef llvm::PointerIntPair<DeclContext *, 1, bool> ClosureDiscriminatorKey;
llvm::DenseMap<ClosureDiscriminatorKey, SmallBitVector>
ClosureDiscriminators;
DeclContext *CanonicalTopLevelContext = nullptr;
DeclContext *CanonicalTopLevelSubcontext = nullptr;

Verifier(PointerUnion<ModuleDecl *, SourceFile *> M, DeclContext *DC)
: M(M),
Expand Down Expand Up @@ -898,9 +898,9 @@ class Verifier : public ASTWalker {
DeclContext *getCanonicalDeclContext(DeclContext *DC) {
// All we really need to do is use a single TopLevelCodeDecl.
if (auto topLevel = dyn_cast<TopLevelCodeDecl>(DC)) {
if (!CanonicalTopLevelContext)
CanonicalTopLevelContext = topLevel;
return CanonicalTopLevelContext;
if (!CanonicalTopLevelSubcontext)
CanonicalTopLevelSubcontext = topLevel;
return CanonicalTopLevelSubcontext;
}

// TODO: check for uniqueness of initializer contexts?
Expand Down
4 changes: 4 additions & 0 deletions lib/AST/Availability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ AvailabilityContext ASTContext::getConcurrencyAvailability() {
return getSwiftFutureAvailability();
}

AvailabilityContext ASTContext::getDifferentiationAvailability() {
return getSwiftFutureAvailability();
}

AvailabilityContext ASTContext::getSwift52Availability() {
auto target = LangOpts.Target;

Expand Down
28 changes: 28 additions & 0 deletions lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,25 @@ static ValueDecl *getCreateAsyncTaskFuture(ASTContext &ctx, Identifier id) {
return builder.build(id);
}

static ValueDecl *getAutoDiffCreateLinearMapContext(ASTContext &ctx,
Identifier id) {
return getBuiltinFunction(
id, {BuiltinIntegerType::getWordType(ctx)}, ctx.TheNativeObjectType);
}

static ValueDecl *getAutoDiffProjectTopLevelSubcontext(ASTContext &ctx,
Identifier id) {
return getBuiltinFunction(
id, {ctx.TheNativeObjectType}, ctx.TheRawPointerType);
}

static ValueDecl *getAutoDiffAllocateSubcontext(ASTContext &ctx,
Identifier id) {
return getBuiltinFunction(
id, {ctx.TheNativeObjectType, BuiltinIntegerType::getWordType(ctx)},
ctx.TheRawPointerType);
}

static ValueDecl *getPoundAssert(ASTContext &Context, Identifier Id) {
auto int1Type = BuiltinIntegerType::get(1, Context);
auto optionalRawPointerType = BoundGenericEnumType::get(
Expand Down Expand Up @@ -2549,6 +2568,15 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {

case BuiltinValueKind::TriggerFallbackDiagnostic:
return getTriggerFallbackDiagnosticOperation(Context, Id);

case BuiltinValueKind::AutoDiffCreateLinearMapContext:
return getAutoDiffCreateLinearMapContext(Context, Id);

case BuiltinValueKind::AutoDiffProjectTopLevelSubcontext:
return getAutoDiffProjectTopLevelSubcontext(Context, Id);

case BuiltinValueKind::AutoDiffAllocateSubcontext:
return getAutoDiffAllocateSubcontext(Context, Id);
}

llvm_unreachable("bad builtin value!");
Expand Down
4 changes: 2 additions & 2 deletions lib/IDE/CodeCompletion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,7 @@ class CodeCompletionCallbacksImpl : public CodeCompletionCallbacks {
} // end anonymous namespace

namespace {
static bool isTopLevelContext(const DeclContext *DC) {
static bool isTopLevelSubcontext(const DeclContext *DC) {
for (; DC && DC->isLocalContext(); DC = DC->getParent()) {
switch (DC->getContextKind()) {
case DeclContextKind::TopLevelCodeDecl:
Expand Down Expand Up @@ -2139,7 +2139,7 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
if (CurrDeclContext && D->getModuleContext() == CurrModule) {
// Treat global variables from the same source file as local when
// completing at top-level.
if (isa<VarDecl>(D) && isTopLevelContext(CurrDeclContext) &&
if (isa<VarDecl>(D) && isTopLevelSubcontext(CurrDeclContext) &&
D->getDeclContext()->getParentSourceFile() ==
CurrDeclContext->getParentSourceFile()) {
return SemanticContextKind::Local;
Expand Down
22 changes: 22 additions & 0 deletions lib/IRGen/GenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1115,5 +1115,27 @@ if (Builtin.ID == BuiltinValueKind::id) { \
return;
}

if (Builtin.ID == BuiltinValueKind::AutoDiffCreateLinearMapContext) {
auto topLevelSubcontextSize = args.claimNext();
out.add(emitAutoDiffCreateLinearMapContext(IGF, topLevelSubcontextSize)
.getAddress());
return;
}

if (Builtin.ID == BuiltinValueKind::AutoDiffProjectTopLevelSubcontext) {
Address allocatorAddr(args.claimNext(), IGF.IGM.getPointerAlignment());
out.add(
emitAutoDiffProjectTopLevelSubcontext(IGF, allocatorAddr).getAddress());
return;
}

if (Builtin.ID == BuiltinValueKind::AutoDiffAllocateSubcontext) {
Address allocatorAddr(args.claimNext(), IGF.IGM.getPointerAlignment());
auto size = args.claimNext();
out.add(
emitAutoDiffAllocateSubcontext(IGF, allocatorAddr, size).getAddress());
return;
}

llvm_unreachable("IRGen unimplemented for this builtin!");
}
29 changes: 29 additions & 0 deletions lib/IRGen/GenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4595,3 +4595,32 @@ IRGenFunction::getFunctionPointerForResumeIntrinsic(llvm::Value *resume) {
PointerAuthInfo(), signature);
return fnPtr;
}

Address irgen::emitAutoDiffCreateLinearMapContext(
IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize) {
auto *call = IGF.Builder.CreateCall(
IGF.IGM.getAutoDiffCreateLinearMapContextFn(), {topLevelSubcontextSize});
call->setDoesNotThrow();
call->setCallingConv(IGF.IGM.SwiftCC);
return Address(call, IGF.IGM.getPointerAlignment());
}

Address irgen::emitAutoDiffProjectTopLevelSubcontext(
IRGenFunction &IGF, Address context) {
auto *call = IGF.Builder.CreateCall(
IGF.IGM.getAutoDiffProjectTopLevelSubcontextFn(),
{context.getAddress()});
call->setDoesNotThrow();
call->setCallingConv(IGF.IGM.SwiftCC);
return Address(call, IGF.IGM.getPointerAlignment());
}

Address irgen::emitAutoDiffAllocateSubcontext(
IRGenFunction &IGF, Address context, llvm::Value *size) {
auto *call = IGF.Builder.CreateCall(
IGF.IGM.getAutoDiffAllocateSubcontextFn(),
{context.getAddress(), size});
call->setDoesNotThrow();
call->setCallingConv(IGF.IGM.SwiftCC);
return Address(call, IGF.IGM.getPointerAlignment());
}
7 changes: 7 additions & 0 deletions lib/IRGen/GenCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,13 @@ namespace irgen {

void emitAsyncReturn(IRGenFunction &IGF, AsyncContextLayout &layout,
CanSILFunctionType fnType);

Address emitAutoDiffCreateLinearMapContext(
IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize);
Address emitAutoDiffProjectTopLevelSubcontext(
IRGenFunction &IGF, Address context);
Address emitAutoDiffAllocateSubcontext(
IRGenFunction &IGF, Address context, llvm::Value *size);
} // end namespace irgen
} // end namespace swift

Expand Down
8 changes: 8 additions & 0 deletions lib/IRGen/IRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,14 @@ namespace RuntimeConstants {
}
return RuntimeAvailability::AlwaysAvailable;
}

RuntimeAvailability DifferentiationAvailability(ASTContext &context) {
auto featureAvailability = context.getDifferentiationAvailability();
if (!isDeploymentAvailabilityContainedIn(context, featureAvailability)) {
return RuntimeAvailability::ConditionallyAvailable;
}
return RuntimeAvailability::AlwaysAvailable;
}
} // namespace RuntimeConstants

// We don't use enough attributes to justify generalizing the
Expand Down
3 changes: 3 additions & 0 deletions lib/SIL/IR/OperandOwnership.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,9 @@ CONSTANT_OWNERSHIP_BUILTIN(Owned, LifetimeEnding, UnsafeGuaranteed)
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, CancelAsyncTask)
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, CreateAsyncTask)
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, CreateAsyncTaskFuture)
CONSTANT_OWNERSHIP_BUILTIN(None, NonLifetimeEnding, AutoDiffCreateLinearMapContext)
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, AutoDiffAllocateSubcontext)
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, AutoDiffProjectTopLevelSubcontext)

#undef CONSTANT_OWNERSHIP_BUILTIN

Expand Down
3 changes: 3 additions & 0 deletions lib/SIL/IR/ValueOwnership.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,9 @@ CONSTANT_OWNERSHIP_BUILTIN(None, GetCurrentAsyncTask)
CONSTANT_OWNERSHIP_BUILTIN(None, CancelAsyncTask)
CONSTANT_OWNERSHIP_BUILTIN(Owned, CreateAsyncTask)
CONSTANT_OWNERSHIP_BUILTIN(Owned, CreateAsyncTaskFuture)
CONSTANT_OWNERSHIP_BUILTIN(Owned, AutoDiffCreateLinearMapContext)
CONSTANT_OWNERSHIP_BUILTIN(None, AutoDiffProjectTopLevelSubcontext)
CONSTANT_OWNERSHIP_BUILTIN(None, AutoDiffAllocateSubcontext)

#undef CONSTANT_OWNERSHIP_BUILTIN

Expand Down
2 changes: 2 additions & 0 deletions lib/SIL/Utils/MemAccessUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,8 @@ static void visitBuiltinAddress(BuiltinInst *builtin,
case BuiltinValueKind::CancelAsyncTask:
case BuiltinValueKind::CreateAsyncTask:
case BuiltinValueKind::CreateAsyncTaskFuture:
case BuiltinValueKind::AutoDiffCreateLinearMapContext:
case BuiltinValueKind::AutoDiffAllocateSubcontext:
return;

// General memory access to a pointer in first operand position.
Expand Down
Loading