diff --git a/river/preprocessing/scale.py b/river/preprocessing/scale.py index 69c6bca0e8..8283c2d190 100644 --- a/river/preprocessing/scale.py +++ b/river/preprocessing/scale.py @@ -230,9 +230,15 @@ def transform_many(self, X: pd.DataFrame): """ + # Determine dtype of input dtypes = X.dtypes.unique() dtype = dtypes[0] if len(dtypes) == 1 else np.float64 + # Check if the dtype is integer type and convert to corresponding float type + if np.issubdtype(dtype, np.integer): + bytes_size = dtype.itemsize + dtype = np.dtype(f"float{bytes_size * 8}") + means = np.array([self.means[c] for c in X.columns], dtype=dtype) Xt = X.values - means