diff --git a/src/azul/indexer/__init__.py b/src/azul/indexer/__init__.py index 6e673d9ec..f8228cc83 100644 --- a/src/azul/indexer/__init__.py +++ b/src/azul/indexer/__init__.py @@ -552,6 +552,9 @@ def spec_cls(cls) -> type[SourceSpec]: spec_cls, ref_cls = get_generic_type_params(cls, SourceSpec, SourceRef) return spec_cls + def with_prefix(self, prefix: Prefix) -> Self: + return attrs.evolve(self, spec=attrs.evolve(self.spec, prefix=prefix)) + class SourcedBundleFQIDJSON(BundleFQIDJSON): source: SourceJSON diff --git a/src/azul/plugins/__init__.py b/src/azul/plugins/__init__.py index 361ed1856..893499fb2 100644 --- a/src/azul/plugins/__init__.py +++ b/src/azul/plugins/__init__.py @@ -640,7 +640,7 @@ def partition_source(self, prefix = Prefix.for_main_deployment(count) else: prefix = Prefix.for_lesser_deployment(count) - source = attr.evolve(source, spec=attr.evolve(source.spec, prefix=prefix)) + source = source.with_prefix(prefix) return source @abstractmethod diff --git a/test/integration_test.py b/test/integration_test.py index 44dd36726..005f8d346 100644 --- a/test/integration_test.py +++ b/test/integration_test.py @@ -3,10 +3,8 @@ ) from collections.abc import ( Iterable, - Iterator, Mapping, Sequence, - Set, ) from concurrent.futures.thread import ( ThreadPoolExecutor, @@ -20,10 +18,8 @@ BytesIO, TextIOWrapper, ) -import itertools from itertools import ( count, - starmap, ) import json import os @@ -43,7 +39,6 @@ Callable, ContextManager, IO, - Optional, Protocol, TypedDict, cast, @@ -80,6 +75,7 @@ first, grouper, one, + only, ) from openapi_spec_validator import ( validate_spec, @@ -107,6 +103,9 @@ from azul.chalice import ( AzulChaliceApp, ) +from azul.collections import ( + alist, +) from azul.drs import ( AccessMethod, ) @@ -117,6 +116,7 @@ http_client, ) from azul.indexer import ( + Prefix, SourceJSON, SourceRef, SourceSpec, @@ -152,7 +152,6 @@ Link, ) from azul.plugins.repository.tdr_anvil import ( - BundleType, TDRAnvilBundleFQID, ) from azul.portal_service import ( @@ -286,73 +285,34 @@ def managed_access_sources_by_catalog(self managed_access_sources[catalog].add(ref) return managed_access_sources - def _list_partitions(self, - catalog: CatalogName, - *, - min_bundles: int, - public_1st: bool - ) -> Iterator[tuple[SourceRef, str, list[SourcedBundleFQID]]]: - """ - Iterate through the sources in the given catalog and yield partitions of - bundle FQIDs until a desired minimum number of bundles are found. For - each emitted source, every partition is included, even if it's empty. - """ - total_bundles = 0 - sources = sorted(config.sources(catalog)) - self.random.shuffle(sources) - if public_1st: - managed_access_sources = frozenset( + def _choose_source(self, + catalog: CatalogName, + *, + public: bool | None = None + ) -> SourceRef | None: + plugin = self.repository_plugin(catalog) + sources = set(config.sources(catalog)) + if public is not None: + ma_sources = { str(source.spec) + # This would raise a KeyError during the can bundle script test + # due to it using a mock catalog, so we only evaluate it when + # it's actually needed for source in self.managed_access_sources_by_catalog[catalog] - ) - index = first( - i - for i, source in enumerate(sources) - if source not in managed_access_sources - ) - sources[0], sources[index] = sources[index], sources[0] - plugin = self.azul_client.repository_plugin(catalog) - # This iteration prefers sources occurring first, so we shuffle them - # above to neutralize the bias. - for source in sources: - source = plugin.resolve_source(source) - source = plugin.partition_source(catalog, source) - for prefix in source.spec.prefix.partition_prefixes(): - new_fqids = self.azul_client.list_bundles(catalog, source, prefix) - total_bundles += len(new_fqids) - yield source, prefix, new_fqids - # We postpone this check until after we've yielded all partitions in - # the current source to ensure test coverage for handling multiple - # partitions per source - if total_bundles >= min_bundles: - break + } + self.assertIsSubset(ma_sources, sources) + if public is True: + sources -= ma_sources + elif public is False: + sources &= ma_sources + else: + assert False, public + if len(sources) == 0: + assert public is False, 'Every catalog should contain a public source' + return None else: - log.warning('Checked all sources and found only %d bundles instead of the ' - 'expected minimum %d', total_bundles, min_bundles) - - def _list_managed_access_bundles(self, - catalog: CatalogName - ) -> Iterator[tuple[SourceRef, str, list[SourcedBundleFQID]]]: - sources = self.azul_client.catalog_sources(catalog) - # We need at least one managed_access bundle per IT. To index them with - # remote_reindex and avoid collateral bundles, we use as specific a - # prefix as possible. - for source in self.managed_access_sources_by_catalog[catalog]: - assert str(source.spec) in sources - source = self.repository_plugin(catalog).partition_source(catalog, source) - bundle_fqids = sorted( - bundle_fqid - for bundle_fqid in self.azul_client.list_bundles(catalog, source, prefix='') - if not ( - # DUOS bundles are too sparse to fulfill the managed access tests - config.is_anvil_enabled(catalog) - and cast(TDRAnvilBundleFQID, bundle_fqid).table_name is BundleType.duos - ) - ) - bundle_fqid = self.random.choice(bundle_fqids) - prefix = bundle_fqid.uuid[:8] - new_fqids = self.azul_client.list_bundles(catalog, source, prefix) - yield source, prefix, new_fqids + source = self.random.choice(sorted(sources)) + return plugin.resolve_source(source) class IndexingIntegrationTest(IntegrationTestCase, AlwaysTearDownTestCase): @@ -428,6 +388,8 @@ class Catalog: name: CatalogName bundles: set[SourcedBundleFQID] notifications: list[JSON] + public_source: SourceRef + ma_source: SourceRef | None def _wait_for_indexer(): self.azul_client.wait_for_indexer() @@ -444,12 +406,23 @@ def _wait_for_indexer(): catalogs: list[Catalog] = [] for catalog in config.integration_test_catalogs: if index: - notifications, fqids = self._prepare_notifications(catalog) + public_source = self._choose_source(catalog, public=True) + ma_source = self._choose_source(catalog, public=False) + notifications, fqids = self._prepare_notifications(catalog, + sources=alist(public_source, ma_source)) else: - notifications, fqids = [], set() + with self._service_account_credentials: + fqids = self._get_indexed_bundles(catalog) + indexed_sources = {fqid.source for fqid in fqids} + ma_sources = self.managed_access_sources_by_catalog[catalog] + public_source = one(s for s in indexed_sources if s not in ma_sources) + ma_source = only(s for s in indexed_sources if s in ma_sources) + notifications = [] catalogs.append(Catalog(name=catalog, bundles=fqids, - notifications=notifications)) + notifications=notifications, + public_source=public_source, + ma_source=ma_source)) if index: for catalog in catalogs: @@ -465,12 +438,9 @@ def _wait_for_indexer(): self._test_manifest_tagging_race(catalog.name) self._test_dos_and_drs(catalog.name) self._test_repository_files(catalog.name) - if index: - bundle_fqids = catalog.bundles - else: - with self._service_account_credentials: - bundle_fqids = self._get_indexed_bundles(catalog.name) - self._test_managed_access(catalog=catalog.name, bundle_fqids=bundle_fqids) + self._test_managed_access(catalog=catalog.name, + public_source=catalog.public_source, + ma_source=catalog.ma_source) if index and delete: # FIXME: Test delete notifications @@ -802,8 +772,8 @@ def _check_endpoint(self, method: str, path: str, *, - args: Optional[Mapping[str, Any]] = None, - endpoint: Optional[furl] = None, + args: Mapping[str, Any] | None = None, + endpoint: furl | None = None, fetch: bool = False ) -> bytes: if endpoint is None: @@ -1207,7 +1177,7 @@ def _test_dos(self, catalog: CatalogName, file: FileInnerEntity): def _get_gs_url_content(self, url: furl, - size: Optional[int] = None + size: int | None = None ) -> BytesIO: self.assertEquals('gs', url.scheme) path = os.environ['GOOGLE_APPLICATION_CREDENTIALS'] @@ -1228,33 +1198,21 @@ def _validate_fastq_content(self, content: ReadableFileObject): self.assertTrue(lines[2].startswith(b'+')) def _prepare_notifications(self, - catalog: CatalogName + catalog: CatalogName, + sources: Iterable[SourceRef] ) -> tuple[JSONs, set[SourcedBundleFQID]]: - bundle_fqids: set[SourcedBundleFQID] = set() + plugin = self.repository_plugin(catalog) + bundle_fqids = set() notifications = [] - - def update(source: SourceRef, - prefix: str, - partition_bundle_fqids: Iterable[SourcedBundleFQID]): - bundle_fqids.update(partition_bundle_fqids) - notifications.append(self.azul_client.reindex_message(catalog, - source, - prefix)) - - list(starmap(update, self._list_managed_access_bundles(catalog))) - num_bundles = max(self.min_bundles - len(bundle_fqids), 1) - log.info('Selected %d bundles to satisfy managed access coverage; ' - 'selecting at least %d more', len(bundle_fqids), num_bundles) - # _list_partitions selects both public and managed access sources at random. - # If we don't index at least one public source, every request would need - # service account credentials and we couldn't compare the responses for - # public and managed access data. `public_1st` ensures that at least - # one of the sources will be public because sources are indexed starting - # with the first one yielded by the iteration. - list(starmap(update, self._list_partitions(catalog, - min_bundles=num_bundles, - public_1st=True))) - + for source in sources: + source = plugin.partition_source(catalog, source) + # Some partitions may be empty, but we include them anyway to + # ensure test coverage for handling multiple partitions per source + for partition_prefix in source.spec.prefix.partition_prefixes(): + bundle_fqids.update(self.azul_client.list_bundles(catalog, source, partition_prefix)) + notifications.append(self.azul_client.reindex_message(catalog, + source, + partition_prefix)) # Index some bundles again to test that we handle duplicate additions. # Note: random.choices() may pick the same element multiple times so # some notifications may end up being sent three or more times. @@ -1268,7 +1226,7 @@ def update(source: SourceRef, def _get_indexed_bundles(self, catalog: CatalogName, - filters: Optional[JSON] = None + filters: JSON | None = None ) -> set[SourcedBundleFQID]: indexed_fqids = set() hits = self._get_entities(catalog, 'bundles', filters) @@ -1286,7 +1244,7 @@ def _get_indexed_bundles(self, def _assert_catalog_complete(self, catalog: CatalogName, - bundle_fqids: Set[SourcedBundleFQID] + bundle_fqids: set[SourcedBundleFQID] ) -> None: with self.subTest('catalog_complete', catalog=catalog): expected_fqids = bundle_fqids @@ -1345,7 +1303,7 @@ def _assert_catalog_empty(self, catalog: CatalogName): def _get_entities(self, catalog: CatalogName, entity_type: EntityType, - filters: Optional[JSON] = None + filters: JSON | None = None ) -> MutableJSONs: entities = [] size = 100 @@ -1377,40 +1335,34 @@ def _assert_indices_exist(self, catalog: CatalogName): def _test_managed_access(self, catalog: CatalogName, - bundle_fqids: Set[SourcedBundleFQID] + public_source: SourceRef, + ma_source: SourceRef | None, ) -> None: with self.subTest('managed_access', catalog=catalog): - indexed_source_ids = {fqid.source.id for fqid in bundle_fqids} - managed_access_sources = self.managed_access_sources_by_catalog[catalog] - managed_access_source_ids = {source.id for source in managed_access_sources} - self.assertIsSubset(managed_access_source_ids, indexed_source_ids) - - if not managed_access_sources: + if ma_source is None: if config.deployment_stage in ('dev', 'sandbox'): # There should always be at least one managed-access source # indexed and tested on the default catalog for these deployments self.assertNotEqual(catalog, config.it_catalog_for(config.default_catalog)) self.skipTest(f'No managed access sources found in catalog {catalog!r}') - with self.subTest('managed_access_indices', catalog=catalog): - self._test_managed_access_indices(catalog, managed_access_source_ids) + self._test_managed_access_indices(catalog, public_source, ma_source) with self.subTest('managed_access_repository_files', catalog=catalog): - files = self._test_managed_access_repository_files(catalog, managed_access_source_ids) + files = self._test_managed_access_repository_files(catalog, ma_source) with self.subTest('managed_access_summary', catalog=catalog): self._test_managed_access_summary(catalog, files) with self.subTest('managed_access_repository_sources', catalog=catalog): - public_source_ids = self._test_managed_access_repository_sources(catalog, - indexed_source_ids, - managed_access_source_ids) - with self.subTest('managed_access_manifest', catalog=catalog): - source_id = self.random.choice(sorted(public_source_ids & indexed_source_ids)) - self._test_managed_access_manifest(catalog, files, source_id) + self._test_managed_access_repository_sources(catalog, + public_source, + ma_source) + with self.subTest('managed_access_manifest', catalog=catalog): + self._test_managed_access_manifest(catalog, files, public_source) def _test_managed_access_repository_sources(self, catalog: CatalogName, - indexed_source_ids: Set[str], - managed_access_source_ids: Set[str] - ) -> set[str]: + public_source: SourceRef, + ma_source: SourceRef + ) -> None: """ Test the managed access controls for the /repository/sources endpoint :return: the set of public sources @@ -1423,11 +1375,14 @@ def list_source_ids() -> set[str]: return {source['sourceId'] for source in cast(JSONs, response['sources'])} with self._service_account_credentials: - self.assertIsSubset(indexed_source_ids, list_source_ids()) + self.assertIsSubset({public_source.id, ma_source.id}, list_source_ids()) with self._public_service_account_credentials: public_source_ids = list_source_ids() + self.assertIn(public_source.id, public_source_ids) + self.assertNotIn(ma_source.id, public_source_ids) with self._unregistered_service_account_credentials: self.assertEqual(public_source_ids, list_source_ids()) + self.assertEqual(public_source_ids, list_source_ids()) invalid_auth = OAuth2('foo') with self.assertRaises(UnauthorizedError): TDRClient.for_registered_user(invalid_auth) @@ -1435,13 +1390,11 @@ def list_source_ids() -> set[str]: invalid_client = OAuth2Client(credentials_provider=invalid_provider) with self._authorization_context(invalid_client): self.assertEqual(401, self._get_url_unchecked(GET, url).status) - self.assertEqual(set(), list_source_ids() & managed_access_source_ids) - self.assertEqual(public_source_ids, list_source_ids()) - return public_source_ids def _test_managed_access_indices(self, catalog: CatalogName, - managed_access_source_ids: Set[str] + public_source: SourceRef, + ma_source: SourceRef ) -> JSONs: """ Test the managed-access controls for the /index/bundles and @@ -1451,11 +1404,6 @@ def _test_managed_access_indices(self, """ special_fields = self.metadata_plugin(catalog).special_fields - - def source_id_from_hit(hit: JSON) -> str: - sources: JSONs = hit['sources'] - return one(sources)[special_fields.source_id] - bundle_type = self._bundle_type(catalog) project_type = self._project_type(catalog) @@ -1468,31 +1416,22 @@ def source_id_from_hit(hit: JSON) -> str: hits = self._get_entities(catalog, project_type, filters=filters) if accessible is None: unfiltered_hits = hits - accessible_sources, inaccessible_sources = set(), set() for hit in hits: - source_id = source_id_from_hit(hit) - source_accessible = source_id not in managed_access_source_ids + source_id = one(hit['sources'])[special_fields.source_id] + source_accessible = {public_source.id: True, ma_source.id: False}[source_id] hit_accessible = one(hit[project_type])[special_fields.accessible] self.assertEqual(source_accessible, hit_accessible, hit['entryId']) if accessible is not None: self.assertEqual(accessible, hit_accessible) - if source_accessible: - accessible_sources.add(source_id) - else: - inaccessible_sources.add(source_id) - self.assertIsDisjoint(accessible_sources, inaccessible_sources) - self.assertIsDisjoint(managed_access_source_ids, accessible_sources) - self.assertEqual(set() if accessible else managed_access_source_ids, - inaccessible_sources) self.assertIsNotNone(unfiltered_hits, 'Cannot recover from subtest failure') bundle_fqids = self._get_indexed_bundles(catalog) hit_source_ids = {fqid.source.id for fqid in bundle_fqids} - self.assertEqual(set(), hit_source_ids & managed_access_source_ids) + self.assertEqual(hit_source_ids, {public_source.id}) source_filter = { special_fields.source_id: { - 'is': list(managed_access_source_ids) + 'is': [ma_source.id] } } params = { @@ -1501,18 +1440,18 @@ def source_id_from_hit(hit: JSON) -> str: } url = config.service_endpoint.set(path=('index', bundle_type), args=params) response = self._get_url_unchecked(GET, url) - self.assertEqual(403 if managed_access_source_ids else 200, response.status) + self.assertEqual(403, response.status) with self._service_account_credentials: bundle_fqids = self._get_indexed_bundles(catalog, filters=source_filter) hit_source_ids = {fqid.source.id for fqid in bundle_fqids} - self.assertEqual(managed_access_source_ids, hit_source_ids) + self.assertEqual({ma_source.id}, hit_source_ids) return unfiltered_hits def _test_managed_access_repository_files(self, catalog: CatalogName, - managed_access_source_ids: set[str] + ma_source: SourceRef ) -> JSONs: """ Test the managed access controls for the /repository/files endpoint @@ -1522,7 +1461,7 @@ def _test_managed_access_repository_files(self, with self._service_account_credentials: files = self._get_entities(catalog, 'files', filters={ special_fields.source_id: { - 'is': list(managed_access_source_ids) + 'is': [ma_source.id] } }) managed_access_file_urls = { @@ -1559,7 +1498,7 @@ def _get_summary_file_count() -> int: def _test_managed_access_manifest(self, catalog: CatalogName, files: JSONs, - source_id: str + public_source: SourceRef ) -> None: """ Test the managed access controls for the /manifest/files endpoint and @@ -1581,7 +1520,7 @@ def bundle_uuids(hit: JSON) -> set[str]: for file in files if len(file['sources']) == 1 )) - filters = {special_fields.source_id: {'is': [source_id]}} + filters = {special_fields.source_id: {'is': [public_source.id]}} params = {'size': 1, 'catalog': catalog, 'filters': json.dumps(filters)} files_url = furl(url=endpoint, path='index/files', args=params) response = self._get_url_json(GET, files_url) @@ -1946,13 +1885,10 @@ def test_can_bundle_canned_repository(self): self._test_catalog(mock_catalog) def bundle_fqid(self, catalog: CatalogName) -> SourcedBundleFQID: - # Skip through empty partitions - bundle_fqids = itertools.chain.from_iterable( - bundle_fqids - for _, _, bundle_fqids in self._list_partitions(catalog, - min_bundles=1, - public_1st=False) - ) + source = self._choose_source(catalog) + # The plugin will raise an exception if the source lacks a prefix + source = source.with_prefix(Prefix.of_everything) + bundle_fqids = self.repository_plugin(catalog).list_bundles(source, '') return self.random.choice(sorted(bundle_fqids)) def _can_bundle(self,