diff --git a/CHANGELOG.md b/CHANGELOG.md index f5894683..8fa255a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,8 @@ # Changelog ## v1.6.1 -* (#233) Fix some bugs in `emitter.py` and allow Numpy arrays in Store schemas +* (#234) Allow Numpy arrays to be deeply nested in Store schemas +* (#233) Fix some bugs in `emitter.py` ## v1.6.0 * (#231) Update `networkx` error message for already deleted Step. diff --git a/vivarium/core/engine.py b/vivarium/core/engine.py index 2be018fb..61a71701 100644 --- a/vivarium/core/engine.py +++ b/vivarium/core/engine.py @@ -19,8 +19,10 @@ import datetime import time as clock import uuid +import warnings import networkx as nx +import numpy as np import pytest from vivarium.core.store import ( @@ -1236,3 +1238,173 @@ def next_update(self, timestep: float, states: State) -> Update: 'stepA2': [], }, ) + +def test_numpy_schema_validation() -> None: + class ProcessA(Process): + + defaults = { + 'ports_schema_dict': { + 'b': {'_default': 1} + } + } + + def __init__(self, parameters: Optional[dict] = None) -> None: + super().__init__(parameters) + self.ports_schema_dict = self.parameters['ports_schema_dict'] + + def ports_schema(self) -> Schema: + return self.ports_schema_dict + + def next_update(self, timestep: float, states: State) -> Update: + return {} + + # Since procA1 and procA2 differ in their _divider schemas, + # _check_schema_support_defaults should raise a warning + with pytest.warns(UserWarning, match='Incompatible schema assignment'): + _ = Engine( + processes={ + 'procA1': ProcessA( + { + 'ports_schema_dict': { + 'a': { + '_divider': { + 'divider': 'set_value', + 'config': { + 'value': np.ones(10) + } + } + } + } + } + ), + 'procA2': ProcessA( + { + 'ports_schema_dict': { + 'a': { + '_divider': { + 'divider': 'set_value', + 'config': { + 'value': np.zeros(10) + } + } + } + } + } + ), + }, + topology={ + 'procA1': {'a': ('a',)}, + 'procA2': {'a': ('a',)}, + } + ) + + # Check that np.ones(10) is not considered equal to 1 + with pytest.raises(ValueError, match='The truth value of an array'): + _ = Engine( + processes={ + 'procA1': ProcessA( + { + 'ports_schema_dict': { + 'a': { + '_divider': { + 'divider': 'set_value', + 'config': { + 'value': 1 + } + } + } + } + } + ), + 'procA2': ProcessA( + { + 'ports_schema_dict': { + 'a': { + '_divider': { + 'divider': 'set_value', + 'config': { + 'value': np.ones(10) + } + } + } + } + } + ), + }, + topology={ + 'procA1': {'a': ('a',)}, + 'procA2': {'a': ('a',)}, + } + ) + + # Since ProcA1 and ProcA2 differ in their _value schemas, + # _check_schema should raise an exception + with pytest.raises(ValueError, match='Incompatible schema assignment'): + _ = Engine( + processes={ + 'procA1': ProcessA( + { + 'ports_schema_dict': { + 'a': { + '_value': np.ones(10) + } + } + } + ), + 'procA2': ProcessA( + { + 'ports_schema_dict': { + 'a': { + '_value': np.zeros(10) + } + } + } + ), + }, + topology={ + 'procA1': {'a': ('a',)}, + 'procA2': {'a': ('a',)}, + } + ) + + # Matching schemas should raise no warnings or exceptions + with warnings.catch_warnings(): + warnings.simplefilter("error") + _ = Engine( + processes={ + 'procA1': ProcessA( + { + 'ports_schema_dict': { + 'a': { + '_value': np.ones(10), + '_divider': { + 'divider': 'set_value', + 'config': { + 'value': np.ones(10) + } + } + } + } + } + ), + 'procA2': ProcessA( + { + 'ports_schema_dict': { + 'a': { + '_value': np.ones(10), + '_divider': { + 'divider': 'set_value', + 'config': { + 'value': np.ones(10) + } + } + } + } + } + ), + }, + topology={ + 'procA1': {'a': ('a',)}, + 'procA2': {'a': ('a',)}, + } + ) diff --git a/vivarium/core/store.py b/vivarium/core/store.py index 20677593..6298e042 100644 --- a/vivarium/core/store.py +++ b/vivarium/core/store.py @@ -18,7 +18,7 @@ from vivarium.core.registry import divider_registry, serializer_registry, updater_registry from vivarium.core.process import ParallelProcess, Process -from vivarium.library.dict_utils import deep_merge, deep_merge_check, MULTI_UPDATE_KEY +from vivarium.library.dict_utils import deep_compare, deep_merge, deep_merge_check, MULTI_UPDATE_KEY from vivarium.library.topology import dict_to_paths from vivarium.core.types import Processes, Topology, State, Steps, Flow from vivarium.core.serialize import QuantitySerializer @@ -499,9 +499,13 @@ def _check_schema(self, schema_key, new_schema): value is different from the existing one. """ current_schema_value = getattr(self, schema_key) - if current_schema_value is not None and np.all( - current_schema_value != new_schema - ): + if isinstance(current_schema_value, dict) and isinstance(new_schema, dict): + schemas_equal = deep_compare(current_schema_value, new_schema) + elif isinstance(current_schema_value, np.ndarray) and isinstance(new_schema, np.ndarray): + schemas_equal = np.array_equal(current_schema_value, new_schema) + else: + schemas_equal = (current_schema_value == new_schema) + if current_schema_value is not None and not schemas_equal: if schema_key == "units": # Different Python interpreters (inc. from multiprocessing with # spawn start method) yield different hashes for the same value @@ -527,10 +531,14 @@ def _check_schema_support_defaults(self, schema_key, new_schema, schema_registry new_schema[schema_key], str): new_schema[schema_key] = schema_registry.access( new_schema[schema_key]) + if isinstance(current_schema_value, dict) and isinstance(new_schema, dict): + schemas_equal = deep_compare(current_schema_value, new_schema) + else: + schemas_equal = (current_schema_value == new_schema) if ( current_schema_value - and np.all(current_schema_value != DEFAULT_SCHEMA) - and np.all(current_schema_value != new_schema)): + and current_schema_value != DEFAULT_SCHEMA + and not schemas_equal): warnings.warn( f"Incompatible schema assignment at {self.path_for()}. " f"Trying to assign the value {new_schema} to key {schema_key}, " diff --git a/vivarium/library/dict_utils.py b/vivarium/library/dict_utils.py index 47033db5..e56f3389 100644 --- a/vivarium/library/dict_utils.py +++ b/vivarium/library/dict_utils.py @@ -5,7 +5,9 @@ import operator import traceback from typing import Optional, Any, Callable +import warnings +import numpy as np from vivarium.library.units import Quantity @@ -21,6 +23,43 @@ def merge_dicts(dicts): return merge +def deep_compare(dct_1, dct_2, path=tuple()): + """Recursively checks for equality between two dictionaries in a way + that supports Numpy arrays. + + Args: + dct_1, dct_2, dictionaries to compare + path: If ``dct_1`` or ``dct_2`` are nested within larger dictionaries, + this is the path to them. This is normally an empty tuple + for the end user but is used for recursive calls + + Returns: + True when two dictionaries are equal, False otherwise + + Raises: + ValueError: Raised when conflicting values are found between + ``dct_1`` and ``dct_2`` + """ + key_diff = dct_1.keys() ^ dct_2.keys() + if len(key_diff) > 0: + warnings.warn(f'Unshared keys at {path}: {key_diff}') + return False + for key, val_1 in dct_1.items(): + val_2 = dct_2[key] + if isinstance(val_1, dict) and isinstance(val_2, dict): + if not deep_compare(val_1, val_2, path + (key,)): + return False + elif isinstance(val_1, np.ndarray) and isinstance(val_2, np.ndarray): + if not np.array_equal(val_1, val_2): + warnings.warn(f'Dicts differ at {path}: {val_1}, {val_2}') + return False + else: + if not val_1 == val_2: + warnings.warn(f'Dicts differ at {path}: {val_1}, {val_2}') + return False + return True + + def deep_merge_check(dct, merge_dct, check_equality=False, path=tuple()): """Recursively merge dictionaries with checks to avoid overwriting.