This repository has been archived by the owner on Apr 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from lincc-frameworks/ts-accessor
Init impl of .ts accessor
- Loading branch information
Showing
6 changed files
with
453 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .ts_accessor import TsAccessor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from typing import Any, cast | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pyarrow as pa | ||
from pandas.api.extensions import register_series_accessor | ||
|
||
__all__ = ["TsAccessor"] | ||
|
||
|
||
def pa_type_is_any_list(pa_type): | ||
return ( | ||
pa.types.is_list(pa_type) or pa.types.is_large_list(pa_type) or pa.types.is_fixed_size_list(pa_type) | ||
) | ||
|
||
|
||
@register_series_accessor("ts") | ||
class TsAccessor: | ||
def __init__(self, series): | ||
self._check_series(series) | ||
|
||
self._series = series | ||
|
||
@staticmethod | ||
def _check_series(series): | ||
dtype = series.dtype | ||
TsAccessor._check_dtype(dtype) | ||
|
||
@staticmethod | ||
def _check_dtype(dtype): | ||
# TODO: check if dtype is TsDtype when it is implemented | ||
if not hasattr(dtype, "pyarrow_dtype"): | ||
raise AttributeError("Can only use .ts accessor with a Series with dtype pyarrow struct dtype") | ||
pyarrow_dtype = dtype.pyarrow_dtype | ||
if not pa.types.is_struct(pyarrow_dtype): | ||
raise AttributeError("Can only use .ts accessor with a Series with dtype pyarrow struct dtype") | ||
|
||
for field in pyarrow_dtype: | ||
if not pa_type_is_any_list(field.type): | ||
raise AttributeError( | ||
f"Can only use .ts accessor with a Series with dtype pyarrow struct dtype, all fields must be list types. Given struct has unsupported field {field}" | ||
) | ||
|
||
def to_nested(self): | ||
"""Convert ts into dataframe of nested arrays""" | ||
return self._series.struct.explode() | ||
|
||
def to_flat(self): | ||
"""Convert ts into dataframe of flat arrays""" | ||
fields = self._series.struct.dtypes.index | ||
if len(fields) == 0: | ||
raise ValueError("Cannot flatten a struct with no fields") | ||
|
||
flat_series = {} | ||
index = None | ||
for field in fields: | ||
list_array = cast(pa.ListArray, pa.array(self._series.struct.field(field))) | ||
if index is None: | ||
index = np.repeat(self._series.index.values, np.diff(list_array.offsets)) | ||
flat_series[field] = pd.Series( | ||
list_array.flatten(), | ||
index=index, | ||
name=field, | ||
copy=False, | ||
) | ||
return pd.DataFrame(flat_series) | ||
|
||
@property | ||
def fields(self) -> pd.Index: | ||
"""Names of the nested columns""" | ||
return self._series.struct.dtypes.index | ||
|
||
def __getitem__(self, key: str) -> pd.Series: | ||
return self._series.struct.field(key) | ||
|
||
def get(self, index: Any) -> pd.DataFrame: | ||
"""Get a single ts item by label (index value) as a dataframe | ||
Parameters | ||
---------- | ||
index : Any | ||
The label of the item to get, must be in the index of | ||
the series. | ||
Returns | ||
------- | ||
pd.DataFrame | ||
A dataframe with the nested arrays of the item. | ||
See Also | ||
-------- | ||
pandas_ts.TsAccessor.iget : Get a single ts item by position. | ||
""" | ||
item = self._series.loc[index] | ||
return pd.DataFrame.from_dict(item) | ||
|
||
def iget(self, index: int) -> pd.DataFrame: | ||
"""Get a single ts item by position as a dataframe | ||
Parameters | ||
---------- | ||
index : int | ||
The position of the item to get, must be a valid position | ||
in the series, i.e. between 0 and len(series) - 1. | ||
Returns | ||
------- | ||
pd.DataFrame | ||
A dataframe with the nested arrays of the item. | ||
See Also | ||
-------- | ||
pandas_ts.TsAccessor.get : Get a single ts item by label (index value). | ||
""" | ||
item = self._series.iloc[index] | ||
print(item) | ||
return pd.DataFrame.from_dict(item) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from numpy.testing import assert_array_equal | ||
|
||
|
||
def assert_nested_array_series_equal(a, b): | ||
assert_array_equal(a.index, b.index) | ||
for inner_a, inner_b in zip(a, b): | ||
assert_array_equal(inner_a, inner_b, err_msg=f"Series '{a.name}' is not equal series '{b.name}'") | ||
|
||
|
||
def assert_df_equal(a, b): | ||
assert_array_equal(a.index, b.index) | ||
assert_array_equal(a.columns, b.columns) | ||
assert_array_equal(a.dtypes, b.dtypes) | ||
for column in a.columns: | ||
assert_array_equal(a[column], b[column], err_msg=f"Column '{column}' is not equal column '{column}'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.