29
29
#include "swift/Basic/Range.h"
30
30
#include "swift/SIL/Consumption.h"
31
31
#include "swift/SIL/SILAllocated.h"
32
+ // SWIFT_ENABLE_TENSORFLOW
33
+ #include "swift/SIL/SILConstants.h"
32
34
#include "swift/SIL/SILDeclRef.h"
33
35
#include "swift/SIL/SILFunctionConventions.h"
34
36
#include "swift/SIL/SILLocation.h"
@@ -51,6 +53,8 @@ class MultipleValueInstruction;
51
53
class MultipleValueInstructionResult;
52
54
class DestructureTupleInst;
53
55
class DestructureStructInst;
56
+ // SWIFT_ENABLE_TENSORFLOW
57
+ class GraphOperationInst;
54
58
class NonValueInstruction;
55
59
class SILBasicBlock;
56
60
class SILBuilder;
@@ -8141,6 +8145,128 @@ inline DestructureTupleInst *DestructureTupleResult::getParent() {
8141
8145
return cast<DestructureTupleInst>(Parent);
8142
8146
}
8143
8147
8148
+ /// SWIFT_ENABLE_TENSORFLOW
8149
+ /// A result for the graph_op instruction. See documentation for
8150
+ /// graph_op for more information.
8151
+ class GraphOperationResult final : public MultipleValueInstructionResult {
8152
+ public:
8153
+ GraphOperationResult(unsigned Index, SILType Type,
8154
+ ValueOwnershipKind OwnershipKind)
8155
+ : MultipleValueInstructionResult(ValueKind::GraphOperationResult, Index,
8156
+ Type, OwnershipKind) {}
8157
+
8158
+ static bool classof(const SILNode *N) {
8159
+ return N->getKind() == SILNodeKind::GraphOperationResult;
8160
+ }
8161
+
8162
+ GraphOperationInst *getParent() {
8163
+ auto *Parent = MultipleValueInstructionResult::getParent();
8164
+ return cast<GraphOperationInst>(Parent);
8165
+ };
8166
+
8167
+ const GraphOperationInst *getParent() const {
8168
+ return const_cast<GraphOperationResult *>(this)->getParent();
8169
+ }
8170
+ };
8171
+
8172
+ /// SWIFT_ENABLE_TENSORFLOW
8173
+ /// A graph operation attribute. Attributes have a name and a constant value.
8174
+ struct GraphOperationAttribute {
8175
+ Identifier name;
8176
+ SymbolicValue value;
8177
+ };
8178
+
8179
+ /// SWIFT_ENABLE_TENSORFLOW
8180
+ /// A graph operation. This instruction will be extracted to a graph program
8181
+ /// via graph program extraction passes.
8182
+ class GraphOperationInst final
8183
+ : public InstructionBase<
8184
+ SILInstructionKind::GraphOperationInst,
8185
+ MultipleValueInstruction>,
8186
+ public MultipleValueInstructionTrailingObjects<
8187
+ GraphOperationInst, GraphOperationResult,
8188
+ InitialTrailingObjects<>,
8189
+ FinalTrailingObjects<Operand, GraphOperationAttribute>> {
8190
+ friend TrailingObjects;
8191
+
8192
+ /// The name of the graph operation.
8193
+ Identifier Name;
8194
+ /// The number of operands.
8195
+ unsigned NumOperands;
8196
+ /// The number of attributes.
8197
+ unsigned NumAttributes;
8198
+
8199
+ GraphOperationInst(SILModule &M, SILDebugLocation loc, Identifier name,
8200
+ ArrayRef<SILValue> arguments,
8201
+ ArrayRef<GraphOperationAttribute> attrs,
8202
+ ArrayRef<SILType> resultTypes,
8203
+ ArrayRef<ValueOwnershipKind> resultOwnerships) :
8204
+ InstructionBase(loc),
8205
+ MultipleValueInstructionTrailingObjects(this, resultTypes,
8206
+ resultOwnerships),
8207
+ Name(name), NumOperands(arguments.size()), NumAttributes(attrs.size()) {
8208
+ auto allOperands = getAllOperands();
8209
+ for (unsigned i : indices(arguments))
8210
+ new (&allOperands[i]) Operand(this, arguments[i]);
8211
+ std::uninitialized_copy(attrs.begin(), attrs.end(),
8212
+ getAttributes().data());
8213
+ }
8214
+
8215
+ public:
8216
+ using MultipleValueInstructionTrailingObjects::numTrailingObjects;
8217
+ using MultipleValueInstructionTrailingObjects::totalSizeToAlloc;
8218
+
8219
+ ~GraphOperationInst() {
8220
+ for (auto &operand : getAllOperands())
8221
+ operand.~Operand();
8222
+ }
8223
+
8224
+ static GraphOperationInst *create(SILModule &M, SILDebugLocation loc,
8225
+ Identifier name,
8226
+ ArrayRef<SILValue> arguments,
8227
+ ArrayRef<GraphOperationAttribute> attrs,
8228
+ ArrayRef<SILType> resultTypes);
8229
+
8230
+ Identifier getName() const { return Name; }
8231
+ unsigned getNumOperands() const { return NumOperands; }
8232
+ unsigned getNumAttributes() const { return NumAttributes; }
8233
+
8234
+ unsigned numTrailingObjects(OverloadToken<Operand>) const {
8235
+ return NumOperands;
8236
+ }
8237
+
8238
+ ArrayRef<Operand> getAllOperands() const {
8239
+ return { getTrailingObjects<Operand>(), NumOperands };
8240
+ }
8241
+
8242
+ MutableArrayRef<Operand> getAllOperands() {
8243
+ return { getTrailingObjects<Operand>(), NumOperands };
8244
+ }
8245
+
8246
+ OperandValueArrayRef getArguments() const {
8247
+ return OperandValueArrayRef(getAllOperands());
8248
+ }
8249
+
8250
+ ArrayRef<GraphOperationAttribute> getAttributes() const {
8251
+ return { getTrailingObjects<GraphOperationAttribute>(), NumAttributes };
8252
+ }
8253
+
8254
+ MutableArrayRef<GraphOperationAttribute> getAttributes() {
8255
+ return { getTrailingObjects<GraphOperationAttribute>(), NumAttributes };
8256
+ }
8257
+
8258
+ Optional<GraphOperationAttribute> getAttribute(StringRef name) {
8259
+ for (auto attr : getAttributes())
8260
+ if (attr.name.is(name))
8261
+ return attr;
8262
+ return None;
8263
+ };
8264
+
8265
+ static bool classof(const SILNode *N) {
8266
+ return N->getKind() == SILNodeKind::GraphOperationInst;
8267
+ }
8268
+ };
8269
+
8144
8270
inline SILType *AllocRefInstBase::getTypeStorage() {
8145
8271
// If the size of the subclasses are equal, then all of this compiles away.
8146
8272
if (auto I = dyn_cast<AllocRefInst>(this))
0 commit comments