@@ -1747,6 +1747,84 @@ class DifferentiableAttr final
1747
1747
}
1748
1748
};
1749
1749
1750
+ // / Attribute that registers a function as a derivative of another function.
1751
+ // /
1752
+ // / Examples:
1753
+ // / @derivative(of: sin(_:))
1754
+ // / @derivative(of: +, wrt: (lhs, rhs))
1755
+ class DerivativeAttr final
1756
+ : public DeclAttribute,
1757
+ private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
1758
+ friend TrailingObjects;
1759
+
1760
+ // / The original function name.
1761
+ DeclNameWithLoc OriginalFunctionName;
1762
+ // / The original function declaration, resolved by the type checker.
1763
+ AbstractFunctionDecl *OriginalFunction = nullptr ;
1764
+ // / The number of parsed parameters specified in 'wrt:'.
1765
+ unsigned NumParsedParameters = 0 ;
1766
+ // / The differentiation parameters' indices, resolved by the type checker.
1767
+ IndexSubset *ParameterIndices = nullptr ;
1768
+ // / The derivative function kind (JVP or VJP), resolved by the type checker.
1769
+ Optional<AutoDiffDerivativeFunctionKind> Kind = None;
1770
+
1771
+ explicit DerivativeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
1772
+ DeclNameWithLoc original,
1773
+ ArrayRef<ParsedAutoDiffParameter> params);
1774
+
1775
+ explicit DerivativeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
1776
+ DeclNameWithLoc original, IndexSubset *indices);
1777
+
1778
+ public:
1779
+ static DerivativeAttr *create (ASTContext &context, bool implicit,
1780
+ SourceLoc atLoc, SourceRange baseRange,
1781
+ DeclNameWithLoc original,
1782
+ ArrayRef<ParsedAutoDiffParameter> params);
1783
+
1784
+ static DerivativeAttr *create (ASTContext &context, bool implicit,
1785
+ SourceLoc atLoc, SourceRange baseRange,
1786
+ DeclNameWithLoc original, IndexSubset *indices);
1787
+
1788
+ DeclNameWithLoc getOriginalFunctionName () const {
1789
+ return OriginalFunctionName;
1790
+ }
1791
+ AbstractFunctionDecl *getOriginalFunction () const {
1792
+ return OriginalFunction;
1793
+ }
1794
+ void setOriginalFunction (AbstractFunctionDecl *decl) {
1795
+ OriginalFunction = decl;
1796
+ }
1797
+
1798
+ AutoDiffDerivativeFunctionKind getDerivativeKind () const {
1799
+ assert (Kind && " Derivative function kind has not yet been resolved" );
1800
+ return *Kind;
1801
+ }
1802
+ void setDerivativeKind (AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
1803
+
1804
+ // / The parsed differentiation parameters, i.e. the list of parameters
1805
+ // / specified in 'wrt:'.
1806
+ ArrayRef<ParsedAutoDiffParameter> getParsedParameters () const {
1807
+ return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1808
+ }
1809
+ MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters () {
1810
+ return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1811
+ }
1812
+ size_t numTrailingObjects (OverloadToken<ParsedAutoDiffParameter>) const {
1813
+ return NumParsedParameters;
1814
+ }
1815
+
1816
+ IndexSubset *getParameterIndices () const {
1817
+ return ParameterIndices;
1818
+ }
1819
+ void setParameterIndices (IndexSubset *parameterIndices) {
1820
+ ParameterIndices = parameterIndices;
1821
+ }
1822
+
1823
+ static bool classof (const DeclAttribute *DA) {
1824
+ return DA->getKind () == DAK_Derivative;
1825
+ }
1826
+ };
1827
+
1750
1828
// / Attributes that may be applied to declarations.
1751
1829
class DeclAttributes {
1752
1830
// / Linked list of declaration attributes.
0 commit comments