Skip to content

Commit

Permalink
Import PyCapsule if it already exists and add descriptive warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
anivegesana committed May 17, 2022
1 parent 44095bc commit 361f048
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
48 changes: 38 additions & 10 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,15 +1059,45 @@ def _create_namedtuple(name, fieldnames, modulename, defaults=None):
t = collections.namedtuple(name, fieldnames, defaults=defaults, module=modulename)
return t

# TODO: Remove this copy when #450 is pulled
# https://github.com/python/cpython/blob/a8912a0f8d9eba6d502c37d522221f9933e976db/Lib/pickle.py#L322-L333
def _getattribute(obj, name):
for subpath in name.split('.'):
if subpath == '<locals>':
raise AttributeError("Can't get local attribute {!r} on {!r}"
.format(name, obj))
try:
parent = obj
obj = getattr(obj, subpath)
except AttributeError:
raise AttributeError("Can't get attribute {!r} on {!r}"
.format(name, obj)) # In Py3: from None
return obj,

def _create_capsule(pointer, name, context, destructor):
try:
# TODO: Somehow check this condition?
return _PyCapsule_Import(name, False)
# based on https://github.com/python/cpython/blob/f4095e53ab708d95e019c909d5928502775ba68f/Objects/capsule.c#L209-L231
for i in range(1, name.count(b'.')+1):
names = name.rsplit(b'.', i)
try:
module = __import__(names[0])
except:
continue
obj = module
for attr in names[1:]:
obj = getattr(obj, attr)
capsule = obj
break
except:
warnings.warn('Creating a new PyCapsule %s for a C data structure that may not be present in memory. Segmentation faults or other memory errors are possible.' % (name,), UnpicklingWarning)
capsule = _PyCapsule_New(pointer, name, destructor)
_PyCapsule_SetContext(capsule, context)
return capsule

if _PyCapsule_IsValid(capsule, name):
return capsule
raise UnpicklingError("%s object exists at %s but a PyCapsule object was expected." % (type(capsule), name))

def _getattr(objclass, name, repr_str):
# hack to grab the reference directly
try: #XXX: works only for __builtin__ ?
Expand Down Expand Up @@ -2030,9 +2060,9 @@ def save_function(pickler, obj):
_PyCapsule_GetName = ctypes.pythonapi.PyCapsule_GetName
_PyCapsule_GetName.argtypes = (ctypes.py_object,)
_PyCapsule_GetName.restype = ctypes.c_char_p
_PyCapsule_Import = ctypes.pythonapi.PyCapsule_Import
_PyCapsule_Import.argtypes = (ctypes.c_char_p, ctypes.c_bool)
_PyCapsule_Import.restype = ctypes.py_object
_PyCapsule_IsValid = ctypes.pythonapi.PyCapsule_IsValid
_PyCapsule_IsValid.argtypes = (ctypes.py_object, ctypes.c_char_p)
_PyCapsule_IsValid.restype = ctypes.c_bool
_PyCapsule_SetContext = ctypes.pythonapi.PyCapsule_SetContext
_PyCapsule_SetContext.argtypes = (ctypes.py_object, ctypes.c_void_p)
_PyCapsule_SetDestructor = ctypes.pythonapi.PyCapsule_SetDestructor
Expand All @@ -2041,20 +2071,18 @@ def save_function(pickler, obj):
_PyCapsule_SetName.argtypes = (ctypes.py_object, ctypes.c_char_p)
_PyCapsule_SetPointer = ctypes.pythonapi.PyCapsule_SetPointer
_PyCapsule_SetPointer.argtypes = (ctypes.py_object, ctypes.c_void_p)
# _PyCapsule_CheckExact = ctypes.pythonapi.PyCapsule_CheckExact
_PyCapsule_IsValid = ctypes.pythonapi.PyCapsule_IsValid
PyCapsuleType = type(
_PyCapsule_New(
ctypes.cast(_PyCapsule_New, ctypes.c_void_p),
ctypes.create_string_buffer(b'dill'),
ctypes.create_string_buffer(b'dill._testcapsule'),
None
)
)
@register(PyCapsuleType) # TODO
@register(PyCapsuleType)
def save_capsule(pickler, obj):
log.info("Cap: %s", obj)
name = _PyCapsule_GetName(obj)
warnings.warn('Pickling a PyCapsule does not pickle any C data structures and could cause segmentation faults or other memory errors when unpickling.' % (name,), PicklingWarning)
warnings.warn('Pickling a PyCapsule (%s) does not pickle any C data structures and could cause segmentation faults or other memory errors when unpickling.' % (name,), PicklingWarning)
pointer = _PyCapsule_GetPointer(obj, name)
context = _PyCapsule_GetContext(obj)
destructor = _PyCapsule_GetDestructor(obj)
Expand Down
20 changes: 17 additions & 3 deletions tests/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from dill import load_types, objects, extend
load_types(pickleable=True,unpickleable=False)

import warnings

# uncomment the next two lines to test cloudpickle
#extend(False)
#import cloudpickle as pickle
Expand Down Expand Up @@ -63,14 +65,26 @@ def test_objects():
import ctypes
if hasattr(ctypes, 'pythonapi'):
def test_pycapsule():
name = ctypes.create_string_buffer(b'dill._testcapsule')
capsule = pickle._dill._PyCapsule_New(
ctypes.cast(pickle._dill._PyCapsule_New, ctypes.c_void_p),
ctypes.create_string_buffer(b'dill'),
name,
None
)
pickle.copy(capsule)
with warnings.catch_warnings('ignore'):
pickle.copy(capsule)
pickle._testcapsule = capsule
with warnings.catch_warnings('ignore'):
pickle.copy(capsule)
pickle._testcapsule = None
try:
pickle.copy(capsule)
except pickle.UnpicklingError:
pass
else:
raise AssertionError("Expected a different error")

if __name__ == '__main__':
test_objects()
if test_pycapsule is not test_pycapsule:
if test_pycapsule is not None:
test_pycapsule()

0 comments on commit 361f048

Please sign in to comment.