Skip to content

Commit 786d586

Browse files
committed
[SIL] Add GraphOperationInst. (#17077)
* [SIL] Add `GraphOperationInst`. Add `GraphOperationInst` to represent graph operations. A graph operation may have: - Zero or more SILValue operands - Zero or more attributes (consisting of a name and a `SymbolicValue` constant value) - One or more result types `GraphOperationInst` will replace `BuiltinInst` as the representation for TensorFlow operations in the TFDeabstraction/TFPartition passes. Todos: - @lattner: Edit deabstraction/partitioning to use `GraphOperationInst`. - Low priority: Implement SIL serialization/deserialization for `GraphOperationInst`.
1 parent 71ce7ba commit 786d586

18 files changed

+665
-7
lines changed

docs/SIL.rst

+32-1
Original file line numberDiff line numberDiff line change
@@ -5341,7 +5341,7 @@ gradient
53415341
sil-autodiff-preserving-result ::= '[' 'preserving_result' ']'
53425342

53435343
%original = function_ref @original : $(Float, Float) -> Float
5344-
%original_grad = gradient [wrt 0, 1] [preserving_result]
5344+
%original_grad = gradient [wrt 0, 1] [preserving_result] \
53455345
%original : $(Float, Float) -> Float
53465346

53475347
Computes the gradient function of a value ``%original`` using reverse-mode
@@ -5352,6 +5352,37 @@ automatic differentiation.
53525352
This instruction is only valid in raw SIL and is rewritten by the automatic
53535353
differentiation pass.
53545354

5355+
.. SWIFT_ENABLE_TENSORFLOW
5356+
5357+
Graph Program Extraction
5358+
~~~~~~~~~~~~~~~~~~~~~~~~
5359+
5360+
graph_op
5361+
````````
5362+
::
5363+
5364+
sil-instruction ::= 'graph_op' string-literal
5365+
'(' (sil-operand (',' sil-operand)*)? ')'
5366+
('{' (sil-graph-op-attr (',' sil-graph-op-attr)*)? '}')?
5367+
':' sil-type (',' sil-type)*
5368+
sil-graph-op-attr ::= sil-identifier ':' sil-symbolic-value
5369+
sil-symbolic-value ::= i[0-9]+ int-literal |
5370+
f(32|64) float-literal |
5371+
sil-type |
5372+
'[' (sil-symbolic-value (',' sil-symbolic-value)*)? ']'
5373+
5374+
%add = graph_op "tf.Add"(%x : $Tensor<Float>, %y : $Tensor<Float>) \
5375+
{T: $Float} : $Tensor<Float>
5376+
5377+
Represents a graph program operation.
5378+
5379+
``graph_op`` instructions have a name, zero or more operands, zero or more
5380+
attributes (an identifier and SIL constant value), and one or more result
5381+
types.
5382+
5383+
This instruction is only valid in raw SIL and is rewritten and extracted by the
5384+
graph program extraction passes (debastraction, partitioning, graph lowering).
5385+
53555386
Assertion configuration
53565387
~~~~~~~~~~~~~~~~~~~~~~~
53575388

include/swift/AST/DiagnosticsParse.def

+30
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,36 @@ ERROR(expr_expected_function_to_differentiate,none,
16741674
ERROR(gradient_expr_expected_parameter,none,
16751675
"expected a parameter, which can be the index of a function parameter with a leading dot (e.g. '.0'), or 'self'", ())
16761676

1677+
// SWIFT_ENABLE_TENSORFLOW
1678+
//------------------------------------------------------------------------------
1679+
// Graph operation related parsing diagnostics
1680+
//------------------------------------------------------------------------------
1681+
ERROR(sil_graph_op_expected_attr_name,PointsToFirstBadToken,
1682+
"expected 'graph_op' attribute name", ())
1683+
ERROR(sil_graph_op_expected_attr_value,PointsToFirstBadToken,
1684+
"expected 'graph_op' attribute value", ())
1685+
ERROR(sil_graph_op_unhandled_attr_value,PointsToFirstBadToken,
1686+
"unhandled 'graph_op' attribute value", ())
1687+
ERROR(sil_graph_op_expected_rparen,PointsToFirstBadToken,
1688+
"expected ')' in 'graph_op' argument list", ())
1689+
ERROR(sil_graph_op_expected_rbrace,PointsToFirstBadToken,
1690+
"expected '}' in 'graph_op' attribute list", ())
1691+
ERROR(sil_graph_op_expected_colon_after_attr_name,PointsToFirstBadToken,
1692+
"expected ':' after 'graph_op' attribute name", ())
1693+
ERROR(sil_graph_op_expected_colon_before_result_types,PointsToFirstBadToken,
1694+
"expected ':' before 'graph_op' result types", ())
1695+
1696+
ERROR(sil_const_expected_int_datatype,PointsToFirstBadToken,
1697+
"expected integer datatype ('i[0-9]+', e.g. 'i32')", ())
1698+
ERROR(sil_const_expected_int_value,PointsToFirstBadToken,
1699+
"expected integer value in SIL constant value", ())
1700+
ERROR(sil_const_expected_fp_datatype,PointsToFirstBadToken,
1701+
"expected floating point datatype ('f32' or 'f64')", ())
1702+
ERROR(sil_const_expected_fp_value,PointsToFirstBadToken,
1703+
"expected floating point value in SIL constant value", ())
1704+
ERROR(sil_const_aggregate_expected_rsquare,PointsToFirstBadToken,
1705+
"expected ']' at end of aggregate 'SymbolicValue'", ())
1706+
16771707
#ifndef DIAG_NO_UNDEF
16781708
# if defined(DIAG)
16791709
# undef DIAG

include/swift/SIL/SILBuilder.h

+8
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,14 @@ class SILBuilder {
13001300
SILLocation Loc, SILValue Operand,
13011301
llvm::SmallVectorImpl<SILValue> &Result);
13021302

1303+
GraphOperationInst *createGraphOperation(
1304+
SILLocation loc, Identifier name, ArrayRef<SILValue> operands,
1305+
ArrayRef<GraphOperationAttribute> attrs, ArrayRef<SILType> resultTypes) {
1306+
return insert(GraphOperationInst::create(
1307+
getModule(), getSILDebugLocation(loc), name, operands, attrs,
1308+
resultTypes));
1309+
}
1310+
13031311
ClassMethodInst *createClassMethod(SILLocation Loc, SILValue Operand,
13041312
SILDeclRef Member, SILType MethodTy) {
13051313
return insert(new (getModule()) ClassMethodInst(

include/swift/SIL/SILCloner.h

+15
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,21 @@ void SILCloner<ImplClass>::visitDestructureTupleInst(
16071607
getOpValue(Inst->getOperand())));
16081608
}
16091609

1610+
// SWIFT_ENABLE_TENSORFLOW
1611+
template <typename ImplClass>
1612+
void SILCloner<ImplClass>::visitGraphOperationInst(GraphOperationInst *Inst) {
1613+
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
1614+
auto arguments =
1615+
getOpValueArray<4>(OperandValueArrayRef(Inst->getArguments()));
1616+
SmallVector<SILType, 4> resultTypes;
1617+
for (auto result : Inst->getResults())
1618+
resultTypes.push_back(result->getType());
1619+
doPostProcess(Inst,
1620+
getBuilder().createGraphOperation(getOpLocation(Inst->getLoc()),
1621+
Inst->getName(), arguments,
1622+
Inst->getAttributes(), resultTypes));
1623+
}
1624+
16101625
template <typename ImplClass>
16111626
void SILCloner<ImplClass>::visitClassMethodInst(ClassMethodInst *Inst) {
16121627
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));

include/swift/SIL/SILInstruction.h

+126
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "swift/Basic/Range.h"
3030
#include "swift/SIL/Consumption.h"
3131
#include "swift/SIL/SILAllocated.h"
32+
// SWIFT_ENABLE_TENSORFLOW
33+
#include "swift/SIL/SILConstants.h"
3234
#include "swift/SIL/SILDeclRef.h"
3335
#include "swift/SIL/SILFunctionConventions.h"
3436
#include "swift/SIL/SILLocation.h"
@@ -51,6 +53,8 @@ class MultipleValueInstruction;
5153
class MultipleValueInstructionResult;
5254
class DestructureTupleInst;
5355
class DestructureStructInst;
56+
// SWIFT_ENABLE_TENSORFLOW
57+
class GraphOperationInst;
5458
class NonValueInstruction;
5559
class SILBasicBlock;
5660
class SILBuilder;
@@ -8141,6 +8145,128 @@ inline DestructureTupleInst *DestructureTupleResult::getParent() {
81418145
return cast<DestructureTupleInst>(Parent);
81428146
}
81438147

8148+
/// SWIFT_ENABLE_TENSORFLOW
8149+
/// A result for the graph_op instruction. See documentation for
8150+
/// graph_op for more information.
8151+
class GraphOperationResult final : public MultipleValueInstructionResult {
8152+
public:
8153+
GraphOperationResult(unsigned Index, SILType Type,
8154+
ValueOwnershipKind OwnershipKind)
8155+
: MultipleValueInstructionResult(ValueKind::GraphOperationResult, Index,
8156+
Type, OwnershipKind) {}
8157+
8158+
static bool classof(const SILNode *N) {
8159+
return N->getKind() == SILNodeKind::GraphOperationResult;
8160+
}
8161+
8162+
GraphOperationInst *getParent() {
8163+
auto *Parent = MultipleValueInstructionResult::getParent();
8164+
return cast<GraphOperationInst>(Parent);
8165+
};
8166+
8167+
const GraphOperationInst *getParent() const {
8168+
return const_cast<GraphOperationResult *>(this)->getParent();
8169+
}
8170+
};
8171+
8172+
/// SWIFT_ENABLE_TENSORFLOW
8173+
/// A graph operation attribute. Attributes have a name and a constant value.
8174+
struct GraphOperationAttribute {
8175+
Identifier name;
8176+
SymbolicValue value;
8177+
};
8178+
8179+
/// SWIFT_ENABLE_TENSORFLOW
8180+
/// A graph operation. This instruction will be extracted to a graph program
8181+
/// via graph program extraction passes.
8182+
class GraphOperationInst final
8183+
: public InstructionBase<
8184+
SILInstructionKind::GraphOperationInst,
8185+
MultipleValueInstruction>,
8186+
public MultipleValueInstructionTrailingObjects<
8187+
GraphOperationInst, GraphOperationResult,
8188+
InitialTrailingObjects<>,
8189+
FinalTrailingObjects<Operand, GraphOperationAttribute>> {
8190+
friend TrailingObjects;
8191+
8192+
/// The name of the graph operation.
8193+
Identifier Name;
8194+
/// The number of operands.
8195+
unsigned NumOperands;
8196+
/// The number of attributes.
8197+
unsigned NumAttributes;
8198+
8199+
GraphOperationInst(SILModule &M, SILDebugLocation loc, Identifier name,
8200+
ArrayRef<SILValue> arguments,
8201+
ArrayRef<GraphOperationAttribute> attrs,
8202+
ArrayRef<SILType> resultTypes,
8203+
ArrayRef<ValueOwnershipKind> resultOwnerships) :
8204+
InstructionBase(loc),
8205+
MultipleValueInstructionTrailingObjects(this, resultTypes,
8206+
resultOwnerships),
8207+
Name(name), NumOperands(arguments.size()), NumAttributes(attrs.size()) {
8208+
auto allOperands = getAllOperands();
8209+
for (unsigned i : indices(arguments))
8210+
new (&allOperands[i]) Operand(this, arguments[i]);
8211+
std::uninitialized_copy(attrs.begin(), attrs.end(),
8212+
getAttributes().data());
8213+
}
8214+
8215+
public:
8216+
using MultipleValueInstructionTrailingObjects::numTrailingObjects;
8217+
using MultipleValueInstructionTrailingObjects::totalSizeToAlloc;
8218+
8219+
~GraphOperationInst() {
8220+
for (auto &operand : getAllOperands())
8221+
operand.~Operand();
8222+
}
8223+
8224+
static GraphOperationInst *create(SILModule &M, SILDebugLocation loc,
8225+
Identifier name,
8226+
ArrayRef<SILValue> arguments,
8227+
ArrayRef<GraphOperationAttribute> attrs,
8228+
ArrayRef<SILType> resultTypes);
8229+
8230+
Identifier getName() const { return Name; }
8231+
unsigned getNumOperands() const { return NumOperands; }
8232+
unsigned getNumAttributes() const { return NumAttributes; }
8233+
8234+
unsigned numTrailingObjects(OverloadToken<Operand>) const {
8235+
return NumOperands;
8236+
}
8237+
8238+
ArrayRef<Operand> getAllOperands() const {
8239+
return { getTrailingObjects<Operand>(), NumOperands };
8240+
}
8241+
8242+
MutableArrayRef<Operand> getAllOperands() {
8243+
return { getTrailingObjects<Operand>(), NumOperands };
8244+
}
8245+
8246+
OperandValueArrayRef getArguments() const {
8247+
return OperandValueArrayRef(getAllOperands());
8248+
}
8249+
8250+
ArrayRef<GraphOperationAttribute> getAttributes() const {
8251+
return { getTrailingObjects<GraphOperationAttribute>(), NumAttributes };
8252+
}
8253+
8254+
MutableArrayRef<GraphOperationAttribute> getAttributes() {
8255+
return { getTrailingObjects<GraphOperationAttribute>(), NumAttributes };
8256+
}
8257+
8258+
Optional<GraphOperationAttribute> getAttribute(StringRef name) {
8259+
for (auto attr : getAttributes())
8260+
if (attr.name.is(name))
8261+
return attr;
8262+
return None;
8263+
};
8264+
8265+
static bool classof(const SILNode *N) {
8266+
return N->getKind() == SILNodeKind::GraphOperationInst;
8267+
}
8268+
};
8269+
81448270
inline SILType *AllocRefInstBase::getTypeStorage() {
81458271
// If the size of the subclasses are equal, then all of this compiles away.
81468272
if (auto I = dyn_cast<AllocRefInst>(this))

include/swift/SIL/SILNodes.def

+9-4
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ ABSTRACT_VALUE(MultipleValueInstructionResult, ValueBase)
218218
MULTIPLE_VALUE_INST_RESULT(BeginApplyResult, MultipleValueInstructionResult)
219219
MULTIPLE_VALUE_INST_RESULT(DestructureStructResult, MultipleValueInstructionResult)
220220
MULTIPLE_VALUE_INST_RESULT(DestructureTupleResult, MultipleValueInstructionResult)
221-
VALUE_RANGE(MultipleValueInstructionResult, BeginApplyResult, DestructureTupleResult)
221+
// SWIFT_ENABLE_TENSORFLOW
222+
MULTIPLE_VALUE_INST_RESULT(GraphOperationResult, MultipleValueInstructionResult)
223+
VALUE_RANGE(MultipleValueInstructionResult, BeginApplyResult, GraphOperationResult)
222224

223225
VALUE(SILUndef, ValueBase)
224226

@@ -668,10 +670,13 @@ MULTIPLE_VALUE_INST(DestructureStructInst, destructure_struct,
668670
SILInstruction, None, DoesNotRelease)
669671
MULTIPLE_VALUE_INST(DestructureTupleInst, destructure_tuple,
670672
SILInstruction, None, DoesNotRelease)
671-
INST_RANGE(MultipleValueInstruction, BeginApplyInst, DestructureTupleInst)
673+
// SWIFT_ENABLE_TENSORFLOW
674+
MULTIPLE_VALUE_INST(GraphOperationInst, graph_op,
675+
SILInstruction, None, DoesNotRelease)
676+
INST_RANGE(MultipleValueInstruction, BeginApplyInst, GraphOperationInst)
672677

673-
NODE_RANGE(SILInstruction, AllocStackInst, DestructureTupleInst)
674-
NODE_RANGE(SILNode, SILPHIArgument, DestructureTupleInst)
678+
NODE_RANGE(SILInstruction, AllocStackInst, GraphOperationInst)
679+
NODE_RANGE(SILNode, SILPHIArgument, GraphOperationInst)
675680

676681
#undef SINGLE_VALUE_INST_RANGE
677682
#undef INST_RANGE

include/swift/Serialization/ModuleFormat.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const uint16_t VERSION_MAJOR = 0;
5555
/// describe what change you made. The content of this comment isn't important;
5656
/// it just ensures a conflict if two people change the module format.
5757
/// Don't worry about adhering to the 80-column limit for this line.
58-
const uint16_t VERSION_MINOR = 415; // SWIFT_ENABLE_TENSORFLOW: gradient.
58+
const uint16_t VERSION_MINOR = 416; // SWIFT_ENABLE_TENSORFLOW: graph_op.
5959

6060
using DeclIDField = BCFixed<31>;
6161

lib/IRGen/IRGenSIL.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,11 @@ class IRGenSILFunction :
907907
llvm_unreachable("gradient is not valid in canonical SIL");
908908
}
909909

910+
// SWIFT_ENABLE_TENSORFLOW
911+
void visitGraphOperationInst(GraphOperationInst *i) {
912+
llvm_unreachable("graph_op is not valid in canonical SIL");
913+
}
914+
910915
void visitFunctionRefInst(FunctionRefInst *i);
911916
void visitAllocGlobalInst(AllocGlobalInst *i);
912917
void visitGlobalAddrInst(GlobalAddrInst *i);

0 commit comments

Comments
 (0)