Skip to content

Commit d5d91ab

Browse files
Bugfix: BaseEstimator __getstate__ in Python 3.11
Since Python 3.11, objects have a __getstate__ method by default: python/cpython#70766 Therefore, the exception in BaseEstimator.__getstate__ will no longer be raised, thus not falling back on using the object's __dict__: https://github.com/scikit-learn/scikit-learn/blob/dc580a8ef5ee2a8aea80498388690e2213118efd/sklearn/base.py#L274-L280 If the instance dict of the object is empty, the return value will, however, be None. Therefore, the line below calling state.items() results in an error. In this bugfix, it is checked if the state is None and if it is, the object's __dict__ is used (which should always be empty). Not addressed in this PR is how to deal with slots (see also discussion in scikit-learn#10079). When there are __slots__, __getstate__ will actually return a tuple, as documented here: https://docs.python.org/3/library/pickle.html#object.__getstate__ The user would thus still get an indiscriptive error message.
1 parent c0eb3d3 commit d5d91ab

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

sklearn/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,11 @@ def __repr__(self, N_CHAR_MAX=700):
273273
def __getstate__(self):
274274
try:
275275
state = super().__getstate__()
276+
if state is None:
277+
state = self.__dict__.copy()
276278
except AttributeError:
279+
# TODO: Remove once Python < 3.11 is dropped, as there will never be
280+
# an AttributeError
277281
state = self.__dict__.copy()
278282

279283
if type(self).__module__.startswith("sklearn."):

sklearn/tests/test_base.py

+24
Original file line numberDiff line numberDiff line change
@@ -675,3 +675,27 @@ def test_clone_keeps_output_config():
675675
ss_clone = clone(ss)
676676
config_clone = _get_output_config("transform", ss_clone)
677677
assert config == config_clone
678+
679+
680+
def test_parent_object_empty_instance_dict():
681+
# Since Python 3.11, Python objects have a __getstate__ method by default
682+
# that returns None if the instance dict is empty
683+
class Empty:
684+
pass
685+
686+
class Estimator(Empty, BaseEstimator):
687+
pass
688+
689+
state = Estimator().__getstate__()
690+
expected = {"_sklearn_version": sklearn.__version__}
691+
assert state == expected
692+
693+
694+
def test_base_estimator_empty_instance_dict():
695+
# Since Python 3.11, Python objects have a __getstate__ method by default
696+
# that returns None if the instance dict is empty
697+
698+
# this should not raise
699+
state = BaseEstimator().__getstate__()
700+
expected = {"_sklearn_version": sklearn.__version__}
701+
assert state == expected

0 commit comments

Comments
 (0)