Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post init option for class generator #1089

Merged
merged 28 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Added `TypeConfigurator` to automatically wrap fields with `TermSetWrapper` according to a configuration file. @mavaylon1 [#1016](https://github.com/hdmf-dev/hdmf/pull/1016)
- Updated `TermSetWrapper` to support validating a single field within a compound array. @mavaylon1 [#1061](https://github.com/hdmf-dev/hdmf/pull/1061)
- Updated testing to not install in editable mode and not run `coverage` by default. @rly [#1107](https://github.com/hdmf-dev/hdmf/pull/1107)
- Add `post_init_method` parameter when generating classes to perform post-init functionality, i.e., validation. @mavaylon1 [#1089](https://github.com/hdmf-dev/hdmf/pull/1089)

## HDMF 3.13.0 (March 20, 2024)

Expand Down
32 changes: 29 additions & 3 deletions src/hdmf/build/classgenerator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
from datetime import datetime, date
from collections.abc import Callable

import numpy as np

Expand Down Expand Up @@ -35,15 +36,19 @@ def register_generator(self, **kwargs):
{'name': 'spec', 'type': BaseStorageSpec, 'doc': ''},
{'name': 'parent_cls', 'type': type, 'doc': ''},
{'name': 'attr_names', 'type': dict, 'doc': ''},
{'name': 'post_init_method', 'type': Callable, 'default': None,
'doc': 'The function used as a post_init method to validate the class generation.'},
{'name': 'type_map', 'type': 'hdmf.build.manager.TypeMap', 'doc': ''},
returns='the class for the given namespace and data_type', rtype=type)
def generate_class(self, **kwargs):
"""Get the container class from data type specification.
If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically
created and returned.
"""
data_type, spec, parent_cls, attr_names, type_map = getargs('data_type', 'spec', 'parent_cls', 'attr_names',
'type_map', kwargs)
data_type, spec, parent_cls, attr_names, type_map, post_init_method = getargs('data_type', 'spec',
'parent_cls', 'attr_names',
'type_map',
'post_init_method', kwargs)

not_inherited_fields = dict()
for k, field_spec in attr_names.items():
Expand Down Expand Up @@ -82,6 +87,8 @@ def generate_class(self, **kwargs):
+ str(e)
+ " Please define that type before defining '%s'." % name)
cls = ExtenderMeta(data_type, tuple(bases), classdict)
cls.post_init_method = post_init_method

return cls


Expand Down Expand Up @@ -316,8 +323,19 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name):
elif attr_name not in attrs_not_to_set:
attrs_to_set.append(attr_name)

@docval(*docval_args, allow_positional=AllowPositional.WARNING)
# We want to use the skip_post_init of the current class and not the parent class
for item in docval_args:
if item['name'] == 'skip_post_init':
docval_args.remove(item)

@docval(*docval_args,
{'name': 'skip_post_init', 'type': bool, 'default': False,
'doc': 'bool to skip post_init'},
allow_positional=AllowPositional.WARNING)
def __init__(self, **kwargs):
skip_post_init = popargs('skip_post_init', kwargs)

original_kwargs = dict(kwargs)
if name is not None: # force container name to be the fixed name in the spec
kwargs.update(name=name)

Expand All @@ -343,6 +361,9 @@ def __init__(self, **kwargs):
for f in fixed_value_attrs_to_set:
self.fields[f] = getattr(not_inherited_fields[f], 'value')

if self.post_init_method is not None and not skip_post_init:
self.post_init_method(**original_kwargs)

classdict['__init__'] = __init__


Expand Down Expand Up @@ -417,6 +438,7 @@ def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name):
def __init__(self, **kwargs):
# store the values passed to init for each MCI attribute so that they can be added
# after calling __init__
original_kwargs = dict(kwargs)
new_kwargs = list()
for field_clsconf in classdict['__clsconf__']:
attr_name = field_clsconf['attr']
Expand All @@ -437,12 +459,16 @@ def __init__(self, **kwargs):
kwargs[attr_name] = list()

# call the parent class init without the MCI attribute
kwargs['skip_post_init'] = True
previous_init(self, **kwargs)

# call the add method for each MCI attribute
for new_kwarg in new_kwargs:
add_method = getattr(self, new_kwarg['add_method_name'])
add_method(new_kwarg['value'])

if self.post_init_method is not None:
self.post_init_method(**original_kwargs)

# override __init__
classdict['__init__'] = __init__
15 changes: 13 additions & 2 deletions src/hdmf/build/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from collections import OrderedDict, deque
from copy import copy
from collections.abc import Callable

