Skip to content

Commit

Permalink
Simplify serialization protocols (#17552)
Browse files Browse the repository at this point in the history
This rewrites all serialization protocols in cudf to remove the need for
pickling intermediates.
  • Loading branch information
vyasr authored Dec 10, 2024
1 parent 439321e commit 2f5bf76
Show file tree
Hide file tree
Showing 18 changed files with 179 additions and 141 deletions.
11 changes: 4 additions & 7 deletions python/cudf/cudf/_lib/copying.pyx
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

import pickle

from libcpp cimport bool
from libcpp.memory cimport unique_ptr
from libcpp.utility cimport move
Expand Down Expand Up @@ -367,24 +365,23 @@ class PackedColumns(Serializable):
header["index-names"] = self.index_names
header["metadata"] = self._metadata.tobytes()
for name, dtype in self.column_dtypes.items():
dtype_header, dtype_frames = dtype.serialize()
dtype_header, dtype_frames = dtype.device_serialize()
self.column_dtypes[name] = (
dtype_header,
(len(frames), len(frames) + len(dtype_frames)),
)
frames.extend(dtype_frames)
header["column-dtypes"] = self.column_dtypes
header["type-serialized"] = pickle.dumps(type(self))
return header, frames

@classmethod
def deserialize(cls, header, frames):
column_dtypes = {}
for name, dtype in header["column-dtypes"].items():
dtype_header, (start, stop) = dtype
column_dtypes[name] = pickle.loads(
dtype_header["type-serialized"]
).deserialize(dtype_header, frames[start:stop])
column_dtypes[name] = Serializable.device_deserialize(
dtype_header, frames[start:stop]
)
return cls(
plc.contiguous_split.pack(
plc.contiguous_split.unpack_from_memoryviews(
Expand Down
8 changes: 0 additions & 8 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import pickle
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -330,13 +329,6 @@ def get_level_values(self, level):
else:
raise KeyError(f"Requested level with name {level} " "not found")

@classmethod
def deserialize(cls, header, frames):
# Dispatch deserialization to the appropriate index type in case
# deserialization is ever attempted with the base class directly.
idx_type = pickle.loads(header["type-serialized"])
return idx_type.deserialize(header, frames)

@property
def names(self):
"""
Expand Down
16 changes: 11 additions & 5 deletions python/cudf/cudf/core/abc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
"""Common abstract base classes for cudf."""

import pickle

import numpy

import cudf
Expand All @@ -22,6 +20,14 @@ class Serializable:
latter converts back from that representation into an equivalent object.
"""

# A mapping from class names to the classes themselves. This is used to
# reconstruct the correct class when deserializing an object.
_name_type_map: dict = {}

def __init_subclass__(cls, /, **kwargs):
super().__init_subclass__(**kwargs)
cls._name_type_map[cls.__name__] = cls

def serialize(self):
"""Generate an equivalent serializable representation of an object.
Expand Down Expand Up @@ -98,7 +104,7 @@ def device_serialize(self):
)
for f in frames
)
header["type-serialized"] = pickle.dumps(type(self))
header["type-serialized-name"] = type(self).__name__
header["is-cuda"] = [
hasattr(f, "__cuda_array_interface__") for f in frames
]
Expand Down Expand Up @@ -128,10 +134,10 @@ def device_deserialize(cls, header, frames):
:meta private:
"""
typ = pickle.loads(header["type-serialized"])
typ = cls._name_type_map[header["type-serialized-name"]]
frames = [
cudf.core.buffer.as_buffer(f) if c else memoryview(f)
for c, f in zip(header["is-cuda"], frames)
for c, f in zip(header["is-cuda"], frames, strict=True)
]
return typ.deserialize(header, frames)

Expand Down
8 changes: 4 additions & 4 deletions python/cudf/cudf/core/buffer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import math
import pickle
import weakref
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -432,8 +431,7 @@ def serialize(self) -> tuple[dict, list]:
second element is a list containing single frame.
"""
header: dict[str, Any] = {}
header["type-serialized"] = pickle.dumps(type(self))
header["owner-type-serialized"] = pickle.dumps(type(self._owner))
header["owner-type-serialized-name"] = type(self._owner).__name__
header["frame_count"] = 1
frames = [self]
return header, frames
Expand All @@ -460,7 +458,9 @@ def deserialize(cls, header: dict, frames: list) -> Self:
if isinstance(frame, cls):
return frame # The frame is already deserialized

owner_type: BufferOwner = pickle.loads(header["owner-type-serialized"])
owner_type: BufferOwner = Serializable._name_type_map[
header["owner-type-serialized-name"]
]
if hasattr(frame, "__cuda_array_interface__"):
owner = owner_type.from_device_memory(frame, exposed=False)
else:
Expand Down
4 changes: 1 addition & 3 deletions python/cudf/cudf/core/buffer/spillable_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import collections.abc
import pickle
import time
import weakref
from threading import RLock
Expand Down Expand Up @@ -415,8 +414,7 @@ def serialize(self) -> tuple[dict, list]:
header: dict[str, Any] = {}
frames: list[Buffer | memoryview]
with self._owner.lock:
header["type-serialized"] = pickle.dumps(self.__class__)
header["owner-type-serialized"] = pickle.dumps(type(self._owner))
header["owner-type-serialized-name"] = type(self._owner).__name__
header["frame_count"] = 1
if self.is_spilled:
frames = [self.memoryview()]
Expand Down
23 changes: 11 additions & 12 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import pickle
from collections import abc
from collections.abc import MutableSequence, Sequence
from functools import cached_property
Expand Down Expand Up @@ -1224,28 +1223,27 @@ def serialize(self) -> tuple[dict, list]:

header: dict[Any, Any] = {}
frames = []
header["type-serialized"] = pickle.dumps(type(self))
try:
dtype, dtype_frames = self.dtype.serialize()
dtype, dtype_frames = self.dtype.device_serialize()
header["dtype"] = dtype
frames.extend(dtype_frames)
header["dtype-is-cudf-serialized"] = True
except AttributeError:
header["dtype"] = pickle.dumps(self.dtype)
header["dtype"] = self.dtype.str
header["dtype-is-cudf-serialized"] = False

if self.data is not None:
data_header, data_frames = self.data.serialize()
data_header, data_frames = self.data.device_serialize()
header["data"] = data_header
frames.extend(data_frames)

if self.mask is not None:
mask_header, mask_frames = self.mask.serialize()
mask_header, mask_frames = self.mask.device_serialize()
header["mask"] = mask_header
frames.extend(mask_frames)
if self.children:
child_headers, child_frames = zip(
*(c.serialize() for c in self.children)
*(c.device_serialize() for c in self.children)
)
header["subheaders"] = list(child_headers)
frames.extend(chain(*child_frames))
Expand All @@ -1257,8 +1255,7 @@ def serialize(self) -> tuple[dict, list]:
def deserialize(cls, header: dict, frames: list) -> ColumnBase:
def unpack(header, frames) -> tuple[Any, list]:
count = header["frame_count"]
klass = pickle.loads(header["type-serialized"])
obj = klass.deserialize(header, frames[:count])
obj = cls.device_deserialize(header, frames[:count])
return obj, frames[count:]

assert header["frame_count"] == len(frames), (
Expand All @@ -1268,7 +1265,7 @@ def unpack(header, frames) -> tuple[Any, list]:
if header["dtype-is-cudf-serialized"]:
dtype, frames = unpack(header["dtype"], frames)
else:
dtype = pickle.loads(header["dtype"])
dtype = np.dtype(header["dtype"])
if "data" in header:
data, frames = unpack(header["data"], frames)
else:
Expand Down Expand Up @@ -2219,7 +2216,9 @@ def serialize_columns(columns: list[ColumnBase]) -> tuple[list[dict], list]:
frames = []

if len(columns) > 0:
header_columns = [c.serialize() for c in columns]
header_columns: list[tuple[dict, list]] = [
c.device_serialize() for c in columns
]
headers, column_frames = zip(*header_columns)
for f in column_frames:
frames.extend(f)
Expand All @@ -2236,7 +2235,7 @@ def deserialize_columns(headers: list[dict], frames: list) -> list[ColumnBase]:

for meta in headers:
col_frame_count = meta["frame_count"]
col_typ = pickle.loads(meta["type-serialized"])
col_typ = Serializable._name_type_map[meta["type-serialized-name"]]
colobj = col_typ.deserialize(meta, frames[:col_frame_count])
columns.append(colobj)
# Advance frames
Expand Down
9 changes: 3 additions & 6 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import itertools
import numbers
import os
import pickle
import re
import sys
import textwrap
Expand Down Expand Up @@ -44,7 +43,6 @@
)
from cudf.core import column, df_protocol, indexing_utils, reshape
from cudf.core._compat import PANDAS_LT_300
from cudf.core.abc import Serializable
from cudf.core.buffer import acquire_spill_lock
from cudf.core.column import (
CategoricalColumn,
Expand Down Expand Up @@ -582,7 +580,7 @@ class _DataFrameiAtIndexer(_DataFrameIlocIndexer):
pass


class DataFrame(IndexedFrame, Serializable, GetAttrGetItemMixin):
class DataFrame(IndexedFrame, GetAttrGetItemMixin):
"""
A GPU Dataframe object.
Expand Down Expand Up @@ -1184,7 +1182,7 @@ def _constructor_expanddim(self):
def serialize(self):
header, frames = super().serialize()

header["index"], index_frames = self.index.serialize()
header["index"], index_frames = self.index.device_serialize()
header["index_frame_count"] = len(index_frames)
# For backwards compatibility with older versions of cuDF, index
# columns are placed before data columns.
Expand All @@ -1199,8 +1197,7 @@ def deserialize(cls, header, frames):
header, frames[header["index_frame_count"] :]
)

idx_typ = pickle.loads(header["index"]["type-serialized"])
index = idx_typ.deserialize(header["index"], frames[:index_nframes])
index = cls.device_deserialize(header["index"], frames[:index_nframes])
obj.index = index

return obj
Expand Down
Loading

0 comments on commit 2f5bf76

Please sign in to comment.