Skip to content

Commit a5a9c5b

Browse files
committed
[AutoDiff] NFC: @differentiable attribute gardening.
- Move `DifferentiableAttr` definition above `DeclAttributes` in include/swift/AST/Attr.h, like other attributes. - Remove unnecessary arguments from `DifferentiableAttr::DifferentiableAttr` and `DifferentiableAttr::setDerivativeGenericSignature`. - Add libSyntax test for `@differentiable` attributes.
1 parent 9cbc761 commit a5a9c5b

File tree

4 files changed

+207
-160
lines changed

4 files changed

+207
-160
lines changed

include/swift/AST/Attr.h

Lines changed: 135 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,6 +1585,141 @@ class OriginallyDefinedInAttr: public DeclAttribute {
15851585
}
15861586
};
15871587

1588+
/// A declaration name with location.
1589+
struct DeclNameWithLoc {
1590+
DeclName Name;
1591+
DeclNameLoc Loc;
1592+
};
1593+
1594+
/// Attribute that marks a function as differentiable and optionally specifies
1595+
/// custom associated derivative functions: 'jvp' and 'vjp'.
1596+
///
1597+
/// Examples:
1598+
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
1599+
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1600+
class DifferentiableAttr final
1601+
: public DeclAttribute,
1602+
private llvm::TrailingObjects<DifferentiableAttr,
1603+
ParsedAutoDiffParameter> {
1604+
friend TrailingObjects;
1605+
1606+
/// Whether this function is linear (optional).
1607+
bool Linear;
1608+
/// The number of parsed parameters specified in 'wrt:'.
1609+
unsigned NumParsedParameters = 0;
1610+
/// The JVP function.
1611+
Optional<DeclNameWithLoc> JVP;
1612+
/// The VJP function.
1613+
Optional<DeclNameWithLoc> VJP;
1614+
/// The JVP function (optional), resolved by the type checker if JVP name is
1615+
/// specified.
1616+
FuncDecl *JVPFunction = nullptr;
1617+
/// The VJP function (optional), resolved by the type checker if VJP name is
1618+
/// specified.
1619+
FuncDecl *VJPFunction = nullptr;
1620+
/// The differentiation parameters' indices, resolved by the type checker.
1621+
IndexSubset *ParameterIndices = nullptr;
1622+
/// The trailing where clause (optional).
1623+
TrailingWhereClause *WhereClause = nullptr;
1624+
/// The generic signature for autodiff associated functions. Resolved by the
1625+
/// type checker based on the original function's generic signature and the
1626+
/// attribute's where clause requirements. This is set only if the attribute
1627+
/// has a where clause.
1628+
GenericSignature DerivativeGenericSignature;
1629+
1630+
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
1631+
SourceRange baseRange, bool linear,
1632+
ArrayRef<ParsedAutoDiffParameter> parameters,
1633+
Optional<DeclNameWithLoc> jvp,
1634+
Optional<DeclNameWithLoc> vjp,
1635+
TrailingWhereClause *clause);
1636+
1637+
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
1638+
SourceRange baseRange, bool linear,
1639+
IndexSubset *parameterIndices,
1640+
Optional<DeclNameWithLoc> jvp,
1641+
Optional<DeclNameWithLoc> vjp,
1642+
GenericSignature derivativeGenericSignature);
1643+
1644+
public:
1645+
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1646+
SourceLoc atLoc, SourceRange baseRange,
1647+
bool linear,
1648+
ArrayRef<ParsedAutoDiffParameter> params,
1649+
Optional<DeclNameWithLoc> jvp,
1650+
Optional<DeclNameWithLoc> vjp,
1651+
TrailingWhereClause *clause);
1652+
1653+
static DifferentiableAttr *create(AbstractFunctionDecl *original,
1654+
bool implicit, SourceLoc atLoc,
1655+
SourceRange baseRange, bool linear,
1656+
IndexSubset *parameterIndices,
1657+
Optional<DeclNameWithLoc> jvp,
1658+
Optional<DeclNameWithLoc> vjp,
1659+
GenericSignature derivativeGenSig);
1660+
1661+
/// Get the optional 'jvp:' function name and location.
1662+
/// Use this instead of `getJVPFunction` to check whether the attribute has a
1663+
/// registered JVP.
1664+
Optional<DeclNameWithLoc> getJVP() const { return JVP; }
1665+
1666+
/// Get the optional 'vjp:' function name and location.
1667+
/// Use this instead of `getVJPFunction` to check whether the attribute has a
1668+
/// registered VJP.
1669+
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
1670+
1671+
IndexSubset *getParameterIndices() const {
1672+
return ParameterIndices;
1673+
}
1674+
void setParameterIndices(IndexSubset *parameterIndices) {
1675+
ParameterIndices = parameterIndices;
1676+
}
1677+
1678+
/// The parsed differentiation parameters, i.e. the list of parameters
1679+
/// specified in 'wrt:'.
1680+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1681+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1682+
}
1683+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1684+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1685+
}
1686+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1687+
return NumParsedParameters;
1688+
}
1689+
1690+
bool isLinear() const { return Linear; }
1691+
1692+
TrailingWhereClause *getWhereClause() const { return WhereClause; }
1693+
1694+
GenericSignature getDerivativeGenericSignature() const {
1695+
return DerivativeGenericSignature;
1696+
}
1697+
void setDerivativeGenericSignature(GenericSignature derivativeGenSig) {
1698+
DerivativeGenericSignature = derivativeGenSig;
1699+
}
1700+
1701+
FuncDecl *getJVPFunction() const { return JVPFunction; }
1702+
void setJVPFunction(FuncDecl *decl);
1703+
FuncDecl *getVJPFunction() const { return VJPFunction; }
1704+
void setVJPFunction(FuncDecl *decl);
1705+
1706+
/// Get the derivative generic environment for the given `@differentiable`
1707+
/// attribute and original function.
1708+
GenericEnvironment *
1709+
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;
1710+
1711+
// Print the attribute to the given stream.
1712+
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1713+
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1714+
void print(llvm::raw_ostream &OS, const Decl *D,
1715+
bool omitWrtClause = false,
1716+
bool omitAssociatedFunctions = false) const;
1717+
1718+
static bool classof(const DeclAttribute *DA) {
1719+
return DA->getKind() == DAK_Differentiable;
1720+
}
1721+
};
1722+
15881723
/// Attributes that may be applied to declarations.
15891724
class DeclAttributes {
15901725
/// Linked list of declaration attributes.
@@ -1764,148 +1899,6 @@ class DeclAttributes {
17641899
SourceLoc getStartLoc(bool forModifiers = false) const;
17651900
};
17661901

1767-
/// A declaration name with location.
1768-
struct DeclNameWithLoc {
1769-
DeclName Name;
1770-
DeclNameLoc Loc;
1771-
};
1772-
1773-
/// Attribute that marks a function as differentiable and optionally specifies
1774-
/// custom associated derivative functions: 'jvp' and 'vjp'.
1775-
///
1776-
/// Examples:
1777-
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
1778-
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1779-
class DifferentiableAttr final
1780-
: public DeclAttribute,
1781-
private llvm::TrailingObjects<DifferentiableAttr,
1782-
ParsedAutoDiffParameter> {
1783-
friend TrailingObjects;
1784-
1785-
/// Whether this function is linear (optional).
1786-
bool linear;
1787-
/// The number of parsed parameters specified in 'wrt:'.
1788-
unsigned NumParsedParameters = 0;
1789-
/// The JVP function.
1790-
Optional<DeclNameWithLoc> JVP;
1791-
/// The VJP function.
1792-
Optional<DeclNameWithLoc> VJP;
1793-
/// The JVP function (optional), resolved by the type checker if JVP name is
1794-
/// specified.
1795-
FuncDecl *JVPFunction = nullptr;
1796-
/// The VJP function (optional), resolved by the type checker if VJP name is
1797-
/// specified.
1798-
FuncDecl *VJPFunction = nullptr;
1799-
/// The differentiation parameters' indices, resolved by the type checker.
1800-
IndexSubset *ParameterIndices = nullptr;
1801-
/// The trailing where clause (optional).
1802-
TrailingWhereClause *WhereClause = nullptr;
1803-
/// The generic signature for autodiff associated functions. Resolved by the
1804-
/// type checker based on the original function's generic signature and the
1805-
/// attribute's where clause requirements. This is set only if the attribute
1806-
/// has a where clause.
1807-
GenericSignature DerivativeGenericSignature;
1808-
1809-
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1810-
SourceLoc atLoc, SourceRange baseRange,
1811-
bool linear,
1812-
ArrayRef<ParsedAutoDiffParameter> parameters,
1813-
Optional<DeclNameWithLoc> jvp,
1814-
Optional<DeclNameWithLoc> vjp,
1815-
TrailingWhereClause *clause);
1816-
1817-
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1818-
SourceLoc atLoc, SourceRange baseRange,
1819-
bool linear, IndexSubset *indices,
1820-
Optional<DeclNameWithLoc> jvp,
1821-
Optional<DeclNameWithLoc> vjp,
1822-
GenericSignature derivativeGenericSignature);
1823-
1824-
public:
1825-
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1826-
SourceLoc atLoc, SourceRange baseRange,
1827-
bool linear,
1828-
ArrayRef<ParsedAutoDiffParameter> params,
1829-
Optional<DeclNameWithLoc> jvp,
1830-
Optional<DeclNameWithLoc> vjp,
1831-
TrailingWhereClause *clause);
1832-
1833-
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1834-
SourceLoc atLoc, SourceRange baseRange,
1835-
bool linear, IndexSubset *indices,
1836-
Optional<DeclNameWithLoc> jvp,
1837-
Optional<DeclNameWithLoc> vjp,
1838-
GenericSignature derivativeGenSig);
1839-
1840-
/// Get the optional 'jvp:' function name and location.
1841-
/// Use this instead of `getJVPFunction` to check whether the attribute has a
1842-
/// registered JVP.
1843-
Optional<DeclNameWithLoc> getJVP() const { return JVP; }
1844-
1845-
/// Get the optional 'vjp:' function name and location.
1846-
/// Use this instead of `getVJPFunction` to check whether the attribute has a
1847-
/// registered VJP.
1848-
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
1849-
1850-
IndexSubset *getParameterIndices() const {
1851-
return ParameterIndices;
1852-
}
1853-
void setParameterIndices(IndexSubset *pi) {
1854-
ParameterIndices = pi;
1855-
}
1856-
1857-
/// The parsed differentiation parameters, i.e. the list of parameters
1858-
/// specified in 'wrt:'.
1859-
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1860-
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1861-
}
1862-
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1863-
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1864-
}
1865-
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1866-
return NumParsedParameters;
1867-
}
1868-
1869-
bool isLinear() const { return linear; }
1870-
1871-
TrailingWhereClause *getWhereClause() const { return WhereClause; }
1872-
1873-
GenericSignature getDerivativeGenericSignature() const {
1874-
return DerivativeGenericSignature;
1875-
}
1876-
void setDerivativeGenericSignature(ASTContext &context,
1877-
GenericSignature derivativeGenSig) {
1878-
DerivativeGenericSignature = derivativeGenSig;
1879-
}
1880-
1881-
FuncDecl *getJVPFunction() const { return JVPFunction; }
1882-
void setJVPFunction(FuncDecl *decl);
1883-
FuncDecl *getVJPFunction() const { return VJPFunction; }
1884-
void setVJPFunction(FuncDecl *decl);
1885-
1886-
bool parametersMatch(const DifferentiableAttr &other) const {
1887-
assert(ParameterIndices && other.ParameterIndices);
1888-
return ParameterIndices == other.ParameterIndices;
1889-
}
1890-
1891-
/// Get the derivative generic environment for the given `@differentiable`
1892-
/// attribute and original function.
1893-
GenericEnvironment *
1894-
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;
1895-
1896-
// Print the attribute to the given stream.
1897-
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1898-
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1899-
void print(llvm::raw_ostream &OS, const Decl *D,
1900-
bool omitWrtClause = false,
1901-
bool omitAssociatedFunctions = false) const;
1902-
1903-
static bool classof(const DeclAttribute *DA) {
1904-
return DA->getKind() == DAK_Differentiable;
1905-
}
1906-
};
1907-
1908-
19091902
void simple_display(llvm::raw_ostream &out, const DeclAttribute *attr);
19101903

