diff --git a/hickle/hickle.py b/hickle/hickle.py index 6a986eff..ed0c1ea8 100644 --- a/hickle/hickle.py +++ b/hickle/hickle.py @@ -51,15 +51,8 @@ if PY3: file = io.TextIOWrapper -# Import a default 'pickler' -# Not the nicest import code, but should work on Py2/Py3 -try: - import dill as pickle -except ImportError: - try: - import cPickle as pickle - except ImportError: - import pickle +# Import dill as pickle +import dill as pickle try: from pathlib import Path diff --git a/hickle/loaders/load_python.py b/hickle/loaders/load_python.py index 58de921e..f53aee0f 100644 --- a/hickle/loaders/load_python.py +++ b/hickle/loaders/load_python.py @@ -102,6 +102,11 @@ def load_unicode_dataset(h_node): def load_none_dataset(h_node): return None +def load_pickled_data(h_node): + py_type, data = get_type_and_data(h_node) + import dill as pickle + return pickle.loads(data[0]) + def load_python_dtype_dataset(h_node): py_type, data = get_type_and_data(h_node) subtype = h_node.attrs["python_subdtype"] @@ -136,6 +141,7 @@ def load_python_dtype_dataset(h_node): "python_dtype" : load_python_dtype_dataset, "string" : load_string_dataset, "unicode" : load_unicode_dataset, + "pickle" : load_pickled_data, "none" : load_none_dataset } diff --git a/hickle/loaders/load_python3.py b/hickle/loaders/load_python3.py index c6b173fd..f532f324 100644 --- a/hickle/loaders/load_python3.py +++ b/hickle/loaders/load_python3.py @@ -153,10 +153,7 @@ def load_none_dataset(h_node): def load_pickled_data(h_node): py_type, data = get_type_and_data(h_node) - try: - import cPickle as pickle - except ModuleNotFoundError: - import pickle + import dill as pickle return pickle.loads(data[0]) diff --git a/hickle/tests/test_hickle.py b/hickle/tests/test_hickle.py index 54910542..57686306 100644 --- a/hickle/tests/test_hickle.py +++ b/hickle/tests/test_hickle.py @@ -14,6 +14,7 @@ import six import time from pprint import pprint +import pytest from py.path import local @@ -42,9 +43,24 @@ } } +# Define a test function that must be serialized and unpacked again +def func(a, b, c=0): + return(a, b, c) + + DUMP_CACHE = [] # Used in test_track_times() +def test_local_func(): + """ Dumping and loading a local function """ + filename, mode = 'test.h5', 'w' + with pytest.warns(SerializedWarning): + dump(func, filename, mode) + func_hkl = load(filename) + assert type(func) == type(func_hkl) + assert func(1, 2) == func_hkl(1, 2) + + def test_string(): """ Dumping and loading a string """ if six.PY2: @@ -821,6 +837,7 @@ def test_np_scalar(): test_complex_dict() test_multi_hickle() test_dict_int_key() + test_local_func() # Cleanup print("ALL TESTS PASSED!") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b3846b87..9dce5b33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -h5py -numpy -dill +dill>=0.3.0 +h5py>=2.8.0 +numpy>=1.8 +six>=1.11.0 \ No newline at end of file