Skip to content

Commit

Permalink
Refactor DTR warning (#917)
Browse files Browse the repository at this point in the history
* Refactor DTR warning

* Fix

---------

Co-authored-by: Matthew Avaylon <[email protected]>
  • Loading branch information
rly and mavaylon1 authored Aug 12, 2023
1 parent dd39b38 commit be52652
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 43 deletions.
20 changes: 20 additions & 0 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,26 @@ def __repr__(self):
id(self.table))
return template

def _validate_on_set_parent(self):
# when this DynamicTableRegion is added to a parent, check:
# 1) if the table was read from a written file, no need to validate further
p = self.table
while p is not None:
if p.container_source is not None:
return super()._validate_on_set_parent()
p = p.parent

# 2) if none of the ancestors are ancestors of the linked-to table, then when this is written, the table
# field will point to a table that is not in the file
table_ancestor_ids = [id(x) for x in self.table.get_ancestors()]
self_ancestor_ids = [id(x) for x in self.get_ancestors()]

if set(table_ancestor_ids).isdisjoint(self_ancestor_ids):
msg = (f"The linked table for DynamicTableRegion '{self.name}' does not share an ancestor with the "
"DynamicTableRegion.")
warn(msg)
return super()._validate_on_set_parent()


def _uint_precision(elements):
""" Calculate the uint precision needed to encode a set of elements """
Expand Down
25 changes: 19 additions & 6 deletions src/hdmf/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ def get_ancestor(self, **kwargs):
p = p.parent
return None

@docval()
def get_ancestors(self, **kwargs):
p = self.parent
ret = []
while p is not None:
ret.append(p)
p = p.parent
return tuple(ret)

