Skip to content

Commit

Permalink
refactor(enums): fix linter warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Sep 20, 2024
1 parent dad88a1 commit 25cb01f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 16 deletions.
16 changes: 11 additions & 5 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def encode(
return array

# String array
if isinstance(array, numpy.ndarray) and array.dtype.kind in {"U", "S"}:
if array.dtype.kind in {"U", "S"}:
array = numpy.select(
[array == item.name for item in cls],
[item.index for item in cls],
).astype(t.ArrayEnum)

# Enum items arrays
elif isinstance(array, numpy.ndarray) and array.dtype.kind == "O":
elif array.dtype.kind == "O":
# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
# variable.possible_values, while the array values may come from
Expand All @@ -87,12 +87,18 @@ def encode(
# name to check that the values in the array, if non-empty, are of
# the right type.
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
cls = array[0].__class__
klass = array[0].__class__

else:
klass = cls

array = numpy.select(
[array == item for item in cls],
[item.index for item in cls],
[array == item for item in klass],
[item.index for item in klass],
).astype(t.ArrayEnum)

array = numpy.asarray(array, dtype=t.ArrayEnum)
return EnumArray(array, cls)


__all__ = ["Enum"]
18 changes: 14 additions & 4 deletions openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,26 @@ def _is_an_enum(self, other: object) -> TypeGuard[t.Enum]:
if self.possible_values is None:
raise NotImplementedError

if other is None:
raise NotImplementedError

return (
not hasattr(other, "__name__")
and other.__class__.__name__ is self.possible_values.__name__
)

def _is_an_enum_type(self, other: object) -> TypeGuard[type[t.Enum]]:
name: None | str

if self.possible_values is None:
raise NotImplementedError

return (
hasattr(other, "__name__")
and other.__name__ is self.possible_values.__name__
)
if other is None:
raise NotImplementedError

name = getattr(other, "__name__", None)

return isinstance(name, str) and name is self.possible_values.__name__


__all__ = ["EnumArray"]
29 changes: 22 additions & 7 deletions openfisca_core/indexed_enums/types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from openfisca_core.types import Array as Array
from openfisca_core.types import ArrayAny as ArrayAny # noqa: F401
from openfisca_core.types import ArrayBool as ArrayBool # noqa: F401
from openfisca_core.types import ArrayBytes as ArrayBytes # noqa: F401
from openfisca_core.types import ArrayEnum as ArrayEnum
from openfisca_core.types import ArrayInt as ArrayInt # noqa: F401
from openfisca_core.types import ArrayStr as ArrayStr # noqa: F401
from openfisca_core.types import (
Array,
ArrayAny,
ArrayBool,
ArrayBytes,
ArrayEnum,
ArrayInt,
ArrayStr,
)

import abc
import enum
Expand All @@ -20,3 +22,16 @@ class Enum(enum.Enum):

class EnumArray(Array[ArrayEnum], metaclass=abc.ABCMeta):
...


__all__ = [
"Array",
"ArrayAny",
"ArrayBool",
"ArrayBytes",
"ArrayEnum",
"ArrayInt",
"ArrayStr",
"Enum",
"EnumArray",
]

0 comments on commit 25cb01f

Please sign in to comment.