diff --git a/ruskit/opsm/__init__.py b/ruskit/opsm/__init__.py new file mode 100644 index 0000000..025d30b --- /dev/null +++ b/ruskit/opsm/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .lib import Task, SequenceTask, ParallelTask, RetryTask +from .lib import TaskSuccess, TaskFailure + +from .exceptions import OPSMReturnOnErrorShortcutException +from .exceptions import PreviousTaskFailedError + +from .decorators import enable_failure_unwrap diff --git a/ruskit/opsm/decorators.py b/ruskit/opsm/decorators.py new file mode 100644 index 0000000..6c2bf48 --- /dev/null +++ b/ruskit/opsm/decorators.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +import functools + +from . import exceptions + + +def enable_failure_unwrap(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + try: + ret = f(*args, **kwargs) + except exceptions.OPSMReturnOnErrorShortcutException as e: + return e.failure + return ret + return wrapper diff --git a/ruskit/opsm/exceptions.py b/ruskit/opsm/exceptions.py new file mode 100644 index 0000000..5278099 --- /dev/null +++ b/ruskit/opsm/exceptions.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +import sys +if sys.version_info.major < 3: + # Circular import issue in python 2, but works in python 3 + # Do not use absolute_import: from __future__ import absolute_import + import lib +else: + from . import lib + + +class OPSMReturnOnErrorShortcutException(Exception): + def __init__(self, failure): + assert isinstance(failure, lib.TaskFailure) + self.failure = failure + + def __str__(self): + return 'OPSM_ROESE: {}'.format(self.failure) + + +class PreviousTaskFailedError(Exception): + def __init__(self, msg='PTF'): + self.msg = msg + + def __str__(self): + return self.msg diff --git a/ruskit/opsm/lib.py b/ruskit/opsm/lib.py new file mode 100644 index 0000000..13eb5a3 --- /dev/null +++ b/ruskit/opsm/lib.py @@ -0,0 +1,268 @@ +# -*- coding: utf-8 -*- +from __future__ import division, print_function, absolute_import + +from collections import namedtuple +import sys + +from . import exceptions +from . import utils + + +class TaskSuccess(namedtuple('TaskSuccess', 'task_name, value')): + def __py2_str__(self): + ret = u'{}({}) ✓'.encode('utf8') + ret = ret.format(self.task_name, self.value) + return ret + + def __py3_str__(self): + return '{}({}) ✓'.format(self.task_name, self.value) + + def __py2_repr__(self): + ret = u'{} ✓'.encode('utf8') + ret = ret.format(self.task_name) + return ret + + def __py3_repr__(self): + return '{} ✓'.format(self.task_name) + + def __str__(self): + if sys.version_info.major < 3: + return self.__py2_str__() + else: + return self.__py3_str__() + + def __repr__(self): + if sys.version_info.major < 3: + return self.__py2_repr__() + else: + return self.__py3_repr__() + + def ok(self): + return True + + def val(self): + return self.value + + def err(self): + return None + + def unwrap(self): + return self.value + + def aggregate(self): + if utils.is_iterable_not_str(self.value): + return tuple(v.aggregate() for v in self.value) + elif isinstance(self.value, tuple): + return self.value + else: + return (self.value,) + + +class TaskFailure(namedtuple('TaskFailure', 'task_name, error, grdst')): + def __py2_str__(self): + if self.grdst: + ret = u'{}({}) ✗ => {}'.encode('utf8') + ret = ret.format(self.task_name, self.error, self.grdst) + else: + ret = u'{}({}) ✗'.encode('utf8') + ret = ret.format(self.task_name, self.error) + return ret + + def __py3_str__(self): + if self.grdst: + ret = '{}({}) ✗ => {}'.format(self.task_name, self.error, + self.grdst) + else: + ret = '{}({}) ✗'.format(self.task_name, self.error) + return ret + + def __py2_repr__(self): + if self.grdst: + ret = u'{} ✗ => {}'.encode('utf8') + ret = ret.format(self.task_name, self.grdst.task_name) + else: + ret = u'{} ✗'.encode('utf8') + ret = ret.format(self.task_name) + return ret + + def __py3_repr__(self): + if self.grdst: + ret = '{} ✗ => {}'.format(self.task_name, self.grdst.task_name) + else: + ret = '{} ✗'.format(self.task_name) + return ret + + def __str__(self): + if sys.version_info.major < 3: + return self.__py2_str__() + else: + return self.__py3_str__() + + def __repr__(self): + if sys.version_info.major < 3: + return self.__py2_repr__() + else: + return self.__py3_repr__() + + def ok(self): + return False + + def val(self): + return None + + def err(self): + return self.error + + def unwrap(self): + raise exceptions.OPSMReturnOnErrorShortcutException() + + def aggregate(self): + raise exceptions.OPSMReturnOnErrorShortcutException() + + +class Task(object): + ''' + A Task is a base abstract object for implementing task-based + event scheduling. + + Anyone who'd like to create a task should follow the example: + + @example: + class EchoTask(Task): + def _setup(self, *args, **kwargs): + self.msg = kwargs.get('msg') + def _run(self): + if not self.msg: + raise ValueError('No message found') + print(self.msg) + + the `setup` method is optional, while the run method should + be implemented. + + A Guard is a `Task` used to guard of + If you have specified any `guard`, raising any exception + or setting `self.ok = False` will trigger the `guard`. + ''' + + def __init__(self, *args, **kwargs): + self.ok = True + self._task_name = self.__class__.__name__ + self.guard = kwargs.get('guard') + self._setup(*args, **kwargs) + + def _setup(self, *args, **kwargs): + ''' + Optional for derived classes + ''' + pass + + def _try_guard(self): + try: + if self.guard: + ret = self.guard.run() + else: + ret = None + except Exception as e: + ret = TaskFailure(self._task_name, error=e, grdst=None) + return ret + + def run(self): + try: + rslt = self._run() + ret = TaskSuccess(self._task_name, value=rslt) + except Exception as e: + ret = TaskFailure(self._task_name, error=e, grdst=None) + self.ok = False + finally: + if self.ok: + return ret + else: + grdst = self._try_guard() + return TaskFailure( + self._task_name, error=ret.error, grdst=grdst) + + def _run(self): + raise NotImplementedError("Should override _run") + + +class SequenceTask(Task): + def __init__(self, *tasks, **kwargs): + super(SequenceTask, self).__init__(**kwargs) + + self.subtasks = list(tasks) + + def add(self, task): + self.subtasks.append(task) + + def run(self): + return self._run() + + def _run_one(self, task): + if self.ok is True: + ret = None + try: + ret = task.run() + except Exception as e: + ret = TaskFailure(task._task_name, error=e, grdst=None) + if not ret.ok(): + self.ok = False + return ret + else: + return TaskFailure( + task._task_name, + error=exceptions.PreviousTaskFailedError(), + grdst=None) + + def _run(self): + ret = [self._run_one(task) for task in self.subtasks] + if self.ok is True: + return TaskSuccess(self._task_name, value=ret) + else: + grdst = self._try_guard() + return TaskFailure(self._task_name, error=ret, grdst=grdst) + + +class ParallelTask(Task): + def __init__(self, pool, *tasks, **kwargs): + super(ParallelTask, self).__init__(**kwargs) + self.subtasks = list(tasks) + self.gevent_pool = pool + + def add(self, task): + self.subtasks.append(task) + + def run(self): + return self._run() + + def _run_one(self, task): + ret = None + try: + ret = task.run() + except Exception as e: + ret = TaskFailure(task._task_name, error=e, grdst=None) + if not ret.ok(): + self.ok = False + return ret + + def _run(self): + ret = self.gevent_pool.map(self._run_one, self.subtasks) + if self.ok: + return TaskSuccess(self._task_name, value=ret) + else: + grdst = self._try_guard() + return TaskFailure(self._task_name, error=ret, grdst=grdst) + + +class RetryTask(Task): + def __init__(self, *args, **kwargs): + super(RetryTask, self).__init__(*args, **kwargs) + self.retry_times = kwargs.get('retry_times', 1) + + def run(self): + for i in range(self.retry_times): + # Cleanup self.ok flag + self.ok = True + ret = super(RetryTask, self).run() + if ret.ok(): + break + return ret diff --git a/ruskit/opsm/utils.py b/ruskit/opsm/utils.py new file mode 100644 index 0000000..0c90405 --- /dev/null +++ b/ruskit/opsm/utils.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- + + +def is_iterable_not_str(obj): + return hasattr(obj, '__iter__') and not isinstance(obj, str) diff --git a/ruskit/utils.py b/ruskit/utils.py index a340586..7a21d37 100644 --- a/ruskit/utils.py +++ b/ruskit/utils.py @@ -4,6 +4,7 @@ import os import sys from functools import wraps +import contextlib from ruskit import cli @@ -119,3 +120,11 @@ def _wrapper(*arguments): ClusterNode.socket_timeout = args.timeout return func(*arguments) return _wrapper + + +@contextlib.contextmanager +def contextlib_suppress(*exceptions): + try: + yield + except exceptions: + pass diff --git a/tests/test_opsm.py b/tests/test_opsm.py new file mode 100644 index 0000000..6308b14 --- /dev/null +++ b/tests/test_opsm.py @@ -0,0 +1,416 @@ +from __future__ import absolute_import, print_function + +from six.moves import range +import contextlib +import operator +import random +from functools import reduce + +import gevent +import gevent.pool +import mock + +import ruskit.opsm as opsm + +sleep_time_lb = 0.001 +sleep_time_ub = 0.005 +raise_msg = 'raise' +cleanup_msg = 'cleanup' +rterr = RuntimeError('RAISE EXCEPTION') + + +@contextlib.contextmanager +def global_echo_mock(): + global echo + echo = mock.Mock() + yield + del echo + + +def typical_fail(task_name, grdst=None): + return opsm.TaskFailure(task_name=task_name, error=rterr, grdst=grdst) + + +def typical_failclean(task_name, guard_name): + return opsm.TaskFailure( + task_name=task_name, + error=rterr, + grdst=opsm.TaskSuccess( + task_name=guard_name, value=cleanup_msg)) + + +def previous_fail(task_name): + return opsm.TaskFailure( + task_name=task_name, error=opsm.PreviousTaskFailedError(), grdst=None) + + +def usorted(lst): + return sorted(lst, key=lambda e: str(e)) + + +def assert_task_result(expect, actual): + def is_iterable(obj): + return hasattr(obj, '__iter__') and not isinstance(obj, str) + + def assert_task_success(expect, actual): + if expect.task_name != actual.task_name: + return False + + if is_iterable(expect.value): + if len(expect.value) != len(actual.value): + return False + return reduce(operator.and_, [ + assert_task_result_one(*pair) + for pair in zip(expect.value, actual.value) + ]) + else: + return expect.value == actual.value + + def assert_task_failure(expect, actual): + if expect.task_name != actual.task_name: + return False + + check_error = True + if is_iterable(expect.error): + if len(expect.error) != len(actual.error): + check_error = False + check_error = reduce(operator.and_, [ + assert_task_result_one(*pair) + for pair in zip(expect.error, actual.error) + ]) + elif isinstance(expect.error, Exception): + check_error = isinstance(actual, expect.__class__) + else: + check_error = expect.error == actual.error + assert isinstance(actual.grdst, expect.grdst.__class__) + if expect.grdst: + return check_error and assert_task_result_one(expect.grdst, + actual.grdst) + else: + return check_error + + def assert_task_result_one(expect, actual): + _type_dispatch = { + opsm.TaskSuccess: assert_task_success, + opsm.TaskFailure: assert_task_failure, + } + if not isinstance(actual, expect.__class__): + return False + return _type_dispatch[expect.__class__](expect, actual) + + assert assert_task_result_one(expect, actual), '''Mismatch: + Expect: {} + Actual: {}'''.format(expect, actual) + + +class EchoTaskS(opsm.Task): + def _setup(self, *args, **kwargs): + self.msg = kwargs['msg'] + + def _run(self): + if self.msg == raise_msg: + raise rterr + else: + echo(self.msg) + return self.msg + + +class EchoTaskP(opsm.Task): + def _setup(self, *args, **kwargs): + self.msg = kwargs['msg'] + + def _run(self): + if self.msg == raise_msg: + raise rterr + else: + gevent.sleep(random.uniform(sleep_time_lb, sleep_time_ub)) + echo(self.msg) + return self.msg + + +class CleanupTask(opsm.Task): + def _run(self): + echo(cleanup_msg) + return cleanup_msg + + +def test_task_success(): + with global_echo_mock(): + msg = 'hello' + + # Expect + ret_expect = opsm.TaskSuccess(task_name='EchoTaskS', value=msg) + + # Actual + ee = EchoTaskS(msg=msg) + ret = ee.run() + + echo.assert_called_once_with(msg) + assert_task_result(ret_expect, ret) + + +def test_task_failure(): + with global_echo_mock(): + ee = EchoTaskS(msg=raise_msg, guard=CleanupTask()) + ret = ee.run() + + assert_task_result(typical_failclean('EchoTaskS', 'CleanupTask'), ret) + + +def test_sequence_task_all_success(): + with global_echo_mock(): + num = 10 + + ret_expect = opsm.TaskSuccess( + task_name='SequenceTask', + value=[ + opsm.TaskSuccess( + 'EchoTaskS', value=i) for i in range(num) + ]) + mock_call_expect = [mock.call(i) for i in range(num)] + + worker = opsm.SequenceTask(guard=CleanupTask()) + for i in range(num): + worker.add(EchoTaskS(msg=i)) + ret = worker.run() + + assert mock_call_expect == echo.mock_calls + assert_task_result(ret_expect, ret) + + +def test_sequence_task_partial_failure(): + with global_echo_mock(): + succ_num1 = 5 + fail_num1 = 3 + succ_num2 = 8 + + # Expects + ret_expect = [] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskS', value=i) for i in range(succ_num1) + ] + ret_expect += [typical_fail('EchoTaskS')] + ret_expect += [ + previous_fail('EchoTaskS') + for i in range(fail_num1 - 1 + succ_num2) + ] + ret_expect = opsm.TaskFailure( + task_name='SequenceTask', + error=ret_expect, + grdst=opsm.TaskSuccess( + task_name='CleanupTask', value=cleanup_msg)) + mock_calls_expect = [mock.call(i) for i in range(succ_num1)] + mock_calls_expect += [mock.call('cleanup')] + + # Actuals + worker = opsm.SequenceTask(guard=CleanupTask()) + for i in range(succ_num1): + worker.add(EchoTaskS(msg=i)) + for _ in range(fail_num1): + worker.add(EchoTaskS(msg=raise_msg)) + for i in range(succ_num2): + worker.add(EchoTaskS(msg=i)) + ret = worker.run() + + assert mock_calls_expect == echo.mock_calls + assert_task_result(ret_expect, ret) + + +def test_sequence_task_partial_failure_without_guard(): + with global_echo_mock(): + succ_num1 = 5 + fail_num1 = 3 + succ_num2 = 8 + + # Expects + ret_expect = [] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskS', value=i) for i in range(succ_num1) + ] + ret_expect += [typical_fail('EchoTaskS')] + ret_expect += [ + previous_fail('EchoTaskS') + for i in range(fail_num1 - 1 + succ_num2) + ] + ret_expect = opsm.TaskFailure( + task_name='SequenceTask', error=ret_expect, grdst=None) + mock_calls_expect = [mock.call(i) for i in range(succ_num1)] + + # Actuals + worker = opsm.SequenceTask() + for i in range(succ_num1): + worker.add(EchoTaskS(msg=i)) + for _ in range(fail_num1): + worker.add(EchoTaskS(msg=raise_msg)) + for i in range(succ_num2): + worker.add(EchoTaskS(msg=i)) + ret = worker.run() + + assert mock_calls_expect == echo.mock_calls + assert_task_result(ret_expect, ret) + + +def test_parallel_task_all_success(): + with global_echo_mock(): + num = 10 + thread_num = 3 + + ret_expect = opsm.TaskSuccess( + task_name='ParallelTask', + value=[ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(num) + ]) + mock_call_expect = [mock.call(i) for i in range(num)] + + pool = gevent.pool.Pool(thread_num) + worker = opsm.ParallelTask(pool, guard=CleanupTask()) + for i in range(num): + worker.add(EchoTaskP(msg=i)) + ret = worker.run() + + assert usorted(mock_call_expect) == usorted(echo.mock_calls) + assert_task_result(ret_expect, ret) + + +def test_parallel_task_partial_failure(): + with global_echo_mock(): + thread_num = 3 + succ_num1 = 5 + fail_num1 = 3 + succ_num2 = 8 + + # Expects + ret_expect = [] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(succ_num1) + ] + ret_expect += [typical_fail('EchoTaskP') for i in range(fail_num1)] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(succ_num2) + ] + ret_expect = opsm.TaskFailure( + task_name='ParallelTask', + error=ret_expect, + grdst=opsm.TaskSuccess( + task_name='CleanupTask', value=cleanup_msg)) + mock_calls_expect = [mock.call(i) for i in range(succ_num1)] + mock_calls_expect += [mock.call(i) for i in range(succ_num2)] + mock_calls_expect += [mock.call('cleanup')] + + # Actuals + pool = gevent.pool.Pool(thread_num) + worker = opsm.ParallelTask(pool, guard=CleanupTask()) + for i in range(succ_num1): + worker.add(EchoTaskP(msg=i)) + for _ in range(fail_num1): + worker.add(EchoTaskP(msg=raise_msg)) + for i in range(succ_num2): + worker.add(EchoTaskP(msg=i)) + ret = worker.run() + + assert usorted(mock_calls_expect) == usorted(echo.mock_calls) + assert_task_result(ret_expect, ret) + + +def test_parallel_task_partial_failure_without_guard(): + with global_echo_mock(): + thread_num = 3 + succ_num1 = 5 + fail_num1 = 3 + succ_num2 = 8 + + # Expects + ret_expect = [] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(succ_num1) + ] + ret_expect += [typical_fail('EchoTaskP') for i in range(fail_num1)] + ret_expect += [ + opsm.TaskSuccess( + task_name='EchoTaskP', value=i) for i in range(succ_num2) + ] + ret_expect = opsm.TaskFailure( + task_name='ParallelTask', error=ret_expect, grdst=None) + mock_calls_expect = [mock.call(i) for i in range(succ_num1)] + mock_calls_expect += [mock.call(i) for i in range(succ_num2)] + + # Actuals + pool = gevent.pool.Pool(thread_num) + worker = opsm.ParallelTask(pool) + for i in range(succ_num1): + worker.add(EchoTaskP(msg=i)) + for _ in range(fail_num1): + worker.add(EchoTaskP(msg=raise_msg)) + for i in range(succ_num2): + worker.add(EchoTaskP(msg=i)) + ret = worker.run() + + assert usorted(mock_calls_expect) == usorted(echo.mock_calls) + assert_task_result(ret_expect, ret) + + +def test_complex_guard_failure(): + with global_echo_mock(): + # Expects + grdst = opsm.TaskSuccess( + task_name='SequenceTask', + value=[ + opsm.TaskSuccess( + task_name='EchoTaskS', value='complex'), opsm.TaskSuccess( + task_name='EchoTaskS', value='guard') + ]) + ret_expect = typical_fail('EchoTaskS', grdst) + + mock_calls_expect = [mock.call('complex'), mock.call('guard')] + + # Actuals + complex_guard = opsm.SequenceTask() + complex_guard.add(EchoTaskS(msg='complex')) + complex_guard.add(EchoTaskS(msg='guard')) + worker = EchoTaskS(msg=raise_msg, guard=complex_guard) + ret = worker.run() + + assert mock_calls_expect == echo.mock_calls + assert_task_result(ret_expect, ret) + + +def test_retry_task(): + class EchoTaskR(opsm.RetryTask): + def _setup(self, *args, **kwargs): + self.error_times = kwargs['error_times'] + self.msg = kwargs['msg'] + + def _run(self): + if self.error_times > 0: + self.error_times -= 1 + raise rterr + else: + echo(self.msg) + return self.msg + + with global_echo_mock(): + msg = 'hello' + error_times = 2 + retry_times = 5 + + # Expects + ret_expect = opsm.TaskSuccess(task_name='EchoTaskR', value=msg) + mock_calls_expect = [mock.call('cleanup') for i in range(error_times)] + mock_calls_expect.append(mock.call(msg)) + + # Actuals + retry = EchoTaskR( + error_times=error_times, + msg=msg, + retry_times=retry_times, + guard=CleanupTask()) + ret = retry.run() + + assert mock_calls_expect == echo.mock_calls + assert_task_result(ret_expect, ret)