Skip to content

Commit 4b6adec

Browse files
committed
[Serialization] Implement serialization for @differentiable attribute. (#17155)
Implement (de)serialization for all components of `@differentiable` attribute except the trailing where clause (which needs to be type-checked). This is a necessary step for the `#adjoint` expression to look up `@differentiable` attributes declared on functions in other modules correctly. Addresses SR-7977.
1 parent 786d586 commit 4b6adec

File tree

5 files changed

+164
-8
lines changed

5 files changed

+164
-8
lines changed

include/swift/AST/Attr.def

+1-3
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,8 @@ SIMPLE_DECL_ATTR(_usableFromInline, UsableFromInlineImport,
379379
77)
380380

381381
// SWIFT_ENABLE_TENSORFLOW
382-
// FIXME: Make it serialized
383382
DECL_ATTR(differentiable, Differentiable,
384-
OnFunc | LongAttribute | NotSerialized,
385-
/* Not serialized */ 78)
383+
OnFunc | LongAttribute, 78)
386384

387385
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
388386
OnFunc | OnConstructor, /* Not serialized */ 79)

include/swift/Serialization/ModuleFormat.h

+12-3
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const uint16_t VERSION_MAJOR = 0;
5555
/// describe what change you made. The content of this comment isn't important;
5656
/// it just ensures a conflict if two people change the module format.
5757
/// Don't worry about adhering to the 80-column limit for this line.
58-
const uint16_t VERSION_MINOR = 416; // SWIFT_ENABLE_TENSORFLOW: graph_op.
58+
const uint16_t VERSION_MINOR = 417; // SWIFT_ENABLE_TENSORFLOW: serialize @differentiable.
5959

6060
using DeclIDField = BCFixed<31>;
6161

