Skip to content

Commit

Permalink
allowed resultdata to receive also lists
Browse files Browse the repository at this point in the history
  • Loading branch information
victormvy committed May 27, 2024
1 parent ff407fb commit 734784b
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 68 deletions.
34 changes: 14 additions & 20 deletions remayn/result/result_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np

from ..utils import check_array


class ResultData:
"""Stores the results of a experiment.
Expand Down Expand Up @@ -51,26 +53,18 @@ def __init__(
best_params: Optional[dict] = None,
best_model: Optional[object] = None,
):
if not isinstance(targets, np.ndarray):
raise TypeError("targets must be a numpy array")
if not isinstance(predictions, np.ndarray):
raise TypeError("predictions must be a numpy array")
if train_targets is not None and not isinstance(train_targets, np.ndarray):
raise TypeError("train_targets must be a numpy array")
if train_predictions is not None and not isinstance(
train_predictions, np.ndarray
):
raise TypeError("train_predictions must be a numpy array")
if val_targets is not None and not isinstance(val_targets, np.ndarray):
raise TypeError("val_targets must be a numpy array")
if val_predictions is not None and not isinstance(val_predictions, np.ndarray):
raise TypeError("val_predictions must be a numpy array")
if time is not None and not isinstance(time, (int, float)):
raise TypeError("time must be a number")
if train_history is not None and not isinstance(train_history, np.ndarray):
raise TypeError("train_history must be a numpy array")
if val_history is not None and not isinstance(val_history, np.ndarray):
raise TypeError("val_history must be a numpy array")
targets = check_array(targets)
predictions = check_array(predictions)
train_targets = check_array(train_targets, allow_none=True)
train_predictions = check_array(train_predictions, allow_none=True)
val_targets = check_array(val_targets, allow_none=True)
val_predictions = check_array(val_predictions, allow_none=True)
train_history = check_array(train_history, allow_none=True)
val_history = check_array(val_history, allow_none=True)

if not isinstance(time, (float, int)) and time is not None:
raise TypeError("time must be a float")

if best_params is not None and not isinstance(best_params, dict):
raise TypeError("best_params must be a dictionary")

Expand Down
81 changes: 33 additions & 48 deletions remayn/result/tests/test_result_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,42 +81,30 @@ def test_validation():
with pytest.raises(TypeError):
ResultData(targets=None, predictions=None)

with pytest.raises(TypeError):
ResultData(targets=[], predictions=[])

with pytest.raises(TypeError):
ResultData(targets=random_targets(), predictions=[])

with pytest.raises(TypeError):
ResultData(targets=[], predictions=random_predictions())

with pytest.raises(TypeError):
ResultData(
targets=random_targets(),
predictions=random_predictions(),
train_targets=[],
)

with pytest.raises(TypeError):
ResultData(
targets=random_targets(),
predictions=random_predictions(),
train_predictions=[],
)

with pytest.raises(TypeError):
ResultData(
targets=random_targets(),
predictions=random_predictions(),
val_targets=[],
)
ResultData(targets=[], predictions=[])
ResultData(targets=random_targets(), predictions=[])
ResultData(targets=[], predictions=random_predictions())
ResultData(
targets=random_targets(),
predictions=random_predictions(),
train_targets=[],
)
ResultData(
targets=random_targets(),
predictions=random_predictions(),
train_predictions=[],
)

with pytest.raises(TypeError):
ResultData(
targets=random_targets(),
predictions=random_predictions(),
val_predictions=[],
)
ResultData(
targets=random_targets(),
predictions=random_predictions(),
val_targets=[],
)
ResultData(
targets=random_targets(),
predictions=random_predictions(),
val_predictions=[],
)

with pytest.raises(TypeError):
ResultData(
Expand All @@ -125,19 +113,16 @@ def test_validation():
time="time",
)

with pytest.raises(TypeError):
ResultData(
targets=random_targets(),
predictions=random_predictions(),
train_history=[],
)

with pytest.raises(TypeError):
ResultData(
targets=random_targets(),
predictions=random_predictions(),
val_history=[],
)
ResultData(
targets=random_targets(),
predictions=random_predictions(),
train_history=[],
)
ResultData(
targets=random_targets(),
predictions=random_predictions(),
val_history=[],
)

with pytest.raises(TypeError):
ResultData(
Expand Down
2 changes: 2 additions & 0 deletions remayn/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .array import check_array
from .dicts import dict_contains_dict, get_deep_item_from_dict
from .json import NonDefaultStrMethodError, sanitize_json

Expand All @@ -6,4 +7,5 @@
"dict_contains_dict",
"NonDefaultStrMethodError",
"get_deep_item_from_dict",
"check_array",
]
32 changes: 32 additions & 0 deletions remayn/utils/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List, Union

import numpy as np


def check_array(array: Union[np.ndarray, List], allow_none=False) -> np.ndarray:
"""Check if the input is a numpy array. If not, convert it to a numpy array.
Parameters
----------
array : Union[np.ndarray, List]
The input array to check.
allow_none : bool, optional
Whether to allow the input to be None, by default False
Returns
-------
np.ndarray
The input array as a numpy array.
"""

if array is None:
if allow_none:
return None
else:
raise TypeError("None is not a valid array")
if not isinstance(array, np.ndarray):
if isinstance(array, list):
return np.array(array)
else:
raise TypeError(f"{array} must be a numpy array")
return array
19 changes: 19 additions & 0 deletions remayn/utils/tests/test_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
import pytest

from remayn.utils.array import check_array


def test_check_array():
array = [1, 2, 3]
assert check_array(array).tolist() == array
assert (check_array(np.array(array)) == np.array(array)).all()

array = None
assert check_array(array, allow_none=True) is None
with pytest.raises(TypeError):
check_array(array)

array = "string"
with pytest.raises(TypeError):
check_array(array)

0 comments on commit 734784b

Please sign in to comment.