diff --git a/remayn/result/result_data.py b/remayn/result/result_data.py index 0caa781..4da647b 100644 --- a/remayn/result/result_data.py +++ b/remayn/result/result_data.py @@ -2,6 +2,8 @@ import numpy as np +from ..utils import check_array + class ResultData: """Stores the results of a experiment. @@ -51,26 +53,18 @@ def __init__( best_params: Optional[dict] = None, best_model: Optional[object] = None, ): - if not isinstance(targets, np.ndarray): - raise TypeError("targets must be a numpy array") - if not isinstance(predictions, np.ndarray): - raise TypeError("predictions must be a numpy array") - if train_targets is not None and not isinstance(train_targets, np.ndarray): - raise TypeError("train_targets must be a numpy array") - if train_predictions is not None and not isinstance( - train_predictions, np.ndarray - ): - raise TypeError("train_predictions must be a numpy array") - if val_targets is not None and not isinstance(val_targets, np.ndarray): - raise TypeError("val_targets must be a numpy array") - if val_predictions is not None and not isinstance(val_predictions, np.ndarray): - raise TypeError("val_predictions must be a numpy array") - if time is not None and not isinstance(time, (int, float)): - raise TypeError("time must be a number") - if train_history is not None and not isinstance(train_history, np.ndarray): - raise TypeError("train_history must be a numpy array") - if val_history is not None and not isinstance(val_history, np.ndarray): - raise TypeError("val_history must be a numpy array") + targets = check_array(targets) + predictions = check_array(predictions) + train_targets = check_array(train_targets, allow_none=True) + train_predictions = check_array(train_predictions, allow_none=True) + val_targets = check_array(val_targets, allow_none=True) + val_predictions = check_array(val_predictions, allow_none=True) + train_history = check_array(train_history, allow_none=True) + val_history = check_array(val_history, allow_none=True) + + if not isinstance(time, (float, int)) and time is not None: + raise TypeError("time must be a float") + if best_params is not None and not isinstance(best_params, dict): raise TypeError("best_params must be a dictionary") diff --git a/remayn/result/tests/test_result_data.py b/remayn/result/tests/test_result_data.py index e6099c2..3b0ca60 100644 --- a/remayn/result/tests/test_result_data.py +++ b/remayn/result/tests/test_result_data.py @@ -81,42 +81,30 @@ def test_validation(): with pytest.raises(TypeError): ResultData(targets=None, predictions=None) - with pytest.raises(TypeError): - ResultData(targets=[], predictions=[]) - - with pytest.raises(TypeError): - ResultData(targets=random_targets(), predictions=[]) - - with pytest.raises(TypeError): - ResultData(targets=[], predictions=random_predictions()) - - with pytest.raises(TypeError): - ResultData( - targets=random_targets(), - predictions=random_predictions(), - train_targets=[], - ) - - with pytest.raises(TypeError): - ResultData( - targets=random_targets(), - predictions=random_predictions(), - train_predictions=[], - ) - - with pytest.raises(TypeError): - ResultData( - targets=random_targets(), - predictions=random_predictions(), - val_targets=[], - ) + ResultData(targets=[], predictions=[]) + ResultData(targets=random_targets(), predictions=[]) + ResultData(targets=[], predictions=random_predictions()) + ResultData( + targets=random_targets(), + predictions=random_predictions(), + train_targets=[], + ) + ResultData( + targets=random_targets(), + predictions=random_predictions(), + train_predictions=[], + ) - with pytest.raises(TypeError): - ResultData( - targets=random_targets(), - predictions=random_predictions(), - val_predictions=[], - ) + ResultData( + targets=random_targets(), + predictions=random_predictions(), + val_targets=[], + ) + ResultData( + targets=random_targets(), + predictions=random_predictions(), + val_predictions=[], + ) with pytest.raises(TypeError): ResultData( @@ -125,19 +113,16 @@ def test_validation(): time="time", ) - with pytest.raises(TypeError): - ResultData( - targets=random_targets(), - predictions=random_predictions(), - train_history=[], - ) - - with pytest.raises(TypeError): - ResultData( - targets=random_targets(), - predictions=random_predictions(), - val_history=[], - ) + ResultData( + targets=random_targets(), + predictions=random_predictions(), + train_history=[], + ) + ResultData( + targets=random_targets(), + predictions=random_predictions(), + val_history=[], + ) with pytest.raises(TypeError): ResultData( diff --git a/remayn/utils/__init__.py b/remayn/utils/__init__.py index 86c338f..3e5aa46 100644 --- a/remayn/utils/__init__.py +++ b/remayn/utils/__init__.py @@ -1,3 +1,4 @@ +from .array import check_array from .dicts import dict_contains_dict, get_deep_item_from_dict from .json import NonDefaultStrMethodError, sanitize_json @@ -6,4 +7,5 @@ "dict_contains_dict", "NonDefaultStrMethodError", "get_deep_item_from_dict", + "check_array", ] diff --git a/remayn/utils/array.py b/remayn/utils/array.py new file mode 100644 index 0000000..1d47ef7 --- /dev/null +++ b/remayn/utils/array.py @@ -0,0 +1,32 @@ +from typing import List, Union + +import numpy as np + + +def check_array(array: Union[np.ndarray, List], allow_none=False) -> np.ndarray: + """Check if the input is a numpy array. If not, convert it to a numpy array. + + Parameters + ---------- + array : Union[np.ndarray, List] + The input array to check. + allow_none : bool, optional + Whether to allow the input to be None, by default False + + Returns + ------- + np.ndarray + The input array as a numpy array. + """ + + if array is None: + if allow_none: + return None + else: + raise TypeError("None is not a valid array") + if not isinstance(array, np.ndarray): + if isinstance(array, list): + return np.array(array) + else: + raise TypeError(f"{array} must be a numpy array") + return array diff --git a/remayn/utils/tests/test_array.py b/remayn/utils/tests/test_array.py new file mode 100644 index 0000000..6835328 --- /dev/null +++ b/remayn/utils/tests/test_array.py @@ -0,0 +1,19 @@ +import numpy as np +import pytest + +from remayn.utils.array import check_array + + +def test_check_array(): + array = [1, 2, 3] + assert check_array(array).tolist() == array + assert (check_array(np.array(array)) == np.array(array)).all() + + array = None + assert check_array(array, allow_none=True) is None + with pytest.raises(TypeError): + check_array(array) + + array = "string" + with pytest.raises(TypeError): + check_array(array)