Skip to content

Commit e2db11f

Browse files
committed
WIP
1 parent d45b4c3 commit e2db11f

File tree

3 files changed

+112
-34
lines changed

3 files changed

+112
-34
lines changed

changelog.d/324.change.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
``attr.s`` now auto-detects user-written methods and does not overwrite them. # is missing the __hash__ == None if it shouldn't be hashable

src/attr/_make.py

+45-27
Original file line numberDiff line numberDiff line change
@@ -555,56 +555,74 @@ def slots_setstate(self, state):
555555
return cls
556556

557557
def add_repr(self, ns):
558-
self._cls_dict["__repr__"] = self._add_method_dunders(
559-
_make_repr(self._attrs, ns=ns)
560-
)
558+
if "__repr__" in self._cls.__dict__:
559+
meth = getattr(self._cls, "__repr__")
560+
else:
561+
meth = self._add_method_dunders(_make_repr(self._attrs, ns=ns))
562+
563+
self._cls_dict["__repr__"] = meth
561564
return self
562565

563566
def add_str(self):
564-
repr = self._cls_dict.get("__repr__")
565-
if repr is None:
566-
raise ValueError(
567-
"__str__ can only be generated if a __repr__ exists."
568-
)
567+
if "__str__" in self._cls.__dict__:
568+
meth = getattr(self._cls, "__str__")
569+
else:
570+
repr = self._cls_dict.get("__repr__")
571+
if repr is None:
572+
raise ValueError(
573+
"__str__ can only be generated if a __repr__ exists."
574+
)
569575

570-
def __str__(self):
571-
return self.__repr__()
576+
def __str__(self):
577+
return self.__repr__()
572578

573-
self._cls_dict["__str__"] = self._add_method_dunders(__str__)
579+
meth = self._add_method_dunders(__str__)
580+
581+
self._cls_dict["__str__"] = meth
574582
return self
575583

576584
def make_unhashable(self):
577585
self._cls_dict["__hash__"] = None
578586
return self
579587

580588
def add_hash(self):
581-
self._cls_dict["__hash__"] = self._add_method_dunders(
582-
_make_hash(self._attrs)
583-
)
589+
if "__hash__" in self._cls.__dict__:
590+
meth = getattr(self._cls, "__hash__")
591+
else:
592+
meth = self._add_method_dunders(_make_hash(self._attrs))
584593

594+
self._cls_dict["__hash__"] = meth
585595
return self
586596

587597
def add_init(self):
588-
self._cls_dict["__init__"] = self._add_method_dunders(
589-
_make_init(
590-
self._attrs,
591-
self._has_post_init,
592-
self._frozen,
593-
self._slots,
594-
self._super_attr_map,
598+
if "__init__" in self._cls.__dict__:
599+
meth = getattr(self._cls, "__init__")
600+
else:
601+
meth = self._add_method_dunders(
602+
_make_init(
603+
self._attrs,
604+
self._has_post_init,
605+
self._frozen,
606+
self._slots,
607+
self._super_attr_map,
608+
)
595609
)
596-
)
597610

611+
self._cls_dict["__init__"] = meth
598612
return self
599613

600614
def add_cmp(self):
601615
cd = self._cls_dict
602616

603-
cd["__eq__"], cd["__ne__"], cd["__lt__"], cd["__le__"], cd[
604-
"__gt__"
605-
], cd["__ge__"] = (
606-
self._add_method_dunders(meth) for meth in _make_cmp(self._attrs)
607-
)
617+
for meth in _make_cmp(self._attrs):
618+
method_name = meth.__name__
619+
620+
if method_name in self._cls.__dict__:
621+
meth = getattr(self._cls, method_name)
622+
else:
623+
meth = self._add_method_dunders(meth)
624+
625+
cd[method_name] = meth
608626

609627
return self
610628

tests/test_make.py

+66-7
Original file line numberDiff line numberDiff line change
@@ -422,21 +422,21 @@ def test_adds_all_by_default(self, method_name):
422422
If no further arguments are supplied, all add_XXX functions except
423423
add_hash are applied. __hash__ is set to None.
424424
"""
425-
# Set the method name to a sentinel and check whether it has been
426-
# overwritten afterwards.
427-
sentinel = object()
428425

429426
class C(object):
430427
x = attr.ib()
431428

432-
setattr(C, method_name, sentinel)
429+
# Assert that the method does not exist yet.
430+
assert method_name not in C.__dict__
433431

434432
C = attr.s(C)
435-
meth = getattr(C, method_name)
436433

437-
assert sentinel != meth
434+
method = getattr(C, method_name)
435+
438436
if method_name == "__hash__":
439-
assert meth is None
437+
assert method is None
438+
else:
439+
assert method is not None
440440

441441
@pytest.mark.parametrize(
442442
"arg_name, method_name",
@@ -1270,6 +1270,65 @@ class C2(C):
12701270

12711271
assert [C2] == C.__subclasses__()
12721272

1273+
@pytest.mark.parametrize(
1274+
"method_name",
1275+
[
1276+
"__init__",
1277+
"__hash__",
1278+
"__repr__",
1279+
"__str__",
1280+
"__eq__",
1281+
"__ne__",
1282+
"__lt__",
1283+
"__le__",
1284+
"__gt__",
1285+
"__ge__",
1286+
],
1287+
)
1288+
def test_respect_user_defined_methods(self, method_name):
1289+
"""
1290+
Does not replace methods provided by the original class.
1291+
"""
1292+
1293+
# Set the method name to a sentinel and check that it has not been
1294+
# overwritten afterwards.
1295+
def sentinel():
1296+
pass
1297+
1298+
# add_cmp relies on __name__.
1299+
sentinel.__name__ = method_name
1300+
1301+
class C(object):
1302+
x = attr.ib()
1303+
1304+
setattr(C, method_name, sentinel)
1305+
1306+
# set sentinel to unbound method of C otherwise assertion for py27
1307+
# doesn't work
1308+
sentinel = getattr(C, method_name)
1309+
1310+
C = attr.s(C, hash=True, str=True)
1311+
1312+
assert sentinel == getattr(C, method_name)
1313+
1314+
def test_ignores_user_defined_hash_method(self):
1315+
"""
1316+
if it should be unhashable, the '__hash__' method is replaced.
1317+
"""
1318+
1319+
# Set the method name to a sentinel and check that it has not been
1320+
# overwritten afterwards.
1321+
sentinel = object()
1322+
1323+
class C(object):
1324+
x = attr.ib()
1325+
1326+
setattr(C, "__hash__", sentinel)
1327+
1328+
C = attr.s(C)
1329+
1330+
assert getattr(C, "__hash__") is None
1331+
12731332

12741333
class TestMakeCmp:
12751334
"""

0 commit comments

Comments
 (0)