Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write scalar datasets with compound data type #1176

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
- Improved "already exists" error message when adding a container to a `MultiContainerInterface`. @rly [#1165](https://github.com/hdmf-dev/hdmf/pull/1165)
- Added support to write multidimensional string arrays. @stephprince [#1173](https://github.com/hdmf-dev/hdmf/pull/1173)

### Bug fixes
- Fixed issue where scalar datasets with a compound data type were being written as non-scalar datasets @stephprince [#1176](https://github.com/hdmf-dev/hdmf/pull/1176)

## HDMF 3.14.3 (July 29, 2024)

### Enhancements
Expand Down
4 changes: 4 additions & 0 deletions src/hdmf/backends/hdf5/h5tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,8 @@ def __read_dataset(self, h5obj, name=None):
d = ReferenceBuilder(target_builder)
kwargs['data'] = d
kwargs['dtype'] = d.dtype
elif h5obj.dtype.kind == 'V': # scalar compound data type
kwargs['data'] = np.array(scalar, dtype=h5obj.dtype)
else:
kwargs["data"] = scalar
else:
Expand Down Expand Up @@ -1227,6 +1229,8 @@ def _filler():

return
# If the compound data type contains only regular data (i.e., no references) then we can write it as usual
elif len(np.shape(data)) == 0:
dset = self.__scalar_fill__(parent, name, data, options)
else:
dset = self.__list_fill__(parent, name, data, options)
# Write a dataset containing references, i.e., a region or object reference.
Expand Down
6 changes: 1 addition & 5 deletions src/hdmf/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,8 @@ def __repr__(self):
template += "\nFields:\n"
for k in sorted(self.fields): # sorted to enable tests
v = self.fields[k]
# if isinstance(v, DataIO) or not hasattr(v, '__len__') or len(v) > 0:
if hasattr(v, '__len__'):
if isinstance(v, (np.ndarray, list, tuple)):
if len(v) > 0:
template += " {}: {}\n".format(k, self.__smart_str(v, 1))
elif v:
if isinstance(v, (np.ndarray, list, tuple)) or v:
template += " {}: {}\n".format(k, self.__smart_str(v, 1))
else:
template += " {}: {}\n".format(k, v)
Expand Down
13 changes: 9 additions & 4 deletions src/hdmf/validate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_type(data, builder_dtype=None):
elif isinstance(data, ReferenceResolver):
return data.dtype, None
# Numpy nd-array data
elif isinstance(data, np.ndarray):
elif isinstance(data, np.ndarray) and len(data.dtype) <= 1:
if data.size > 0:
return get_type(data[0], builder_dtype)
else:
Expand All @@ -147,11 +147,14 @@ def get_type(data, builder_dtype=None):
# Case for h5py.Dataset and other I/O specific array types
else:
# Compound dtype
if builder_dtype and isinstance(builder_dtype, list):
if builder_dtype and len(builder_dtype) > 1:
dtypes = []
string_formats = []
for i in range(len(builder_dtype)):
dtype, string_format = get_type(data[0][i])
if len(np.shape(data)) == 0:
dtype, string_format = get_type(data[()][i])
else:
dtype, string_format = get_type(data[0][i])
dtypes.append(dtype)
string_formats.append(string_format)
return dtypes, string_formats
Expand Down Expand Up @@ -438,7 +441,9 @@ def validate(self, **kwargs):
except EmptyArrayError:
# do not validate dtype of empty array. HDMF does not yet set dtype when writing a list/tuple
pass
if isinstance(builder.dtype, list):
if builder.dtype is not None and len(builder.dtype) > 1 and len(np.shape(builder.data)) == 0:
shape = () # scalar compound dataset
elif isinstance(builder.dtype, list):
shape = (len(builder.data), ) # only 1D datasets with compound types are supported
else:
shape = get_data_shape(data)
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_io_hdf5_h5tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def test_write_dataset_string(self):
read_a = read_a.decode('utf-8')
self.assertEqual(read_a, a)

def test_write_dataset_scalar_compound(self):
cmpd_dtype = np.dtype([('x', np.int32), ('y', np.float64)])
a = np.array((1, 0.1), dtype=cmpd_dtype)
self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a,
dtype=[DtypeSpec('x', doc='x', dtype='int32'),
DtypeSpec('y', doc='y', dtype='float64')]))
dset = self.f['test_dataset']
self.assertTupleEqual(dset.shape, ())
self.assertEqual(dset[()].tolist(), a.tolist())

##########################################
# write_dataset tests: TermSetWrapper
##########################################
Expand Down Expand Up @@ -787,6 +797,17 @@ def test_read_str(self):
self.assertEqual(str(bldr['test_dataset'].data),
'<HDF5 dataset "test_dataset": shape (5,), type "|O">')

def test_read_scalar_compound(self):
cmpd_dtype = np.dtype([('x', np.int32), ('y', np.float64)])
a = np.array((1, 0.1), dtype=cmpd_dtype)
self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a,
dtype=[DtypeSpec('x', doc='x', dtype='int32'),
DtypeSpec('y', doc='y', dtype='float64')]))
self.io.close()
with HDF5IO(self.path, 'r') as io:
bldr = io.read_builder()
np.testing.assert_array_equal(bldr['test_dataset'].data[()], a)


class TestRoundTrip(TestCase):

Expand Down
22 changes: 22 additions & 0 deletions tests/unit/validator_tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,28 @@ def test_np_bool_for_bool(self):
results = self.vmap.validate(bar_builder)
self.assertEqual(len(results), 0)

def test_scalar_compound_dtype(self):
"""Test that validator allows scalar compound dtype data where a compound dtype is specified."""
spec_catalog = SpecCatalog()
dtype = [DtypeSpec('x', doc='x', dtype='int'), DtypeSpec('y', doc='y', dtype='float')]
spec = GroupSpec('A test group specification with a data type',
data_type_def='Bar',
datasets=[DatasetSpec('an example dataset', dtype, name='data',)],
attributes=[AttributeSpec('attr1', 'an example attribute', 'text',)])
spec_catalog.register_spec(spec, 'test2.yaml')
self.namespace = SpecNamespace(
'a test namespace', CORE_NAMESPACE, [{'source': 'test2.yaml'}], version='0.1.0', catalog=spec_catalog)
self.vmap = ValidatorMap(self.namespace)

value = np.array((1, 2.2), dtype=[('x', 'int'), ('y', 'float')])
bar_builder = GroupBuilder('my_bar',
attributes={'data_type': 'Bar', 'attr1': 'test'},
datasets=[DatasetBuilder(name='data',
data=value,
dtype=[DtypeSpec('x', doc='x', dtype='int'),
DtypeSpec('y', doc='y', dtype='float'),],),])
results = self.vmap.validate(bar_builder)
self.assertEqual(len(results), 0)

class Test1DArrayValidation(TestCase):

Expand Down
Loading