diff --git a/docs/notebooks/intro_notebook.ipynb b/docs/notebooks/intro_notebook.ipynb index b678e06..fd5c90a 100644 --- a/docs/notebooks/intro_notebook.ipynb +++ b/docs/notebooks/intro_notebook.ipynb @@ -30,7 +30,12 @@ "cell_type": "code", "execution_count": null, "id": "165a7a0918b5a866", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.765333Z", + "start_time": "2024-02-11T02:05:40.754638Z" + } + }, "outputs": [], "source": [ "import numpy as np\n", @@ -67,7 +72,12 @@ "cell_type": "code", "execution_count": null, "id": "951dbb53d50f21c3", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.790964Z", + "start_time": "2024-02-11T02:05:40.763685Z" + } + }, "outputs": [], "source": [ "packed = pack_flat(sources, name=\"sources\")\n", @@ -88,7 +98,12 @@ "cell_type": "code", "execution_count": null, "id": "edd0b2714196c9d0", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.793801Z", + "start_time": "2024-02-11T02:05:40.775923Z" + } + }, "outputs": [], "source": [ "packed.iloc[0]" @@ -98,7 +113,12 @@ "cell_type": "code", "execution_count": null, "id": "3144e1a6c5964ed9", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.794104Z", + "start_time": "2024-02-11T02:05:40.778911Z" + } + }, "outputs": [], "source": [ "# Get the linearly interpolated flux for time=10\n", @@ -120,7 +140,12 @@ "cell_type": "code", "execution_count": null, "id": "620ad241f94d3e98", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.797306Z", + "start_time": "2024-02-11T02:05:40.783399Z" + } + }, "outputs": [], "source": [ "packed.ts.to_flat()" @@ -130,7 +155,12 @@ "cell_type": "code", "execution_count": null, "id": "63e47b51a269305f", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.839231Z", + "start_time": "2024-02-11T02:05:40.788390Z" + } + }, "outputs": [], "source": [ "packed.ts.to_lists()" @@ -140,7 +170,12 @@ "cell_type": "code", "execution_count": null, "id": "ac15e872786696ef", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.839650Z", + "start_time": "2024-02-11T02:05:40.800540Z" + } + }, "outputs": [], "source": [ "packed.ts[\"flux\"]" @@ -150,7 +185,12 @@ "cell_type": "code", "execution_count": null, "id": "ce413366fa0a3a43", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.839887Z", + "start_time": "2024-02-11T02:05:40.806279Z" + } + }, "outputs": [], "source": [ "packed.ts[[\"time\", \"flux\"]]" @@ -160,7 +200,12 @@ "cell_type": "code", "execution_count": null, "id": "dc7dbd52f1a8407a", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.840131Z", + "start_time": "2024-02-11T02:05:40.808308Z" + } + }, "outputs": [], "source": [ "packed.dtype" @@ -180,11 +225,16 @@ "cell_type": "code", "execution_count": null, "id": "996f07b4d16e17e5", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.840401Z", + "start_time": "2024-02-11T02:05:40.810815Z" + } + }, "outputs": [], "source": [ "# Change flux in place with flat arrays\n", - "packed.ts[\"flux\"] = -1 * packed.ts[\"flux\"].list.flatten()\n", + "packed.ts[\"flux\"] = -2 * packed.ts[\"flux\"]\n", "packed.ts[\"flux\"]" ] }, @@ -192,7 +242,12 @@ "cell_type": "code", "execution_count": null, "id": "21d5c009ef0990a4", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.841105Z", + "start_time": "2024-02-11T02:05:40.815193Z" + } + }, "outputs": [], "source": [ "# Change errors for object 8003\n", @@ -207,14 +262,17 @@ "cell_type": "code", "execution_count": null, "id": "3a713c94897456e1", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.841866Z", + "start_time": "2024-02-11T02:05:40.821262Z" + } + }, "outputs": [], "source": [ "# Delete field and add new one\n", "del packed.ts[\"count\"]\n", - "filters = packed.ts.pop_field(\"band\")\n", - "filters = \"lsst_\" + filters.list.flatten()\n", - "packed.ts[\"filters\"] = filters\n", + "packed.ts[\"filters\"] = \"lsst_\" + packed.ts.pop_field(\"band\")\n", "packed" ] }, @@ -232,7 +290,12 @@ "cell_type": "code", "execution_count": null, "id": "ab27747eba156888", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.870737Z", + "start_time": "2024-02-11T02:05:40.828671Z" + } + }, "outputs": [], "source": [ "# Subsample light curves\n", @@ -245,7 +308,12 @@ "cell_type": "code", "execution_count": null, "id": "26c558e3551b5092", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.872977Z", + "start_time": "2024-02-11T02:05:40.834811Z" + } + }, "outputs": [], "source": [ "# Query sources\n", @@ -257,7 +325,12 @@ "cell_type": "code", "execution_count": null, "id": "945d5bf9417e8220", - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-11T02:05:40.914308Z", + "start_time": "2024-02-11T02:05:40.841661Z" + } + }, "outputs": [], "source": [] } diff --git a/src/pandas_ts/ts_accessor.py b/src/pandas_ts/ts_accessor.py index 9944f41..3599214 100644 --- a/src/pandas_ts/ts_accessor.py +++ b/src/pandas_ts/ts_accessor.py @@ -163,15 +163,34 @@ def query_flat(self, query: str) -> pd.Series: return pd.Series([], dtype=self._series.dtype) return pack_sorted_df_into_struct(flat) + def get_list_series(self, field: str) -> pd.Series: + """Get the list-array field as a Series + + Parameters + ---------- + field : str + Name of the field to get. + + Returns + ------- + pd.Series + The list-array field. + """ + return self._series.struct.field(field) + def __getitem__(self, key: str | list[str]) -> pd.Series: if isinstance(key, list): new_array = self._series.array.view_fields(key) return pd.Series(new_array, index=self._series.index, name=self._series.name) - return self._series.struct.field(key) + + series = self._series.struct.field(key).list.flatten() + series.index = np.repeat(self._series.index.values, np.diff(self._series.array.list_offsets)) + series.name = key + return series def __setitem__(self, key: str, value: ArrayLike) -> None: # TODO: we can be much-much smarter about the performance here - # TODO: think better about underlying pa.ChunkArray + # TODO: think better about underlying pa.ChunkArray in both self._series.array and value # Everything is empty, do nothing if len(self._series) == 0 and np.ndim(value) != 0: diff --git a/tests/pandas_ts/test_ts_accessor.py b/tests/pandas_ts/test_ts_accessor.py index 737a6db..d9f7f1c 100644 --- a/tests/pandas_ts/test_ts_accessor.py +++ b/tests/pandas_ts/test_ts_accessor.py @@ -184,10 +184,10 @@ def test_set_flat_field(): assert_series_equal( series.ts["a"], pd.Series( - data=[["a", "b", "c"], ["d", "e", "f"]], - index=[0, 1], + data=["a", "b", "c", "d", "e", "f"], + index=[0, 0, 0, 1, 1, 1], name="a", - dtype=pd.ArrowDtype(pa.list_(pa.string())), + dtype=pd.ArrowDtype(pa.string()), ), ) @@ -207,10 +207,10 @@ def test_set_list_field(): assert_series_equal( series.ts["c"], pd.Series( - data=[["a", "b", "c"], ["d", "e", "f"]], - index=[0, 1], + data=["a", "b", "c", "d", "e", "f"], + index=[0, 0, 0, 1, 1, 1], name="c", - dtype=pd.ArrowDtype(pa.list_(pa.string())), + dtype=pd.ArrowDtype(pa.string()), ), ) @@ -231,9 +231,9 @@ def test_pop_field(): assert_series_equal( a, pd.Series( - [np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0])], - dtype=pd.ArrowDtype(pa.list_(pa.float64())), - index=[0, 1], + [1.0, 2.0, 3.0, 1.0, 2.0, 1.0], + dtype=pd.ArrowDtype(pa.float64()), + index=[0, 0, 0, 1, 1, 1], name="a", ), ) @@ -280,6 +280,29 @@ def test_query_flat_2(): assert_series_equal(filtered, desired) +def test_get_list_series(): + struct_array = pa.StructArray.from_arrays( + arrays=[ + pa.array([np.array([1, 2, 3]), np.array([4, 5, 6])]), + pa.array([np.array([6, 4, 2]), np.array([1, 2, 3])]), + ], + names=["a", "b"], + ) + series = pd.Series(struct_array, dtype=TsDtype(struct_array.type), index=[5, 7]) + + lists = series.ts.get_list_series("a") + + assert_series_equal( + lists, + pd.Series( + data=[np.array([1, 2, 3]), np.array([4, 5, 6])], + dtype=pd.ArrowDtype(pa.list_(pa.int64())), + index=[5, 7], + name="a", + ), + ) + + def test___getitem___single_field(): struct_array = pa.StructArray.from_arrays( arrays=[ @@ -293,18 +316,18 @@ def test___getitem___single_field(): assert_series_equal( series.ts["a"], pd.Series( - [np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0])], - dtype=pd.ArrowDtype(pa.list_(pa.float64())), - index=[0, 1], + np.array([1.0, 2.0, 3.0, 1.0, 2.0, 1.0]), + dtype=pd.ArrowDtype(pa.float64()), + index=[0, 0, 0, 1, 1, 1], name="a", ), ) assert_series_equal( series.ts["b"], pd.Series( - [-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0])], - dtype=pd.ArrowDtype(pa.list_(pa.float64())), - index=[0, 1], + -np.array([4.0, 5.0, 6.0, 3.0, 4.0, 5.0]), + dtype=pd.ArrowDtype(pa.float64()), + index=[0, 0, 0, 1, 1, 1], name="b", ), ) @@ -354,10 +377,10 @@ def test___setitem___with_flat(): assert_series_equal( series.ts["a"], pd.Series( - data=[["a", "b", "c"], ["d", "e", "f"]], - index=[0, 1], + data=["a", "b", "c", "d", "e", "f"], + index=[0, 0, 0, 1, 1, 1], name="a", - dtype=pd.ArrowDtype(pa.list_(pa.string())), + dtype=pd.ArrowDtype(pa.string()), ), ) @@ -377,10 +400,10 @@ def test___setitem___with_list(): assert_series_equal( series.ts["c"], pd.Series( - data=[["a", "b", "c"], ["d", "e", "f"]], - index=[0, 1], + data=["a", "b", "c", "d", "e", "f"], + index=[0, 0, 0, 1, 1, 1], name="c", - dtype=pd.ArrowDtype(pa.list_(pa.string())), + dtype=pd.ArrowDtype(pa.string()), ), )