@@ -1585,6 +1585,141 @@ class OriginallyDefinedInAttr: public DeclAttribute {
1585
1585
}
1586
1586
};
1587
1587
1588
+ // / A declaration name with location.
1589
+ struct DeclNameWithLoc {
1590
+ DeclName Name;
1591
+ DeclNameLoc Loc;
1592
+ };
1593
+
1594
+ // / Attribute that marks a function as differentiable and optionally specifies
1595
+ // / custom associated derivative functions: 'jvp' and 'vjp'.
1596
+ // /
1597
+ // / Examples:
1598
+ // / @differentiable(jvp: jvpFoo where T : FloatingPoint)
1599
+ // / @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1600
+ class DifferentiableAttr final
1601
+ : public DeclAttribute,
1602
+ private llvm::TrailingObjects<DifferentiableAttr,
1603
+ ParsedAutoDiffParameter> {
1604
+ friend TrailingObjects;
1605
+
1606
+ // / Whether this function is linear (optional).
1607
+ bool Linear;
1608
+ // / The number of parsed parameters specified in 'wrt:'.
1609
+ unsigned NumParsedParameters = 0 ;
1610
+ // / The JVP function.
1611
+ Optional<DeclNameWithLoc> JVP;
1612
+ // / The VJP function.
1613
+ Optional<DeclNameWithLoc> VJP;
1614
+ // / The JVP function (optional), resolved by the type checker if JVP name is
1615
+ // / specified.
1616
+ FuncDecl *JVPFunction = nullptr ;
1617
+ // / The VJP function (optional), resolved by the type checker if VJP name is
1618
+ // / specified.
1619
+ FuncDecl *VJPFunction = nullptr ;
1620
+ // / The differentiation parameters' indices, resolved by the type checker.
1621
+ IndexSubset *ParameterIndices = nullptr ;
1622
+ // / The trailing where clause (optional).
1623
+ TrailingWhereClause *WhereClause = nullptr ;
1624
+ // / The generic signature for autodiff associated functions. Resolved by the
1625
+ // / type checker based on the original function's generic signature and the
1626
+ // / attribute's where clause requirements. This is set only if the attribute
1627
+ // / has a where clause.
1628
+ GenericSignature DerivativeGenericSignature;
1629
+
1630
+ explicit DifferentiableAttr (bool implicit, SourceLoc atLoc,
1631
+ SourceRange baseRange, bool linear,
1632
+ ArrayRef<ParsedAutoDiffParameter> parameters,
1633
+ Optional<DeclNameWithLoc> jvp,
1634
+ Optional<DeclNameWithLoc> vjp,
1635
+ TrailingWhereClause *clause);
1636
+
1637
+ explicit DifferentiableAttr (Decl *original, bool implicit, SourceLoc atLoc,
1638
+ SourceRange baseRange, bool linear,
1639
+ IndexSubset *parameterIndices,
1640
+ Optional<DeclNameWithLoc> jvp,
1641
+ Optional<DeclNameWithLoc> vjp,
1642
+ GenericSignature derivativeGenericSignature);
1643
+
1644
+ public:
1645
+ static DifferentiableAttr *create (ASTContext &context, bool implicit,
1646
+ SourceLoc atLoc, SourceRange baseRange,
1647
+ bool linear,
1648
+ ArrayRef<ParsedAutoDiffParameter> params,
1649
+ Optional<DeclNameWithLoc> jvp,
1650
+ Optional<DeclNameWithLoc> vjp,
1651
+ TrailingWhereClause *clause);
1652
+
1653
+ static DifferentiableAttr *create (AbstractFunctionDecl *original,
1654
+ bool implicit, SourceLoc atLoc,
1655
+ SourceRange baseRange, bool linear,
1656
+ IndexSubset *parameterIndices,
1657
+ Optional<DeclNameWithLoc> jvp,
1658
+ Optional<DeclNameWithLoc> vjp,
1659
+ GenericSignature derivativeGenSig);
1660
+
1661
+ // / Get the optional 'jvp:' function name and location.
1662
+ // / Use this instead of `getJVPFunction` to check whether the attribute has a
1663
+ // / registered JVP.
1664
+ Optional<DeclNameWithLoc> getJVP () const { return JVP; }
1665
+
1666
+ // / Get the optional 'vjp:' function name and location.
1667
+ // / Use this instead of `getVJPFunction` to check whether the attribute has a
1668
+ // / registered VJP.
1669
+ Optional<DeclNameWithLoc> getVJP () const { return VJP; }
1670
+
1671
+ IndexSubset *getParameterIndices () const {
1672
+ return ParameterIndices;
1673
+ }
1674
+ void setParameterIndices (IndexSubset *parameterIndices) {
1675
+ ParameterIndices = parameterIndices;
1676
+ }
1677
+
1678
+ // / The parsed differentiation parameters, i.e. the list of parameters
1679
+ // / specified in 'wrt:'.
1680
+ ArrayRef<ParsedAutoDiffParameter> getParsedParameters () const {
1681
+ return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1682
+ }
1683
+ MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters () {
1684
+ return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1685
+ }
1686
+ size_t numTrailingObjects (OverloadToken<ParsedAutoDiffParameter>) const {
1687
+ return NumParsedParameters;
1688
+ }
1689
+
1690
+ bool isLinear () const { return Linear; }
1691
+
1692
+ TrailingWhereClause *getWhereClause () const { return WhereClause; }
1693
+
1694
+ GenericSignature getDerivativeGenericSignature () const {
1695
+ return DerivativeGenericSignature;
1696
+ }
1697
+ void setDerivativeGenericSignature (GenericSignature derivativeGenSig) {
1698
+ DerivativeGenericSignature = derivativeGenSig;
1699
+ }
1700
+
1701
+ FuncDecl *getJVPFunction () const { return JVPFunction; }
1702
+ void setJVPFunction (FuncDecl *decl);
1703
+ FuncDecl *getVJPFunction () const { return VJPFunction; }
1704
+ void setVJPFunction (FuncDecl *decl);
1705
+
1706
+ // / Get the derivative generic environment for the given `@differentiable`
1707
+ // / attribute and original function.
1708
+ GenericEnvironment *
1709
+ getDerivativeGenericEnvironment (AbstractFunctionDecl *original) const ;
1710
+
1711
+ // Print the attribute to the given stream.
1712
+ // If `omitWrtClause` is true, omit printing the `wrt:` clause.
1713
+ // If `omitAssociatedFunctions` is true, omit printing associated functions.
1714
+ void print (llvm::raw_ostream &OS, const Decl *D,
1715
+ bool omitWrtClause = false ,
1716
+ bool omitAssociatedFunctions = false ) const ;
1717
+
1718
+ static bool classof (const DeclAttribute *DA) {
1719
+ return DA->getKind () == DAK_Differentiable;
1720
+ }
1721
+ };
1722
+
1588
1723
// / Attributes that may be applied to declarations.
1589
1724
class DeclAttributes {
1590
1725
// / Linked list of declaration attributes.
@@ -1764,148 +1899,6 @@ class DeclAttributes {
1764
1899
SourceLoc getStartLoc (bool forModifiers = false ) const ;
1765
1900
};
1766
1901
1767
- // / A declaration name with location.
1768
- struct DeclNameWithLoc {
1769
- DeclName Name;
1770
- DeclNameLoc Loc;
1771
- };
1772
-
1773
- // / Attribute that marks a function as differentiable and optionally specifies
1774
- // / custom associated derivative functions: 'jvp' and 'vjp'.
1775
- // /
1776
- // / Examples:
1777
- // / @differentiable(jvp: jvpFoo where T : FloatingPoint)
1778
- // / @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1779
- class DifferentiableAttr final
1780
- : public DeclAttribute,
1781
- private llvm::TrailingObjects<DifferentiableAttr,
1782
- ParsedAutoDiffParameter> {
1783
- friend TrailingObjects;
1784
-
1785
- // / Whether this function is linear (optional).
1786
- bool linear;
1787
- // / The number of parsed parameters specified in 'wrt:'.
1788
- unsigned NumParsedParameters = 0 ;
1789
- // / The JVP function.
1790
- Optional<DeclNameWithLoc> JVP;
1791
- // / The VJP function.
1792
- Optional<DeclNameWithLoc> VJP;
1793
- // / The JVP function (optional), resolved by the type checker if JVP name is
1794
- // / specified.
1795
- FuncDecl *JVPFunction = nullptr ;
1796
- // / The VJP function (optional), resolved by the type checker if VJP name is
1797
- // / specified.
1798
- FuncDecl *VJPFunction = nullptr ;
1799
- // / The differentiation parameters' indices, resolved by the type checker.
1800
- IndexSubset *ParameterIndices = nullptr ;
1801
- // / The trailing where clause (optional).
1802
- TrailingWhereClause *WhereClause = nullptr ;
1803
- // / The generic signature for autodiff associated functions. Resolved by the
1804
- // / type checker based on the original function's generic signature and the
1805
- // / attribute's where clause requirements. This is set only if the attribute
1806
- // / has a where clause.
1807
- GenericSignature DerivativeGenericSignature;
1808
-
1809
- explicit DifferentiableAttr (ASTContext &context, bool implicit,
1810
- SourceLoc atLoc, SourceRange baseRange,
1811
- bool linear,
1812
- ArrayRef<ParsedAutoDiffParameter> parameters,
1813
- Optional<DeclNameWithLoc> jvp,
1814
- Optional<DeclNameWithLoc> vjp,
1815
- TrailingWhereClause *clause);
1816
-
1817
- explicit DifferentiableAttr (ASTContext &context, bool implicit,
1818
- SourceLoc atLoc, SourceRange baseRange,
1819
- bool linear, IndexSubset *indices,
1820
- Optional<DeclNameWithLoc> jvp,
1821
- Optional<DeclNameWithLoc> vjp,
1822
- GenericSignature derivativeGenericSignature);
1823
-
1824
- public:
1825
- static DifferentiableAttr *create (ASTContext &context, bool implicit,
1826
- SourceLoc atLoc, SourceRange baseRange,
1827
- bool linear,
1828
- ArrayRef<ParsedAutoDiffParameter> params,
1829
- Optional<DeclNameWithLoc> jvp,
1830
- Optional<DeclNameWithLoc> vjp,
1831
- TrailingWhereClause *clause);
1832
-
1833
- static DifferentiableAttr *create (ASTContext &context, bool implicit,
1834
- SourceLoc atLoc, SourceRange baseRange,
1835
- bool linear, IndexSubset *indices,
1836
- Optional<DeclNameWithLoc> jvp,
1837
- Optional<DeclNameWithLoc> vjp,
1838
- GenericSignature derivativeGenSig);
1839
-
1840
- // / Get the optional 'jvp:' function name and location.
1841
- // / Use this instead of `getJVPFunction` to check whether the attribute has a
1842
- // / registered JVP.
1843
- Optional<DeclNameWithLoc> getJVP () const { return JVP; }
1844
-
1845
- // / Get the optional 'vjp:' function name and location.
1846
- // / Use this instead of `getVJPFunction` to check whether the attribute has a
1847
- // / registered VJP.
1848
- Optional<DeclNameWithLoc> getVJP () const { return VJP; }
1849
-
1850
- IndexSubset *getParameterIndices () const {
1851
- return ParameterIndices;
1852
- }
1853
- void setParameterIndices (IndexSubset *pi) {
1854
- ParameterIndices = pi;
1855
- }
1856
-
1857
- // / The parsed differentiation parameters, i.e. the list of parameters
1858
- // / specified in 'wrt:'.
1859
- ArrayRef<ParsedAutoDiffParameter> getParsedParameters () const {
1860
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1861
- }
1862
- MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters () {
1863
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1864
- }
1865
- size_t numTrailingObjects (OverloadToken<ParsedAutoDiffParameter>) const {
1866
- return NumParsedParameters;
1867
- }
1868
-
1869
- bool isLinear () const { return linear; }
1870
-
1871
- TrailingWhereClause *getWhereClause () const { return WhereClause; }
1872
-
1873
- GenericSignature getDerivativeGenericSignature () const {
1874
- return DerivativeGenericSignature;
1875
- }
1876
- void setDerivativeGenericSignature (ASTContext &context,
1877
- GenericSignature derivativeGenSig) {
1878
- DerivativeGenericSignature = derivativeGenSig;
1879
- }
1880
-
1881
- FuncDecl *getJVPFunction () const { return JVPFunction; }
1882
- void setJVPFunction (FuncDecl *decl);
1883
- FuncDecl *getVJPFunction () const { return VJPFunction; }
1884
- void setVJPFunction (FuncDecl *decl);
1885
-
1886
- bool parametersMatch (const DifferentiableAttr &other) const {
1887
- assert (ParameterIndices && other.ParameterIndices );
1888
- return ParameterIndices == other.ParameterIndices ;
1889
- }
1890
-
1891
- // / Get the derivative generic environment for the given `@differentiable`
1892
- // / attribute and original function.
1893
- GenericEnvironment *
1894
- getDerivativeGenericEnvironment (AbstractFunctionDecl *original) const ;
1895
-
1896
- // Print the attribute to the given stream.
1897
- // If `omitWrtClause` is true, omit printing the `wrt:` clause.
1898
- // If `omitAssociatedFunctions` is true, omit printing associated functions.
1899
- void print (llvm::raw_ostream &OS, const Decl *D,
1900
- bool omitWrtClause = false ,
1901
- bool omitAssociatedFunctions = false ) const ;
1902
-
1903
- static bool classof (const DeclAttribute *DA) {
1904
- return DA->getKind () == DAK_Differentiable;
1905
- }
1906
- };
1907
-
1908
-
1909
1902
void simple_display (llvm::raw_ostream &out, const DeclAttribute *attr);
1910
1903
1911
1904
inline SourceLoc extractNearestSourceLoc (const DeclAttribute *attr) {
0 commit comments