Skip to content

Commit

Permalink
test: __call__ (#1238 #1276 #1277)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 14, 2024
1 parent f24c898 commit eecc7ef
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 16 deletions.
11 changes: 9 additions & 2 deletions openfisca_core/populations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@

from . import types
from ._core_population import CorePopulation
from ._errors import InvalidArraySizeError, PeriodValidityError
from .config import ADD, DIVIDE
from ._errors import (
IncompatibleOptionsError,
InvalidArraySizeError,
InvalidOptionError,
PeriodValidityError,
)
from .group_population import GroupPopulation
from .population import Population

ADD, DIVIDE = types.Option
SinglePopulation = Population

__all__ = [
Expand All @@ -45,7 +50,9 @@
"EntityToPersonProjector",
"FirstPersonToEntityProjector",
"GroupPopulation",
"IncompatibleOptionsError",
"InvalidArraySizeError",
"InvalidOptionError",
"PeriodValidityError",
"Population",
"Projector",
Expand Down
99 changes: 87 additions & 12 deletions openfisca_core/populations/_core_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from openfisca_core import holders, periods

from . import types as t
from ._errors import InvalidArraySizeError, PeriodValidityError
from ._errors import (
IncompatibleOptionsError,
InvalidArraySizeError,
InvalidOptionError,
PeriodValidityError,
)

#: Type variable for a covariant data type.
_DT_co = TypeVar("_DT_co", covariant=True, bound=t.VarDType)
Expand Down Expand Up @@ -45,18 +50,87 @@ def __init__(self, entity: t.CoreEntity, *__args: object, **__kwds: object) -> N
def __call__(
self,
variable_name: t.VariableName,
period: None | t.PeriodLike = None,
period: t.PeriodLike,
options: None | Sequence[t.Option] = None,
) -> None | t.FloatArray:
) -> None | t.VarArray:
"""Calculate ``variable_name`` for ``period``, using the formula if it exists.
# Example:
# >>> person("salary", "2017-04")
# >>> array([300.0])
Args:
variable_name: The name of the variable to calculate.
period: The period to calculate the variable for.
options: The options to use for the calculation.
Returns:
None: If there is no :class:`.Simulation`.
ndarray[float32]: The result of the calculation.
ndarray[generic]: The result of the calculation.
Raises:
IncompatibleOptionsError: If the options are incompatible.
InvalidOptionError: If the option is invalid.
Examples:
>>> from openfisca_core import (
... entities,
... periods,
... populations,
... simulations,
... taxbenefitsystems,
... variables,
... )
>>> class Person(entities.SingleEntity): ...
>>> person = Person("person", "people", "", "")
>>> period = periods.Period.eternity()
>>> population = populations.CorePopulation(person)
>>> population.count = 3
>>> population("salary", period)
>>> tbs = taxbenefitsystems.TaxBenefitSystem([person])
>>> person.set_tax_benefit_system(tbs)
>>> simulation = simulations.Simulation(tbs, {person.key: population})
>>> population("salary", period)
Traceback (most recent call last):
VariableNotFoundError: You tried to calculate or to set a value ...
>>> class Salary(variables.Variable):
... definition_period = periods.ETERNITY
... entity = person
... value_type = int
>>> tbs.add_variable(Salary)
<openfisca_core.populations._core_population.Salary object at...
>>> population(Salary().name, period)
array([0, 0, 0], dtype=int32)
>>> class Tax(Salary):
... default_value = 100.0
... definition_period = periods.ETERNITY
... entity = person
... value_type = float
>>> tbs.add_variable(Tax)
<openfisca_core.populations._core_population.Tax object at...
>>> population(Tax().name, period)
array([100., 100., 100.], dtype=float32)
>>> population(Tax().name, period, [populations.ADD])
Traceback (most recent call last):
ValueError: Unable to ADD constant variable 'Tax' over the perio...
>>> population(Tax().name, period, [populations.DIVIDE])
Traceback (most recent call last):
ValueError: Unable to DIVIDE constant variable 'Tax' over the pe...
>>> population(Tax().name, period, [populations.ADD, populations.DIVIDE])
Traceback (most recent call last):
IncompatibleOptionsError: Options ADD and DIVIDE are incompatibl...
>>> population(Tax().name, period, ["LAGRANGIAN"])
Traceback (most recent call last):
InvalidOptionError: Option LAGRANGIAN is not a valid option (try...
"""
if self.simulation is None:
Expand All @@ -77,21 +151,22 @@ def __call__(
calculate.period,
)

if t.Option.ADD in map(str.upper, calculate.option):
if t.Option.ADD in calculate.option and t.Option.DIVIDE in calculate.option:
raise IncompatibleOptionsError(variable_name)

if t.Option.ADD in calculate.option:
return self.simulation.calculate_add(
calculate.variable,
calculate.period,
)

if t.Option.DIVIDE in map(str.upper, calculate.option):
if t.Option.DIVIDE in calculate.option:
return self.simulation.calculate_divide(
calculate.variable,
calculate.period,
)

raise ValueError(
f"Options config.ADD and config.DIVIDE are incompatible (trying to compute variable {variable_name})".encode(),
)
raise InvalidOptionError(calculate.option[0], variable_name)

def empty_array(self) -> t.FloatArray:
"""Return an empty array.
Expand Down
23 changes: 23 additions & 0 deletions openfisca_core/populations/_errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
from . import types as t


class IncompatibleOptionsError(ValueError):
"""Raised when two options are incompatible."""

def __init__(self, variable_name: t.VariableName) -> None:
add, divide = t.Option
msg = (
f"Options {add} and {divide} are incompatible (trying to compute "
f"variable {variable_name})."
)
super().__init__(msg)


class InvalidOptionError(ValueError):
"""Raised when an option is invalid."""

def __init__(self, option: str, variable_name: t.VariableName) -> None:
msg = (
f"Option {option} is not a valid option (trying to compute "
f"variable {variable_name})."
)
super().__init__(msg)


class InvalidArraySizeError(ValueError):
"""Raised when an array has an invalid size."""

Expand Down
2 changes: 0 additions & 2 deletions openfisca_core/populations/config.py

This file was deleted.

3 changes: 3 additions & 0 deletions openfisca_core/populations/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class Option(strenum.StrEnum):
ADD = enum.auto()
DIVIDE = enum.auto()

def __contains__(self, option: str) -> bool:
return option.upper() is self


class Calculate(NamedTuple):
variable: VariableName
Expand Down
1 change: 1 addition & 0 deletions openfisca_tasks/lint.mk
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ check-types:
openfisca_core/data_storage \
openfisca_core/experimental \
openfisca_core/entities \
openfisca_core/indexed_enums \
openfisca_core/periods \
openfisca_core/types.py
@$(call print_pass,$@:)
Expand Down

0 comments on commit eecc7ef

Please sign in to comment.