diff --git a/bartender/app.py b/bartender/app.py index 25d83f4..fcd2a76 100644 --- a/bartender/app.py +++ b/bartender/app.py @@ -15,14 +15,14 @@ from bartender.local_plugins.monitor import LocalPluginMonitor from bartender.local_plugins.registry import LocalPluginRegistry from bartender.local_plugins.validator import LocalPluginValidator -from bartender.mongo_pruner import MongoPruner +from bartender.mongo_pruner import MongoPruner, PruneTask, GridFSPrune from bartender.monitor import PluginStatusMonitor from bartender.pika import PikaClient from bartender.pyrabbit import PyrabbitClient from bartender.request_validator import RequestValidator from bartender.thrift.handler import BartenderHandler from bartender.thrift.server import make_server -from bg_utils.mongo.models import Event, Request +from bg_utils.mongo.models import Event, Request, RequestFile from brewtils.models import Events from brewtils.stoppable_thread import StoppableThread @@ -100,14 +100,10 @@ def __init__(self): ), ] - # Only want to run the MongoPruner if it would do anything tasks, run_every = self._setup_pruning_tasks() - if run_every: - self.helper_threads.append( - HelperThread( - MongoPruner, tasks=tasks, run_every=timedelta(minutes=run_every) - ) - ) + self.helper_threads.append( + HelperThread(MongoPruner, tasks=tasks, run_every=run_every) + ) super(BartenderApp, self).__init__(logger=self.logger, name="BartenderApp") @@ -173,48 +169,60 @@ def _shutdown(self): @staticmethod def _setup_pruning_tasks(): + info_ttl = bartender.config.db.ttl.info + action_ttl = bartender.config.db.ttl.action + event_ttl = bartender.config.db.ttl.event + + # Delete request files that were created, but no request ever + # referenced them. + request_file_prune = PruneTask( + RequestFile, + "created_at", + delete_after=timedelta(minutes=15), + additional_query=Q(request=None), + ) - prune_tasks = [] - if bartender.config.db.ttl.info > 0: - prune_tasks.append( - { - "collection": Request, - "field": "created_at", - "delete_after": timedelta(minutes=bartender.config.db.ttl.info), - "additional_query": ( - Q(status="SUCCESS") | Q(status="CANCELED") | Q(status="ERROR") - ) - & Q(has_parent=False) - & Q(command_type="INFO"), - } - ) + # Delete all orphaned gridfs objects. + gridfs_prune = GridFSPrune() + + # Delete INFO/ACTION requests past the TTL. + base_query = ( + Q(status="SUCCESS") | Q(status="CANCELED") | Q(status="ERROR") + ) & Q(has_parent=False) + info_prune = PruneTask( + Request, + "created_at", + delete_after=timedelta(minutes=info_ttl), + additional_query=base_query & Q(command_type="INFO"), + ) + action_prune = PruneTask( + Request, + "created_at", + delete_after=timedelta(minutes=action_ttl), + additional_query=base_query & Q(command_type="ACTION"), + ) - if bartender.config.db.ttl.action > 0: - prune_tasks.append( - { - "collection": Request, - "field": "created_at", - "delete_after": timedelta(minutes=bartender.config.db.ttl.action), - "additional_query": ( - Q(status="SUCCESS") | Q(status="CANCELED") | Q(status="ERROR") - ) - & Q(has_parent=False) - & Q(command_type="ACTION"), - } - ) + # Delete events past their TTL. + event_prune = PruneTask(Event, "timestamp", timedelta(minutes=event_ttl)) - if bartender.config.db.ttl.event > 0: - prune_tasks.append( - { - "collection": Event, - "field": "timestamp", - "delete_after": timedelta(minutes=bartender.config.db.ttl.event), - } - ) + prune_tasks = [request_file_prune] + + if info_ttl > 0: + prune_tasks.append(info_prune) + + if action_ttl > 0: + prune_tasks.append(action_prune) + + if event_ttl > 0: + prune_tasks.append(event_prune) + + # Order matters here, the orphan gridfs prune happens *AFTER* the info/action + # pruning, so that any affected requests get their files/chunks deleted. + prune_tasks.append(gridfs_prune) # Look at the various TTLs to determine how often to run the MongoPruner real_ttls = [x for x in bartender.config.db.ttl.values() if x > 0] - run_every = min(real_ttls) / 2 if real_ttls else None + run_every = min(real_ttls) / 2 if real_ttls else 7.5 return prune_tasks, run_every diff --git a/bartender/mongo_pruner.py b/bartender/mongo_pruner.py index 66ea195..4685200 100644 --- a/bartender/mongo_pruner.py +++ b/bartender/mongo_pruner.py @@ -1,11 +1,56 @@ import logging from datetime import datetime, timedelta +from gridfs import GridFS from mongoengine import Q +from mongoengine.connection import get_db +from bg_utils.mongo.models import RequestFile from brewtils.stoppable_thread import StoppableThread +class PruneTask(object): + def __init__(self, collection, field, delete_after, additional_query=None): + self.logger = logging.getLogger(__name__) + self.collection = collection + self.field = field + self.delete_after = delete_after + self.additional_query = additional_query + + def setup_query(self, delete_older_than): + query = Q(**{self.field + "__lt": delete_older_than}) + if self.additional_query: + query = query & self.additional_query + return query + + def execute(self): + current_time = datetime.utcnow() + delete_older_than = current_time - self.delete_after + query = self.setup_query(delete_older_than) + self.logger.debug( + "Removing %ss older than %s" + % (self.collection.__name__, str(delete_older_than)) + ) + self.collection.objects(query).delete() + + +class GridFSPrune(object): + def __init__(self, gridfs=None): + self.fs = gridfs + + def execute(self): + if self.fs is None: + self.fs = GridFS(get_db()) + + orphan_ids = {f._id for f in self.fs.find()} + for rf in RequestFile.objects: + if rf.body.grid_id in orphan_ids: + orphan_ids.remove(rf.body.grid_id) + + for file_id in orphan_ids: + self.fs.delete(file_id) + + class MongoPruner(StoppableThread): def __init__(self, tasks=None, run_every=timedelta(minutes=15)): self.logger = logging.getLogger(__name__) @@ -15,35 +60,11 @@ def __init__(self, tasks=None, run_every=timedelta(minutes=15)): super(MongoPruner, self).__init__(logger=self.logger, name="Remover") - def add_task( - self, collection=None, field=None, delete_after=None, additional_query=None - ): - self._tasks.append( - { - "collection": collection, - "field": field, - "delete_after": delete_after, - "additional_query": additional_query, - } - ) - def run(self): self.logger.info(self.display_name + " is started") while not self.wait(self._run_every): - current_time = datetime.utcnow() - for task in self._tasks: - delete_older_than = current_time - task["delete_after"] - - query = Q(**{task["field"] + "__lt": delete_older_than}) - if task.get("additional_query", None): - query = query & task["additional_query"] - - self.logger.debug( - "Removing %ss older than %s" - % (task["collection"].__name__, str(delete_older_than)) - ) - task["collection"].objects(query).delete() + task.execute() self.logger.info(self.display_name + " is stopped") diff --git a/bartender/request_validator.py b/bartender/request_validator.py index 6eb9e18..eb8e714 100644 --- a/bartender/request_validator.py +++ b/bartender/request_validator.py @@ -6,10 +6,11 @@ import six import urllib3 from builtins import str +from mongoengine import DoesNotExist from requests import Session import bartender -from bg_utils.mongo.models import System, Choices +from bg_utils.mongo.models import System, Choices, RequestFile from brewtils.choices import parse from brewtils.errors import ModelValidationError from brewtils.rest.system_client import SystemClient @@ -40,7 +41,7 @@ def validate_request(self, request): request.parameters = self.get_and_validate_parameters(request, command) - return request + return request, command def get_and_validate_system(self, request): """Ensure there is a system in the DB that corresponds to this Request. @@ -427,31 +428,33 @@ def _validate_parameter_based_on_type(self, value, parameter, command, request): """Validates the value passed in, ensures the type matches. Recursive calls for dictionaries which also have nested parameters""" + p_type = parameter.type.lower() + try: if value is None and not parameter.nullable: raise ModelValidationError( "There is no value for parameter '%s' " "and this field is not nullable." % parameter.key ) - elif parameter.type.upper() == "STRING": + elif p_type == "string": if isinstance(value, six.string_types): return str(value) else: raise TypeError("Invalid value for string (%s)" % value) - elif parameter.type.upper() == "INTEGER": + elif p_type == "integer": if int(value) != float(value): raise TypeError("Invalid value for integer (%s)" % value) return int(value) - elif parameter.type.upper() == "FLOAT": + elif p_type == "float": return float(value) - elif parameter.type.upper() == "ANY": + elif p_type == "any": return value - elif parameter.type.upper() == "BOOLEAN": + elif p_type == "boolean": if value in [True, False]: return value else: raise TypeError("Invalid value for boolean (%s)" % value) - elif parameter.type.upper() == "DICTIONARY": + elif p_type == "dictionary": dict_value = dict(value) if parameter.parameters: self.logger.debug("Found Nested Parameters.") @@ -459,10 +462,10 @@ def _validate_parameter_based_on_type(self, value, parameter, command, request): request, command, parameter.parameters, dict_value ) return dict_value - elif parameter.type.upper() == "DATE": - return int(value) - elif parameter.type.upper() == "DATETIME": + elif p_type in ["date", "datetime"]: return int(value) + elif p_type == "bytes": + return self._get_bytes_value(value, request) else: raise ModelValidationError( "Unknown type for parameter. Please contact a system administrator." @@ -479,3 +482,31 @@ def _validate_parameter_based_on_type(self, value, parameter, command, request): "Value for key: %s is not the correct type. Should be: %s" % (parameter.key, parameter.type) ) + + def _get_bytes_value(self, value, request): + required_keys = ["storage_type", "id", "filename"] + if not isinstance(value, dict): + raise ModelValidationError( + "Bytes parameters should be a dictionary with at least the following keys: %s" + % required_keys + ) + + for key in required_keys: + if key not in value: + raise ModelValidationError("Bytes parameter missing %s field" % key) + + if value["storage_type"] not in RequestFile.STORAGE_ENGINES: + raise ModelValidationError( + "Bytes parameter had invalid storage type: %s" % value["storage_type"] + ) + + try: + rf = RequestFile.objects.get(id=value["id"]) + rf.request = request + rf.save() + except DoesNotExist: + raise ModelValidationError( + "Bytes parameter had an id, but that id did not exist in the database." + ) + + return value diff --git a/bartender/thrift/handler.py b/bartender/thrift/handler.py index 3b1dcc0..4ae969d 100644 --- a/bartender/thrift/handler.py +++ b/bartender/thrift/handler.py @@ -1,3 +1,4 @@ +import json import logging import random import string @@ -47,12 +48,17 @@ def processRequest(self, request_id): # Validates the request based on what is in the database. # This includes the validation of the request parameters, # systems are there, commands are there etc. - request = self.request_validator.validate_request(request) + request, command = self.request_validator.validate_request(request) request.save() - if not self.clients["pika"].publish_request( - request, confirm=True, mandatory=True - ): + publish_kwargs = {"confirm": True, "mandatory": True} + bytes_params = command.parameter_keys_by_type("Bytes") + if bytes_params: + publish_kwargs["headers"] = { + "resolve_parameters": json.dumps(bytes_params).encode("utf-8") + } + + if not self.clients["pika"].publish_request(request, **publish_kwargs): msg = "Error while publishing request to queue (%s[%s]-%s %s)" % ( request.system, request.system_version, diff --git a/test/app_test.py b/test/app_test.py index 25fb4ed..c3b9d2c 100644 --- a/test/app_test.py +++ b/test/app_test.py @@ -7,8 +7,9 @@ import bartender from bartender.app import BartenderApp, HelperThread +from bartender.mongo_pruner import GridFSPrune from bartender.specification import SPECIFICATION -from bg_utils.mongo.models import Event, Request +from bg_utils.mongo.models import Event, Request, RequestFile @patch("bartender.app.time", Mock()) @@ -117,24 +118,30 @@ def test_setup_pruning_tasks(self): bartender.config.db.ttl.event = 15 prune_tasks, run_every = BartenderApp._setup_pruning_tasks() - self.assertEqual(3, len(prune_tasks)) + self.assertEqual(5, len(prune_tasks)) self.assertEqual(2.5, run_every) - info_task = prune_tasks[0] - action_task = prune_tasks[1] - event_task = prune_tasks[2] + rf_task = prune_tasks[0] + info_task = prune_tasks[1] + action_task = prune_tasks[2] + event_task = prune_tasks[3] + gridfs_task = prune_tasks[4] - self.assertEqual(Request, info_task["collection"]) - self.assertEqual(Request, action_task["collection"]) - self.assertEqual(Event, event_task["collection"]) + self.assertEqual(RequestFile, rf_task.collection) + self.assertEqual(Request, info_task.collection) + self.assertEqual(Request, action_task.collection) + self.assertEqual(Event, event_task.collection) + self.assertIsInstance(gridfs_task, GridFSPrune) - self.assertEqual("created_at", info_task["field"]) - self.assertEqual("created_at", action_task["field"]) - self.assertEqual("timestamp", event_task["field"]) + self.assertEqual("created_at", rf_task.field) + self.assertEqual("created_at", info_task.field) + self.assertEqual("created_at", action_task.field) + self.assertEqual("timestamp", event_task.field) - self.assertEqual(timedelta(minutes=5), info_task["delete_after"]) - self.assertEqual(timedelta(minutes=10), action_task["delete_after"]) - self.assertEqual(timedelta(minutes=15), event_task["delete_after"]) + self.assertEqual(timedelta(minutes=15), rf_task.delete_after) + self.assertEqual(timedelta(minutes=5), info_task.delete_after) + self.assertEqual(timedelta(minutes=10), action_task.delete_after) + self.assertEqual(timedelta(minutes=15), event_task.delete_after) def test_setup_pruning_tasks_empty(self): bartender.config.db.ttl.info = -1 @@ -142,8 +149,8 @@ def test_setup_pruning_tasks_empty(self): bartender.config.db.ttl.event = -1 prune_tasks, run_every = BartenderApp._setup_pruning_tasks() - self.assertEqual([], prune_tasks) - self.assertIsNone(run_every) + self.assertEqual(len(prune_tasks), 2) + self.assertEqual(7.5, run_every) def test_setup_pruning_tasks_one(self): bartender.config.db.ttl.info = -1 @@ -151,7 +158,7 @@ def test_setup_pruning_tasks_one(self): bartender.config.db.ttl.event = -1 prune_tasks, run_every = BartenderApp._setup_pruning_tasks() - self.assertEqual(1, len(prune_tasks)) + self.assertEqual(3, len(prune_tasks)) self.assertEqual(0.5, run_every) def test_setup_pruning_tasks_mixed(self): @@ -160,20 +167,20 @@ def test_setup_pruning_tasks_mixed(self): bartender.config.db.ttl.event = 15 prune_tasks, run_every = BartenderApp._setup_pruning_tasks() - self.assertEqual(2, len(prune_tasks)) + self.assertEqual(4, len(prune_tasks)) self.assertEqual(2.5, run_every) - info_task = prune_tasks[0] - event_task = prune_tasks[1] + info_task = prune_tasks[1] + event_task = prune_tasks[2] - self.assertEqual(Request, info_task["collection"]) - self.assertEqual(Event, event_task["collection"]) + self.assertEqual(Request, info_task.collection) + self.assertEqual(Event, event_task.collection) - self.assertEqual("created_at", info_task["field"]) - self.assertEqual("timestamp", event_task["field"]) + self.assertEqual("created_at", info_task.field) + self.assertEqual("timestamp", event_task.field) - self.assertEqual(timedelta(minutes=5), info_task["delete_after"]) - self.assertEqual(timedelta(minutes=15), event_task["delete_after"]) + self.assertEqual(timedelta(minutes=5), info_task.delete_after) + self.assertEqual(timedelta(minutes=15), event_task.delete_after) class HelperThreadTest(unittest.TestCase): diff --git a/test/mongo_pruner_test.py b/test/mongo_pruner_test.py index f407583..2ec1b4b 100644 --- a/test/mongo_pruner_test.py +++ b/test/mongo_pruner_test.py @@ -1,32 +1,66 @@ -import unittest +import pytest from datetime import timedelta -from mock import MagicMock, Mock, patch +from mock import Mock, patch -from bartender.mongo_pruner import MongoPruner +from bartender.mongo_pruner import MongoPruner, PruneTask, GridFSPrune +from mongoengine import Q -class MongoPrunerTest(unittest.TestCase): - def setUp(self): - self.mongo_pruner = MongoPruner(tasks=None) +class TestPruner(object): + @pytest.fixture + def collection(self): + return Mock(__name__="mock") - self.collection_mock = MagicMock(__name__="MOCK") - self.field_mock = "test" - self.delete_after_mock = timedelta(microseconds=1) - self.additional_query_mock = Mock() + @pytest.fixture + def prune_task(self, collection): + query = Q(foo=None) + return PruneTask(collection, "field", timedelta(microseconds=1), query) - self.task = { - "collection": self.collection_mock, - "field": self.field_mock, - "delete_after": self.delete_after_mock, - "additional_query": self.additional_query_mock, - } + def test_task_execute(self, prune_task): + prune_task.execute() + assert prune_task.collection.objects.return_value.delete.call_count == 1 - self.mongo_pruner.add_task(**self.task) + def test_task_setup_query_no_children(self, collection): + task = PruneTask(collection, "field", timedelta(microseconds=1)) + query = task.setup_query(0) + assert query.query["field__lt"] == 0 - @patch("bartender.mongo_pruner.Q", MagicMock()) - def test_prune_something(self): - self.mongo_pruner._stop_event = Mock(wait=Mock(side_effect=[False, True])) + def test_task_setup_query(self, prune_task): + q = prune_task.setup_query(0) + assert len(q.children) == 2 + assert q.children[0].query["field__lt"] == 0 + assert q.children[1].query["foo"] is None - self.mongo_pruner.run() - self.assertTrue(self.collection_mock.objects.return_value.delete.called) + @patch("bg_utils.mongo.models.RequestFile.objects") + def test_gridfs_prune_empty(self, get_mock): + get_mock.return_value = [] + gridfs = Mock(find=Mock(return_value=[])) + task = GridFSPrune(gridfs) + task.execute() + assert gridfs.delete.call_count == 0 + + @patch("bartender.mongo_pruner.RequestFile") + def test_gridfs_no_delete(self, get_mock): + get_mock.objects = [Mock(body=Mock(grid_id="id"))] + find_mock = Mock(return_value=[Mock(_id="id")]) + gridfs = Mock(find=find_mock) + task = GridFSPrune(gridfs) + task.execute() + assert gridfs.delete.call_count == 0 + + @patch("bartender.mongo_pruner.RequestFile") + def test_gridfs_delete_orphans(self, get_mock): + get_mock.objects = [Mock(body=Mock(grid_id="id1"))] + find_mock = Mock(return_value=[Mock(_id="id1"), Mock(_id="id2")]) + gridfs = Mock(find=find_mock) + task = GridFSPrune(gridfs) + task.execute() + gridfs.delete.assert_called_with("id2") + + def test_pruner_thread(self, prune_task): + pruner = MongoPruner([prune_task]) + prune_task.execute = Mock() + pruner._stop_event = Mock(wait=Mock(side_effect=[False, True])) + pruner.run() + assert prune_task.execute.call_count == 1 diff --git a/test/request_validator_test.py b/test/request_validator_test.py index 28ba6b1..e2cd11b 100644 --- a/test/request_validator_test.py +++ b/test/request_validator_test.py @@ -1,9 +1,17 @@ import pytest from box import Box from mock import Mock, call, patch +from mongoengine import DoesNotExist from bartender.request_validator import RequestValidator -from bg_utils.mongo.models import Command, Parameter, Request, System, Choices +from bg_utils.mongo.models import ( + Command, + Parameter, + Request, + System, + Choices, + RequestFile, +) from brewtils.errors import ModelValidationError @@ -83,7 +91,8 @@ def test_no_verify(self, validator): class TestValidateRequest(object): def test_success(self, validator, system_find, bg_system, bg_request): system_find.return_value = bg_system - assert validator.validate_request(bg_request) == bg_request + request, command = validator.validate_request(bg_request) + assert request == bg_request class TestGetAndValidateSystem(object): @@ -449,6 +458,16 @@ def test_success(self, validator, req_value, param_type, expected): ("foo", "UH OH THIS IS BAD"), (["not an int"], "Integer"), ([1], "Integer"), + ("SHOULD_BE_DICT", "Bytes"), + ({}, "Bytes"), + ( + { + "storage_type": "INVALID_TYPE", + "id": "also technically invalid", + "filename": "some_filename", + }, + "Bytes", + ), ], ) def test_fail(self, validator, req_value, param_type): @@ -469,6 +488,39 @@ def test_nested_parameters(self, validator): ) assert validated_parameters["key1"]["foo"] == "bar" + @patch("bg_utils.mongo.models.RequestFile.objects") + def test_bytes_invalid_id(self, objects_mock, validator): + objects_mock.get.side_effect = DoesNotExist + param = make_param(key="key1", type="Bytes") + request = make_request( + parameters={ + "key1": { + "storage_type": "gridfs", + "id": "does_not_exist", + "filename": "some_filename", + } + } + ) + with pytest.raises(ModelValidationError): + validator.get_and_validate_parameters(request, Mock(parameters=[param])) + + @patch("bg_utils.mongo.models.RequestFile.objects", Mock()) + def test_bytes_success(self, validator): + param = make_param(key="key1", type="Bytes") + expected_value = { + "storage_type": "gridfs", + "id": "pretend_this_exists", + "filename": "some_filename", + } + fake_rf = RequestFile.objects.get() + request = make_request(parameters={"key1": expected_value}) + validated_parameters = validator.get_and_validate_parameters( + request, Mock(parameters=[param]) + ) + assert validated_parameters["key1"] == expected_value + assert fake_rf.request == request + assert fake_rf.save.call_count == 1 + class TestValidateChoices(object): @pytest.mark.parametrize( @@ -749,7 +801,7 @@ def test_validate_choices_static_bad_type(self, validator): '["a", "b", "value"]', '["a", {"text": "text", "value": "value"}]', '["a", {"text": "b", "value": "2"}, "value"]', - ] + ], ) def test_validate_url_choices(self, validator, response): session_mock = Mock() diff --git a/test/thrift/handler_test.py b/test/thrift/handler_test.py index 985f553..df7e430 100644 --- a/test/thrift/handler_test.py +++ b/test/thrift/handler_test.py @@ -1,5 +1,6 @@ import unittest +import json import mongoengine from mock import MagicMock, Mock, PropertyMock, patch, call from pyrabbit2.http import HTTPError @@ -46,8 +47,9 @@ def test_process_request_bad_backend(self): @patch("bg_utils.mongo.models.Request.find_or_none") def test_process_request(self, find_mock): request = Mock() + command = Mock(parameter_keys_by_type=Mock(return_value=[])) find_mock.return_value = request - self.request_validator.validate_request.return_value = request + self.request_validator.validate_request.return_value = request, command self.handler.processRequest("id") find_mock.assert_called_once_with("id") @@ -55,11 +57,32 @@ def test_process_request(self, find_mock): request, confirm=True, mandatory=True ) + @patch("bg_utils.mongo.models.Request.find_or_none") + def test_process_bytes_request(self, find_mock): + request = Mock() + command = Mock(parameter_keys_by_type=Mock(return_value=["bytes_param"])) + find_mock.return_value = request + self.request_validator.validate_request.return_value = request, command + + self.handler.processRequest("id") + find_mock.assert_called_once_with("id") + expected_kwargs = { + "confirm": True, + "mandatory": True, + "headers": { + "resolve_parameters": json.dumps(["bytes_param"]).encode("utf-8") + }, + } + self.clients["pika"].publish_request.assert_called_once_with( + request, **expected_kwargs + ) + @patch("bg_utils.mongo.models.Request.find_or_none") def test_process_request_fail(self, find_mock): request = Mock() + command = Mock(parameter_keys_by_type=Mock(return_value=[])) find_mock.return_value = request - self.request_validator.validate_request.return_value = request + self.request_validator.validate_request.return_value = request, command self.clients["pika"].publish_request.return_value = False self.assertRaises(