@property
def fields(self):
'''
Expand Down Expand Up @@ -414,12 +423,8 @@ def parent(self, parent_container):
parent_container.__children.append(self)
parent_container.set_modified()
for child in self.children:
if type(child).__name__ == "DynamicTableRegion":
if child.table.parent is None:
msg = "The table for this DynamicTableRegion has not been added to the parent."
warn(msg)
else:
continue
# used by hdmf.common.table.DynamicTableRegion to check for orphaned tables
child._validate_on_set_parent()

def _remove_child(self, child):
"""Remove a child Container. Intended for use in subclasses that allow dynamic addition of child Containers."""
Expand All @@ -445,6 +450,14 @@ def reset_parent(self):
else:
raise ValueError("Cannot reset parent when parent is not an AbstractContainer: %s" % repr(self.parent))

def _validate_on_set_parent(self):
"""Validate this Container after setting the parent.
This method is called by the parent setter. It can be overridden in subclasses to perform additional
validation. The default implementation does nothing.
"""
pass


class Container(AbstractContainer):
"""A container that can contain other containers and has special functionality for printing."""
Expand Down
90 changes: 54 additions & 36 deletions tests/unit/common/test_linkedtables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Module for testing functions specific to tables containing DynamicTableRegion columns
"""

import warnings
import numpy as np
from hdmf.common import DynamicTable, AlignedDynamicTable, VectorData, DynamicTableRegion, VectorIndex
from hdmf.testing import TestCase
Expand Down Expand Up @@ -139,11 +140,16 @@ def setUp(self):
description='filter value',
index=False)
# Aligned table
self.aligned_table = AlignedDynamicTable(name='my_aligned_table',
description='my test table',
columns=[VectorData(name='a1', description='a1', data=np.arange(3)), ],
colnames=['a1', ],
category_tables=[self.category0, self.category1])
with warnings.catch_warnings():
msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion."
warnings.filterwarnings("ignore", category=UserWarning, message=msg)
self.aligned_table = AlignedDynamicTable(
name='my_aligned_table',
description='my test table',
columns=[VectorData(name='a1', description='a1', data=np.arange(3)), ],
colnames=['a1', ],
category_tables=[self.category0, self.category1]
)

def tearDown(self):
del self.table_level0_0
Expand Down Expand Up @@ -241,13 +247,16 @@ def test_get_foreign_column_in_main_and_category_table(self):
columns=[VectorData(name='c1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='c2', description='c2',
data=np.arange(4), table=temp_table0)])
temp_aligned_table = AlignedDynamicTable(name='my_aligned_table',
description='my test table',
category_tables=[temp_table],
colnames=['a1', 'a2'],
columns=[VectorData(name='a1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='a2', description='c2',
data=np.arange(4), table=temp_table)])
with warnings.catch_warnings():
msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion."
warnings.filterwarnings("ignore", category=UserWarning, message=msg)
temp_aligned_table = AlignedDynamicTable(name='my_aligned_table',
description='my test table',
category_tables=[temp_table],
colnames=['a1', 'a2'],
columns=[VectorData(name='a1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='a2', description='c2',
data=np.arange(4), table=temp_table)])
# We should get both the DynamicTableRegion from the main table and the category 't1'
self.assertListEqual(temp_aligned_table.get_foreign_columns(), [(None, 'a2'), ('t1', 'c2')])
# We should only get the column from the main table
Expand Down Expand Up @@ -275,12 +284,15 @@ def test_get_linked_tables_none(self):
colnames=['c1', 'c2'],
columns=[VectorData(name='c1', description='c1', data=np.arange(4)),
VectorData(name='c2', description='c2', data=np.arange(4))])
temp_aligned_table = AlignedDynamicTable(name='my_aligned_table',
description='my test table',
category_tables=[temp_table],
colnames=['a1', 'a2'],
columns=[VectorData(name='a1', description='c1', data=np.arange(4)),
VectorData(name='a2', description='c2', data=np.arange(4))])
with warnings.catch_warnings():
msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion."
warnings.filterwarnings("ignore", category=UserWarning, message=msg)
temp_aligned_table = AlignedDynamicTable(name='my_aligned_table',
description='my test table',
category_tables=[temp_table],
colnames=['a1', 'a2'],
columns=[VectorData(name='a1', description='c1', data=np.arange(4)),
VectorData(name='a2', description='c2', data=np.arange(4))])
self.assertListEqual(temp_aligned_table.get_linked_tables(), [])
self.assertListEqual(temp_aligned_table.get_linked_tables(ignore_category_tables=True), [])

Expand All @@ -294,13 +306,16 @@ def test_get_linked_tables_complex_link(self):
columns=[VectorData(name='c1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='c2', description='c2',
data=np.arange(4), table=temp_table0)])
temp_aligned_table = AlignedDynamicTable(name='my_aligned_table',
description='my test table',
category_tables=[temp_table],
colnames=['a1', 'a2'],
columns=[VectorData(name='a1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='a2', description='c2',
data=np.arange(4), table=temp_table)])
with warnings.catch_warnings():
msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion."
warnings.filterwarnings("ignore", category=UserWarning, message=msg)
temp_aligned_table = AlignedDynamicTable(name='my_aligned_table',
description='my test table',
category_tables=[temp_table],
colnames=['a1', 'a2'],
columns=[VectorData(name='a1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='a2', description='c2',
data=np.arange(4), table=temp_table)])
# NOTE: in this example templ_aligned_table both points to temp_table and at the
# same time contains temp_table as a category. This could lead to temp_table
# visited multiple times and we want to make sure this doesn't happen
Expand All @@ -326,17 +341,20 @@ def test_get_linked_tables_simple_link(self):
columns=[VectorData(name='c1', description='c1', data=np.arange(4)),
VectorData(name='c2', description='c2', data=np.arange(4))])
temp_table = DynamicTable(name='t1', description='t1',
colnames=['c1', 'c2'],
columns=[VectorData(name='c1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='c2', description='c2',
data=np.arange(4), table=temp_table0)])
temp_aligned_table = AlignedDynamicTable(name='my_aligned_table',
description='my test table',
category_tables=[temp_table],
colnames=['a1', 'a2'],
columns=[VectorData(name='a1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='a2', description='c2',
data=np.arange(4), table=temp_table0)])
colnames=['c1', 'c2'],
columns=[VectorData(name='c1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='c2', description='c2',
data=np.arange(4), table=temp_table0)])
with warnings.catch_warnings():
msg = "The linked table for DynamicTableRegion '.*' does not share an ancestor with the DynamicTableRegion."
warnings.filterwarnings("ignore", category=UserWarning, message=msg)
temp_aligned_table = AlignedDynamicTable(name='my_aligned_table',
description='my test table',
category_tables=[temp_table],
colnames=['a1', 'a2'],
columns=[VectorData(name='a1', description='c1', data=np.arange(4)),
DynamicTableRegion(name='a2', description='c2',
data=np.arange(4), table=temp_table0)])
# NOTE: in this example temp_aligned_table and temp_table both point to temp_table0
# We should get both the DynamicTableRegion from the main table and the category 't1'
linked_tables = temp_aligned_table.get_linked_tables()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/common/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def setUp(self):
super().setUp()

def setUpContainer(self):
multi_container = SimpleMultiContainer(name='multi', containers=[self.table, self.target_table])
multi_container = SimpleMultiContainer(name='multi', containers=[self.target_table, self.table])
return multi_container

def _get(self, arg):
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,18 @@ def test_reset_parent_no_parent(self):
obj.reset_parent()
self.assertIsNone(obj.parent)

def test_get_ancestors(self):
"""Test that get_ancestors returns the correct ancestors.
"""
grandparent_obj = Container('obj1')
parent_obj = Container('obj2')
child_obj = Container('obj3')
parent_obj.parent = grandparent_obj
child_obj.parent = parent_obj
self.assertTupleEqual(grandparent_obj.get_ancestors(), tuple())
self.assertTupleEqual(parent_obj.get_ancestors(), (grandparent_obj, ))
self.assertTupleEqual(child_obj.get_ancestors(), (parent_obj, grandparent_obj))


class TestHTMLRepr(TestCase):

Expand Down

0 comments on commit be52652

Please sign in to comment.