diff --git a/.github/workflows/monodocs_build.yml b/.github/workflows/monodocs_build.yml index 6ecaa2cf87..46da3c14a5 100644 --- a/.github/workflows/monodocs_build.yml +++ b/.github/workflows/monodocs_build.yml @@ -57,4 +57,5 @@ jobs: DOCSEARCH_API_KEY: fake_docsearch_api_key # must be set to get doc build to succeed run: | conda activate monodocs-env + pip install grpcio-health-checking==1.49.0 make -C docs clean html SPHINXOPTS="-W -vvv" diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 7dda4f5588..f89257ed43 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -294,6 +294,11 @@ jobs: tags: localhost:30000/flytekit:dev cache-from: type=gha cache-to: type=gha,mode=max + - name: Install dependencies + run: | + pip install grpcio + pip install grpcio-tools + pip install grpcio-health-checking - name: Integration Test with coverage env: FLYTEKIT_IMAGE: localhost:30000/flytekit:dev diff --git a/dev-requirements.in b/dev-requirements.in index 5241f02605..683b98d0b9 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -60,3 +60,5 @@ ipykernel orjson kubernetes>=12.0.1 + +grpcio-health-checking==1.49.0 diff --git a/dev-requirements.txt b/dev-requirements.txt index 9acff98cb6..c8911ee6f7 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -564,3 +564,5 @@ zipp==3.19.1 # The following packages are considered to be unsafe in a requirements file: # setuptools + +grpcio-health-checking==1.49.0 diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index df643d554d..72236062f0 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing import grpc @@ -10,6 +11,7 @@ from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service from flyteidl.service import signal_pb2_grpc as signal_service from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub +from grpc_health.v1 import health_pb2, health_pb2_grpc from flytekit.clients.auth_helper import ( get_channel, @@ -18,6 +20,12 @@ wrap_exceptions_channel, ) from flytekit.configuration import PlatformConfig +from flytekit.exceptions.system import FlyteSystemUnavailableException +from flytekit.exceptions.user import ( + FlyteEntityAlreadyExistsException, + FlyteEntityNotExistException, + FlyteInvalidInputException, +) from flytekit.loggers import logger @@ -51,12 +59,18 @@ def __init__(self, cfg: PlatformConfig, **kwargs): # 32KB for error messages, 20MB for actual messages. options = (("grpc.max_metadata_size", 32 * 1024), ("grpc.max_receive_message_length", 20 * 1024 * 1024)) self._cfg = cfg - self._channel = wrap_exceptions_channel( - cfg, - upgrade_channel_to_authenticated( - cfg, upgrade_channel_to_proxy_authenticated(cfg, get_channel(cfg, options=options)) - ), - ) + base_channel = get_channel(cfg, options=options) + + if self.check_grpc_health_with_authentication(base_channel): + self._channel = wrap_exceptions_channel(cfg, base_channel) + else: + self._channel = wrap_exceptions_channel( + cfg, + upgrade_channel_to_authenticated( + cfg, upgrade_channel_to_proxy_authenticated(cfg, get_channel(cfg, options=options)) + ), + ) + self._stub = _admin_service.AdminServiceStub(self._channel) self._signal = signal_service.SignalServiceStub(self._channel) self._dataproxy_stub = dataproxy_service.DataProxyServiceStub(self._channel) @@ -67,6 +81,27 @@ def __init__(self, cfg: PlatformConfig, **kwargs): # metadata will hold the value of the token to send to the various endpoints. self._metadata = None + @staticmethod + def check_grpc_health_with_authentication(in_channel): + health_stub = health_pb2_grpc.HealthStub(in_channel) + request = health_pb2.HealthCheckRequest() + try: + response = health_stub.Check(request) + if response.status == health_pb2.HealthCheckResponse.SERVING: + logging.info("Service is healthy and ready to serve.") + return True + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAUTHENTICATED: + return False + elif e.code() == grpc.StatusCode.ALREADY_EXISTS: + raise FlyteEntityAlreadyExistsException() from e + elif e.code() == grpc.StatusCode.NOT_FOUND: + raise FlyteEntityNotExistException() from e + elif e.code() == grpc.StatusCode.INVALID_ARGUMENT: + raise FlyteInvalidInputException(request) from e + elif e.code() == grpc.StatusCode.UNAVAILABLE: + raise FlyteSystemUnavailableException() from e + @classmethod def with_root_certificate(cls, cfg: PlatformConfig, root_cert_file: str) -> RawSynchronousFlyteClient: b = None diff --git a/tests/flytekit/unit/clients/test_friendly.py b/tests/flytekit/unit/clients/test_friendly.py index b553ae78a0..6dbf0f6ac4 100644 --- a/tests/flytekit/unit/clients/test_friendly.py +++ b/tests/flytekit/unit/clients/test_friendly.py @@ -8,10 +8,11 @@ from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient from flytekit.configuration import PlatformConfig from flytekit.models.project import Project as _Project - +from grpc_health.v1 import health_pb2 @mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.update_project") -def test_update_project(mock_raw_update_project): +@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.check_grpc_health_with_authentication", return_value=health_pb2.HealthCheckResponse.SERVING) +def test_update_project(mock_check_health, mock_raw_update_project): client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True)) project = _Project("foo", "name", "description", state=_Project.ProjectState.ACTIVE) client.update_project(project) @@ -19,7 +20,8 @@ def test_update_project(mock_raw_update_project): @mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.list_projects") -def test_list_projects_paginated(mock_raw_list_projects): +@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.check_grpc_health_with_authentication", return_value=health_pb2.HealthCheckResponse.SERVING) +def test_list_projects_paginated(mock_check_health, mock_raw_list_projects): client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True)) client.list_projects_paginated(limit=100, token="") project_list_request = _project_pb2.ProjectListRequest(limit=100, token="", filters=None, sort_by=None) @@ -27,7 +29,8 @@ def test_list_projects_paginated(mock_raw_list_projects): @mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.create_upload_location") -def test_create_upload_location(mock_raw_create_upload_location): +@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.check_grpc_health_with_authentication", return_value=health_pb2.HealthCheckResponse.SERVING) +def test_create_upload_location(mock_check_health, mock_raw_create_upload_location): client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True)) client.get_upload_signed_url("foo", "bar", bytes(), "baz.qux", timedelta(minutes=42), add_content_md5_metadata=True) duration_pb = Duration() diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index ee4e516354..6f9b46e6f0 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -4,11 +4,13 @@ from flytekit.clients.raw import RawSynchronousFlyteClient from flytekit.configuration import PlatformConfig - +from grpc_health.v1 import health_pb2 @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") -def test_update_project(mock_channel, mock_admin): +@mock.patch.object(RawSynchronousFlyteClient, "check_grpc_health_with_authentication", return_value=True) +def test_update_project(mock_check_health, mock_channel, mock_admin): + mock_health_stub = mock.Mock() client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) project = _project_pb2.Project(id="foo", name="name", description="description", state=_project_pb2.Project.ACTIVE) client.update_project(project) @@ -17,7 +19,8 @@ def test_update_project(mock_channel, mock_admin): @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") -def test_list_projects_paginated(mock_channel, mock_admin): +@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.check_grpc_health_with_authentication", return_value=health_pb2.HealthCheckResponse.SERVING) +def test_list_projects_paginated(mock_check_health, mock_channel, mock_admin): client = RawSynchronousFlyteClient(PlatformConfig(endpoint="a.b.com", insecure=True)) project_list_request = _project_pb2.ProjectListRequest(limit=100, token="", filters=None, sort_by=None) client.list_projects(project_list_request)