From 102d30add77e9a618d38e8eba6fa1f8472e7c10c Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 18 Jun 2024 07:41:49 -1000 Subject: [PATCH 1/7] Remove `override_dtypes` and `include_index` from `Frame._copy_type_metadata` (#16043) * `override_dtypes` logic was only needed for `.explode`. I think it's appropriate to make it a postprocessing step in that function * `include_index` logic was able to be transferred more simply to `IndexedFrame._from_columns_like_self` Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/16043 --- python/cudf/cudf/core/_base_index.py | 4 +- python/cudf/cudf/core/dataframe.py | 6 -- python/cudf/cudf/core/frame.py | 26 +----- python/cudf/cudf/core/index.py | 25 ++---- python/cudf/cudf/core/indexed_frame.py | 101 +++++++---------------- python/cudf/cudf/core/multiindex.py | 6 +- python/cudf/cudf/tests/test_dataframe.py | 18 ++++ 7 files changed, 63 insertions(+), 123 deletions(-) diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index e71e45e410e..ad73cd57f7d 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -282,9 +282,7 @@ def __contains__(self, item): hash(item) return item in self._values - def _copy_type_metadata( - self, other: Self, *, override_dtypes=None - ) -> Self: + def _copy_type_metadata(self: Self, other: Self) -> Self: raise NotImplementedError def get_level_values(self, level): diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 065b13561ab..76bb9d2a8ed 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -7361,9 +7361,6 @@ def explode(self, column, ignore_index=False): 3 4 44 3 5 44 """ - if column not in self._column_names: - raise KeyError(column) - return super()._explode(column, ignore_index) def pct_change( @@ -7511,14 +7508,11 @@ def _from_columns_like_self( columns: list[ColumnBase], column_names: abc.Iterable[str] | None = None, index_names: list[str] | None = None, - *, - override_dtypes: abc.Iterable[Dtype | None] | None = None, ) -> DataFrame: result = super()._from_columns_like_self( columns, column_names, index_names, - override_dtypes=override_dtypes, ) result._set_columns_like(self._data) return result diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index c58a0161ee0..38bff3946d6 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -3,7 +3,6 @@ from __future__ import annotations import copy -import itertools import operator import pickle import warnings @@ -80,7 +79,7 @@ def _columns(self) -> tuple[ColumnBase, ...]: return self._data.columns @property - def _dtypes(self) -> abc.Iterator: + def _dtypes(self) -> abc.Iterable: return zip(self._data.names, (col.dtype for col in self._data.columns)) @property @@ -145,8 +144,6 @@ def _from_columns_like_self( self, columns: list[ColumnBase], column_names: abc.Iterable[str] | None = None, - *, - override_dtypes: abc.Iterable[Dtype | None] | None = None, ): """Construct a Frame from a list of columns with metadata from self. @@ -156,7 +153,7 @@ def _from_columns_like_self( column_names = self._column_names data = dict(zip(column_names, columns)) frame = self.__class__._from_data(data) - return frame._copy_type_metadata(self, override_dtypes=override_dtypes) + return frame._copy_type_metadata(self) @_cudf_nvtx_annotate def _mimic_inplace( @@ -1032,29 +1029,14 @@ def _positions_from_column_names(self, column_names) -> list[int]: ] @_cudf_nvtx_annotate - def _copy_type_metadata( - self, - other: Self, - *, - override_dtypes: abc.Iterable[Dtype | None] | None = None, - ) -> Self: + def _copy_type_metadata(self: Self, other: Self) -> Self: """ Copy type metadata from each column of `other` to the corresponding column of `self`. - If override_dtypes is provided, any non-None entry - will be used in preference to the relevant column of other to - provide the new dtype. - See `ColumnBase._with_type_metadata` for more information. """ - if override_dtypes is None: - override_dtypes = itertools.repeat(None) - dtypes = ( - dtype if dtype is not None else col.dtype - for (dtype, col) in zip(override_dtypes, other._data.values()) - ) - for (name, col), dtype in zip(self._data.items(), dtypes): + for (name, col), (_, dtype) in zip(self._data.items(), other._dtypes): self._data.set_by_label( name, col._with_type_metadata(dtype), validate=False ) diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index df21d392311..1c5d05d2d87 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -62,7 +62,7 @@ from cudf.utils.utils import _warn_no_dask_cudf, search_range if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Iterable class IndexMeta(type): @@ -232,9 +232,7 @@ def __init__( raise ValueError("Step must not be zero.") from err raise - def _copy_type_metadata( - self, other: RangeIndex, *, override_dtypes=None - ) -> Self: + def _copy_type_metadata(self: Self, other: Self) -> Self: # There is no metadata to be copied for RangeIndex since it does not # have an underlying column. return self @@ -485,6 +483,10 @@ def dtype(self): dtype = np.dtype(np.int64) return _maybe_convert_to_default_type(dtype) + @property + def _dtypes(self) -> Iterable: + return [(self.name, self.dtype)] + @_cudf_nvtx_annotate def to_pandas( self, *, nullable: bool = False, arrow_type: bool = False @@ -1115,15 +1117,6 @@ def _binaryop( return ret.values return ret - # Override just to make mypy happy. - @_cudf_nvtx_annotate - def _copy_type_metadata( - self, other: Self, *, override_dtypes=None - ) -> Self: - return super()._copy_type_metadata( - other, override_dtypes=override_dtypes - ) - @property # type: ignore @_cudf_nvtx_annotate def _values(self): @@ -1769,10 +1762,8 @@ def __init__( raise ValueError("No unique frequency found") @_cudf_nvtx_annotate - def _copy_type_metadata( - self: DatetimeIndex, other: DatetimeIndex, *, override_dtypes=None - ) -> Index: - super()._copy_type_metadata(other, override_dtypes=override_dtypes) + def _copy_type_metadata(self: Self, other: Self) -> Self: + super()._copy_type_metadata(other) self._freq = _validate_freq(other._freq) return self diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 06da62306e8..f1b74adefed 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -313,17 +313,11 @@ def _from_columns_like_self( columns: list[ColumnBase], column_names: abc.Iterable[str] | None = None, index_names: list[str] | None = None, - *, - override_dtypes: abc.Iterable[Dtype | None] | None = None, ) -> Self: """Construct a `Frame` from a list of columns with metadata from self. If `index_names` is set, the first `len(index_names)` columns are used to construct the index of the frame. - - If override_dtypes is provided then any non-None entry will be - used for the dtype of the matching column in preference to the - dtype of the column in self. """ if column_names is None: column_names = self._column_names @@ -337,22 +331,24 @@ def _from_columns_like_self( index = _index_from_data( dict(enumerate(columns[:n_index_columns])) ) + index = index._copy_type_metadata(self.index) + # TODO: Should this if statement be handled in Index._copy_type_metadata? + if ( + isinstance(self.index, cudf.CategoricalIndex) + and not isinstance(index, cudf.CategoricalIndex) + ) or ( + isinstance(self.index, cudf.MultiIndex) + and not isinstance(index, cudf.MultiIndex) + ): + index = type(self.index)._from_data(index._data) if isinstance(index, cudf.MultiIndex): index.names = index_names else: index.name = index_names[0] data = dict(zip(column_names, data_columns)) - frame = self.__class__._from_data(data) - - if index is not None: - # TODO: triage why using the setter here breaks dask_cuda.ProxifyHostFile - frame._index = index - return frame._copy_type_metadata( - self, - include_index=bool(index_names), - override_dtypes=override_dtypes, - ) + frame = type(self)._from_data(data, index) + return frame._copy_type_metadata(self) def __round__(self, digits=0): # Shouldn't be added to BinaryOperand @@ -1913,45 +1909,6 @@ def nans_to_nulls(self): self._data._from_columns_like_self(result) ) - def _copy_type_metadata( - self, - other: Self, - include_index: bool = True, - *, - override_dtypes: abc.Iterable[Dtype | None] | None = None, - ) -> Self: - """ - Copy type metadata from each column of `other` to the corresponding - column of `self`. - See `ColumnBase._with_type_metadata` for more information. - """ - super()._copy_type_metadata(other, override_dtypes=override_dtypes) - if ( - include_index - and self.index is not None - and other.index is not None - ): - self.index._copy_type_metadata(other.index) - # When other.index is a CategoricalIndex, the current index - # will be a NumericalIndex with an underlying CategoricalColumn - # (the above _copy_type_metadata call will have converted the - # column). Calling cudf.Index on that column generates the - # appropriate index. - if isinstance( - other.index, cudf.core.index.CategoricalIndex - ) and not isinstance(self.index, cudf.core.index.CategoricalIndex): - self.index = cudf.Index( - cast("cudf.Index", self.index)._column, - name=self.index.name, - ) - elif isinstance(other.index, cudf.MultiIndex) and not isinstance( - self.index, cudf.MultiIndex - ): - self.index = cudf.MultiIndex._from_data( - self.index._data, name=self.index.name - ) - return self - @_cudf_nvtx_annotate def interpolate( self, @@ -5195,36 +5152,36 @@ def _explode(self, explode_column: Any, ignore_index: bool): # duplicated. If ignore_index is set, the original index is not # exploded and will be replaced with a `RangeIndex`. if not isinstance(self._data[explode_column].dtype, ListDtype): - data = self._data.copy(deep=True) - idx = None if ignore_index else self.index.copy(deep=True) - return self.__class__._from_data(data, index=idx) + result = self.copy() + if ignore_index: + result.index = RangeIndex(len(result)) + return result column_index = self._column_names.index(explode_column) - if not ignore_index and self.index is not None: - index_offset = self.index.nlevels + if not ignore_index: + idx_cols = self.index._columns else: - index_offset = 0 + idx_cols = () exploded = libcudf.lists.explode_outer( - [ - *(self.index._data.columns if not ignore_index else ()), - *self._columns, - ], - column_index + index_offset, + [*idx_cols, *self._columns], + column_index + len(idx_cols), ) # We must copy inner datatype of the exploded list column to # maintain struct dtype key names - exploded_dtype = cast( + element_type = cast( ListDtype, self._columns[column_index].dtype ).element_type + exploded = [ + column._with_type_metadata(element_type) + if i == column_index + else column + for i, column in enumerate(exploded, start=-len(idx_cols)) + ] return self._from_columns_like_self( exploded, self._column_names, - self._index_names if not ignore_index else None, - override_dtypes=( - exploded_dtype if i == column_index else None - for i in range(len(self._columns)) - ), + self.index.names if not ignore_index else None, ) @_cudf_nvtx_annotate diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index 832cc003d2e..a01242d957d 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -37,6 +37,8 @@ if TYPE_CHECKING: from collections.abc import Generator + from typing_extensions import Self + from cudf._typing import DataFrameOrSeries @@ -2100,9 +2102,7 @@ def _intersection(self, other, sort=None): return midx @_cudf_nvtx_annotate - def _copy_type_metadata( - self: MultiIndex, other: MultiIndex, *, override_dtypes=None - ) -> MultiIndex: + def _copy_type_metadata(self: Self, other: Self) -> Self: res = super()._copy_type_metadata(other) if isinstance(other, MultiIndex): res._names = other._names diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 649821b9b7c..3661e13bd39 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -9466,6 +9466,24 @@ def test_explode(data, labels, ignore_index, p_index, label_to_explode): assert_eq(expect, got, check_dtype=False) +def test_explode_preserve_categorical(): + gdf = cudf.DataFrame( + { + "A": [[1, 2], None, [2, 3]], + "B": cudf.Series([0, 1, 2], dtype="category"), + } + ) + result = gdf.explode("A") + expected = cudf.DataFrame( + { + "A": [1, 2, None, 2, 3], + "B": cudf.Series([0, 0, 1, 2, 2], dtype="category"), + } + ) + expected.index = cudf.Index([0, 0, 1, 2, 2]) + assert_eq(result, expected) + + @pytest.mark.parametrize( "df,ascending,expected", [ From 231cb716baf44b64e0284e23ae9666500de7d593 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 18 Jun 2024 11:50:46 -0700 Subject: [PATCH 2/7] Fix a size overflow bug in hash groupby (#16053) This PR fixes a size overflow bug discovered by @matal-nvidia. It converts the groupby problem size to `int64_t` so it won't overflow if larger than `INT_MAX / 2` with 50% hash table occupancy. Unit tests for this scenario will saturate device memory and take longer than necessary, making them likely not worth adding. Authors: - Yunsong Wang (https://github.com/PointKernel) Approvers: - Bradley Dice (https://github.com/bdice) - Matthew Roeschke (https://github.com/mroeschke) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/16053 --- cpp/src/groupby/hash/groupby.cu | 3 ++- java/src/test/java/ai/rapids/cudf/TableTest.java | 3 ++- python/cudf/cudf/core/groupby/groupby.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cpp/src/groupby/hash/groupby.cu b/cpp/src/groupby/hash/groupby.cu index 0ec293ae3f0..5fe4a5eb30f 100644 --- a/cpp/src/groupby/hash/groupby.cu +++ b/cpp/src/groupby/hash/groupby.cu @@ -553,7 +553,8 @@ std::unique_ptr groupby(table_view const& keys, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - auto const num_keys = keys.num_rows(); + // convert to int64_t to avoid potential overflow with large `keys` + auto const num_keys = static_cast(keys.num_rows()); auto const null_keys_are_equal = null_equality::EQUAL; auto const has_null = nullate::DYNAMIC{cudf::has_nested_nulls(keys)}; diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index dc6eb55fc6a..050bcbb268f 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -7838,11 +7838,12 @@ void testSumWithStrings() { .build(); Table result = t.groupBy(0).aggregate( GroupByAggregation.sum().onColumn(1)); + Table sorted = result.orderBy(OrderByArg.asc(0)); Table expected = new Table.TestBuilder() .column("1-URGENT", "3-MEDIUM") .column(5289L + 5303L, 5203L + 5206L) .build()) { - assertTablesAreEqual(expected, result); + assertTablesAreEqual(expected, sorted); } } diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index d08268eea3a..77b54a583d3 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -1308,7 +1308,7 @@ def pipe(self, func, *args, **kwargs): To get the difference between each groups maximum and minimum value in one pass, you can do - >>> df.groupby('A').pipe(lambda x: x.max() - x.min()) + >>> df.groupby('A', sort=True).pipe(lambda x: x.max() - x.min()) B A a 2 From fc4b3d3ecbf95ee9afdcd509554bbeb5367a3059 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 18 Jun 2024 09:02:05 -1000 Subject: [PATCH 3/7] Reduce deep copies in Index ops (#16054) 1. Changed `Index.rename(inplace=False)` to shallow copy which matches pandas behavior. Let me know if there's a reason why we should deep copy here. 2. Made `RangeIndex.unique` return a shallow copy like pandas. 3. Made `Index.dropna` with no NA's shallow copy like pandas. Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/16054 --- python/cudf/cudf/core/_base_index.py | 6 +++--- python/cudf/cudf/core/index.py | 5 +++-- python/cudf/cudf/tests/test_index.py | 25 +++++++++++++++++++++++-- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index ad73cd57f7d..caf07b286cd 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -1120,7 +1120,7 @@ def difference(self, other, sort=None): res_name = _get_result_name(self.name, other.name) if is_mixed_with_object_dtype(self, other) or len(other) == 0: - difference = self.copy().unique() + difference = self.unique() difference.name = res_name if sort is True: return difference.sort_values() @@ -1744,7 +1744,7 @@ def rename(self, name, inplace=False): self.name = name return None else: - out = self.copy(deep=True) + out = self.copy(deep=False) out.name = name return out @@ -2068,7 +2068,7 @@ def dropna(self, how="any"): raise ValueError(f"{how=} must be 'any' or 'all'") try: if not self.hasnans: - return self.copy() + return self.copy(deep=False) except NotImplementedError: pass # This is to be consistent with IndexedFrame.dropna to handle nans diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 1c5d05d2d87..71658695b80 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -528,7 +528,7 @@ def memory_usage(self, deep: bool = False) -> int: def unique(self) -> Self: # RangeIndex always has unique values - return self + return self.copy() @_cudf_nvtx_annotate def __mul__(self, other): @@ -3197,7 +3197,8 @@ def _get_nearest_indexer( ) right_indexer = _get_indexer_basic( index=index, - positions=positions.copy(deep=True), + # positions no longer used so don't copy + positions=positions, method="backfill", target_col=target_col, tolerance=tolerance, diff --git a/python/cudf/cudf/tests/test_index.py b/python/cudf/cudf/tests/test_index.py index 3d6c71ebc1b..a59836df5ba 100644 --- a/python/cudf/cudf/tests/test_index.py +++ b/python/cudf/cudf/tests/test_index.py @@ -252,10 +252,10 @@ def test_index_rename_inplace(): pds = pd.Index([1, 2, 3], name="asdf") gds = Index(pds) - # inplace=False should yield a deep copy + # inplace=False should yield a shallow copy gds_renamed_deep = gds.rename("new_name", inplace=False) - assert gds_renamed_deep._values.data_ptr != gds._values.data_ptr + assert gds_renamed_deep._values.data_ptr == gds._values.data_ptr # inplace=True returns none expected_ptr = gds._values.data_ptr @@ -3214,6 +3214,27 @@ def test_rangeindex_dropna(): assert_eq(result, expected) +def test_rangeindex_unique_shallow_copy(): + ri_pandas = pd.RangeIndex(1) + result = ri_pandas.unique() + assert result is not ri_pandas + + ri_cudf = cudf.RangeIndex(1) + result = ri_cudf.unique() + assert result is not ri_cudf + assert_eq(result, ri_cudf) + + +def test_rename_shallow_copy(): + idx = pd.Index([1]) + result = idx.rename("a") + assert idx.to_numpy(copy=False) is result.to_numpy(copy=False) + + idx = cudf.Index([1]) + result = idx.rename("a") + assert idx._column is result._column + + @pytest.mark.parametrize("data", [range(2), [10, 11, 12]]) def test_index_contains_hashable(data): gidx = cudf.Index(data) From 2ddbe2a0665066fe8a5021b23c9268ce91ce67a2 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 18 Jun 2024 20:06:04 +0100 Subject: [PATCH 4/7] Test behaviour of containers (#15994) This ensures we cover all implementation. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/15994 --- .../cudf_polars/containers/column.py | 2 +- .../cudf_polars/tests/containers/__init__.py | 6 ++ .../tests/containers/test_column.py | 70 ++++++++++++++ .../tests/containers/test_dataframe.py | 92 +++++++++++++++++++ 4 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 python/cudf_polars/tests/containers/__init__.py create mode 100644 python/cudf_polars/tests/containers/test_column.py create mode 100644 python/cudf_polars/tests/containers/test_dataframe.py diff --git a/python/cudf_polars/cudf_polars/containers/column.py b/python/cudf_polars/cudf_polars/containers/column.py index 156dd395d64..28685f0c4ed 100644 --- a/python/cudf_polars/cudf_polars/containers/column.py +++ b/python/cudf_polars/cudf_polars/containers/column.py @@ -130,7 +130,7 @@ def copy(self) -> Self: def mask_nans(self) -> Self: """Return a copy of self with nans masked out.""" if self.nan_count > 0: - raise NotImplementedError + raise NotImplementedError("Need to port transform.hpp to pylibcudf") return self.copy() @functools.cached_property diff --git a/python/cudf_polars/tests/containers/__init__.py b/python/cudf_polars/tests/containers/__init__.py new file mode 100644 index 00000000000..4611d642f14 --- /dev/null +++ b/python/cudf_polars/tests/containers/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/python/cudf_polars/tests/containers/test_column.py b/python/cudf_polars/tests/containers/test_column.py new file mode 100644 index 00000000000..3291d8db161 --- /dev/null +++ b/python/cudf_polars/tests/containers/test_column.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pyarrow +import pytest + +import cudf._lib.pylibcudf as plc + +from cudf_polars.containers import Column + + +def test_non_scalar_access_raises(): + column = Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID + ) + ) + with pytest.raises(ValueError): + _ = column.obj_scalar + + +@pytest.mark.parametrize("length", [0, 1]) +def test_length_leq_one_always_sorted(length): + column = Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), length, plc.MaskState.ALL_VALID + ) + ) + assert column.is_sorted == plc.types.Sorted.YES + column.set_sorted( + is_sorted=plc.types.Sorted.NO, + order=plc.types.Order.ASCENDING, + null_order=plc.types.NullOrder.AFTER, + ) + assert column.is_sorted == plc.types.Sorted.YES + + +def test_shallow_copy(): + column = Column( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID + ) + ) + copy = column.copy() + copy = copy.set_sorted( + is_sorted=plc.types.Sorted.YES, + order=plc.types.Order.ASCENDING, + null_order=plc.types.NullOrder.AFTER, + ) + assert column.is_sorted == plc.types.Sorted.NO + assert copy.is_sorted == plc.types.Sorted.YES + + +@pytest.mark.parametrize("typeid", [plc.TypeId.INT8, plc.TypeId.FLOAT32]) +def test_mask_nans(typeid): + dtype = plc.DataType(typeid) + values = pyarrow.array([0, 0, 0], type=plc.interop.to_arrow(dtype)) + column = Column(plc.interop.from_arrow(values)) + masked = column.mask_nans() + assert column.obj is masked.obj + + +def test_mask_nans_float_with_nan_notimplemented(): + dtype = plc.DataType(plc.TypeId.FLOAT32) + values = pyarrow.array([0, 0, float("nan")], type=plc.interop.to_arrow(dtype)) + column = Column(plc.interop.from_arrow(values)) + with pytest.raises(NotImplementedError): + _ = column.mask_nans() diff --git a/python/cudf_polars/tests/containers/test_dataframe.py b/python/cudf_polars/tests/containers/test_dataframe.py new file mode 100644 index 00000000000..2e385e39eef --- /dev/null +++ b/python/cudf_polars/tests/containers/test_dataframe.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import cudf._lib.pylibcudf as plc + +from cudf_polars.containers import DataFrame, NamedColumn + + +def test_select_missing_raises(): + df = DataFrame( + [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID + ), + "a", + ) + ] + ) + with pytest.raises(ValueError): + df.select(["b", "a"]) + + +def test_replace_missing_raises(): + df = DataFrame( + [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID + ), + "a", + ) + ] + ) + replacement = df.columns[0].copy(new_name="b") + with pytest.raises(ValueError): + df.replace_columns(replacement) + + +def test_from_table_wrong_names(): + table = plc.Table( + [ + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 1, plc.MaskState.ALL_VALID + ) + ] + ) + with pytest.raises(ValueError): + DataFrame.from_table(table, ["a", "b"]) + + +def test_sorted_like_raises_mismatching_names(): + df = DataFrame( + [ + NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID + ), + "a", + ) + ] + ) + like = df.copy().rename_columns({"a": "b"}) + with pytest.raises(ValueError): + df.sorted_like(like) + + +def test_shallow_copy(): + column = NamedColumn( + plc.column_factories.make_numeric_column( + plc.DataType(plc.TypeId.INT8), 2, plc.MaskState.ALL_VALID + ), + "a", + ) + column.set_sorted( + is_sorted=plc.types.Sorted.YES, + order=plc.types.Order.ASCENDING, + null_order=plc.types.NullOrder.AFTER, + ) + df = DataFrame([column]) + copy = df.copy() + copy.columns[0].set_sorted( + is_sorted=plc.types.Sorted.NO, + order=plc.types.Order.ASCENDING, + null_order=plc.types.NullOrder.AFTER, + ) + assert df.columns[0].is_sorted == plc.types.Sorted.YES + assert copy.columns[0].is_sorted == plc.types.Sorted.NO From 9bc794aa355c8e4c42fbc611fe9d496c20a4db90 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 18 Jun 2024 20:06:45 +0100 Subject: [PATCH 5/7] Coverage of binops where one or both operands are a scalar (#15998) Just needed the tests here. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/15998 --- .../tests/expressions/test_numeric_binops.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/cudf_polars/tests/expressions/test_numeric_binops.py b/python/cudf_polars/tests/expressions/test_numeric_binops.py index 7eefc59d927..b6bcd0026fa 100644 --- a/python/cudf_polars/tests/expressions/test_numeric_binops.py +++ b/python/cudf_polars/tests/expressions/test_numeric_binops.py @@ -99,3 +99,15 @@ def test_numeric_binop(df, binop): q = df.select(binop(left, right)) assert_gpu_result_equal(q) + + +@pytest.mark.parametrize("left_scalar", [False, True]) +@pytest.mark.parametrize("right_scalar", [False, True]) +def test_binop_with_scalar(left_scalar, right_scalar): + df = pl.LazyFrame({"a": [1, 2, 3], "b": [5, 6, 7]}) + + lop = pl.lit(2) if left_scalar else pl.col("a") + rop = pl.lit(6) if right_scalar else pl.col("b") + q = df.select(lop / rop) + + assert_gpu_result_equal(q) From c83e5b3fdd7f9fe8a08c4f6874fbf847bba70c53 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Tue, 18 Jun 2024 16:22:44 -0400 Subject: [PATCH 6/7] Fix JSON multi-source reading when total source size exceeds `INT_MAX` bytes (#15930) Fixes #15917. - [X] Batched read and parse operations - [x] Fail when any single source file exceeds `INT_MAX` bytes. This case will be handled with a chunked reader later. Authors: - Shruti Shivakumar (https://github.com/shrshi) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - Karthikeyan (https://github.com/karthikeyann) URL: https://github.com/rapidsai/cudf/pull/15930 --- cpp/include/cudf/io/types.hpp | 13 +++ cpp/src/io/json/read_json.cu | 121 +++++++++++++++++++++---- cpp/tests/CMakeLists.txt | 1 + cpp/tests/large_strings/json_tests.cpp | 58 ++++++++++++ 4 files changed, 177 insertions(+), 16 deletions(-) create mode 100644 cpp/tests/large_strings/json_tests.cpp diff --git a/cpp/include/cudf/io/types.hpp b/cpp/include/cudf/io/types.hpp index 0dab1c606de..0c96268f6c7 100644 --- a/cpp/include/cudf/io/types.hpp +++ b/cpp/include/cudf/io/types.hpp @@ -256,6 +256,19 @@ struct column_name_info { } column_name_info() = default; + + /** + * @brief Compares two column name info structs for equality + * + * @param rhs column name info struct to compare against + * @return boolean indicating if this and rhs are equal + */ + bool operator==(column_name_info const& rhs) const + { + return ((name == rhs.name) && (is_nullable == rhs.is_nullable) && + (is_binary == rhs.is_binary) && (type_length == rhs.type_length) && + (children == rhs.children)); + }; }; /** diff --git a/cpp/src/io/json/read_json.cu b/cpp/src/io/json/read_json.cu index e999be8f83a..74001e5e01a 100644 --- a/cpp/src/io/json/read_json.cu +++ b/cpp/src/io/json/read_json.cu @@ -18,7 +18,9 @@ #include "io/json/nested_json.hpp" #include "read_json.hpp" +#include #include +#include #include #include #include @@ -76,7 +78,7 @@ device_span ingest_raw_input(device_span buffer, auto constexpr num_delimiter_chars = 1; if (compression == compression_type::NONE) { - std::vector delimiter_map{}; + std::vector delimiter_map{}; std::vector prefsum_source_sizes(sources.size()); std::vector> h_buffers; delimiter_map.reserve(sources.size()); @@ -84,7 +86,7 @@ device_span ingest_raw_input(device_span buffer, std::transform_inclusive_scan(sources.begin(), sources.end(), prefsum_source_sizes.begin(), - std::plus{}, + std::plus{}, [](std::unique_ptr const& s) { return s->size(); }); auto upper = std::upper_bound(prefsum_source_sizes.begin(), prefsum_source_sizes.end(), range_offset); @@ -259,6 +261,33 @@ datasource::owning_buffer> get_record_range_raw_input( readbufspan.size() - first_delim_pos - shift_for_nonzero_offset); } +table_with_metadata read_batch(host_span> sources, + json_reader_options const& reader_opts, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + datasource::owning_buffer> bufview = + get_record_range_raw_input(sources, reader_opts, stream); + + // If input JSON buffer has single quotes and option to normalize single quotes is enabled, + // invoke pre-processing FST + if (reader_opts.is_enabled_normalize_single_quotes()) { + normalize_single_quotes(bufview, stream, rmm::mr::get_current_device_resource()); + } + + // If input JSON buffer has unquoted spaces and tabs and option to normalize whitespaces is + // enabled, invoke pre-processing FST + if (reader_opts.is_enabled_normalize_whitespace()) { + normalize_whitespace(bufview, stream, rmm::mr::get_current_device_resource()); + } + + auto buffer = + cudf::device_span(reinterpret_cast(bufview.data()), bufview.size()); + stream.synchronize(); + return device_parse_nested_json(buffer, reader_opts, stream, mr); +} + table_with_metadata read_json(host_span> sources, json_reader_options const& reader_opts, rmm::cuda_stream_view stream, @@ -278,25 +307,85 @@ table_with_metadata read_json(host_span> sources, "Multiple inputs are supported only for JSON Lines format"); } - datasource::owning_buffer> bufview = - get_record_range_raw_input(sources, reader_opts, stream); + std::for_each(sources.begin(), sources.end(), [](auto const& source) { + CUDF_EXPECTS(source->size() < std::numeric_limits::max(), + "The size of each source file must be less than INT_MAX bytes"); + }); - // If input JSON buffer has single quotes and option to normalize single quotes is enabled, - // invoke pre-processing FST - if (reader_opts.is_enabled_normalize_single_quotes()) { - normalize_single_quotes(bufview, stream, rmm::mr::get_current_device_resource()); + constexpr size_t batch_size_ub = std::numeric_limits::max(); + size_t const chunk_offset = reader_opts.get_byte_range_offset(); + size_t chunk_size = reader_opts.get_byte_range_size(); + chunk_size = !chunk_size ? sources_size(sources, 0, 0) : chunk_size; + + // Identify the position of starting source file from which to begin batching based on + // byte range offset. If the offset is larger than the sum of all source + // sizes, then start_source is total number of source files i.e. no file is read + size_t const start_source = [&]() { + size_t sum = 0; + for (size_t src_idx = 0; src_idx < sources.size(); ++src_idx) { + if (sum + sources[src_idx]->size() > chunk_offset) return src_idx; + sum += sources[src_idx]->size(); + } + return sources.size(); + }(); + + // Construct batches of source files, with starting position of batches indicated by + // batch_positions. The size of each batch i.e. the sum of sizes of the source files in the batch + // is capped at INT_MAX bytes. + size_t cur_size = 0; + std::vector batch_positions; + std::vector batch_sizes; + batch_positions.push_back(0); + for (size_t i = start_source; i < sources.size(); i++) { + cur_size += sources[i]->size(); + if (cur_size >= batch_size_ub) { + batch_positions.push_back(i); + batch_sizes.push_back(cur_size - sources[i]->size()); + cur_size = sources[i]->size(); + } } + batch_positions.push_back(sources.size()); + batch_sizes.push_back(cur_size); - // If input JSON buffer has unquoted spaces and tabs and option to normalize whitespaces is - // enabled, invoke pre-processing FST - if (reader_opts.is_enabled_normalize_whitespace()) { - normalize_whitespace(bufview, stream, rmm::mr::get_current_device_resource()); + // If there is a single batch, then we can directly return the table without the + // unnecessary concatenate + if (batch_sizes.size() == 1) return read_batch(sources, reader_opts, stream, mr); + + std::vector partial_tables; + json_reader_options batched_reader_opts{reader_opts}; + + // Dispatch individual batches to read_batch and push the resulting table into + // partial_tables array. Note that the reader options need to be updated for each + // batch to adjust byte range offset and byte range size. + for (size_t i = 0; i < batch_sizes.size(); i++) { + batched_reader_opts.set_byte_range_size(std::min(batch_sizes[i], chunk_size)); + partial_tables.emplace_back(read_batch( + host_span>(sources.begin() + batch_positions[i], + batch_positions[i + 1] - batch_positions[i]), + batched_reader_opts, + stream, + rmm::mr::get_current_device_resource())); + if (chunk_size <= batch_sizes[i]) break; + chunk_size -= batch_sizes[i]; + batched_reader_opts.set_byte_range_offset(0); } - auto buffer = - cudf::device_span(reinterpret_cast(bufview.data()), bufview.size()); - stream.synchronize(); - return device_parse_nested_json(buffer, reader_opts, stream, mr); + auto expects_schema_equality = + std::all_of(partial_tables.begin() + 1, + partial_tables.end(), + [> = partial_tables[0].metadata.schema_info](auto& ptbl) { + return ptbl.metadata.schema_info == gt; + }); + CUDF_EXPECTS(expects_schema_equality, + "Mismatch in JSON schema across batches in multi-source multi-batch reading"); + + auto partial_table_views = std::vector(partial_tables.size()); + std::transform(partial_tables.begin(), + partial_tables.end(), + partial_table_views.begin(), + [](auto const& table) { return table.tbl->view(); }); + return table_with_metadata{cudf::concatenate(partial_table_views, stream, mr), + {partial_tables[0].metadata.schema_info}}; } } // namespace cudf::io::json::detail diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 329edbe4d36..eda470d2309 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -570,6 +570,7 @@ ConfigureTest( LARGE_STRINGS_TEST large_strings/concatenate_tests.cpp large_strings/case_tests.cpp + large_strings/json_tests.cpp large_strings/large_strings_fixture.cpp large_strings/merge_tests.cpp large_strings/parquet_tests.cpp diff --git a/cpp/tests/large_strings/json_tests.cpp b/cpp/tests/large_strings/json_tests.cpp new file mode 100644 index 00000000000..bf16d131ba7 --- /dev/null +++ b/cpp/tests/large_strings/json_tests.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "large_strings_fixture.hpp" + +#include +#include + +struct JsonLargeReaderTest : public cudf::test::StringsLargeTest {}; + +TEST_F(JsonLargeReaderTest, MultiBatch) +{ + std::string json_string = R"( + { "a": { "y" : 6}, "b" : [1, 2, 3], "c": 11 } + { "a": { "y" : 6}, "b" : [4, 5 ], "c": 12 } + { "a": { "y" : 6}, "b" : [6 ], "c": 13 } + { "a": { "y" : 6}, "b" : [7 ], "c": 14 })"; + constexpr size_t expected_file_size = std::numeric_limits::max() / 2; + std::size_t const log_repetitions = + static_cast(std::ceil(std::log2(expected_file_size / json_string.size()))); + + json_string.reserve(json_string.size() * (1UL << log_repetitions)); + std::size_t numrows = 4; + for (std::size_t i = 0; i < log_repetitions; i++) { + json_string += json_string; + numrows <<= 1; + } + + constexpr int num_sources = 2; + std::vector> hostbufs( + num_sources, cudf::host_span(json_string.data(), json_string.size())); + + // Initialize parsing options (reading json lines) + cudf::io::json_reader_options json_lines_options = + cudf::io::json_reader_options::builder( + cudf::io::source_info{ + cudf::host_span>(hostbufs.data(), hostbufs.size())}) + .lines(true) + .compression(cudf::io::compression_type::NONE) + .recovery_mode(cudf::io::json_recovery_mode_t::FAIL); + + // Read full test data via existing, nested JSON lines reader + cudf::io::table_with_metadata current_reader_table = cudf::io::read_json(json_lines_options); + ASSERT_EQ(current_reader_table.tbl->num_rows(), numrows * num_sources); +} From f536e3017205be8b09f3dc2cfd448dc9c5a94d5d Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 19 Jun 2024 16:50:48 +0100 Subject: [PATCH 7/7] Add basic tests of dataframe scan (#16003) Also assert that unsupported file scan operations raise. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - https://github.com/brandon-b-miller URL: https://github.com/rapidsai/cudf/pull/16003 --- python/cudf_polars/cudf_polars/dsl/ir.py | 4 +- .../cudf_polars/testing/asserts.py | 34 ++++++++++++++- python/cudf_polars/docs/overview.md | 18 ++++++++ .../cudf_polars/tests/test_dataframescan.py | 43 +++++++++++++++++++ python/cudf_polars/tests/test_scan.py | 13 +++++- python/cudf_polars/tests/testing/__init__.py | 6 +++ .../cudf_polars/tests/testing/test_asserts.py | 35 +++++++++++++++ 7 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 python/cudf_polars/tests/test_dataframescan.py create mode 100644 python/cudf_polars/tests/testing/__init__.py create mode 100644 python/cudf_polars/tests/testing/test_asserts.py diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 83957e4286d..3ccefac6b0a 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -196,7 +196,9 @@ def __post_init__(self) -> None: if self.file_options.n_rows is not None: raise NotImplementedError("row limit in scan") if self.typ not in ("csv", "parquet"): - raise NotImplementedError(f"Unhandled scan type: {self.typ}") + raise NotImplementedError( + f"Unhandled scan type: {self.typ}" + ) # pragma: no cover; polars raises on the rust side for now def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" diff --git a/python/cudf_polars/cudf_polars/testing/asserts.py b/python/cudf_polars/cudf_polars/testing/asserts.py index 3edaa427432..a9a4ae5f0a6 100644 --- a/python/cudf_polars/cudf_polars/testing/asserts.py +++ b/python/cudf_polars/cudf_polars/testing/asserts.py @@ -11,6 +11,7 @@ from polars.testing.asserts import assert_frame_equal from cudf_polars.callback import execute_with_cudf +from cudf_polars.dsl.translate import translate_ir if TYPE_CHECKING: from collections.abc import Mapping @@ -19,7 +20,7 @@ from cudf_polars.typing import OptimizationArgs -__all__: list[str] = ["assert_gpu_result_equal"] +__all__: list[str] = ["assert_gpu_result_equal", "assert_ir_translation_raises"] def assert_gpu_result_equal( @@ -84,3 +85,34 @@ def assert_gpu_result_equal( atol=atol, categorical_as_str=categorical_as_str, ) + + +def assert_ir_translation_raises(q: pl.LazyFrame, *exceptions: type[Exception]) -> None: + """ + Assert that translation of a query raises an exception. + + Parameters + ---------- + q + Query to translate. + exceptions + Exceptions that one expects might be raised. + + Returns + ------- + None + If translation successfully raised the specified exceptions. + + Raises + ------ + AssertionError + If the specified exceptions were not raised. + """ + try: + _ = translate_ir(q._ldf.visit()) + except exceptions: + return + except Exception as e: + raise AssertionError(f"Translation DID NOT RAISE {exceptions}") from e + else: + raise AssertionError(f"Translation DID NOT RAISE {exceptions}") diff --git a/python/cudf_polars/docs/overview.md b/python/cudf_polars/docs/overview.md index b50d01c26db..874bb849747 100644 --- a/python/cudf_polars/docs/overview.md +++ b/python/cudf_polars/docs/overview.md @@ -224,6 +224,24 @@ def test_whatever(): assert_gpu_result_equal(query) ``` +## Test coverage and asserting failure modes + +Where translation of a query should fail due to the feature being +unsupported we should test this. To assert that _translation_ raises +an exception (usually `NotImplementedError`), use the utility function +`assert_ir_translation_raises`: + +```python +from cudf_polars.testing.asserts import assert_ir_translation_raises + + +def test_whatever(): + unsupported_query = ... + assert_ir_translation_raises(unsupported_query, NotImplementedError) +``` + +This test will fail if translation does not raise. + # Debugging If the callback execution fails during the polars `collect` call, we diff --git a/python/cudf_polars/tests/test_dataframescan.py b/python/cudf_polars/tests/test_dataframescan.py new file mode 100644 index 00000000000..1ffe06ac562 --- /dev/null +++ b/python/cudf_polars/tests/test_dataframescan.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.mark.parametrize( + "subset", + [ + None, + ["a", "c"], + ["b", "c", "d"], + ["b", "d"], + ["b", "c"], + ["c", "e"], + ["d", "e"], + pl.selectors.string(), + pl.selectors.integer(), + ], +) +@pytest.mark.parametrize("predicate_pushdown", [False, True]) +def test_scan_drop_nulls(subset, predicate_pushdown): + df = pl.LazyFrame( + { + "a": [1, 2, 3, 4], + "b": [None, 4, 5, None], + "c": [6, 7, None, None], + "d": [8, None, 9, 10], + "e": [None, None, "A", None], + } + ) + # Drop nulls are pushed into filters + q = df.drop_nulls(subset) + + assert_gpu_result_equal( + q, collect_kwargs={"predicate_pushdown": predicate_pushdown} + ) diff --git a/python/cudf_polars/tests/test_scan.py b/python/cudf_polars/tests/test_scan.py index b2443e357e2..f129cc7ca32 100644 --- a/python/cudf_polars/tests/test_scan.py +++ b/python/cudf_polars/tests/test_scan.py @@ -6,7 +6,10 @@ import polars as pl -from cudf_polars.testing.asserts import assert_gpu_result_equal +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) @pytest.fixture( @@ -86,3 +89,11 @@ def test_scan(df, columns, mask): if columns is not None: q = df.select(*columns) assert_gpu_result_equal(q) + + +def test_scan_unsupported_raises(tmp_path): + df = pl.DataFrame({"a": [1, 2, 3]}) + + df.write_ndjson(tmp_path / "df.json") + q = pl.scan_ndjson(tmp_path / "df.json") + assert_ir_translation_raises(q, NotImplementedError) diff --git a/python/cudf_polars/tests/testing/__init__.py b/python/cudf_polars/tests/testing/__init__.py new file mode 100644 index 00000000000..4611d642f14 --- /dev/null +++ b/python/cudf_polars/tests/testing/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/python/cudf_polars/tests/testing/test_asserts.py b/python/cudf_polars/tests/testing/test_asserts.py new file mode 100644 index 00000000000..5bc2fe1efb7 --- /dev/null +++ b/python/cudf_polars/tests/testing/test_asserts.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) + + +def test_translation_assert_raises(): + df = pl.LazyFrame({"a": [1, 2, 3]}) + + # This should succeed + assert_gpu_result_equal(df) + + with pytest.raises(AssertionError): + # This should fail, because we can translate this query. + assert_ir_translation_raises(df, NotImplementedError) + + class E(Exception): + pass + + unsupported = df.group_by("a").agg(pl.col("a").cum_max().alias("b")) + # Unsupported query should raise NotImplementedError + assert_ir_translation_raises(unsupported, NotImplementedError) + + with pytest.raises(AssertionError): + # This should fail, because we can't translate this query, but it doesn't raise E. + assert_ir_translation_raises(unsupported, E)