from .builders import DatasetBuilder, GroupBuilder, LinkBuilder, Builder, BaseBuilder
from .classgenerator import ClassGenerator, CustomClassGenerator, MCIClassGenerator
Expand Down Expand Up @@ -498,11 +499,14 @@ def get_container_cls(self, **kwargs):
created and returned.
"""
# NOTE: this internally used function get_container_cls will be removed in favor of get_dt_container_cls
# Deprecated: Will be removed by HDMF 4.0
namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs)
return self.get_dt_container_cls(data_type, namespace, autogen)

@docval({"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"},
{"name": "namespace", "type": str, "doc": "the namespace containing the data_type", "default": None},
{'name': 'post_init_method', 'type': Callable, 'default': None,
'doc': 'The function used as a post_init method to validate the class generation.'},
{"name": "autogen", "type": bool, "doc": "autogenerate class if one does not exist", "default": True},
returns='the class for the given namespace and data_type', rtype=type)
def get_dt_container_cls(self, **kwargs):
Expand All @@ -513,7 +517,8 @@ def get_dt_container_cls(self, **kwargs):
Replaces get_container_cls but namespace is optional. If namespace is unknown, it will be looked up from
all namespaces.
"""
namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs)
namespace, data_type, post_init_method, autogen = getargs('namespace', 'data_type',
'post_init_method','autogen', kwargs)

# namespace is unknown, so look it up
if namespace is None:
Expand All @@ -527,12 +532,18 @@ def get_dt_container_cls(self, **kwargs):
raise ValueError("Namespace could not be resolved.")

cls = self.__get_container_cls(namespace, data_type)

if cls is None and autogen: # dynamically generate a class
spec = self.__ns_catalog.get_spec(namespace, data_type)
self.__check_dependent_types(spec, namespace)
parent_cls = self.__get_parent_cls(namespace, data_type, spec)
attr_names = self.__default_mapper_cls.get_attr_names(spec)
cls = self.__class_generator.generate_class(data_type, spec, parent_cls, attr_names, self)
cls = self.__class_generator.generate_class(data_type=data_type,
spec=spec,
parent_cls=parent_cls,
attr_names=attr_names,
post_init_method=post_init_method,
type_map=self)
self.register_container_type(namespace, data_type, cls)
return cls

Expand Down
7 changes: 5 additions & 2 deletions src/hdmf/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''
import os.path
from copy import deepcopy
from collections.abc import Callable

CORE_NAMESPACE = 'hdmf-common'
EXP_NAMESPACE = 'hdmf-experimental'
Expand Down Expand Up @@ -136,12 +137,14 @@ def available_namespaces():
@docval({'name': 'data_type', 'type': str,
'doc': 'the data_type to get the Container class for'},
{'name': 'namespace', 'type': str, 'doc': 'the namespace the data_type is defined in'},
{'name': 'post_init_method', 'type': Callable, 'default': None,
'doc': 'The function used as a post_init method to validate the class generation.'},
is_method=False)
def get_class(**kwargs):
"""Get the class object of the Container subclass corresponding to a given neurdata_type.
"""
data_type, namespace = getargs('data_type', 'namespace', kwargs)
return __TYPE_MAP.get_dt_container_cls(data_type, namespace)
data_type, namespace, post_init_method = getargs('data_type', 'namespace', 'post_init_method', kwargs)
return __TYPE_MAP.get_dt_container_cls(data_type, namespace, post_init_method)


@docval({'name': 'extensions', 'type': (str, TypeMap, list),
Expand Down
91 changes: 85 additions & 6 deletions tests/unit/build_tests/test_classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import shutil
import tempfile
from warnings import warn

from hdmf.build import TypeMap, CustomClassGenerator
from hdmf.build.classgenerator import ClassGenerator, MCIClassGenerator
Expand Down Expand Up @@ -82,6 +83,79 @@ def test_no_generators(self):
self.assertTrue(hasattr(cls, '__init__'))


class TestPostInitGetClass(TestCase):
def setUp(self):
def post_init_method(self, **kwargs):
attr1 = kwargs['attr1']
if attr1<10:
msg = "attr1 should be >=10"
warn(msg)
self.post_init=post_init_method

def test_post_init(self):
spec = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz',
attributes=[
AttributeSpec(name='attr1', doc='a int attribute', dtype='int')
]
)

spec_catalog = SpecCatalog()
spec_catalog.register_spec(spec, 'test.yaml')
namespace = SpecNamespace(
doc='a test namespace',
name=CORE_NAMESPACE,
schema=[{'source': 'test.yaml'}],
version='0.1.0',
catalog=spec_catalog
)
namespace_catalog = NamespaceCatalog()
namespace_catalog.add_namespace(CORE_NAMESPACE, namespace)
type_map = TypeMap(namespace_catalog)

cls = type_map.get_dt_container_cls('Baz', CORE_NAMESPACE, self.post_init)

with self.assertWarns(Warning):
cls(name='instance', attr1=9)

def test_multi_container_post_init(self):
bar_spec = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Bar',
datasets=[
DatasetSpec(
doc='a dataset',
dtype='int',
name='data',
attributes=[AttributeSpec(name='attr2', doc='an integer attribute', dtype='int')]
)
],
attributes=[AttributeSpec(name='attr1', doc='a string attribute', dtype='text')])

