Skip to content

Commit 2bc3fff

Browse files
viiryaHyukjinKwon
authored andcommitted
[SPARK-29341][PYTHON] Upgrade cloudpickle to 1.0.0
### What changes were proposed in this pull request? This patch upgrades cloudpickle to 1.0.0 version. Main changes: 1. cleanup unused functions: cloudpipe/cloudpickle@936f16f 2. Fix relative imports inside function body: cloudpipe/cloudpickle@31ecdd6 3. Write kw only arguments to pickle: cloudpipe/cloudpickle@6cb4718 ### Why are the changes needed? We should include new bug fix like cloudpipe/cloudpickle@6cb4718, because users might use such python function in PySpark. ```python >>> def f(a, *, b=1): ... return a + b ... >>> rdd = sc.parallelize([1, 2, 3]) >>> rdd.map(f).collect() [Stage 0:> (0 + 12) / 12]19/10/03 00:42:24 ERROR Executor: Exception in task 3.0 in stage 0.0 (TID 3) org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/spark/python/lib/pyspark.zip/pyspark/worker.py", line 598, in main process() File "/spark/python/lib/pyspark.zip/pyspark/worker.py", line 590, in process serializer.dump_stream(out_iter, outfile) File "/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 513, in dump_stream vs = list(itertools.islice(iterator, batch)) File "/spark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper return f(*args, **kwargs) TypeError: f() missing 1 required keyword-only argument: 'b' ``` After: ```python >>> def f(a, *, b=1): ... return a + b ... >>> rdd = sc.parallelize([1, 2, 3]) >>> rdd.map(f).collect() [2, 3, 4] ``` ### Does this PR introduce any user-facing change? Yes. This fixes two bugs when pickling Python functions. ### How was this patch tested? Existing tests. Closes #26009 from viirya/upgrade-cloudpickle. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 858bf76 commit 2bc3fff

File tree

4 files changed

+22
-50
lines changed

4 files changed

+22
-50
lines changed

python/pyspark/broadcast.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
from tempfile import NamedTemporaryFile
2222
import threading
2323

24-
from pyspark.cloudpickle import print_exec
2524
from pyspark.java_gateway import local_connect_and_auth
2625
from pyspark.serializers import ChunkedStream, pickle_protocol
27-
from pyspark.util import _exception_message
26+
from pyspark.util import _exception_message, print_exec
2827

2928
if sys.version < '3':
3029
import cPickle as pickle

python/pyspark/cloudpickle.py

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,8 @@ def save_function_tuple(self, func):
591591
state['annotations'] = func.__annotations__
592592
if hasattr(func, '__qualname__'):
593593
state['qualname'] = func.__qualname__
594+
if hasattr(func, '__kwdefaults__'):
595+
state['kwdefaults'] = func.__kwdefaults__
594596
save(state)
595597
write(pickle.TUPLE)
596598
write(pickle.REDUCE) # applies _fill_function on the tuple
@@ -666,6 +668,15 @@ def extract_func_data(self, func):
666668
# multiple invokations are bound to the same Cloudpickler.
667669
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
668670

671+
if base_globals == {}:
672+
# Add module attributes used to resolve relative imports
673+
# instructions inside func.
674+
for k in ["__package__", "__name__", "__path__", "__file__"]:
675+
# Some built-in functions/methods such as object.__new__ have
676+
# their __globals__ set to None in PyPy
677+
if func.__globals__ is not None and k in func.__globals__:
678+
base_globals[k] = func.__globals__[k]
679+
669680
return (code, f_globals, defaults, closure, dct, base_globals)
670681

671682
def save_builtin_function(self, obj):
@@ -979,43 +990,6 @@ def _restore_attr(obj, attr):
979990
return obj
980991

981992

982-
def _get_module_builtins():
983-
return pickle.__builtins__
984-
985-
986-
def print_exec(stream):
987-
ei = sys.exc_info()
988-
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
989-
990-
991-
def _modules_to_main(modList):
992-
"""Force every module in modList to be placed into main"""
993-
if not modList:
994-
return
995-
996-
main = sys.modules['__main__']
997-
for modname in modList:
998-
if type(modname) is str:
999-
try:
1000-
mod = __import__(modname)
1001-
except Exception:
1002-
sys.stderr.write('warning: could not import %s\n. '
1003-
'Your function may unexpectedly error due to this import failing;'
1004-
'A version mismatch is likely. Specific error was:\n' % modname)
1005-
print_exec(sys.stderr)
1006-
else:
1007-
setattr(main, mod.__name__, mod)
1008-
1009-
1010-
# object generators:
1011-
def _genpartial(func, args, kwds):
1012-
if not args:
1013-
args = ()
1014-
if not kwds:
1015-
kwds = {}
1016-
return partial(func, *args, **kwds)
1017-
1018-
1019993
def _gen_ellipsis():
1020994
return Ellipsis
1021995

@@ -1103,6 +1077,8 @@ def _fill_function(*args):
11031077
func.__module__ = state['module']
11041078
if 'qualname' in state:
11051079
func.__qualname__ = state['qualname']
1080+
if 'kwdefaults' in state:
1081+
func.__kwdefaults__ = state['kwdefaults']
11061082

11071083
cells = func.__closure__
11081084
if cells is not None:
@@ -1188,15 +1164,6 @@ def _is_dynamic(module):
11881164
return False
11891165

11901166

1191-
"""Constructors for 3rd party libraries
1192-
Note: These can never be renamed due to client compatibility issues"""
1193-
1194-
1195-
def _getobject(modname, attribute):
1196-
mod = __import__(modname, fromlist=[attribute])
1197-
return mod.__dict__[attribute]
1198-
1199-
12001167
""" Use copy_reg to extend global pickle definitions """
12011168

12021169
if sys.version_info < (3, 4): # pragma: no branch

python/pyspark/serializers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
pickle_protocol = pickle.HIGHEST_PROTOCOL
7070

7171
from pyspark import cloudpickle
72-
from pyspark.util import _exception_message
72+
from pyspark.util import _exception_message, print_exec
7373

7474

7575
__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
@@ -716,7 +716,7 @@ def dumps(self, obj):
716716
msg = "Object too large to serialize: %s" % emsg
717717
else:
718718
msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
719-
cloudpickle.print_exec(sys.stderr)
719+
print_exec(sys.stderr)
720720
raise pickle.PicklingError(msg)
721721

722722

python/pyspark/util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import re
2020
import sys
21+
import traceback
2122
import inspect
2223
from py4j.protocol import Py4JJavaError
2324

@@ -62,6 +63,11 @@ def _get_argspec(f):
6263
return argspec
6364

6465

66+
def print_exec(stream):
67+
ei = sys.exc_info()
68+
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
69+
70+
6571
class VersionUtils(object):
6672
"""
6773
Provides utility method to determine Spark versions with given input string.

0 commit comments

Comments
 (0)