19111904
inline SourceLoc extractNearestSourceLoc(const DeclAttribute *attr) {

lib/AST/Attr.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,31 +1352,30 @@ SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,
13521352
specializedSignature);
13531353
}
13541354

1355-
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
1356-
SourceLoc atLoc, SourceRange baseRange,
1357-
bool linear,
1355+
DifferentiableAttr::DifferentiableAttr(bool implicit, SourceLoc atLoc,
1356+
SourceRange baseRange, bool linear,
13581357
ArrayRef<ParsedAutoDiffParameter> params,
13591358
Optional<DeclNameWithLoc> jvp,
13601359
Optional<DeclNameWithLoc> vjp,
13611360
TrailingWhereClause *clause)
13621361
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1363-
linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
1362+
Linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
13641363
VJP(std::move(vjp)), WhereClause(clause) {
13651364
std::copy(params.begin(), params.end(),
13661365
getTrailingObjects<ParsedAutoDiffParameter>());
13671366
}
13681367

1369-
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
1368+
DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
13701369
SourceLoc atLoc, SourceRange baseRange,
13711370
bool linear,
1372-
IndexSubset *indices,
1371+
IndexSubset *parameterIndices,
13731372
Optional<DeclNameWithLoc> jvp,
13741373
Optional<DeclNameWithLoc> vjp,
13751374
GenericSignature derivativeGenSig)
13761375
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1377-
linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
1378-
ParameterIndices(indices) {
1379-
setDerivativeGenericSignature(context, derivativeGenSig);
1376+
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)) {
1377+
setParameterIndices(parameterIndices);
1378+
setDerivativeGenericSignature(derivativeGenSig);
13801379
}
13811380

