From 803e6367cf0310e01e259e3d179a9bc4db8f6f08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Fri, 25 Aug 2023 23:22:18 +0100 Subject: [PATCH] fix wrong columns in test params --- skpro/distributions/mixture.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/skpro/distributions/mixture.py b/skpro/distributions/mixture.py index 4b51cdc8..b21a9c1a 100644 --- a/skpro/distributions/mixture.py +++ b/skpro/distributions/mixture.py @@ -112,8 +112,10 @@ def _average(self, method, x=None, weights=None): def _average_df(self, df_list, weights=None): """Average a list of `pd.DataFrame` objects, with weights.""" - if weights is None: + if weights is None and hasattr(self, "_weights"): weights = self._weights + elif weights is None: + weights = np.ones(len(df_list)) / len(df_list) n_df = len(df_list) df_weighted = [df * w for df, w in zip(df_list, weights)] @@ -169,7 +171,7 @@ def get_test_params(cls, parameter_set="default"): index = pd.RangeIndex(3) columns = pd.Index(["a", "b"]) normal1 = Normal(mu=0, sigma=1, index=index, columns=columns) - normal2 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=1) + normal2 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=1, columns=columns) dists = [("normal1", normal1), ("normal2", normal2)]