Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyCapsule #477

Merged
merged 5 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,39 @@ def _create_namedtuple(name, fieldnames, modulename, defaults=None):
t = collections.namedtuple(name, fieldnames, defaults=defaults, module=modulename)
return t

def _create_capsule(pointer, name, context, destructor):
attr_found = False
try:
# based on https://github.com/python/cpython/blob/f4095e53ab708d95e019c909d5928502775ba68f/Objects/capsule.c#L209-L231
if PY3:
uname = name.decode('utf8')
else:
uname = name
for i in range(1, uname.count('.')+1):
names = uname.rsplit('.', i)
try:
module = __import__(names[0])
except:
pass
obj = module
for attr in names[1:]:
obj = getattr(obj, attr)
capsule = obj
attr_found = True
break
except:
pass

if attr_found:
if _PyCapsule_IsValid(capsule, name):
return capsule
raise UnpicklingError("%s object exists at %s but a PyCapsule object was expected." % (type(capsule), name))
else:
warnings.warn('Creating a new PyCapsule %s for a C data structure that may not be present in memory. Segmentation faults or other memory errors are possible.' % (name,), UnpicklingWarning)
mmckerns marked this conversation as resolved.
Show resolved Hide resolved
capsule = _PyCapsule_New(pointer, name, destructor)
_PyCapsule_SetContext(capsule, context)
return capsule

def _getattr(objclass, name, repr_str):
# hack to grab the reference directly
try: #XXX: works only for __builtin__ ?
Expand Down Expand Up @@ -2177,6 +2210,52 @@ def save_function(pickler, obj):
log.info("# F2")
return

if HAS_CTYPES and hasattr(ctypes, 'pythonapi'):
_PyCapsule_New = ctypes.pythonapi.PyCapsule_New
_PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p)
_PyCapsule_New.restype = ctypes.py_object
_PyCapsule_GetPointer = ctypes.pythonapi.PyCapsule_GetPointer
_PyCapsule_GetPointer.argtypes = (ctypes.py_object, ctypes.c_char_p)
_PyCapsule_GetPointer.restype = ctypes.c_void_p
_PyCapsule_GetDestructor = ctypes.pythonapi.PyCapsule_GetDestructor
_PyCapsule_GetDestructor.argtypes = (ctypes.py_object,)
_PyCapsule_GetDestructor.restype = ctypes.c_void_p
_PyCapsule_GetContext = ctypes.pythonapi.PyCapsule_GetContext
_PyCapsule_GetContext.argtypes = (ctypes.py_object,)
_PyCapsule_GetContext.restype = ctypes.c_void_p
_PyCapsule_GetName = ctypes.pythonapi.PyCapsule_GetName
_PyCapsule_GetName.argtypes = (ctypes.py_object,)
_PyCapsule_GetName.restype = ctypes.c_char_p
_PyCapsule_IsValid = ctypes.pythonapi.PyCapsule_IsValid
_PyCapsule_IsValid.argtypes = (ctypes.py_object, ctypes.c_char_p)
_PyCapsule_IsValid.restype = ctypes.c_bool
_PyCapsule_SetContext = ctypes.pythonapi.PyCapsule_SetContext
_PyCapsule_SetContext.argtypes = (ctypes.py_object, ctypes.c_void_p)
_PyCapsule_SetDestructor = ctypes.pythonapi.PyCapsule_SetDestructor
_PyCapsule_SetDestructor.argtypes = (ctypes.py_object, ctypes.c_void_p)
_PyCapsule_SetName = ctypes.pythonapi.PyCapsule_SetName
_PyCapsule_SetName.argtypes = (ctypes.py_object, ctypes.c_char_p)
_PyCapsule_SetPointer = ctypes.pythonapi.PyCapsule_SetPointer
_PyCapsule_SetPointer.argtypes = (ctypes.py_object, ctypes.c_void_p)
_testcapsule = _PyCapsule_New(
ctypes.cast(_PyCapsule_New, ctypes.c_void_p),
ctypes.create_string_buffer(b'dill._dill._testcapsule'),
None
)
PyCapsuleType = type(_testcapsule)
@register(PyCapsuleType)
def save_capsule(pickler, obj):
log.info("Cap: %s", obj)
name = _PyCapsule_GetName(obj)
warnings.warn('Pickling a PyCapsule (%s) does not pickle any C data structures and could cause segmentation faults or other memory errors when unpickling.' % (name,), PicklingWarning)
pointer = _PyCapsule_GetPointer(obj, name)
context = _PyCapsule_GetContext(obj)
destructor = _PyCapsule_GetDestructor(obj)
pickler.save_reduce(_create_capsule, (pointer, name, context, destructor), obj=obj)
log.info("# Cap")
else:
_testcapsule = None

# quick sanity checking
def pickles(obj,exact=False,safe=False,**kwds):
"""
Expand Down
5 changes: 5 additions & 0 deletions dill/_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,11 @@ class _Struct(ctypes.Structure):
else:
x['BufferType'] = buffer('')

from dill._dill import _testcapsule
if _testcapsule is not None:
x['PyCapsuleType'] = _testcapsule
del _testcapsule

# -- cleanup ----------------------------------------------------------------
a.update(d) # registered also succeed
if sys.platform[:3] == 'win':
Expand Down
2 changes: 1 addition & 1 deletion dill/_shims.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
#
# Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
# Author: Anirudh Vegesana (avegesan@stanford.edu)
# Author: Anirudh Vegesana (avegesan@cs.stanford.edu)
# Copyright (c) 2021-2022 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dictviews.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/env python
#
# Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
# Copyright (c) 2008-2016 California Institute of Technology.
# Copyright (c) 2016-2021 The Uncertainty Quantification Foundation.
# Author: Anirudh Vegesana (avegesan@cs.stanford.edu)
# Copyright (c) 2021-2022 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE

Expand Down
1 change: 0 additions & 1 deletion tests/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,5 @@ def test_objects():
#pickles(member, exact=True)
pickles(member, exact=False)


if __name__ == '__main__':
test_objects()
45 changes: 45 additions & 0 deletions tests/test_pycapsule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python
#
# Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
mmckerns marked this conversation as resolved.
Show resolved Hide resolved
# Author: Anirudh Vegesana (avegesan@cs.stanford.edu)
# Copyright (c) 2022 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE
"""
test pickling a PyCapsule object
"""

import dill
import warnings

test_pycapsule = None

if dill._dill._testcapsule is not None:
import ctypes
def test_pycapsule():
name = ctypes.create_string_buffer(b'dill._testcapsule')
capsule = dill._dill._PyCapsule_New(
ctypes.cast(dill._dill._PyCapsule_New, ctypes.c_void_p),
name,
None
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
dill.copy(capsule)
dill._testcapsule = capsule
with warnings.catch_warnings():
warnings.simplefilter("ignore")
dill.copy(capsule)
dill._testcapsule = None
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", dill.PicklingWarning)
dill.copy(capsule)
except dill.UnpicklingError:
pass
else:
raise AssertionError("Expected a different error")

if __name__ == '__main__':
if test_pycapsule is not None:
test_pycapsule()