13821381
DifferentiableAttr *
@@ -1389,22 +1388,23 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
13891388
TrailingWhereClause *clause) {
13901389
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
13911390
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
1392-
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
1393-
linear, parameters, std::move(jvp),
1391+
return new (mem) DifferentiableAttr(implicit, atLoc, baseRange, linear,
1392+
parameters, std::move(jvp),
13941393
std::move(vjp), clause);
13951394
}
13961395

13971396
DifferentiableAttr *
1398-
DifferentiableAttr::create(ASTContext &context, bool implicit,
1399-
SourceLoc atLoc, SourceRange baseRange,
1400-
bool linear, IndexSubset *indices,
1397+
DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
1398+
SourceLoc atLoc, SourceRange baseRange, bool linear,
1399+
IndexSubset *parameterIndices,
14011400
Optional<DeclNameWithLoc> jvp,
14021401
Optional<DeclNameWithLoc> vjp,
14031402
GenericSignature derivativeGenSig) {
1404-
void *mem = context.Allocate(sizeof(DifferentiableAttr),
1405-
alignof(DifferentiableAttr));
1406-
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
1407-
linear, indices, std::move(jvp),
1403+
auto &ctx = original->getASTContext();
1404+
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
1405+
alignof(DifferentiableAttr));
1406+
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
1407+
linear, parameterIndices, std::move(jvp),
14081408
std::move(vjp), derivativeGenSig);
14091409
}
14101410

0 commit comments

Comments
 (0)