17
17
18
18
#define DEBUG_TYPE " differentiation"
19
19
20
+ #include " swift/AST/Types.h"
21
+
20
22
#include " swift/SILOptimizer/Differentiation/VJPCloner.h"
21
23
#include " swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
22
24
#include " swift/SILOptimizer/Differentiation/ADContext.h"
@@ -118,15 +120,21 @@ class VJPCloner::Implementation final
118
120
auto pullbackTupleType =
119
121
remapASTType (pullbackInfo.getLinearMapTupleType (returnBB)->getCanonicalType ());
120
122
Builder.setInsertionPoint (vjp->getEntryBlock ());
121
- auto topLevelSubcontextSize = emitMemoryLayoutSize (
122
- Builder, original->getLocation (), pullbackTupleType);
123
+
124
+ auto pbTupleMetatypeType =
125
+ CanMetatypeType::get (pullbackTupleType, MetatypeRepresentation::Thick);
126
+ auto pbTupleMetatypeSILType =
127
+ SILType::getPrimitiveObjectType (pbTupleMetatypeType);
128
+ auto pbTupleMetatype =
129
+ Builder.createMetatype (original->getLocation (), pbTupleMetatypeSILType);
130
+
123
131
// Create an context.
124
132
pullbackContextValue = Builder.createBuiltin (
125
133
original->getLocation (),
126
- getASTContext ().getIdentifier (
127
- getBuiltinName ( BuiltinValueKind::AutoDiffCreateLinearMapContext )),
128
- SILType::getNativeObjectType (getASTContext ()),
129
- SubstitutionMap (), {topLevelSubcontextSize });
134
+ getASTContext ().getIdentifier (getBuiltinName (
135
+ BuiltinValueKind::AutoDiffCreateLinearMapContextWithType )),
136
+ SILType::getNativeObjectType (getASTContext ()), SubstitutionMap (),
137
+ {pbTupleMetatype });
130
138
borrowedPullbackContextValue = Builder.createBeginBorrow (
131
139
original->getLocation (), pullbackContextValue);
132
140
LLVM_DEBUG (getADDebugStream ()
@@ -148,8 +156,8 @@ class VJPCloner::Implementation final
148
156
return builtinAutoDiffAllocateSubcontextGenericSignature;
149
157
auto &ctx = getASTContext ();
150
158
auto *decl = cast<FuncDecl>(getBuiltinValueDecl (
151
- ctx, ctx.getIdentifier (
152
- getBuiltinName ( BuiltinValueKind::AutoDiffAllocateSubcontext ))));
159
+ ctx, ctx.getIdentifier (getBuiltinName (
160
+ BuiltinValueKind::AutoDiffAllocateSubcontextWithType ))));
153
161
builtinAutoDiffAllocateSubcontextGenericSignature =
154
162
decl->getGenericSignature ();
155
163
assert (builtinAutoDiffAllocateSubcontextGenericSignature);
@@ -1067,14 +1075,21 @@ EnumInst *VJPCloner::Implementation::buildPredecessorEnumValue(
1067
1075
assert (enumEltType == rawPtrType);
1068
1076
auto pbTupleType =
1069
1077
remapASTType (pullbackInfo.getLinearMapTupleType (predBB)->getCanonicalType ());
1070
- SILValue pbTupleSize =
1071
- emitMemoryLayoutSize (Builder, loc, pbTupleType);
1078
+
1079
+ auto pbTupleMetatypeType =
1080
+ CanMetatypeType::get (pbTupleType, MetatypeRepresentation::Thick);
1081
+ auto pbTupleMetatypeSILType =
1082
+ SILType::getPrimitiveObjectType (pbTupleMetatypeType);
1083
+ auto pbTupleMetatype =
1084
+ Builder.createMetatype (original->getLocation (), pbTupleMetatypeSILType);
1085
+
1072
1086
auto rawBufferValue = builder.createBuiltin (
1073
1087
loc,
1074
- getASTContext ().getIdentifier (
1075
- getBuiltinName ( BuiltinValueKind::AutoDiffAllocateSubcontext )),
1088
+ getASTContext ().getIdentifier (getBuiltinName (
1089
+ BuiltinValueKind::AutoDiffAllocateSubcontextWithType )),
1076
1090
rawPtrType, SubstitutionMap (),
1077
- {borrowedPullbackContextValue, pbTupleSize});
1091
+ {borrowedPullbackContextValue, pbTupleMetatype});
1092
+
1078
1093
auto typedBufferValue =
1079
1094
builder.createPointerToAddress (
1080
1095
loc, rawBufferValue, pbTupleVal->getType ().getAddressType (),
0 commit comments