Skip to content

Commit

Permalink
ak.with_named_axis: add check to validate the given named axis mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Sep 23, 2024
1 parent 01b459c commit c970d06
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
37 changes: 35 additions & 2 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,46 @@ def _check_valid_axis(axis: AxisName) -> AxisName:
>>> _check_valid_axis(1)
Traceback (most recent call last):
...
ValueError: Axis names must be hashable and not int, got 1
ValueError: Axis names must be hashable and not int, got 1 [type(axis)=<class 'int'>]
"""
if not _is_valid_named_axis(axis):
raise ValueError(f"Axis names must be hashable and not int, got {axis!r}")
raise ValueError(
f"Axis names must be hashable and not int, got {axis!r} [{type(axis)=}]"
)
return axis


def _check_valid_named_axis_mapping(named_axis: AxisMapping) -> AxisMapping:
"""
Checks if the given named axis mapping is valid. A valid named axis mapping is a dictionary where the keys are valid named axes
(hashable objects that are not integers) and the values are integers.
Args:
named_axis (AxisMapping): The named axis mapping to check.
Raises:
ValueError: If any of the keys in the named axis mapping is not a valid named axis or if any of the values is not an integer.
Examples:
>>> _check_valid_named_axis_mapping({"x": 0, "y": 1, "z": 2}) # No exception is raised
>>> _check_valid_named_axis_mapping({"x": 0, "y": 1, "z": "2"})
Traceback (most recent call last):
...
ValueError: Named axis mapping values must be integers, got '2' [type(axis)=<class 'str'>]
>>> _check_valid_named_axis_mapping({"x": 0, 1: 1, "z": 2})
Traceback (most recent call last):
...
ValueError: Axis names must be hashable and not int, got 1 [type(axis)=<class 'int'>]
"""
for name, axis in named_axis.items():
_check_valid_axis(name)
if not is_integer(axis):
raise ValueError(
f"Named axis mapping values must be integers, got {axis!r} [{type(axis)=}]"
)
return named_axis


def _axis_tuple_to_mapping(axis_tuple: AxisTuple) -> AxisMapping:
"""
Converts a tuple of axis names to a dictionary mapping axis names to their positions.
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/operations/ak_with_named_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AxisMapping,
AxisTuple,
_axis_tuple_to_mapping,
_check_valid_named_axis_mapping,
_NamedAxisKey,
)
from awkward._nplikes.numpy_like import NumpyMetadata
Expand Down Expand Up @@ -79,7 +80,7 @@ def _impl(array, named_axis, highlevel, behavior, attrs):

return ctx.with_attr(
key=_NamedAxisKey,
value=_named_axis,
value=_check_valid_named_axis_mapping(_named_axis),
).wrap(
layout,
highlevel=highlevel,
Expand Down

0 comments on commit c970d06

Please sign in to comment.