diff --git a/changelog.d/324.change.rst b/changelog.d/324.change.rst new file mode 100644 index 000000000..11bfd21aa --- /dev/null +++ b/changelog.d/324.change.rst @@ -0,0 +1 @@ +``attr.s`` now auto-detects user-written methods and does not overwrite them. # is missing the __hash__ == None if it shouldn't be hashable diff --git a/src/attr/_make.py b/src/attr/_make.py index a8d9c70c7..1328a4037 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -555,12 +555,18 @@ def slots_setstate(self, state): return cls def add_repr(self, ns): + if "__repr__" in self._cls.__dict__: + return self + self._cls_dict["__repr__"] = self._add_method_dunders( _make_repr(self._attrs, ns=ns) ) return self def add_str(self): + if "__str__" in self._cls.__dict__: + return self + repr = self._cls_dict.get("__repr__") if repr is None: raise ValueError( @@ -578,13 +584,18 @@ def make_unhashable(self): return self def add_hash(self): + if "__hash__" in self._cls.__dict__: + return self + self._cls_dict["__hash__"] = self._add_method_dunders( _make_hash(self._attrs) ) - return self def add_init(self): + if "__init__" in self._cls.__dict__: + return self + self._cls_dict["__init__"] = self._add_method_dunders( _make_init( self._attrs, @@ -600,11 +611,13 @@ def add_init(self): def add_cmp(self): cd = self._cls_dict - cd["__eq__"], cd["__ne__"], cd["__lt__"], cd["__le__"], cd[ - "__gt__" - ], cd["__ge__"] = ( - self._add_method_dunders(meth) for meth in _make_cmp(self._attrs) - ) + for meth in _make_cmp(self._attrs): + method_name = meth.__name__ + + if method_name in self._cls.__dict__: + continue + + cd[method_name] = self._add_method_dunders(meth) return self diff --git a/tests/test_make.py b/tests/test_make.py index 20d13b6a9..1a9c646ff 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -422,21 +422,21 @@ def test_adds_all_by_default(self, method_name): If no further arguments are supplied, all add_XXX functions except add_hash are applied. __hash__ is set to None. """ - # Set the method name to a sentinel and check whether it has been - # overwritten afterwards. - sentinel = object() class C(object): x = attr.ib() - setattr(C, method_name, sentinel) + # Assert that the method does not exist yet. + assert method_name not in C.__dict__ C = attr.s(C) - meth = getattr(C, method_name) - assert sentinel != meth + method = getattr(C, method_name) + if method_name == "__hash__": - assert meth is None + assert method is None + else: + assert method is not None @pytest.mark.parametrize( "arg_name, method_name", @@ -1270,6 +1270,65 @@ class C2(C): assert [C2] == C.__subclasses__() + @pytest.mark.parametrize( + "method_name", + [ + "__init__", + "__hash__", + "__repr__", + "__str__", + "__eq__", + "__ne__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + ], + ) + def test_respect_user_defined_methods(self, method_name): + """ + Does not replace methods provided by the original class. + """ + + # Set the method name to a sentinel and check that it has not been + # overwritten afterwards. + def sentinel(): + pass + + # add_cmp relies on __name__. + sentinel.__name__ = method_name + + class C(object): + x = attr.ib() + + setattr(C, method_name, sentinel) + + # set sentinel to unbound method of C otherwise assertion for py27 + # doesn't work + sentinel = getattr(C, method_name) + + C = attr.s(C, hash=True, str=(method_name == "__str__")) + + assert sentinel == getattr(C, method_name) + + def test_ignores_user_defined_hash_method(self): + """ + if it should be unhashable, the '__hash__' method is replaced. + """ + + # Set the method name to a sentinel and check that it has not been + # overwritten afterwards. + sentinel = object() + + class C(object): + x = attr.ib() + + setattr(C, "__hash__", sentinel) + + C = attr.s(C) + + assert getattr(C, "__hash__") is None + class TestMakeCmp: """