Skip to content

Commit

Permalink
[ENH] make Empirical distribution compatible with multi-index rows (#…
Browse files Browse the repository at this point in the history
…233)

Towards #212, this makes the `Empirical` distribution compatible with
multi-index rows.

Mirror of sktime/sktime#6066
  • Loading branch information
fkiraly authored Apr 5, 2024
1 parent 9a27d45 commit dee8bbd
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions skpro/distributions/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def __init__(self, spl, weights=None, time_indep=True, index=None, columns=None)
self.index = index
self.columns = columns

_timestamps = spl.index.get_level_values(-1).unique()
_spl_instances = spl.index.droplevel(-1).unique()
_timestamps = spl.index.droplevel(0).unique()
_spl_instances = spl.index.get_level_values(0).unique()
self._timestamps = _timestamps
self._spl_instances = _spl_instances
self._N = len(_spl_instances)
Expand All @@ -83,7 +83,8 @@ def _init_sorted(self):
sorted[t] = {}
weights[t] = {}
for col in cols:
spl_t = self.spl.loc[(slice(None), t), col].values
sl = (slice(None),) + self._coerce_tuple(t)
spl_t = self.spl.loc[sl, col].values
sorter = np.argsort(spl_t)
spl_t_sorted = spl_t[sorter]
sorted[t][col] = spl_t_sorted
Expand All @@ -98,6 +99,11 @@ def _init_sorted(self):
self._sorted = sorted
self._weights = weights

def _coerce_tuple(self, x):
if not isinstance(x, tuple):
x = (x,)
return x

def _apply_per_ix(self, func, params, x=None):
"""Apply function per index."""
sorted = self._sorted
Expand Down

0 comments on commit dee8bbd

Please sign in to comment.