Skip to content

Commit

Permalink
fix wrong columns in test params
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Aug 25, 2023
1 parent 4bf1141 commit 803e636
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions skpro/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ def _average(self, method, x=None, weights=None):

def _average_df(self, df_list, weights=None):
"""Average a list of `pd.DataFrame` objects, with weights."""
if weights is None:
if weights is None and hasattr(self, "_weights"):
weights = self._weights
elif weights is None:
weights = np.ones(len(df_list)) / len(df_list)

n_df = len(df_list)
df_weighted = [df * w for df, w in zip(df_list, weights)]
Expand Down Expand Up @@ -169,7 +171,7 @@ def get_test_params(cls, parameter_set="default"):
index = pd.RangeIndex(3)
columns = pd.Index(["a", "b"])
normal1 = Normal(mu=0, sigma=1, index=index, columns=columns)
normal2 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=1)
normal2 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=1, columns=columns)

dists = [("normal1", normal1), ("normal2", normal2)]

Expand Down

0 comments on commit 803e636

Please sign in to comment.