multi_spec = GroupSpec(doc='A test extension that contains a multi',
data_type_def='Multi',
groups=[GroupSpec(data_type_inc=bar_spec, doc='test multi', quantity='*')],
attributes=[AttributeSpec(name='attr1', doc='a float attribute', dtype='float')])

spec_catalog = SpecCatalog()
spec_catalog.register_spec(bar_spec, 'test.yaml')
spec_catalog.register_spec(multi_spec, 'test.yaml')
namespace = SpecNamespace(
doc='a test namespace',
name=CORE_NAMESPACE,
schema=[{'source': 'test.yaml'}],
version='0.1.0',
catalog=spec_catalog
)
namespace_catalog = NamespaceCatalog()
namespace_catalog.add_namespace(CORE_NAMESPACE, namespace)
type_map = TypeMap(namespace_catalog)
Multi = type_map.get_dt_container_cls('Multi', CORE_NAMESPACE, self.post_init)

with self.assertWarns(Warning):
Multi(name='instance', attr1=9.1)

class TestDynamicContainer(TestCase):

def setUp(self):
Expand Down Expand Up @@ -109,13 +183,15 @@ def test_dynamic_container_creation(self):
AttributeSpec('attr4', 'another float attribute', 'float')])
self.spec_catalog.register_spec(baz_spec, 'extension.yaml')
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE)
expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4'}
expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'skip_post_init'}
received_args = set()

for x in get_docval(cls.__init__):
if x['name'] != 'foo':
received_args.add(x['name'])
with self.subTest(name=x['name']):
self.assertNotIn('default', x)
if x['name'] != 'skip_post_init':
self.assertNotIn('default', x)
self.assertSetEqual(expected_args, received_args)
self.assertEqual(cls.__name__, 'Baz')
self.assertTrue(issubclass(cls, Bar))
Expand All @@ -135,7 +211,7 @@ def test_dynamic_container_creation_defaults(self):
AttributeSpec('attr4', 'another float attribute', 'float')])
self.spec_catalog.register_spec(baz_spec, 'extension.yaml')
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE)
expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo'}
expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo', 'skip_post_init'}
received_args = set(map(lambda x: x['name'], get_docval(cls.__init__)))
self.assertSetEqual(expected_args, received_args)
self.assertEqual(cls.__name__, 'Baz')
Expand Down Expand Up @@ -285,13 +361,14 @@ def __init__(self, **kwargs):
AttributeSpec('attr4', 'another float attribute', 'float')])
self.spec_catalog.register_spec(baz_spec, 'extension.yaml')
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE)
expected_args = {'name', 'data', 'attr2', 'attr3', 'attr4'}
expected_args = {'name', 'data', 'attr2', 'attr3', 'attr4', 'skip_post_init'}
received_args = set()
for x in get_docval(cls.__init__):
if x['name'] != 'foo':
received_args.add(x['name'])
with self.subTest(name=x['name']):
self.assertNotIn('default', x)
if x['name'] != 'skip_post_init':
self.assertNotIn('default', x)
self.assertSetEqual(expected_args, received_args)
self.assertTrue(issubclass(cls, FixedAttrBar))
inst = cls(name="My Baz", data=[1, 2, 3, 4], attr2=1000, attr3=98.6, attr4=1.0)
Expand Down Expand Up @@ -445,7 +522,7 @@ def setUp(self):

def test_init_docval(self):
cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # generate the class
expected_args = {'name'} # 'attr1' should not be included
expected_args = {'name', 'skip_post_init'} # 'attr1' should not be included
received_args = set()
for x in get_docval(cls.__init__):
received_args.add(x['name'])
Expand Down Expand Up @@ -518,6 +595,8 @@ def test_gen_parent_class(self):
{'name': 'my_baz1', 'doc': 'A composition inside with a fixed name', 'type': baz1_cls},
{'name': 'my_baz2', 'doc': 'A composition inside with a fixed name', 'type': baz2_cls},
{'name': 'my_baz1_link', 'doc': 'A composition inside without a fixed name', 'type': baz1_cls},
{'name': 'skip_post_init', 'type': bool, 'default': False,
'doc': 'bool to skip post_init'}
))

def test_init_fields(self):
Expand Down
Loading