diff --git a/shotgun_api3/shotgun.py b/shotgun_api3/shotgun.py index fbe6c6c0..323c6da6 100755 --- a/shotgun_api3/shotgun.py +++ b/shotgun_api3/shotgun.py @@ -134,7 +134,7 @@ def __init__(self, host, meta): self.is_dev = False self.version = tuple(self.version[:3]) - self._ensure_json_supported() + self.ensure_json_supported() def _ensure_support(self, feature, raise_hell=True): @@ -157,19 +157,30 @@ def _ensure_support(self, feature, raise_hell=True): return True - def _ensure_json_supported(self): + def ensure_json_supported(self): """Wrapper for ensure_support""" - self._ensure_support({ + return self._ensure_support({ 'version': (2, 4, 0), 'label': 'JSON API' }) - def ensure_include_archived_projects(self): + def ensure_include_archived_projects(self, value=True): """Wrapper for ensure_support""" - self._ensure_support({ + # This defaults to True on the server + # So we only need to raise a version error if it's False + return self._ensure_support({ 'version': (5, 3, 14), 'label': 'include_archived_projects parameter' - }) + }, (value == False)) + + def ensure_include_template_projects(self, value=False): + """Wrapper for ensure_support""" + # This defaults to False on the server + # So we only need to raise a version error if it's True + return self._ensure_support({ + 'version': (6, 0, 0), + 'label': 'include_template_projects parameter' + }, (value == True)) def ensure_per_project_customization(self): """Wrapper for ensure_support""" @@ -485,7 +496,8 @@ def info(self): return self._call_rpc("info", None, include_auth_params=False) def find_one(self, entity_type, filters, fields=None, order=None, - filter_operator=None, retired_only=False, include_archived_projects=True): + filter_operator=None, retired_only=False, + include_archived_projects=True, include_template_projects=False): """Calls the find() method and returns the first result, or None. :param entity_type: Required, entity type (string) to find. @@ -504,16 +516,26 @@ def find_one(self, entity_type, filters, fields=None, order=None, :param limit: Optional, number of entities to return per page. Defaults to 0 which returns all entities that match. - :param page: Optional, page of results to return. By default all - results are returned. Use together with limit. - :param retired_only: Optional, flag to return only entities that have been retried. Defaults to False which returns only entities which have not been retired. + + :param page: Optional, page of results to return. By default all + results are returned. Use together with limit. + + :param include_archived_projects: Optional, flag to include entities + whose projects have been archived. Default: True + + :param include_template_projects: Optional, flag to include entities + belonging to template projects. Default: False + + :returns: dict of requested entity's fields, or None if not found. """ results = self.find(entity_type, filters, fields, order, - filter_operator, 1, retired_only, include_archived_projects=include_archived_projects) + filter_operator, 1, retired_only, + include_archived_projects=include_archived_projects, + include_template_projects=include_template_projects) if results: return results[0] @@ -521,7 +543,7 @@ def find_one(self, entity_type, filters, fields=None, order=None, def find(self, entity_type, filters, fields=None, order=None, filter_operator=None, limit=0, retired_only=False, page=0, - include_archived_projects=True): + include_archived_projects=True, include_template_projects=False): """Find entities matching the given filters. :param entity_type: Required, entity type (string) to find. @@ -548,7 +570,10 @@ def find(self, entity_type, filters, fields=None, order=None, have not been retired. :param include_archived_projects: Optional, flag to include entities - whose projects have been archived + whose projects have been archived. Default: True + + :param include_template_projects: Optional, flag to include entities + belonging to template projects. Default: False :returns: list of the dicts for each entity with the requested fields, and their id and type. @@ -567,18 +592,15 @@ def find(self, entity_type, filters, fields=None, order=None, raise ShotgunError("Deprecated: Use of filter_operator for find()" " is not valid any more. See the documentation on find()") - if not include_archived_projects: - # This defaults to True on the server (no argument is sent) - # So we only need to check the server version if it is False - self.server_caps.ensure_include_archived_projects() - - params = self._construct_read_parameters(entity_type, fields, filters, retired_only, - order, - include_archived_projects) + order) + + params = self._construct_flag_parameters(params, + include_archived_projects, + include_template_projects) if limit and limit <= self.config.records_per_page: params["paging"]["entities_per_page"] = limit @@ -615,26 +637,24 @@ def find(self, entity_type, filters, fields=None, order=None, return self._parse_records(records) - def _construct_read_parameters(self, entity_type, fields, filters, retired_only, - order, - include_archived_projects): - params = {} - params["type"] = entity_type - params["return_fields"] = fields or ["id"] - params["filters"] = filters - params["return_only"] = (retired_only and 'retired') or "active" - params["return_paging_info"] = True - params["paging"] = { "entities_per_page": self.config.records_per_page, - "current_page": 1 } - - if include_archived_projects is False: - # Defaults to True on the server, so only pass it if it's False - params["include_archived_projects"] = False + order): + + params = { + "type": entity_type, + "return_fields": fields or ["id"], + "filters": filters, + "return_only": (retired_only and 'retired') or "active", + "return_paging_info": True, + "paging": { + "entities_per_page": self.config.records_per_page, + "current_page": 1 + } + } if order: sort_list = [] @@ -648,6 +668,21 @@ def _construct_read_parameters(self, 'direction' : sort['direction'] }) params['sorts'] = sort_list + + return params + + + def _construct_flag_parameters(self, + params, + include_archived_projects, + include_template_projects): + + if self.server_caps.ensure_include_archived_projects(include_archived_projects): + params["include_archived_projects"] = include_archived_projects + + if self.server_caps.ensure_include_template_projects(include_template_projects): + params["include_template_projects"] = include_template_projects + return params @@ -665,7 +700,8 @@ def summarize(self, summary_fields, filter_operator=None, grouping=None, - include_archived_projects=True): + include_archived_projects=True, + include_template_projects=False): """ Return group and summary information for entity_type for summary_fields based on the given filters. @@ -678,18 +714,13 @@ def summarize(self, if isinstance(filters, (list, tuple)): filters = _translate_filters(filters, filter_operator) - if not include_archived_projects: - # This defaults to True on the server (no argument is sent) - # So we only need to check the server version if it is False - self.server_caps.ensure_include_archived_projects() - params = {"type": entity_type, "summaries": summary_fields, "filters": filters} - if include_archived_projects is False: - # Defaults to True on the server, so only pass it if it's False - params["include_archived_projects"] = False + params = self._construct_flag_parameters(params, + include_archived_projects, + include_template_projects) if grouping != None: params['grouping'] = grouping @@ -1624,6 +1655,7 @@ def _call_rpc(self, method, params, include_auth_params=True, first=False): """ + log_time = datetime.datetime.now() LOG.debug("Starting rpc call to %s with params %s" % ( method, params)) @@ -1638,7 +1670,10 @@ def _call_rpc(self, method, params, include_auth_params=True, first=False): } http_status, resp_headers, body = self._make_call("POST", self.config.api_path, encoded_payload, req_headers) - LOG.debug("Completed rpc call to %s" % (method)) + + log_time = datetime.datetime.now() - log_time + LOG.debug("Completed rpc call to %s in %s" % (method, str(log_time))) + try: self._parse_http_status(http_status) except ProtocolError, e: diff --git a/tests/base.py b/tests/base.py index 201844b4..39e0ebc7 100644 --- a/tests/base.py +++ b/tests/base.py @@ -3,13 +3,14 @@ import unittest from ConfigParser import ConfigParser - import mock import shotgun_api3 as api from shotgun_api3.shotgun import json from shotgun_api3.shotgun import ServerCapabilities +import logging + CONFIG_PATH = 'tests/config' class TestBase(unittest.TestCase): @@ -35,6 +36,10 @@ def __init__(self, *args, **kws): def setUp(self, auth_mode='ApiUser'): + + self.LOG = logging.getLogger("shotgun_api3") + self.LOG.setLevel(logging.WARN) + self.config = SgTestConfig() self.config.read_config(CONFIG_PATH) self.human_login = self.config.human_login @@ -185,10 +190,14 @@ def _setup_mock_data(self): class LiveTestBase(TestBase): '''Test base for tests relying on connection to server.''' + def setUp(self, auth_mode='ApiUser'): super(LiveTestBase, self).setUp(auth_mode) + self.sg_version = self.sg.info()['version'][:3] + self._setup_db(self.config) + if self.sg.server_caps.version and \ self.sg.server_caps.version >= (3, 3, 0) and \ (self.sg.server_caps.host.startswith('0.0.0.0') or \ @@ -197,17 +206,22 @@ def setUp(self, auth_mode='ApiUser'): else: self.server_address = self.sg.server_caps.host + def _setup_db(self, config): data = {'name':self.config.project_name} self.project = _find_or_create_entity(self.sg, 'Project', data) + self.template_project = _find_or_create_entity(self.sg, 'Project', { + 'name': 'Template Project', + 'is_template': True + }) + data = {'name':self.config.human_name, 'login':self.config.human_login, 'password_proxy':self.config.human_password} if self.sg_version >= (3, 0, 0): data['locked_until'] = None - self.human_user = _find_or_create_entity(self.sg, 'HumanUser', data) data = {'code':self.config.asset_code, @@ -256,6 +270,12 @@ def _setup_db(self, config): keys = ['title','project', 'sg_priority'] self.ticket = _find_or_create_entity(self.sg, 'Ticket', data, keys) + data = {'project': self.template_project, + 'title': self.config.ticket_title, + 'sg_priority': '1'} + keys = ['title', 'project', 'sg_priority'] + self.template_ticket = _find_or_create_entity(self.sg, 'Ticket', data, keys) + keys = ['code'] data = {'code':'api wrapper test storage', 'mac_path':'nowhere', diff --git a/tests/test_api.py b/tests/test_api.py index f4a004c5..eccae70d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -439,7 +439,7 @@ def test_simple_summary(self): assert(result['groups'][0]['summaries']) assert(result['summaries']) - def test_summary_include_archived_projects(self): + def _test_summary_include_archived_projects(self): if self.sg.server_caps.version > (5, 3, 13): # archive project self.sg.update('Project', self.project['id'], {'archived':True}) @@ -1345,7 +1345,7 @@ def test_zero_is_not_none(self): result = self.sg.find_one( 'Asset', [['id','is',self.asset['id']],[num_field, 'is_not', None]] ,[num_field] ) self.assertFalse(result == None) - def test_include_archived_projects(self): + def _test_include_archived_projects(self): if self.sg.server_caps.version > (5, 3, 13): # Ticket #25082 result = self.sg.find_one('Shot', [['id','is',self.shot['id']]]) diff --git a/tests/test_client.py b/tests/test_client.py index 7460b597..21a5182f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -74,18 +74,18 @@ def test_server_version_json(self): sc = ServerCapabilities("foo", {"version" : (2,4,0)}) sc.version = (2,3,99) - self.assertRaises(api.ShotgunError, sc._ensure_json_supported) + self.assertRaises(api.ShotgunError, sc.ensure_json_supported) self.assertRaises(api.ShotgunError, ServerCapabilities, "foo", {"version" : (2,2,0)}) sc.version = (0,0,0) - self.assertRaises(api.ShotgunError, sc._ensure_json_supported) + self.assertRaises(api.ShotgunError, sc.ensure_json_supported) sc.version = (2,4,0) - sc._ensure_json_supported() + sc.ensure_json_supported() sc.version = (2,5,0) - sc._ensure_json_supported() + sc.ensure_json_supported() def test_session_uuid(self): diff --git a/tests/test_flags.py b/tests/test_flags.py new file mode 100644 index 00000000..38f320be --- /dev/null +++ b/tests/test_flags.py @@ -0,0 +1,236 @@ +"""Test using the Shotgun API flags.""" + +import shotgun_api3 +from shotgun_api3 import * +from shotgun_api3.lib.httplib2 import Http + +import base + +import logging + +class TestFlags(base.LiveTestBase): + + def setUp(self): + super(TestFlags, self).setUp() + + # We will need the created_at field for the shot + fields = self.shot.keys()[:] + fields.append('created_at') + self.shot = self.sg.find_one('Shot', [['id', 'is', self.shot['id']]], fields) + + + def test_summary_include_archived_projects(self): + """Test summary with 'include_archived_projects'""" + + if self.sg.server_caps.version > (5, 3, 13): + # Ticket #25082 ability to hide archived projects in summary + + summaries = [{'field': 'id', 'type': 'count'}] + grouping = [{'direction': 'asc', 'field': 'id', 'type': 'exact'}] + filters = [['project', 'is', self.project]] + + # archive project + self.sg.update('Project', self.project['id'], {'archived':True}) + + # should get no result + result = self.sg.summarize('Shot', + filters=filters, + summary_fields=summaries, + grouping=grouping, + include_archived_projects=False) + self.assertEquals(result['summaries']['id'], 0) + + # should get result + result = self.sg.summarize('Shot', + filters=filters, + summary_fields=summaries, + grouping=grouping, + include_archived_projects=True) + self.assertEquals(result['summaries']['id'], 1) + + # setting defaults to True, should get result + result = self.sg.summarize('Shot', + filters=filters, + summary_fields=summaries, + grouping=grouping) + self.assertEquals(result['summaries']['id'], 1) + + # reset project + self.sg.update('Project', self.project['id'], {'archived':False}) + + + def test_summary_include_template_projects(self): + """Test summary with 'include_template_projects'""" + + # Ticket #28441 + + self.LOG.setLevel(logging.DEBUG) + + summaries = [{'field': 'id', 'type': 'count'}] + grouping = [{'direction': 'asc', 'field': 'id', 'type': 'exact'}] + filters = [['project', 'is', self.project]] + + # control + result = self.sg.summarize('Shot', + filters=filters, + summary_fields=summaries, + grouping=grouping) + self.assertEquals(result['summaries']['id'], 1) + + # backwards-compatibility + if self.sg.server_caps.version < (6, 0, 0): + + # flag should not be passed, should get result + self.assertRaises(ShotgunError, self.sg.summarize, 'Shot', + filters=filters, + summary_fields=summaries, + grouping=grouping, + include_template_projects=True) + + # test new features + if self.sg.server_caps.version >= (6, 0, 0): + # set as template project + self.sg.update('Project', self.project['id'], {'is_template':True}) + + # should get result + result = self.sg.summarize('Shot', + filters=filters, + summary_fields=summaries, + grouping=grouping, + include_template_projects=True) + self.assertEquals(result['summaries']['id'], 1) + + # should get no result + result = self.sg.summarize('Shot', + filters=filters, + summary_fields=summaries, + grouping=grouping, + include_template_projects=False) + self.assertEquals(result['summaries']['id'], 0) + + # setting defaults to False, should get no result + result = self.sg.summarize('Shot', + filters=filters, + summary_fields=summaries, + grouping=grouping) + self.assertEquals(result['summaries']['id'], 0) + + # reset project + self.sg.update('Project', self.project['id'], {'is_template':False}) + + self.LOG.setLevel(logging.WARN) + + + def test_include_archived_projects(self): + """Test find with 'include_archived_projects'""" + + # Ticket #25082 + + filters = [['id', 'is', self.shot['id']]] + + if self.sg.server_caps.version > (5, 3, 13): + + # control + result = self.sg.find_one('Shot', + filters=filters) + self.assertEquals(result['id'], self.shot['id']) + + # archive project + self.sg.update('Project', self.project['id'], {'archived':True}) + + # should get no result + result = self.sg.find_one('Shot', + filters=filters, + include_archived_projects=False) + self.assertEquals(result, None) + + # should get result + result = self.sg.find_one('Shot', + filters=filters, + include_archived_projects=True) + self.assertEquals(result['id'], self.shot['id']) + + # setting defaults to True, should get result + result = self.sg.find_one('Shot', + filters=filters) + self.assertEquals(result['id'], self.shot['id']) + + # reset project + self.sg.update('Project', self.project['id'], {'archived':False}) + + + def test_include_template_projects(self): + """Test find with 'include_template_projects'""" + + # Ticket #28441 + + self.LOG.setLevel(logging.DEBUG) + + filters = [['id', 'is', self.shot['id']]] + + # control + result = self.sg.find_one('Shot', + filters=filters) + self.assertEquals(result['id'], self.shot['id']) + + # backwards-compatibility + if self.sg.server_caps.version < (6, 0, 0): + + self.assertRaises(ShotgunError, self.sg.find_one, 'Shot', + filters=filters, + include_template_projects=True) + + # test new features + if self.sg.server_caps.version >= (6, 0, 0): + + # set as template project + self.sg.update('Project', self.project['id'], {'is_template':True}) + + # should get result + result = self.sg.find_one('Shot', + filters=filters, + include_template_projects=True) + self.assertEquals(result['id'], self.shot['id']) + + # should get no result + result = self.sg.find_one('Shot', + filters=filters, + include_template_projects=False) + self.assertEquals(result, None) + + # setting defaults to False, should get no result + result = self.sg.find_one('Shot', + filters=filters) + self.assertEquals(result, None) + + # reset project + self.sg.update('Project', self.project['id'], {'is_template':False}) + + self.LOG.setLevel(logging.WARN) + + + def test_find_template_project(self): + """Test find the 'Template Project'""" + + # Ticket #28441 + + if self.sg.server_caps.version >= (6, 0, 0): + + # find by name + result = self.sg.find_one('Project', [['name', 'is', self.template_project['name']]]) + self.assertEquals(result['id'], self.template_project['id']) + + # find by ID + result = self.sg.find_one('Project', [['id', 'is', self.template_project['id']]]) + self.assertEquals(result['id'], self.template_project['id']) + + # find attached entity + result = self.sg.find_one( + 'Ticket', + [ + ['id', 'is', self.template_ticket['id']], + ['project.Project.name', 'is', 'Template Project'], + ['project.Project.layout_project', 'is', None] + ] + ) + self.assertEquals(result['id'], self.template_ticket['id'])