From be5265273663738924801050d8e26fa3f34e5f40 Mon Sep 17 00:00:00 2001 From: Ryan Ly Date: Sat, 12 Aug 2023 15:19:57 -0700 Subject: [PATCH] Refactor DTR warning (#917) * Refactor DTR warning * Fix --------- Co-authored-by: Matthew Avaylon --- src/hdmf/common/table.py | 20 ++++++ src/hdmf/container.py | 25 +++++-- tests/unit/common/test_linkedtables.py | 90 +++++++++++++++----------- tests/unit/common/test_table.py | 2 +- tests/unit/test_container.py | 12 ++++ 5 files changed, 106 insertions(+), 43 deletions(-) diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 1b4fe76d1..cafd8ff16 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -1421,6 +1421,26 @@ def __repr__(self): id(self.table)) return template + def _validate_on_set_parent(self): + # when this DynamicTableRegion is added to a parent, check: + # 1) if the table was read from a written file, no need to validate further + p = self.table + while p is not None: + if p.container_source is not None: + return super()._validate_on_set_parent() + p = p.parent + + # 2) if none of the ancestors are ancestors of the linked-to table, then when this is written, the table + # field will point to a table that is not in the file + table_ancestor_ids = [id(x) for x in self.table.get_ancestors()] + self_ancestor_ids = [id(x) for x in self.get_ancestors()] + + if set(table_ancestor_ids).isdisjoint(self_ancestor_ids): + msg = (f"The linked table for DynamicTableRegion '{self.name}' does not share an ancestor with the " + "DynamicTableRegion.") + warn(msg) + return super()._validate_on_set_parent() + def _uint_precision(elements): """ Calculate the uint precision needed to encode a set of elements """ diff --git a/src/hdmf/container.py b/src/hdmf/container.py index a10102421..84533220a 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -302,6 +302,15 @@ def get_ancestor(self, **kwargs): p = p.parent return None + @docval() + def get_ancestors(self, **kwargs): + p = self.parent + ret = [] + while p is not None: + ret.append(p) + p = p.parent + return tuple(ret) + @property def fields(self): ''' @@ -414,12 +423,8 @@ def parent(self, parent_container): parent_container.__children.append(self) parent_container.set_modified() for child in self.children: - if type(child).__name__ == "DynamicTableRegion": - if child.table.parent is None: - msg = "The table for this DynamicTableRegion has not been added to the parent." - warn(msg) - else: - continue + # used by hdmf.common.table.DynamicTableRegion to check for orphaned tables + child._validate_on_set_parent() def _remove_child(self, child): """Remove a child Container. Intended for use in subclasses that allow dynamic addition of child Containers.""" @@ -445,6 +450,14 @@ def reset_parent(self): else: raise ValueError("Cannot reset parent when parent is not an AbstractContainer: %s" % repr(self.parent)) + def _validate_on_set_parent(self): + """Validate this Container after setting the parent. + + This method is called by the parent setter. It can be overridden in subclasses to perform additional + validation. The default implementation does nothing. + """ + pass + class Container(AbstractContainer): """A container that can contain other containers and has special functionality for printing.""" diff --git a/tests/unit/common/test_linkedtables.py b/tests/unit/common/test_linkedtables.py index 25a80efa1..3c1c63170 100644 --- a/tests/unit/common/test_linkedtables.py +++ b/tests/unit/common/test_linkedtables.py @@ -2,6 +2,7 @@ Module for testing functions specific to tables containing DynamicTableRegion columns """ +import warnings import numpy as np from hdmf.common import DynamicTable, AlignedDynamicTable, VectorData, DynamicTableRegion, VectorIndex from hdmf.testing import TestCase @@ -139,11 +140,16 @@ def setUp(self): description='filter value', index=False) # Aligned table - self.aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - columns=[VectorData(name='a1', description='a1', data=np.arange(3)), ], - colnames=['a1', ], - category_tables=[self.category0, self.category1]) + with warnings.catch_warnings(): + msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion." + warnings.filterwarnings("ignore", category=UserWarning, message=msg) + self.aligned_table = AlignedDynamicTable( + name='my_aligned_table', + description='my test table', + columns=[VectorData(name='a1', description='a1', data=np.arange(3)), ], + colnames=['a1', ], + category_tables=[self.category0, self.category1] + ) def tearDown(self): del self.table_level0_0 @@ -241,13 +247,16 @@ def test_get_foreign_column_in_main_and_category_table(self): columns=[VectorData(name='c1', description='c1', data=np.arange(4)), DynamicTableRegion(name='c2', description='c2', data=np.arange(4), table=temp_table0)]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='a2', description='c2', - data=np.arange(4), table=temp_table)]) + with warnings.catch_warnings(): + msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion." + warnings.filterwarnings("ignore", category=UserWarning, message=msg) + temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', + description='my test table', + category_tables=[temp_table], + colnames=['a1', 'a2'], + columns=[VectorData(name='a1', description='c1', data=np.arange(4)), + DynamicTableRegion(name='a2', description='c2', + data=np.arange(4), table=temp_table)]) # We should get both the DynamicTableRegion from the main table and the category 't1' self.assertListEqual(temp_aligned_table.get_foreign_columns(), [(None, 'a2'), ('t1', 'c2')]) # We should only get the column from the main table @@ -275,12 +284,15 @@ def test_get_linked_tables_none(self): colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - VectorData(name='a2', description='c2', data=np.arange(4))]) + with warnings.catch_warnings(): + msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion." + warnings.filterwarnings("ignore", category=UserWarning, message=msg) + temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', + description='my test table', + category_tables=[temp_table], + colnames=['a1', 'a2'], + columns=[VectorData(name='a1', description='c1', data=np.arange(4)), + VectorData(name='a2', description='c2', data=np.arange(4))]) self.assertListEqual(temp_aligned_table.get_linked_tables(), []) self.assertListEqual(temp_aligned_table.get_linked_tables(ignore_category_tables=True), []) @@ -294,13 +306,16 @@ def test_get_linked_tables_complex_link(self): columns=[VectorData(name='c1', description='c1', data=np.arange(4)), DynamicTableRegion(name='c2', description='c2', data=np.arange(4), table=temp_table0)]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='a2', description='c2', - data=np.arange(4), table=temp_table)]) + with warnings.catch_warnings(): + msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion." + warnings.filterwarnings("ignore", category=UserWarning, message=msg) + temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', + description='my test table', + category_tables=[temp_table], + colnames=['a1', 'a2'], + columns=[VectorData(name='a1', description='c1', data=np.arange(4)), + DynamicTableRegion(name='a2', description='c2', + data=np.arange(4), table=temp_table)]) # NOTE: in this example templ_aligned_table both points to temp_table and at the # same time contains temp_table as a category. This could lead to temp_table # visited multiple times and we want to make sure this doesn't happen @@ -326,17 +341,20 @@ def test_get_linked_tables_simple_link(self): columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) temp_table = DynamicTable(name='t1', description='t1', - colnames=['c1', 'c2'], - columns=[VectorData(name='c1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='c2', description='c2', - data=np.arange(4), table=temp_table0)]) - temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', - description='my test table', - category_tables=[temp_table], - colnames=['a1', 'a2'], - columns=[VectorData(name='a1', description='c1', data=np.arange(4)), - DynamicTableRegion(name='a2', description='c2', - data=np.arange(4), table=temp_table0)]) + colnames=['c1', 'c2'], + columns=[VectorData(name='c1', description='c1', data=np.arange(4)), + DynamicTableRegion(name='c2', description='c2', + data=np.arange(4), table=temp_table0)]) + with warnings.catch_warnings(): + msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion." + warnings.filterwarnings("ignore", category=UserWarning, message=msg) + temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', + description='my test table', + category_tables=[temp_table], + colnames=['a1', 'a2'], + columns=[VectorData(name='a1', description='c1', data=np.arange(4)), + DynamicTableRegion(name='a2', description='c2', + data=np.arange(4), table=temp_table0)]) # NOTE: in this example temp_aligned_table and temp_table both point to temp_table0 # We should get both the DynamicTableRegion from the main table and the category 't1' linked_tables = temp_aligned_table.get_linked_tables() diff --git a/tests/unit/common/test_table.py b/tests/unit/common/test_table.py index af6b6357e..311e01f8b 100644 --- a/tests/unit/common/test_table.py +++ b/tests/unit/common/test_table.py @@ -1124,7 +1124,7 @@ def setUp(self): super().setUp() def setUpContainer(self): - multi_container = SimpleMultiContainer(name='multi', containers=[self.table, self.target_table]) + multi_container = SimpleMultiContainer(name='multi', containers=[self.target_table, self.table]) return multi_container def _get(self, arg): diff --git a/tests/unit/test_container.py b/tests/unit/test_container.py index 9bbbb9f82..5c71688ff 100644 --- a/tests/unit/test_container.py +++ b/tests/unit/test_container.py @@ -382,6 +382,18 @@ def test_reset_parent_no_parent(self): obj.reset_parent() self.assertIsNone(obj.parent) + def test_get_ancestors(self): + """Test that get_ancestors returns the correct ancestors. + """ + grandparent_obj = Container('obj1') + parent_obj = Container('obj2') + child_obj = Container('obj3') + parent_obj.parent = grandparent_obj + child_obj.parent = parent_obj + self.assertTupleEqual(grandparent_obj.get_ancestors(), tuple()) + self.assertTupleEqual(parent_obj.get_ancestors(), (grandparent_obj, )) + self.assertTupleEqual(child_obj.get_ancestors(), (parent_obj, grandparent_obj)) + class TestHTMLRepr(TestCase):