Skip to content

Commit 654aa92

Browse files
authored
__attrs_init__() (#731)
1 parent 467e28b commit 654aa92

File tree

6 files changed

+118
-14
lines changed

6 files changed

+118
-14
lines changed

changelog.d/731.change.rst

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
``__attrs__init__()`` will now be injected if ``init==False`` or if ``auto_detect=True`` and a user-defined ``__init__()`` exists.
2+
3+
This enables users to do "pre-init" work in their ``__init__()`` (such as ``super().__init__()``).
4+
5+
``__init__()`` can then delegate constructor argument processing to ``__attrs_init__(*args, **kwargs)``.

src/attr/_make.py

+80-13
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,26 @@ def add_init(self):
897897
self._is_exc,
898898
self._on_setattr is not None
899899
and self._on_setattr is not setters.NO_OP,
900+
attrs_init=False,
901+
)
902+
)
903+
904+
return self
905+
906+
def add_attrs_init(self):
907+
self._cls_dict["__attrs_init__"] = self._add_method_dunders(
908+
_make_init(
909+
self._cls,
910+
self._attrs,
911+
self._has_post_init,
912+
self._frozen,
913+
self._slots,
914+
self._cache_hash,
915+
self._base_attr_map,
916+
self._is_exc,
917+
self._on_setattr is not None
918+
and self._on_setattr is not setters.NO_OP,
919+
attrs_init=True,
900920
)
901921
)
902922