@@ -1464,8 +1464,6 @@ namespace decls_block {
14641464
= BCRecordLayout<RestatedObjCConformance_DECL_ATTR>;
14651465
using ClangImporterSynthesizedTypeDeclAttrLayout
14661466
= BCRecordLayout<ClangImporterSynthesizedType_DECL_ATTR>;
1467-
// SWIFT_ENABLE_TENSORFLOW
1468-
using DifferentiableDeclAttrLayout = BCRecordLayout<Differentiable_DECL_ATTR>;
14691467

14701468
using InlineDeclAttrLayout = BCRecordLayout<
14711469
Inline_DECL_ATTR,
@@ -1522,6 +1520,17 @@ namespace decls_block {
15221520
BCFixed<1> // specialization kind
15231521
>;
15241522

1523+
// SWIFT_ENABLE_TENSORFLOW
1524+
using DifferentiableDeclAttrLayout = BCRecordLayout<
1525+
Differentiable_DECL_ATTR,
1526+
BCFixed<1>, // Differentiation mode ('forward' or 'reverse').
1527+
IdentifierIDField, // Primal name.
1528+
DeclIDField, // Primal function declaration.
1529+
IdentifierIDField, // Adjoint name.
1530+
DeclIDField, // Adjoint function declaration.
1531+
BCArray<BCFixed<32>> // Differentiation parameters.
1532+
>;
1533+
15251534
#define SIMPLE_DECL_ATTR(X, CLASS, ...) \
15261535
using CLASS##DeclAttrLayout = BCRecordLayout< \
15271536
CLASS##_DECL_ATTR, \

lib/Serialization/Deserialization.cpp

+46
Original file line numberDiff line numberDiff line change
@@ -2597,6 +2597,52 @@ ModuleFile::getDeclCheckedImpl(DeclID DID, Optional<DeclContext *> ForcedContext
25972597
break;
25982598
}
25992599

2600+
// SWIFT_ENABLE_TENSORFLOW
2601+
case decls_block::Differentiable_DECL_ATTR: {
2602+
AutoDiffMode autodiffMode = AutoDiffMode::Reverse;
2603+
unsigned autodiffModeValue;
2604+
uint64_t primalNameId;
2605+
DeclID primalDeclId;
2606+
uint64_t adjointNameId;
2607+
DeclID adjointDeclId;
2608+
ArrayRef<uint64_t> paramValues;
2609+
2610+
serialization::decls_block::DifferentiableDeclAttrLayout::readRecord(
2611+
scratch, autodiffModeValue, primalNameId, primalDeclId, adjointNameId,
2612+
adjointDeclId, paramValues);
2613+
autodiffMode = autodiffModeValue
2614+
? AutoDiffMode::Reverse
2615+
: AutoDiffMode::Forward;
2616+
2617+
using FuncSpecifier = DifferentiableAttr::FunctionSpecifier;
2618+
Optional<FuncSpecifier> primal;
2619+
FuncDecl *primalDecl = nullptr;
2620+
if (primalNameId != 0 && primalDeclId != 0) {
2621+
primal = { getIdentifier(primalNameId), DeclNameLoc() };
2622+
primalDecl = cast<FuncDecl>(getDecl(primalDeclId));
2623+
}
2624+
FuncSpecifier adjoint = { getIdentifier(adjointNameId), DeclNameLoc() };
2625+
FuncDecl *adjointDecl = cast<FuncDecl>(getDecl(adjointDeclId));
2626+
2627+
SmallVector<AutoDiffParameter, 4> parameters;
2628+
SourceLoc loc;
2629+
for (auto paramValue : paramValues) {
2630+
auto parameter = paramValue & 0x01
2631+
? AutoDiffParameter::getSelfParameter(loc)
2632+
: AutoDiffParameter::getIndexParameter(loc, paramValue >> 1);
2633+
parameters.push_back(parameter);
2634+
}
2635+
// TODO: Deserialize trailing where clause.
2636+
auto diffAttr =
2637+
DifferentiableAttr::create(ctx, loc, SourceRange(), autodiffMode,
2638+
loc, parameters, primal, adjoint,
2639+
/*TrailingWhereClause*/ nullptr);
2640+
diffAttr->setPrimalFunction(primalDecl);
2641+
diffAttr->setAdjointFunction(adjointDecl);
2642+
Attr = diffAttr;
2643+
break;
2644+
}
2645+
26002646
#define SIMPLE_DECL_ATTR(NAME, CLASS, ...) \
26012647
case decls_block::CLASS##_DECL_ATTR: { \
26022648
bool isImplicit; \

lib/Serialization/Serialization.cpp

+37-2
Original file line numberDiff line numberDiff line change
@@ -2222,8 +2222,6 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
22222222
case DAK_ObjCRuntimeName:
22232223
case DAK_RestatedObjCConformance:
22242224
case DAK_ClangImporterSynthesizedType:
2225-
// SWIFT_ENABLE_TENSORFLOW
2226-
case DAK_Differentiable:
22272225
llvm_unreachable("cannot serialize attribute");
22282226

22292227
case DAK_Count:
@@ -2381,6 +2379,43 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
23812379
writeGenericRequirements(SA->getRequirements(), DeclTypeAbbrCodes);
23822380
return;
23832381
}
2382+
2383+
// SWIFT_ENABLE_TENSORFLOW
2384+
case DAK_Differentiable: {
2385+
auto abbrCode = DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
2386+
auto attr = cast<DifferentiableAttr>(DA);
2387+
2388+
IdentifierID primalName = 0;
2389+
DeclID primalRef = 0;
2390+
if (auto primal = attr->getPrimal()) {
2391+
primalName = addDeclBaseNameRef(primal->Name.getBaseName());
2392+
primalRef = addDeclRef(attr->getPrimalFunction());
2393+
}
2394+
auto adjointName = addDeclBaseNameRef(attr->getAdjoint().Name.getBaseName());
2395+
auto adjointRef = addDeclRef(attr->getAdjointFunction());
2396+
2397+
SmallVector<uint32_t, 4> parameters;
2398+
for (auto param : attr->getParameters()) {
2399+
switch (param.getKind()) {
2400+
// The self parameter is uniquely identified by 0x01.
2401+
case AutoDiffParameter::Kind::Self:
2402+
parameters.push_back(1);
2403+
break;
2404+
// Index parameters are left-shifted by 1.
2405+
case AutoDiffParameter::Kind::Index:
2406+
parameters.push_back(param.getIndex() << 1);
2407+
break;
2408+
}
2409+
}
2410+
2411+
DifferentiableDeclAttrLayout::emitRecord(
2412+
Out, ScratchRecord, abbrCode, (unsigned) attr->getMode(), primalName,
2413+
primalRef, adjointName, adjointRef, parameters);
2414+
// TODO: Serialize trailing where clause.
2415+
// Type-checking where clause should be done first (mimicking the
2416+
// @_specialize attribute).
2417+
return;
2418+
}
23842419
}
23852420
}
23862421

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// SWIFT_ENABLE_TENSORFLOW
2+
// TODO: Handle trailing where clause in @differentiable attribute.
3+
4+
// RUN: %empty-directory(%t)
5+
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
6+
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s
7+
8+
struct CheckpointsFoo {}
9+
func pfoo(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) {
10+
return (CheckpointsFoo(), x * x)
11+
}
12+
func dfoo_checkpointed(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float {
13+
return 2 * x
14+
}
15+
// CHECK-DAG: @differentiable(reverse, primal: pfoo, adjoint: dfoo_checkpointed)
16+
// CHECK-DAG: func foo_checkpointed(_ x: Float) -> Float
17+
@differentiable(reverse, primal: pfoo(_:), adjoint: dfoo_checkpointed(_:checkpoints:originalValue:seed:))
18+
func foo_checkpointed(_ x: Float) -> Float {
19+
return x * x
20+
}
21+
22+
struct S<T> {
23+
struct Checkpoints {
24+
let s: S
25+
}
26+
func primal(x: Float) -> (Checkpoints, Float) {
27+
return (Checkpoints(s: self), x)
28+
}
29+
func adjoint_checkpointed(x: Float, _: Checkpoints, _: Float, _: Float) -> S {
30+
return self
31+
}
32+
33+
// CHECK-DAG: @differentiable(reverse, (self), primal: primal, adjoint: adjoint_checkpointed)
34+
// CHECK-DAG: func original(x: Float) -> Float
35+
@differentiable(reverse, withRespectTo: (self), primal: primal, adjoint: adjoint_checkpointed)
36+
func original(x: Float) -> Float {
37+
return x
38+
}
39+
}
40+
41+
func pbaz1<T>(_ x: T, _ y: T) -> ((T, T), T) {
42+
return ((y, y), x)
43+
}
44+
func dbaz1_checkpointed<T>(_ x: T, _ y: T, primal: (T, T), originalValue: T, seed: T) -> (T, T) {
45+
return (y, x)
46+
}
47+
// CHECK-DAG: @differentiable(reverse, primal: pbaz1, adjoint: dbaz1_checkpointed)
48+
// CHECK-DAG: func baz1_checkpointed<T>(_ x: T, _ y: T) -> T
49+
@differentiable(reverse, primal: pbaz1(_:_:), adjoint: dbaz1_checkpointed(_:_:primal:originalValue:seed:))
50+
func baz1_checkpointed<T>(_ x: T, _ y: T) -> T {
51+
return x
52+
}
53+
54+
struct CheckpointsFP<T : FloatingPoint> {
55+
let meow: T
56+
}
57+
func pbaz2<T : FloatingPoint>(_ x: T, _ y: T) -> (CheckpointsFP<T>, T) {
58+
return (CheckpointsFP(meow: 1), x + y)
59+
}
60+
func dbaz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T, primal: CheckpointsFP<T>, originalValue: T, seed: T) -> (T, T) {
61+
return (1, 1)
62+
}
63+
// CHECK-DAG: @differentiable(reverse, primal: pbaz2, adjoint: dbaz2_checkpointed)
64+
// CHECK-DAG: func baz2_checkpointed<T>(_ x: T, _ y: T) -> T where T : FloatingPoint
65+
@differentiable(reverse, primal: pbaz2(_:_:), adjoint: dbaz2_checkpointed(_:_:primal:originalValue:seed:))
66+
func baz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T) -> T {
67+
return x
68+
}

0 commit comments

Comments
 (0)