diff --git a/doc/configuration.rst b/doc/configuration.rst index 9a9103c9ee..39bd50993c 100644 --- a/doc/configuration.rst +++ b/doc/configuration.rst @@ -356,6 +356,16 @@ check_complete_on_run missing. Defaults to false. +prevent_parameter_collision + In complex pipelines especially when tasks are inherited, it can happen that + different tasks define parameters with the same name. Luigi would normally use + the same value for both parameter instances, which might not be desired. + When set to ``true``, luigi will check for parameter collisions and refuse to + run if a parameter is defined multiple times. Optionally, an allow-list of + parameters called ``collisions_to_ignore`` can be passed to ``inherits/requires``, + to ignore when checking for duplicate parameters. + Defaults to false. + [elasticsearch] --------------- diff --git a/luigi/util.py b/luigi/util.py index adce164580..94528e2cf1 100644 --- a/luigi/util.py +++ b/luigi/util.py @@ -219,9 +219,10 @@ class TaskB(luigi.Task): import datetime import logging +from configparser import NoOptionError, NoSectionError -from luigi import task -from luigi import parameter +from luigi import parameter, task +from luigi.configuration import get_config logger = logging.getLogger('luigi-interface') @@ -277,9 +278,23 @@ def requires(self): def run(self): print self.n # this will be defined # ... + + inherits/requires decorator optionally takes an argument called + `collisions_to_ignore` with an iterable of parameters that are + allowed to overwrite parameters in upstream tasks. + In complex pipelines, it can happen that different tasks define parameters + with the same name. + If `prevent-parameter-collision` in the `[worker]` section of the config + is true, luigi will raise an exception in case of parameter conflicts - + unless the parameter is explicitly allowed in `collisions_to_ignore`. """ - def __init__(self, *tasks_to_inherit, **kw_tasks_to_inherit): + def __init__( + self, + *tasks_to_inherit, + collisions_to_ignore=(), + **kw_tasks_to_inherit, + ): super(inherits, self).__init__() if not tasks_to_inherit and not kw_tasks_to_inherit: raise TypeError("tasks_to_inherit or kw_tasks_to_inherit must contain at least one task") @@ -287,8 +302,12 @@ def __init__(self, *tasks_to_inherit, **kw_tasks_to_inherit): raise TypeError("Only one of tasks_to_inherit or kw_tasks_to_inherit may be present") self.tasks_to_inherit = tasks_to_inherit self.kw_tasks_to_inherit = kw_tasks_to_inherit + self.collisions_to_ignore = collisions_to_ignore def __call__(self, task_that_inherits): + # Check for parameter collisions and raise an exception if found + self._check_for_parameter_collisions(task_that_inherits) + # Get all parameter objects from each of the underlying tasks task_iterator = self.tasks_to_inherit or self.kw_tasks_to_inherit.values() for task_to_inherit in task_iterator: @@ -323,6 +342,63 @@ def clone_parents(_self, **kwargs): return task_that_inherits + def _check_for_parameter_collisions(self, task_that_inherits): + """ + Check that the parameters from the tasks_to_inherit don't + silently mask each other or by parameters from the inheriting + task. + + An exception will be raised immediately when the first parameter + collision is encountered. + + Collisions can be ignored by passing `collisions_to_ignore` with + an interable of allowed parameters to `inherits/requires`. + """ + # only check for parameter collisions when enabled in config + config = get_config() + try: + if config.getboolean("worker", "prevent_parameter_collision") is not True: + return + except (NoSectionError, NoOptionError, KeyError): + return + + error_msg = ( + 'Parameter "{param}" in "{task}" is duplicated in "{another_task}" ' + "(or an ancestor). Either rename one of the parameters or include " + '"{param}" in `collisions_to_ignore`.' + ) + + for task_to_inherit in self.tasks_to_inherit: + for param_name, _ in task_to_inherit.get_params(): + # Check that the parameters from the inheriting task don't mask any + # parameters from the inherited tasks. + if ( + hasattr(task_that_inherits, param_name) + and param_name not in self.collisions_to_ignore + ): + raise ValueError( + error_msg.format( + param=param_name, + task=task_that_inherits.task_family, + another_task=task_to_inherit.task_family, + ) + ) + # Check that the parameters from an inherited task don't mask the + # parameters from another inherited task. + for another_task_to_inherit in self.tasks_to_inherit: + if ( + hasattr(another_task_to_inherit, param_name) + and another_task_to_inherit is not task_to_inherit + and param_name not in self.collisions_to_ignore + ): + raise ValueError( + error_msg.format( + param=param_name, + task=task_to_inherit.task_family, + another_task=another_task_to_inherit.task_family, + ) + ) + class requires: """ @@ -332,14 +408,21 @@ class requires: """ - def __init__(self, *tasks_to_require, **kw_tasks_to_require): + def __init__( + self, *tasks_to_require, collisions_to_ignore=(), **kw_tasks_to_require + ): super(requires, self).__init__() self.tasks_to_require = tasks_to_require self.kw_tasks_to_require = kw_tasks_to_require + self.collisions_to_ignore = collisions_to_ignore def __call__(self, task_that_requires): - task_that_requires = inherits(*self.tasks_to_require, **self.kw_tasks_to_require)(task_that_requires) + task_that_requires = inherits( + *self.tasks_to_require, + collisions_to_ignore=self.collisions_to_ignore, + **self.kw_tasks_to_require, + )(task_that_requires) # Modify task_that_requires by adding requires method. # If only one task is required, this single task is returned. @@ -387,7 +470,7 @@ def run(_self): def delegates(task_that_delegates): - """ Lets a task call methods on subtask(s). + """Lets a task call methods on subtask(s). The way this works is that the subtask is run as a part of the task, but the task itself doesn't have to care about the requirements of the subtasks. diff --git a/test/parameter_collision_test.py b/test/parameter_collision_test.py new file mode 100644 index 0000000000..f88b156275 --- /dev/null +++ b/test/parameter_collision_test.py @@ -0,0 +1,49 @@ +import unittest + +import luigi +from luigi.util import requires + +from helpers import with_config + + +class A(luigi.Task): + num = luigi.IntParameter() + + +class B(luigi.Task): + num = luigi.IntParameter() + + +class ParameterCollisionDetectionTest(unittest.TestCase): + @with_config({"worker": {"prevent_parameter_collision": "true"}}) + def test_parameter_collision_with_inherited_task(self): + with self.assertRaises(ValueError): + + @requires(A) + class T(luigi.Task): + num = luigi.IntParameter() + + @with_config({"worker": {"prevent_parameter_collision": "true"}}) + def test_parameter_collision_in_inheriting_tasks(self): + with self.assertRaises(ValueError): + + @requires(A, B) + class T(luigi.Task): + pass + + def test_no_parameter_collision_when_disabled_in_config(self): + @requires(A, B) + class T(luigi.Task): + pass + + @with_config({"worker": {"prevent_parameter_collision": "true"}}) + def test_parameter_collision_with_inherited_task_ignored_by_allowlist(self): + @requires(A, collisions_to_ignore=["num"]) + class T(luigi.Task): + num = luigi.IntParameter() + + @with_config({"worker": {"prevent_parameter_collision": "true"}}) + def test_parameter_collision_in_inheriting_tasks_ignored_by_allowlist(self): + @requires(A, B, collisions_to_ignore=["num"]) + class T(luigi.Task): + pass