Skip to content

Commit

Permalink
Merge pull request #234 from vivarium-collective/schema
Browse files Browse the repository at this point in the history
Recursively compare schemas during validation
  • Loading branch information
thalassemia authored Sep 20, 2023
2 parents d3d8cf5 + 31fcdb9 commit 1016ee4
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Changelog

## v1.6.1
* (#233) Fix some bugs in `emitter.py` and allow Numpy arrays in Store schemas
* (#234) Allow Numpy arrays to be deeply nested in Store schemas
* (#233) Fix some bugs in `emitter.py`

## v1.6.0
* (#231) Update `networkx` error message for already deleted Step.
Expand Down
172 changes: 172 additions & 0 deletions vivarium/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import datetime
import time as clock
import uuid
import warnings

import networkx as nx
import numpy as np
import pytest

from vivarium.core.store import (
Expand Down Expand Up @@ -1236,3 +1238,173 @@ def next_update(self, timestep: float, states: State) -> Update:
'stepA2': [],
},
)

def test_numpy_schema_validation() -> None:
class ProcessA(Process):

defaults = {
'ports_schema_dict': {
'b': {'_default': 1}
}
}

def __init__(self, parameters: Optional[dict] = None) -> None:
super().__init__(parameters)
self.ports_schema_dict = self.parameters['ports_schema_dict']

def ports_schema(self) -> Schema:
return self.ports_schema_dict

def next_update(self, timestep: float, states: State) -> Update:
return {}

# Since procA1 and procA2 differ in their _divider schemas,
# _check_schema_support_defaults should raise a warning
with pytest.warns(UserWarning, match='Incompatible schema assignment'):
_ = Engine(
processes={
'procA1': ProcessA(
{
'ports_schema_dict': {
'a': {
'_divider': {
'divider': 'set_value',
'config': {
'value': np.ones(10)
}
}
}
}
}
),
'procA2': ProcessA(
{
'ports_schema_dict': {
'a': {
'_divider': {
'divider': 'set_value',
'config': {
'value': np.zeros(10)
}
}
}
}
}
),
},
topology={
'procA1': {'a': ('a',)},
'procA2': {'a': ('a',)},
}
)

# Check that np.ones(10) is not considered equal to 1
with pytest.raises(ValueError, match='The truth value of an array'):
_ = Engine(
processes={
'procA1': ProcessA(
{
'ports_schema_dict': {
'a': {
'_divider': {
'divider': 'set_value',
'config': {
'value': 1
}
}
}
}
}
),
'procA2': ProcessA(
{
'ports_schema_dict': {
'a': {
'_divider': {
'divider': 'set_value',
'config': {
'value': np.ones(10)
}
}
}
}
}
),
},
topology={
'procA1': {'a': ('a',)},
'procA2': {'a': ('a',)},
}
)

# Since ProcA1 and ProcA2 differ in their _value schemas,
# _check_schema should raise an exception
with pytest.raises(ValueError, match='Incompatible schema assignment'):
_ = Engine(
processes={
'procA1': ProcessA(
{
'ports_schema_dict': {
'a': {
'_value': np.ones(10)
}
}
}
),
'procA2': ProcessA(
{
'ports_schema_dict': {
'a': {
'_value': np.zeros(10)
}
}
}
),
},
topology={
'procA1': {'a': ('a',)},
'procA2': {'a': ('a',)},
}
)

# Matching schemas should raise no warnings or exceptions
with warnings.catch_warnings():
warnings.simplefilter("error")
_ = Engine(
processes={
'procA1': ProcessA(
{
'ports_schema_dict': {
'a': {
'_value': np.ones(10),
'_divider': {
'divider': 'set_value',
'config': {
'value': np.ones(10)
}
}
}
}
}
),
'procA2': ProcessA(
{
'ports_schema_dict': {
'a': {
'_value': np.ones(10),
'_divider': {
'divider': 'set_value',
'config': {
'value': np.ones(10)
}
}
}
}
}
),
},
topology={
'procA1': {'a': ('a',)},
'procA2': {'a': ('a',)},
}
)
20 changes: 14 additions & 6 deletions vivarium/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from vivarium.core.registry import divider_registry, serializer_registry, updater_registry
from vivarium.core.process import ParallelProcess, Process
from vivarium.library.dict_utils import deep_merge, deep_merge_check, MULTI_UPDATE_KEY
from vivarium.library.dict_utils import deep_compare, deep_merge, deep_merge_check, MULTI_UPDATE_KEY
from vivarium.library.topology import dict_to_paths
from vivarium.core.types import Processes, Topology, State, Steps, Flow
from vivarium.core.serialize import QuantitySerializer
Expand Down Expand Up @@ -499,9 +499,13 @@ def _check_schema(self, schema_key, new_schema):
value is different from the existing one.
"""
current_schema_value = getattr(self, schema_key)
if current_schema_value is not None and np.all(
current_schema_value != new_schema
):
if isinstance(current_schema_value, dict) and isinstance(new_schema, dict):
schemas_equal = deep_compare(current_schema_value, new_schema)
elif isinstance(current_schema_value, np.ndarray) and isinstance(new_schema, np.ndarray):
schemas_equal = np.array_equal(current_schema_value, new_schema)
else:
schemas_equal = (current_schema_value == new_schema)
if current_schema_value is not None and not schemas_equal:
if schema_key == "units":
# Different Python interpreters (inc. from multiprocessing with
# spawn start method) yield different hashes for the same value
Expand All @@ -527,10 +531,14 @@ def _check_schema_support_defaults(self, schema_key, new_schema, schema_registry
new_schema[schema_key], str):
new_schema[schema_key] = schema_registry.access(
new_schema[schema_key])
if isinstance(current_schema_value, dict) and isinstance(new_schema, dict):
schemas_equal = deep_compare(current_schema_value, new_schema)
else:
schemas_equal = (current_schema_value == new_schema)
if (
current_schema_value
and np.all(current_schema_value != DEFAULT_SCHEMA)
and np.all(current_schema_value != new_schema)):
and current_schema_value != DEFAULT_SCHEMA
and not schemas_equal):
warnings.warn(
f"Incompatible schema assignment at {self.path_for()}. "
f"Trying to assign the value {new_schema} to key {schema_key}, "
Expand Down
39 changes: 39 additions & 0 deletions vivarium/library/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import operator
import traceback
from typing import Optional, Any, Callable
import warnings

import numpy as np
from vivarium.library.units import Quantity


Expand All @@ -21,6 +23,43 @@ def merge_dicts(dicts):
return merge


def deep_compare(dct_1, dct_2, path=tuple()):
"""Recursively checks for equality between two dictionaries in a way
that supports Numpy arrays.
Args:
dct_1, dct_2, dictionaries to compare
path: If ``dct_1`` or ``dct_2`` are nested within larger dictionaries,
this is the path to them. This is normally an empty tuple
for the end user but is used for recursive calls
Returns:
True when two dictionaries are equal, False otherwise
Raises:
ValueError: Raised when conflicting values are found between
``dct_1`` and ``dct_2``
"""
key_diff = dct_1.keys() ^ dct_2.keys()
if len(key_diff) > 0:
warnings.warn(f'Unshared keys at {path}: {key_diff}')
return False
for key, val_1 in dct_1.items():
val_2 = dct_2[key]
if isinstance(val_1, dict) and isinstance(val_2, dict):
if not deep_compare(val_1, val_2, path + (key,)):
return False
elif isinstance(val_1, np.ndarray) and isinstance(val_2, np.ndarray):
if not np.array_equal(val_1, val_2):
warnings.warn(f'Dicts differ at {path}: {val_1}, {val_2}')
return False
else:
if not val_1 == val_2:
warnings.warn(f'Dicts differ at {path}: {val_1}, {val_2}')
return False
return True


def deep_merge_check(dct, merge_dct, check_equality=False, path=tuple()):
"""Recursively merge dictionaries with checks to avoid overwriting.
Expand Down

0 comments on commit 1016ee4

Please sign in to comment.