@@ -1160,6 +1180,11 @@ def attrs(
11601180
``attrs`` attributes. Leading underscores are stripped for the
11611181
argument name. If a ``__attrs_post_init__`` method exists on the
11621182
class, it will be called after the class is fully initialized.
1183+
1184+
If ``init`` is ``False``, an ``__attrs_init__`` method will be
1185+
injected instead. This allows you to define a custom ``__init__``
1186+
method that can do pre-init work such as ``super().__init__()``,
1187+
and then call ``__attrs_init__()`` and ``__attrs_post_init__()``.
11631188
:param bool slots: Create a `slotted class <slotted classes>` that's more
11641189
memory-efficient. Slotted classes are generally superior to the default
11651190
dict classes, but have some gotchas you should know about, so we
@@ -1299,6 +1324,8 @@ def attrs(
12991324
.. versionadded:: 20.1.0 *getstate_setstate*
13001325
.. versionadded:: 20.1.0 *on_setattr*
13011326
.. versionadded:: 20.3.0 *field_transformer*
1327+
.. versionchanged:: 21.1.0
1328+
``init=False`` injects ``__attrs_init__``
13021329
"""
13031330
if auto_detect and PY2:
13041331
raise PythonTooOldError(
@@ -1408,6 +1435,7 @@ def wrap(cls):
14081435
):
14091436
builder.add_init()
14101437
else:
1438+
builder.add_attrs_init()
14111439
if cache_hash:
14121440
raise TypeError(
14131441
"Invalid value for cache_hash. To use hash caching,"
@@ -1872,6 +1900,7 @@ def _make_init(
18721900
base_attr_map,
18731901
is_exc,
18741902
has_global_on_setattr,
1903+
attrs_init,
18751904
):
18761905
if frozen and has_global_on_setattr:
18771906
raise ValueError("Frozen classes can't use on_setattr.")
@@ -1908,6 +1937,7 @@ def _make_init(
19081937
is_exc,
19091938
needs_cached_setattr,
19101939
has_global_on_setattr,
1940+
attrs_init,
19111941
)
19121942
locs = {}
19131943
bytecode = compile(script, unique_filename, "exec")
@@ -1929,10 +1959,10 @@ def _make_init(
19291959
unique_filename,
19301960
)
19311961

1932-
__init__ = locs["__init__"]
1933-
__init__.__annotations__ = annotations
1962+
init = locs["__attrs_init__"] if attrs_init else locs["__init__"]
1963+
init.__annotations__ = annotations
19341964

1935-
return __init__
1965+
return init
19361966

19371967

19381968
def _setattr(attr_name, value_var, has_on_setattr):
@@ -2047,6 +2077,7 @@ def _attrs_to_init_script(
20472077
is_exc,
20482078
needs_cached_setattr,
20492079
has_global_on_setattr,
2080+
attrs_init,
20502081
):
20512082
"""
20522083
Return a script of an initializer for *attrs* and a dict of globals.
@@ -2317,10 +2348,12 @@ def fmt_setter_with_converter(
23172348
)
23182349
return (
23192350
"""\
2320-
def __init__(self, {args}):
2351+
def {init_name}(self, {args}):
23212352
{lines}
23222353
""".format(
2323-
args=args, lines="\n ".join(lines) if lines else "pass"
2354+
init_name=("__attrs_init__" if attrs_init else "__init__"),
2355+
args=args,
2356+
lines="\n ".join(lines) if lines else "pass",
23242357
),
23252358
names_for_globals,
23262359
annotations,
@@ -2666,7 +2699,6 @@ def default(self, meth):
26662699
_CountingAttr = _add_eq(_add_repr(_CountingAttr))
26672700

26682701

2669-
@attrs(slots=True, init=False, hash=True)
26702702
class Factory(object):
26712703
"""
26722704
Stores a factory callable.
@@ -2682,8 +2714,7 @@ class Factory(object):
26822714
.. versionadded:: 17.1.0 *takes_self*
26832715
"""
26842716

2685-
factory = attrib()
2686-
takes_self = attrib()
2717+
__slots__ = ("factory", "takes_self")
26872718

26882719
def __init__(self, factory, takes_self=False):
26892720
"""
@@ -2693,6 +2724,38 @@ def __init__(self, factory, takes_self=False):
26932724
self.factory = factory
26942725
self.takes_self = takes_self
26952726

2727+
def __getstate__(self):
2728+
"""
2729+
Play nice with pickle.
2730+
"""
2731+
return tuple(getattr(self, name) for name in self.__slots__)
2732+
2733+
def __setstate__(self, state):
2734+
"""
2735+
Play nice with pickle.
2736+
"""
2737+
for name, value in zip(self.__slots__, state):
2738+
setattr(self, name, value)
2739+
2740+
2741+
_f = [
2742+
Attribute(
2743+
name=name,
2744+
default=NOTHING,
2745+
validator=None,
2746+
repr=True,
2747+
cmp=None,
2748+
eq=True,
2749+
order=False,
2750+
hash=True,
2751+
init=True,
2752+
inherited=False,
2753+
)
2754+
for name in Factory.__slots__
2755+
]
2756+
2757+
Factory = _add_hash(_add_eq(_add_repr(Factory, attrs=_f), attrs=_f), attrs=_f)
2758+
26962759

26972760
def make_class(name, attrs, bases=(object,), **attributes_arguments):
26982761
"""
@@ -2727,11 +2790,15 @@ def make_class(name, attrs, bases=(object,), **attributes_arguments):
27272790
raise TypeError("attrs argument must be a dict or a list.")
27282791

27292792
post_init = cls_dict.pop("__attrs_post_init__", None)
2730-
type_ = type(
2731-
name,
2732-
bases,
2733-
{} if post_init is None else {"__attrs_post_init__": post_init},
2734-
)
2793+
user_init = cls_dict.pop("__init__", None)
2794+
2795+
body = {}
2796+
if post_init is not None:
2797+
body["__attrs_post_init__"] = post_init
2798+
if user_init is not None:
2799+
body["__init__"] = user_init
2800+
2801+
type_ = type(name, bases, body)
27352802
# For pickling to work, the __module__ variable needs to be set to the
27362803
# frame where the class is created. Bypass this step in environments where
27372804
# sys._getframe is not defined (Jython for example) or sys._getframe is not

tests/strategies.py

+10
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,29 @@ class HypClass:
154154

155155
cls_dict = dict(zip(attr_names, attrs))
156156
post_init_flag = draw(st.booleans())
157+
init_flag = draw(st.booleans())
158+
157159
if post_init_flag:
158160

159161
def post_init(self):
160162
pass
161163

162164
cls_dict["__attrs_post_init__"] = post_init
163165

166+
if not init_flag:
167+
168+
def init(self, *args, **kwargs):
169+
self.__attrs_init__(*args, **kwargs)
170+
171+
cls_dict["__init__"] = init
172+
164173
return make_class(
165174
"HypClass",
166175
cls_dict,
167176
slots=slots_flag if slots is None else slots,
168177
frozen=frozen_flag if frozen is None else frozen,
169178
weakref_slot=weakref_flag if weakref_slot is None else weakref_slot,
179+
init=init_flag,
170180
)
171181

172182

tests/test_dunders.py

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _add_init(cls, frozen):
6969
base_attr_map={},
7070
is_exc=False,
7171
has_global_on_setattr=False,
72+
attrs_init=False,
7273
)
7374
return cls
7475

tests/test_funcs.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,14 @@ def test_unknown(self, C):
544544
# No generated class will have a four letter attribute.
545545
with pytest.raises(TypeError) as e:
546546
evolve(C(), aaaa=2)
547-
expected = "__init__() got an unexpected keyword argument 'aaaa'"
547+
548+
if hasattr(C, "__attrs_init__"):
549+
expected = (
550+
"__attrs_init__() got an unexpected keyword argument 'aaaa'"
551+
)
552+
else:
553+
expected = "__init__() got an unexpected keyword argument 'aaaa'"
554+
548555
assert (expected,) == e.value.args
549556

550557
def test_validator_failure(self):

tests/test_make.py

+14
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,19 @@ class C(object):
569569

570570
assert sentinel == getattr(C, method_name)
571571

572+
@pytest.mark.parametrize("init", [True, False])
573+
def test_respects_init_attrs_init(self, init):
574+
"""
575+
If init=False, adds __attrs_init__ to the class.
576+
Otherwise, it does not.
577+
"""
578+
579+
class C(object):
580+
x = attr.ib()
581+
582+
C = attr.s(init=init)(C)
583+
assert hasattr(C, "__attrs_init__") != init
584+
572585
@pytest.mark.skipif(PY2, reason="__qualname__ is PY3-only.")
573586
@given(slots_outer=booleans(), slots_inner=booleans())
574587
def test_repr_qualname(self, slots_outer, slots_inner):
@@ -1527,6 +1540,7 @@ class C(object):
15271540
.add_order()
15281541
.add_hash()
15291542
.add_init()
1543+
.add_attrs_init()
15301544
.add_repr("ns")
15311545
.add_str()
15321546
.build_class()

0 commit comments

Comments
 (0)