diff --git a/docs/notebooks/intro_notebook.ipynb b/docs/notebooks/intro_notebook.ipynb index 6f0730d..7b010f5 100644 --- a/docs/notebooks/intro_notebook.ipynb +++ b/docs/notebooks/intro_notebook.ipynb @@ -32,8 +32,8 @@ "id": "165a7a0918b5a866", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.191161Z", - "start_time": "2024-02-10T02:23:42.178595Z" + "end_time": "2024-02-11T01:14:37.845904Z", + "start_time": "2024-02-11T01:14:37.832532Z" } }, "outputs": [], @@ -74,8 +74,8 @@ "id": "951dbb53d50f21c3", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.217564Z", - "start_time": "2024-02-10T02:23:42.193011Z" + "end_time": "2024-02-11T01:14:37.862593Z", + "start_time": "2024-02-11T01:14:37.846548Z" } }, "outputs": [], @@ -100,8 +100,8 @@ "id": "edd0b2714196c9d0", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.229553Z", - "start_time": "2024-02-10T02:23:42.204915Z" + "end_time": "2024-02-11T01:14:37.872008Z", + "start_time": "2024-02-11T01:14:37.859405Z" } }, "outputs": [], @@ -115,8 +115,8 @@ "id": "3144e1a6c5964ed9", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.229932Z", - "start_time": "2024-02-10T02:23:42.211090Z" + "end_time": "2024-02-11T01:14:37.872916Z", + "start_time": "2024-02-11T01:14:37.861951Z" } }, "outputs": [], @@ -142,8 +142,8 @@ "id": "620ad241f94d3e98", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.231416Z", - "start_time": "2024-02-10T02:23:42.216005Z" + "end_time": "2024-02-11T01:14:37.902167Z", + "start_time": "2024-02-11T01:14:37.866603Z" } }, "outputs": [], @@ -157,8 +157,8 @@ "id": "63e47b51a269305f", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.284198Z", - "start_time": "2024-02-10T02:23:42.221854Z" + "end_time": "2024-02-11T01:14:37.926479Z", + "start_time": "2024-02-11T01:14:37.870840Z" } }, "outputs": [], @@ -172,8 +172,8 @@ "id": "ac15e872786696ef", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.284672Z", - "start_time": "2024-02-10T02:23:42.236972Z" + "end_time": "2024-02-11T01:14:37.926882Z", + "start_time": "2024-02-11T01:14:37.885781Z" } }, "outputs": [], @@ -181,14 +181,29 @@ "packed.ts[\"flux\"]" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce413366fa0a3a43", + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T01:14:37.927143Z", + "start_time": "2024-02-11T01:14:37.888538Z" + } + }, + "outputs": [], + "source": [ + "packed.ts[[\"time\", \"flux\"]]" + ] + }, { "cell_type": "code", "execution_count": null, "id": "dc7dbd52f1a8407a", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.284955Z", - "start_time": "2024-02-10T02:23:42.240108Z" + "end_time": "2024-02-11T01:14:37.927377Z", + "start_time": "2024-02-11T01:14:37.893102Z" } }, "outputs": [], @@ -212,8 +227,8 @@ "id": "996f07b4d16e17e5", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.285205Z", - "start_time": "2024-02-10T02:23:42.242144Z" + "end_time": "2024-02-11T01:14:37.927613Z", + "start_time": "2024-02-11T01:14:37.894893Z" } }, "outputs": [], @@ -229,8 +244,8 @@ "id": "21d5c009ef0990a4", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.324854Z", - "start_time": "2024-02-10T02:23:42.287797Z" + "end_time": "2024-02-11T01:14:37.928728Z", + "start_time": "2024-02-11T01:14:37.900463Z" } }, "outputs": [], @@ -249,8 +264,8 @@ "id": "3a713c94897456e1", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.358157Z", - "start_time": "2024-02-10T02:23:42.316636Z" + "end_time": "2024-02-11T01:14:37.929587Z", + "start_time": "2024-02-11T01:14:37.906818Z" } }, "outputs": [], @@ -279,8 +294,8 @@ "id": "ab27747eba156888", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.385333Z", - "start_time": "2024-02-10T02:23:42.360474Z" + "end_time": "2024-02-11T01:14:37.960832Z", + "start_time": "2024-02-11T01:14:37.914626Z" } }, "outputs": [], @@ -297,8 +312,8 @@ "id": "26c558e3551b5092", "metadata": { "ExecuteTime": { - "end_time": "2024-02-10T02:23:42.387707Z", - "start_time": "2024-02-10T02:23:42.370602Z" + "end_time": "2024-02-11T01:14:37.961088Z", + "start_time": "2024-02-11T01:14:37.921294Z" } }, "outputs": [], diff --git a/src/pandas_ts/ts_accessor.py b/src/pandas_ts/ts_accessor.py index 77d635c..6d07eed 100644 --- a/src/pandas_ts/ts_accessor.py +++ b/src/pandas_ts/ts_accessor.py @@ -1,3 +1,6 @@ +# Python 3.9 doesn't support "|" for types +from __future__ import annotations + from collections.abc import Generator, MutableMapping from typing import cast @@ -32,13 +35,38 @@ def _check_series(series): if not isinstance(dtype, TsDtype): raise AttributeError(f"Can only use .ts accessor with a Series of TsDtype, got {dtype}") - def to_lists(self) -> pd.DataFrame: - """Convert ts into dataframe of list-array columns""" - return self._series.struct.explode() + def to_lists(self, fields: list[str] | None = None) -> pd.DataFrame: + """Convert ts into dataframe of list-array columns + + Parameters + ---------- + fields : list[str] or None, optional + Names of the fields to include. Default is None, which means all fields. + + Returns + ------- + pd.DataFrame + Dataframe of list-arrays. + """ + df = self._series.struct.explode() + if fields is None: + return df + return df[fields] - def to_flat(self) -> pd.DataFrame: - """Convert ts into dataframe of flat arrays""" - fields = self._series.struct.dtypes.index + def to_flat(self, fields: list[str] | None = None) -> pd.DataFrame: + """Convert ts into dataframe of flat arrays + + Parameters + ---------- + fields : list[str] or None, optional + Names of the fields to include. Default is None, which means all fields. + + Returns + ------- + pd.DataFrame + Dataframe of flat arrays. + """ + fields = fields if fields is not None else list(self._series.struct.dtypes.index) if len(fields) == 0: raise ValueError("Cannot flatten a struct with no fields") @@ -112,7 +140,10 @@ def delete_field(self, field: str) -> pd.Series: self._series.array.delete_field(field) return series - def __getitem__(self, key: str) -> pd.Series: + def __getitem__(self, key: str | list[str]) -> pd.Series: + if isinstance(key, list): + new_array = self._series.array.view_fields(key) + return pd.Series(new_array, index=self._series.index, name=self._series.name) return self._series.struct.field(key) def __setitem__(self, key: str, value: ArrayLike) -> None: diff --git a/src/pandas_ts/ts_ext_array.py b/src/pandas_ts/ts_ext_array.py index 208d0e9..8dcb83b 100644 --- a/src/pandas_ts/ts_ext_array.py +++ b/src/pandas_ts/ts_ext_array.py @@ -1,3 +1,4 @@ +# typing.Self and "|" union syntax don't exist in Python 3.9 from __future__ import annotations from collections.abc import Collection, Iterable, Iterator, Sequence @@ -180,6 +181,38 @@ def flat_length(self) -> int: """Length of the flat arrays""" return sum(chunk.field(0).value_lengths().sum().as_py() for chunk in self._pa_array.iterchunks()) + def view_fields(self, fields: str | list[str]) -> Self: # type: ignore[name-defined] # noqa: F821 + """Get a view of the series with only the specified fields + + Parameters + ---------- + fields : str or list of str + The name of the field or a list of names of the fields to include. + + Returns + ------- + TsExtensionArray + The view of the series with only the specified fields. + """ + if isinstance(fields, str): + fields = [fields] + if len(set(fields)) != len(fields): + raise ValueError("Duplicate field names are not allowed") + if not set(fields).issubset(self.field_names): + raise ValueError(f"Some fields are not found, given: {fields}, available: {self.field_names}") + + chunks = [] + for chunk in self._pa_array.iterchunks(): + chunk = cast(pa.StructArray, chunk) + struct_dict = {} + for field in fields: + struct_dict[field] = chunk.field(field) + struct_array = pa.StructArray.from_arrays(struct_dict.values(), struct_dict.keys()) + chunks.append(struct_array) + pa_array = pa.chunked_array(chunks) + + return self.__class__(pa_array, validate=False) + def set_flat_field(self, field: str, value: ArrayLike) -> None: """Set the field from flat-array of values diff --git a/tests/pandas_ts/test_ts_accessor.py b/tests/pandas_ts/test_ts_accessor.py index ebdd23a..771db09 100644 --- a/tests/pandas_ts/test_ts_accessor.py +++ b/tests/pandas_ts/test_ts_accessor.py @@ -8,6 +8,7 @@ from pandas.testing import assert_frame_equal, assert_series_equal from pandas_ts import TsDtype +from pandas_ts.ts_ext_array import TsExtensionArray def test_ts_accessor_registered(): @@ -52,6 +53,30 @@ def test_ts_accessor_to_lists(): assert_frame_equal(lists, desired) +def test_ts_accessor_to_lists_with_fields(): + struct_array = pa.StructArray.from_arrays( + arrays=[ + pa.array([np.array([1.0, 2.0, 3.0]), -np.array([1.0, 2.0, 1.0])]), + pa.array([np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0])]), + ], + names=["a", "b"], + ) + series = pd.Series(struct_array, dtype=TsDtype(struct_array.type), index=[0, 1]) + + lists = series.ts.to_lists(fields=["a"]) + + desired = pd.DataFrame( + data={ + "a": pd.Series( + data=[np.array([1.0, 2.0, 3.0]), -np.array([1.0, 2.0, 1.0])], + dtype=pd.ArrowDtype(pa.list_(pa.float64())), + index=[0, 1], + ), + }, + ) + assert_frame_equal(lists, desired) + + def test_ts_accessor_to_flat(): struct_array = pa.StructArray.from_arrays( arrays=[ @@ -88,6 +113,36 @@ def test_ts_accessor_to_flat(): assert_array_equal(flat[column], desired[column]) +def test_to_flat_with_fields(): + struct_array = pa.StructArray.from_arrays( + arrays=[ + pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0])]), + pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0])]), + ], + names=["a", "b"], + ) + series = pd.Series(struct_array, dtype=TsDtype(struct_array.type), index=[0, 1]) + + flat = series.ts.to_flat(fields=["a"]) + + desired = pd.DataFrame( + data={ + "a": pd.Series( + data=[1.0, 2.0, 3.0, 1.0, 2.0, 1.0], + index=[0, 0, 0, 1, 1, 1], + name="a", + copy=False, + ), + }, + ) + + assert_array_equal(flat.dtypes, desired.dtypes) + assert_array_equal(flat.index, desired.index) + + for column in flat.columns: + assert_array_equal(flat[column], desired[column]) + + def test_ts_accessor_fields(): struct_array = pa.StructArray.from_arrays( arrays=[ @@ -184,7 +239,7 @@ def test_delete_field(): ) -def test_ts_accessor___getitem__(): +def test_ts_accessor___getitem___single_field(): struct_array = pa.StructArray.from_arrays( arrays=[ pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0])]), @@ -214,6 +269,35 @@ def test_ts_accessor___getitem__(): ) +def test_ts_accessor___getitem___multiple_fields(): + arrays = [ + pa.array([np.array([1.0, 2.0, 3.0]), -np.array([1.0, 2.0, 1.0])]), + pa.array([np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0])]), + ] + series = pd.Series( + TsExtensionArray( + pa.StructArray.from_arrays( + arrays=arrays, + names=["a", "b"], + ) + ), + index=[0, 1], + ) + + assert_series_equal( + series.ts[["b", "a"]], + pd.Series( + TsExtensionArray( + pa.StructArray.from_arrays( + arrays=arrays[::-1], + names=["b", "a"], + ) + ), + index=[0, 1], + ), + ) + + def test_ts_accessor___setitem___with_flat(): struct_array = pa.StructArray.from_arrays( arrays=[ diff --git a/tests/pandas_ts/test_ts_ext_array.py b/tests/pandas_ts/test_ts_ext_array.py index 62a7955..59644ab 100644 --- a/tests/pandas_ts/test_ts_ext_array.py +++ b/tests/pandas_ts/test_ts_ext_array.py @@ -377,6 +377,83 @@ def test_flat_length(): assert ext_array.flat_length == 7 +def test_view_fields_with_single_field(): + arrays = [ + pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0, 2.0])]), + pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0, 6.0])]), + ] + ext_array = TsExtensionArray( + pa.StructArray.from_arrays( + arrays=arrays, + names=["a", "b"], + ) + ) + + view = ext_array.view_fields("a") + assert view.field_names == ["a"] + + desired = TsExtensionArray( + pa.StructArray.from_arrays( + arrays=arrays[:1], + names=["a"], + ) + ) + + assert_series_equal(pd.Series(view), pd.Series(desired)) + + +def test_view_fields_with_multiple_fields(): + arrays = [ + pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 3.0, 4.0])]), + pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0, 6.0])]), + pa.array([["x", "y", "z"], ["x1", "x2", "x3", "x4"]]), + ] + ext_array = TsExtensionArray( + pa.StructArray.from_arrays( + arrays=arrays, + names=["a", "b", "c"], + ) + ) + + view = ext_array.view_fields(["b", "a"]) + assert view.field_names == ["b", "a"] + + assert_series_equal( + pd.Series(view), + pd.Series( + TsExtensionArray(pa.StructArray.from_arrays(arrays=[arrays[1], arrays[0]], names=["b", "a"])) + ), + ) + + +def test_view_fields_raises_for_invalid_field(): + struct_array = pa.StructArray.from_arrays( + arrays=[ + pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 3.0, 4.0])]), + pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0, 6.0])]), + ], + names=["a", "b"], + ) + ext_array = TsExtensionArray(struct_array) + + with pytest.raises(ValueError): + ext_array.view_fields("c") + + +def test_view_fields_raises_for_non_unique_fields(): + struct_array = pa.StructArray.from_arrays( + arrays=[ + pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 3.0, 4.0])]), + pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0, 6.0])]), + ], + names=["a", "b"], + ) + ext_array = TsExtensionArray(struct_array) + + with pytest.raises(ValueError): + ext_array.view_fields(["a", "a"]) + + def test_set_flat_field_new_field_scalar(): struct_array = pa.StructArray.from_arrays( arrays=[