From cab701c14cd92183399607f8ba4e77b31ef258d2 Mon Sep 17 00:00:00 2001 From: Mecoli1219 Date: Mon, 21 Oct 2024 14:56:50 -0700 Subject: [PATCH 1/2] Restrict version mismatch Signed-off-by: Mecoli1219 --- flytekit/core/python_auto_container.py | 24 +++++++++++++++++++++-- flytekit/remote/executions.py | 5 ++++- flytekit/remote/remote.py | 8 +++++++- tests/flytekit/unit/core/test_resolver.py | 21 +++++++++++++++++++- 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 1466c351ac..8883f42637 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -294,11 +294,31 @@ def name(self) -> str: def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: _, entity_name, *_ = loader_args import gzip + import sys import cloudpickle - with gzip.open(PICKLE_FILE_PATH, "r") as f: - entity_dict = cloudpickle.load(f) + try: + with gzip.open(PICKLE_FILE_PATH, "r") as f: + entity_dict = cloudpickle.load(f) + except TypeError: + raise RuntimeError( + "The Python version is smaller than the version used to create the pickle file. " + f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. " + "Please try using the same Python version to create the pickle file or use another " + "container image with a matching version." + ) + + pickled_version = entity_dict["metadata"]["python_version"].split(".") + if sys.version_info.major != int(pickled_version[0]) or sys.version_info.minor != int(pickled_version[1]): + raise RuntimeError( + "The Python version used to create the pickle file is different from the current Python version. " + f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. " + f"Python version used to create the pickle file: {entity_dict['metadata']['python_version']}. " + "Please try using the same Python version to create the pickle file or use another " + "container image with a matching version." + ) + return entity_dict[entity_name] def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index 4aba363f3e..65ef77abc1 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -43,7 +43,10 @@ def outputs(self) -> Optional[LiteralsResolver]: "Please wait until the execution has completed before requesting the outputs." ) if self.error: - raise user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + raise user_exceptions.FlyteAssertion( + "Outputs could not be found because the execution ended in failure. Error message: " + f"{self.error.message}" + ) return self._outputs diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index d87f4d7685..7bb75812e7 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2589,8 +2589,14 @@ def _get_pickled_target_dict(self, root_entity: typing.Any) -> typing.Dict[str, :param root_entity: The entity to get the pickled target for. :return: The pickled target dictionary. """ + import sys + queue = [root_entity] - pickled_target_dict = {} + pickled_target_dict = { + "metadata": { + "python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + } + } while queue: entity = queue.pop() if isinstance(entity, PythonTask): diff --git a/tests/flytekit/unit/core/test_resolver.py b/tests/flytekit/unit/core/test_resolver.py index 116b1251ae..5005adfd76 100644 --- a/tests/flytekit/unit/core/test_resolver.py +++ b/tests/flytekit/unit/core/test_resolver.py @@ -4,6 +4,7 @@ import cloudpickle import mock import pytest +import sys import flytekit.configuration from flytekit.configuration import Image, ImageConfig @@ -123,10 +124,28 @@ def t1(a: str, b: str) -> str: assert c.loader_args(None, t1) == ["entity-name", "tests.flytekit.unit.core.test_resolver.t1"] - pickled_dict = {"tests.flytekit.unit.core.test_resolver.t1": t1} + pickled_dict = { + "tests.flytekit.unit.core.test_resolver.t1": t1, + "metadata": { + "python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + } + } custom_pickled_object = cloudpickle.dumps(pickled_dict) mock_gzip_open.return_value.read.return_value = custom_pickled_object mock_cloudpickle.return_value = pickled_dict t = c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"]) assert t == t1 + + mismatched_pickled_dict = { + "tests.flytekit.unit.core.test_resolver.t1": t1, + "metadata": { + "python_version": f"{sys.version_info.major}.{sys.version_info.minor - 1}.{sys.version_info.micro}", + } + } + mismatched_custom_pickled_object = cloudpickle.dumps(mismatched_pickled_dict) + mock_gzip_open.return_value.read.return_value = mismatched_custom_pickled_object + mock_cloudpickle.return_value = mismatched_pickled_dict + + with pytest.raises(RuntimeError): + c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"]) From 21ae75c9fdc692a6650e1b348a19a127485f7767 Mon Sep 17 00:00:00 2001 From: Mecoli1219 Date: Tue, 22 Oct 2024 16:20:14 -0700 Subject: [PATCH 2/2] Update unit test Signed-off-by: Mecoli1219 --- flytekit/remote/remote.py | 522 ++++++---------------- tests/flytekit/unit/remote/test_remote.py | 7 +- 2 files changed, 134 insertions(+), 395 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 2378fc9504..eef202bd74 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -147,13 +147,9 @@ except ImportError: ... -ExecutionDataResponse = typing.Union[ - WorkflowExecutionGetDataResponse, NodeExecutionGetDataResponse -] +ExecutionDataResponse = typing.Union[WorkflowExecutionGetDataResponse, NodeExecutionGetDataResponse] -MOST_RECENT_FIRST = admin_common_models.Sort( - "created_at", admin_common_models.Sort.Direction.DESCENDING -) +MOST_RECENT_FIRST = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING) class RegistrationSkipped(Exception): @@ -172,9 +168,7 @@ class ResolvedIdentifiers: version: str -def _get_latest_version( - list_entities_method: typing.Callable, project: str, domain: str, name: str -): +def _get_latest_version(list_entities_method: typing.Callable, project: str, domain: str, name: str): named_entity = common_models.NamedEntityIdentifier(project, domain, name) entity_list, _ = list_entities_method( named_entity, @@ -183,9 +177,7 @@ def _get_latest_version( ) admin_entity = None if not entity_list else entity_list[0] if not admin_entity: - raise user_exceptions.FlyteEntityNotExistException( - "Named entity {} not found".format(named_entity) - ) + raise user_exceptions.FlyteEntityNotExistException("Named entity {} not found".format(named_entity)) return admin_entity.id.version @@ -202,11 +194,7 @@ def _get_entity_identifier( project, domain, name, - ( - version - if version is not None - else _get_latest_version(list_entities_method, project, domain, name) - ), + (version if version is not None else _get_latest_version(list_entities_method, project, domain, name)), ) @@ -239,15 +227,11 @@ def _get_git_repo_url(source_path: str): raise ValueError("Unable to parse url") except Exception as e: - logger.debug( - f"unable to find the git config in {source_path} with error: {str(e)}" - ) + logger.debug(f"unable to find the git config in {source_path} with error: {str(e)}") return "" -def _get_pickled_target_dict( - root_entity: typing.Union[WorkflowBase, PythonTask] -) -> typing.Dict[str, typing.Any]: +def _get_pickled_target_dict(root_entity: typing.Union[WorkflowBase, PythonTask]) -> typing.Dict[str, typing.Any]: """ Get the pickled target dictionary for the entity. :param root_entity: The entity to get the pickled target for. @@ -311,22 +295,14 @@ def __init__( The default location - `s3://my-s3-bucket/data` works for sandbox/demo environment. Please override this for non-sandbox cases. :param interactive_mode_enabled: If set to True, the FlyteRemote will pickle the task/workflow. """ - if ( - config is None - or config.platform is None - or config.platform.endpoint is None - ): + if config is None or config.platform is None or config.platform.endpoint is None: raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.") if interactive_mode_enabled is True: - logger.warning( - "Jupyter notebook and interactive task support is still alpha." - ) + logger.warning("Jupyter notebook and interactive task support is still alpha.") if data_upload_location is None: - data_upload_location = ( - FlyteContext.current_context().file_access.raw_output_prefix - ) + data_upload_location = FlyteContext.current_context().file_access.raw_output_prefix self._kwargs = kwargs self._client_initialized = False self._config = config @@ -337,19 +313,13 @@ def __init__( fsspec.register_implementation("flyte", get_flyte_fs(remote=self), clobber=True) self._file_access = FileAccessProvider( - local_sandbox_dir=os.path.join( - config.local_sandbox_path, "control_plane_metadata" - ), + local_sandbox_dir=os.path.join(config.local_sandbox_path, "control_plane_metadata"), raw_output_prefix=data_upload_location, data_config=config.data_config, ) # Save the file access object locally, build a context for it and save that as well. - self._ctx = ( - FlyteContextManager.current_context() - .with_file_access(self._file_access) - .build() - ) + self._ctx = FlyteContextManager.current_context().with_file_access(self._file_access).build() self._interactive_mode_enabled = interactive_mode_enabled @property @@ -410,9 +380,7 @@ def get( return Literal.from_flyte_idl(data_response.literal) elif data_response.HasField("pre_signed_urls"): if len(data_response.pre_signed_urls.signed_url) == 0: - raise ValueError( - f"Flyte url {flyte_uri} resolved to empty download link" - ) + raise ValueError(f"Flyte url {flyte_uri} resolved to empty download link") d = data_response.pre_signed_urls.signed_url[0] logger.debug(f"Download link is {d}") fs = ctx.file_access.get_filesystem_for_path(d) @@ -427,15 +395,11 @@ def get( return html # If not return bytes else: - logger.debug( - f"IPython not found, returning HTML as bytes from {flyte_uri}" - ) + logger.debug(f"IPython not found, returning HTML as bytes from {flyte_uri}") return fs.open(d, "rb").read() except user_exceptions.FlyteUserException as e: - logger.info( - f"Error from Flyte backend when trying to fetch data: {e.__cause__}" - ) + logger.info(f"Error from Flyte backend when trying to fetch data: {e.__cause__}") logger.info(f"Nothing found from {flyte_uri}") @@ -456,14 +420,10 @@ def fetch_task_lazy( Similar to fetch_task, just that it returns a LazyEntity, which will fetch the workflow lazily. """ if name is None: - raise user_exceptions.FlyteAssertion( - "the 'name' argument must be specified." - ) + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") def _fetch(): - return self.fetch_task( - project=project, domain=domain, name=name, version=version - ) + return self.fetch_task(project=project, domain=domain, name=name, version=version) return LazyEntity(name=name, getter=_fetch) @@ -485,9 +445,7 @@ def fetch_task( :raises: FlyteAssertion if name is None """ if name is None: - raise user_exceptions.FlyteAssertion( - "the 'name' argument must be specified." - ) + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") task_id = _get_entity_identifier( self.client.list_tasks_paginated, ResourceType.TASK, @@ -497,9 +455,7 @@ def fetch_task( version, ) admin_task = self.client.get_task(task_id) - flyte_task = FlyteTask.promote_from_model( - admin_task.closure.compiled_task.template - ) + flyte_task = FlyteTask.promote_from_model(admin_task.closure.compiled_task.template) flyte_task.template._id = task_id return flyte_task @@ -514,9 +470,7 @@ def fetch_workflow_lazy( Similar to fetch_workflow, just that it returns a LazyEntity, which will fetch the workflow lazily. """ if name is None: - raise user_exceptions.FlyteAssertion( - "the 'name' argument must be specified." - ) + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") def _fetch(): return self.fetch_workflow(project, domain, name, version) @@ -539,9 +493,7 @@ def fetch_workflow( :raises: FlyteAssertion if name is None """ if name is None: - raise user_exceptions.FlyteAssertion( - "the 'name' argument must be specified." - ) + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") workflow_id = _get_entity_identifier( self.client.list_workflows_paginated, ResourceType.WORKFLOW, @@ -569,35 +521,23 @@ def find_launch_plan( for wf_template in wf_templates: for node in FlyteWorkflow.get_non_system_nodes(wf_template.nodes): - if ( - node.workflow_node is not None - and node.workflow_node.launchplan_ref is not None - ): + if node.workflow_node is not None and node.workflow_node.launchplan_ref is not None: lp_ref = node.workflow_node.launchplan_ref find_launch_plan(lp_ref, node_launch_plans) # Inspect conditional branch nodes for launch plans def get_launch_plan_from_branch( branch_node: BranchNode, - node_launch_plans: Dict[ - id_models, launch_plan_models.LaunchPlanSpec - ], + node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec], ) -> None: def get_launch_plan_from_then_node( child_then_node: Node, - node_launch_plans: Dict[ - id_models, launch_plan_models.LaunchPlanSpec - ], + node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec], ) -> None: # then_node could have nested branch_node or be a normal then_node if child_then_node.branch_node: - get_launch_plan_from_branch( - child_then_node.branch_node, node_launch_plans - ) - elif ( - child_then_node.workflow_node - and child_then_node.workflow_node.launchplan_ref - ): + get_launch_plan_from_branch(child_then_node.branch_node, node_launch_plans) + elif child_then_node.workflow_node and child_then_node.workflow_node.launchplan_ref: lp_ref = child_then_node.workflow_node.launchplan_ref find_launch_plan(lp_ref, node_launch_plans) @@ -605,26 +545,17 @@ def get_launch_plan_from_then_node( branch = branch_node.if_else if branch.case and branch.case.then_node: child_then_node = branch.case.then_node - get_launch_plan_from_then_node( - child_then_node, node_launch_plans - ) + get_launch_plan_from_then_node(child_then_node, node_launch_plans) if branch.other: for o in branch.other: if o.then_node: child_then_node = o.then_node - get_launch_plan_from_then_node( - child_then_node, node_launch_plans - ) + get_launch_plan_from_then_node(child_then_node, node_launch_plans) if branch.else_node: # else_node could have nested conditional branch_node if branch.else_node.branch_node: - get_launch_plan_from_branch( - branch.else_node.branch_node, node_launch_plans - ) - elif ( - branch.else_node.workflow_node - and branch.else_node.workflow_node.launchplan_ref - ): + get_launch_plan_from_branch(branch.else_node.branch_node, node_launch_plans) + elif branch.else_node.workflow_node and branch.else_node.workflow_node.launchplan_ref: lp_ref = branch.else_node.workflow_node.launchplan_ref find_launch_plan(lp_ref, node_launch_plans) @@ -639,9 +570,7 @@ def _upgrade_launchplan(self, lp: launch_plan_models.LaunchPlan) -> FlyteLaunchP """ flyte_lp = FlyteLaunchPlan.promote_from_model(lp.id, lp.spec) wf_id = flyte_lp.workflow_id - workflow = self.fetch_workflow( - wf_id.project, wf_id.domain, wf_id.name, wf_id.version - ) + workflow = self.fetch_workflow(wf_id.project, wf_id.domain, wf_id.name, wf_id.version) flyte_lp._interface = workflow.interface flyte_lp._flyte_workflow = workflow return flyte_lp @@ -654,9 +583,7 @@ def fetch_active_launchplan( """ try: lp = self.client.get_active_launch_plan( - NamedEntityIdentifier( - project or self.default_project, domain or self.default_domain, name - ) + NamedEntityIdentifier(project or self.default_project, domain or self.default_domain, name) ) if lp is not None: return self._upgrade_launchplan(lp) @@ -682,9 +609,7 @@ def fetch_launch_plan( :raises: FlyteAssertion if name is None """ if name is None: - raise user_exceptions.FlyteAssertion( - "the 'name' argument must be specified." - ) + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") launch_plan_id = _get_entity_identifier( self.client.list_launch_plans_paginated, ResourceType.LAUNCH_PLAN, @@ -696,9 +621,7 @@ def fetch_launch_plan( admin_launch_plan = self.client.get_launch_plan(launch_plan_id) return self._upgrade_launchplan(admin_launch_plan) - def fetch_execution( - self, project: str = None, domain: str = None, name: str = None - ) -> FlyteWorkflowExecution: + def fetch_execution(self, project: str = None, domain: str = None, name: str = None) -> FlyteWorkflowExecution: """Fetch a workflow execution entity from flyte admin. :param project: fetch entity from this project. If None, uses the default_project attribute. @@ -709,9 +632,7 @@ def fetch_execution( :raises: FlyteAssertion if name is None """ if name is None: - raise user_exceptions.FlyteAssertion( - "the 'name' argument must be specified." - ) + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") execution = FlyteWorkflowExecution.promote_from_model( self.client.get_execution( WorkflowExecutionIdentifier( @@ -788,13 +709,9 @@ def set_signal( lit = value else: lt = literal_type or ( - TypeEngine.to_literal_type(python_type) - if python_type - else TypeEngine.to_literal_type(type(value)) - ) - lit = TypeEngine.to_literal( - self.context, value, python_type or type(value), lt + TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value)) ) + lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt) logger.debug(f"Converted {value} to literal {lit} using literal type {lt}") req = SignalSetRequest( @@ -842,18 +759,13 @@ def list_tasks_by_version( filters=[filter_models.Filter.from_python_std(f"eq(version,{version})")], limit=limit, ) - return [ - FlyteTask.promote_from_model(t.closure.compiled_task.template) - for t in t_models - ] + return [FlyteTask.promote_from_model(t.closure.compiled_task.template) for t in t_models] ##################### # Register Entities # ##################### - def _resolve_identifier( - self, t: int, name: str, version: str, ss: SerializationSettings - ) -> Identifier: + def _resolve_identifier(self, t: int, name: str, version: str, ss: SerializationSettings) -> Identifier: ident = Identifier( resource_type=t, project=ss.project if ss and ss.project else self.default_project, @@ -893,19 +805,11 @@ def raw_register( if isinstance(cp_entity, RemoteEntity): if isinstance(cp_entity, (FlyteWorkflow, FlyteTask)): if not cp_entity.should_register: - logger.debug( - f"Skipping registration of remote entity: {cp_entity.name}" - ) - raise RegistrationSkipped( - f"Remote task/Workflow {cp_entity.name} is not registrable." - ) + logger.debug(f"Skipping registration of remote entity: {cp_entity.name}") + raise RegistrationSkipped(f"Remote task/Workflow {cp_entity.name} is not registrable.") else: - logger.debug( - f"Skipping registration of remote entity: {cp_entity.name}" - ) - raise RegistrationSkipped( - f"Remote entity {cp_entity.name} is not registrable." - ) + logger.debug(f"Skipping registration of remote entity: {cp_entity.name}") + raise RegistrationSkipped(f"Remote entity {cp_entity.name} is not registrable.") if isinstance( cp_entity, @@ -920,17 +824,13 @@ def raw_register( return None elif isinstance(cp_entity, ReferenceSpec): - logger.debug( - f"Skipping registration of Reference entity, name: {cp_entity.template.id.name}" - ) + logger.debug(f"Skipping registration of Reference entity, name: {cp_entity.template.id.name}") return None if isinstance(cp_entity, task_models.TaskSpec): if isinstance(cp_entity, FlyteTask): version = cp_entity.id.version - ident = self._resolve_identifier( - ResourceType.TASK, cp_entity.template.id.name, version, settings - ) + ident = self._resolve_identifier(ResourceType.TASK, cp_entity.template.id.name, version, settings) try: self.client.create_task(task_identifer=ident, task_spec=cp_entity) except FlyteEntityAlreadyExistsException: @@ -940,13 +840,9 @@ def raw_register( if isinstance(cp_entity, admin_workflow_models.WorkflowSpec): if isinstance(cp_entity, FlyteWorkflow): version = cp_entity.id.version - ident = self._resolve_identifier( - ResourceType.WORKFLOW, cp_entity.template.id.name, version, settings - ) + ident = self._resolve_identifier(ResourceType.WORKFLOW, cp_entity.template.id.name, version, settings) try: - self.client.create_workflow( - workflow_identifier=ident, workflow_spec=cp_entity - ) + self.client.create_workflow(workflow_identifier=ident, workflow_spec=cp_entity) except FlyteEntityAlreadyExistsException: logger.debug(f" {ident} Already Exists!") @@ -978,13 +874,9 @@ def raw_register( return ident if isinstance(cp_entity, launch_plan_models.LaunchPlan): - ident = self._resolve_identifier( - ResourceType.LAUNCH_PLAN, cp_entity.id.name, version, settings - ) + ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, cp_entity.id.name, version, settings) try: - self.client.create_launch_plan( - launch_plan_identifer=ident, launch_plan_spec=cp_entity.spec - ) + self.client.create_launch_plan(launch_plan_identifer=ident, launch_plan_spec=cp_entity.spec) except FlyteEntityAlreadyExistsException: logger.debug(f" {ident} Already Exists!") return ident @@ -1017,13 +909,9 @@ async def _serialize_and_register( if serialization_settings.version is None: serialization_settings.version = version - _ = get_serializable( - m, settings=serialization_settings, entity=entity, options=options - ) + _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) # concurrent register - cp_task_entity_map = OrderedDict( - filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items()) - ) + cp_task_entity_map = OrderedDict(filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items())) tasks = [] loop = asyncio.get_running_loop() for entity, cp_entity in cp_task_entity_map.items(): @@ -1041,9 +929,7 @@ async def _serialize_and_register( ) identifiers_or_exceptions = [] - identifiers_or_exceptions.extend( - await asyncio.gather(*tasks, return_exceptions=True) - ) + identifiers_or_exceptions.extend(await asyncio.gather(*tasks, return_exceptions=True)) # Check to make sure any exceptions are just registration skipped exceptions for ie in identifiers_or_exceptions: if isinstance(ie, RegistrationSkipped): @@ -1052,15 +938,11 @@ async def _serialize_and_register( if isinstance(ie, Exception): raise ie # serial register - cp_other_entities = OrderedDict( - filter(lambda x: not isinstance(x[1], task_models.TaskSpec), m.items()) - ) + cp_other_entities = OrderedDict(filter(lambda x: not isinstance(x[1], task_models.TaskSpec), m.items())) for entity, cp_entity in cp_other_entities.items(): try: identifiers_or_exceptions.append( - self.raw_register( - cp_entity, serialization_settings, version, og_entity=entity - ) + self.raw_register(cp_entity, serialization_settings, version, og_entity=entity) ) except RegistrationSkipped as e: logger.info(f"Skipping registration... {e}") @@ -1146,9 +1028,7 @@ def register_workflow( default_launch_plan, ) - fwf = self.fetch_workflow( - ident.project, ident.domain, ident.name, ident.version - ) + fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version) fwf._python_interface = entity.python_interface return fwf @@ -1177,27 +1057,21 @@ def fast_register_workflow( "Please use register_script for other types of workflows" ) if not isinstance(entity._module_file, pathlib.Path): - raise ValueError( - f"entity._module_file should be pathlib.Path object, got {type(entity._module_file)}" - ) + raise ValueError(f"entity._module_file should be pathlib.Path object, got {type(entity._module_file)}") mod_name = ".".join(entity.name.split(".")[:-1]) # get the path representation of the module module_path = f"{os.sep}".join(entity.name.split(".")[:-1]) module_file = str(entity._module_file.with_suffix("")) if not module_file.endswith(module_path): - raise ValueError( - f"Module file path should end with entity.__module__, got {module_file} and {module_path}" - ) + raise ValueError(f"Module file path should end with entity.__module__, got {module_file} and {module_path}") # remove module suffix to get the root module_root = str(pathlib.Path(module_file[: -len(module_path)])) return self.register_script( entity, - image_config=( - serialization_settings.image_config if serialization_settings else None - ), + image_config=(serialization_settings.image_config if serialization_settings else None), project=serialization_settings.project if serialization_settings else None, domain=serialization_settings.domain if serialization_settings else None, version=version, @@ -1245,9 +1119,7 @@ def upload_file( :return: The uploaded location. """ if not to_upload.is_file(): - raise ValueError( - f"{to_upload} is not a single file, upload arg must be a single file." - ) + raise ValueError(f"{to_upload} is not a single file, upload arg must be a single file.") md5_bytes, str_digest, _ = hash_file(to_upload) upload_location = self.client.get_upload_signed_url( @@ -1332,12 +1204,8 @@ def _version_from_hash( # and does not increase entropy of the hash while making it very inconvenient to copy-and-paste. return base64.urlsafe_b64encode(h.digest()).decode("ascii").rstrip("=") - def _get_image_names( - self, entity: typing.Union[PythonAutoContainerTask, WorkflowBase] - ) -> typing.List[str]: - if isinstance(entity, PythonAutoContainerTask) and isinstance( - entity.container_image, ImageSpec - ): + def _get_image_names(self, entity: typing.Union[PythonAutoContainerTask, WorkflowBase]) -> typing.List[str]: + if isinstance(entity, PythonAutoContainerTask) and isinstance(entity.container_image, ImageSpec): return [entity.container_image.image_name()] if isinstance(entity, WorkflowBase): image_names = [] @@ -1385,29 +1253,20 @@ def register_script( " the copy_style field in fast_package_options instead." ) if not fast_package_options: - fast_package_options = FastPackageOptions( - [], copy_style=CopyFileDetection.ALL - ) + fast_package_options = FastPackageOptions([], copy_style=CopyFileDetection.ALL) else: - fast_package_options = dc_replace( - fast_package_options, copy_style=CopyFileDetection.ALL - ) + fast_package_options = dc_replace(fast_package_options, copy_style=CopyFileDetection.ALL) if image_config is None: image_config = ImageConfig.auto_default_image() with tempfile.TemporaryDirectory() as tmp_dir: - if ( - fast_package_options - and fast_package_options.copy_style != CopyFileDetection.NO_COPY - ): + if fast_package_options and fast_package_options.copy_style != CopyFileDetection.NO_COPY: md5_bytes, upload_native_url = self.fast_package( pathlib.Path(source_path), False, tmp_dir, fast_package_options ) else: - archive_fname = pathlib.Path( - os.path.join(tmp_dir, "script_mode.tar.gz") - ) + archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) compress_scripts( source_path, str(archive_fname), @@ -1452,14 +1311,10 @@ def register_script( return self.register_task(entity, serialization_settings, version) if isinstance(entity, WorkflowBase): - return self.register_workflow( - entity, serialization_settings, version, default_launch_plan, options - ) + return self.register_workflow(entity, serialization_settings, version, default_launch_plan, options) if isinstance(entity, LaunchPlan): # If it's a launch plan, we need to register the workflow first - return self.register_launch_plan( - entity, version, project, domain, options, serialization_settings - ) + return self.register_launch_plan(entity, version, project, domain, options, serialization_settings) raise ValueError(f"Unsupported entity type {type(entity)}") def register_launch_plan( @@ -1499,9 +1354,7 @@ def register_launch_plan( options, False, ) - flp = self.fetch_launch_plan( - ident.project, ident.domain, ident.name, ident.version - ) + flp = self.fetch_launch_plan(ident.project, ident.domain, ident.name, ident.version) flp._python_interface = entity.python_interface return flp @@ -1546,9 +1399,7 @@ def _execute( :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` """ if execution_name is not None and execution_name_prefix is not None: - raise ValueError( - "Only one of execution_name and execution_name_prefix can be set, but got both set" - ) + raise ValueError("Only one of execution_name and execution_name_prefix can be set, but got both set") # todo: The prefix should be passed to the backend if execution_name_prefix is not None: execution_name = execution_name_prefix + "-" + uuid.uuid4().hex[:19] @@ -1576,19 +1427,13 @@ def _execute( if isinstance(v, Literal): lit = v elif isinstance(v, Artifact): - raise user_exceptions.FlyteValueException( - v, "Running with an artifact object is not yet possible." - ) + raise user_exceptions.FlyteValueException(v, "Running with an artifact object is not yet possible.") else: if k not in type_hints: try: - type_hints[k] = TypeEngine.guess_python_type( - input_flyte_type_map[k].type - ) + type_hints[k] = TypeEngine.guess_python_type(input_flyte_type_map[k].type) except ValueError: - logger.debug( - f"Could not guess type for {input_flyte_type_map[k].type}, skipping..." - ) + logger.debug(f"Could not guess type for {input_flyte_type_map[k].type}, skipping...") variable = entity.interface.inputs.get(k) hint = type_hints[k] self.file_access._get_upload_signed_url_fn = functools.partial( @@ -1629,15 +1474,9 @@ def _execute( security_context=options.security_context, envs=common_models.Envs(envs) if envs else None, tags=tags, - cluster_assignment=( - ClusterAssignment(cluster_pool=cluster_pool) - if cluster_pool - else None - ), + cluster_assignment=(ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None), execution_cluster_label=( - ExecutionClusterLabel(execution_cluster_label) - if execution_cluster_label - else None + ExecutionClusterLabel(execution_cluster_label) if execution_cluster_label else None ), ), literal_inputs, @@ -1652,9 +1491,7 @@ def _execute( domain=domain or self.default_domain, name=execution_name, ) - execution = FlyteWorkflowExecution.promote_from_model( - self.client.get_execution(exec_id) - ) + execution = FlyteWorkflowExecution.promote_from_model(self.client.get_execution(exec_id)) if wait: return self.wait(execution) @@ -1893,9 +1730,7 @@ def execute( cluster_pool=cluster_pool, execution_cluster_label=execution_cluster_label, ) - raise NotImplementedError( - f"entity type {type(entity)} not recognized for execution" - ) + raise NotImplementedError(f"entity type {type(entity)} not recognized for execution") # Flyte Remote Entities # --------------------- @@ -1959,9 +1794,7 @@ def execute_remote_wf( NOTE: the name and version arguments are currently not used and only there consistency in the function signature """ - launch_plan = self.fetch_launch_plan( - entity.id.project, entity.id.domain, entity.id.name, entity.id.version - ) + launch_plan = self.fetch_launch_plan(entity.id.project, entity.id.domain, entity.id.name, entity.id.version) return self.execute_remote_task_lp( launch_plan, inputs, @@ -2061,9 +1894,7 @@ def execute_reference_workflow( try: flyte_lp = self.fetch_launch_plan(**resolved_identifiers_dict) except FlyteEntityNotExistException: - logger.info( - "Try to register default launch plan because it wasn't found in Flyte Admin!" - ) + logger.info("Try to register default launch plan because it wasn't found in Flyte Admin!") default_lp = LaunchPlan.get_default_launch_plan(self.context, entity) self.register_launch_plan( default_lp, @@ -2115,9 +1946,7 @@ def execute_reference_launch_plan( ) resolved_identifiers_dict = asdict(resolved_identifiers) try: - flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan( - **resolved_identifiers_dict - ) + flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan(**resolved_identifiers_dict) except FlyteEntityNotExistException: raise ValueError( f'missing entity of type ReferenceLaunchPlan with identifier project:"{entity.reference.project}" domain:"{entity.reference.domain}" name:"{entity.reference.name}" version:"{entity.reference.version}"' @@ -2180,9 +2009,7 @@ def execute_local_task( :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed. :return: FlyteWorkflowExecution object. """ - resolved_identifiers = self._resolve_identifier_kwargs( - entity, project, domain, name, version - ) + resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) resolved_identifiers_dict = asdict(resolved_identifiers) not_found = False try: @@ -2193,9 +2020,7 @@ def execute_local_task( if not_found: fast_serialization_settings = None if self.interactive_mode_enabled: - md5_bytes, fast_serialization_settings = self._pickle_and_upload_entity( - entity - ) + md5_bytes, fast_serialization_settings = self._pickle_and_upload_entity(entity) ss = SerializationSettings( image_config=image_config or ImageConfig.auto_default_image(), @@ -2207,9 +2032,7 @@ def execute_local_task( default_inputs = entity.python_interface.default_inputs_as_kwargs if version is None and self.interactive_mode_enabled: - version = self._version_from_hash( - md5_bytes, ss, default_inputs, *self._get_image_names(entity) - ) + version = self._version_from_hash(md5_bytes, ss, default_inputs, *self._get_image_names(entity)) flyte_task: FlyteTask = self.register_task(entity, ss, version) @@ -2267,18 +2090,14 @@ def execute_local_workflow( :param execution_cluster_label: :return: """ - resolved_identifiers = self._resolve_identifier_kwargs( - entity, project, domain, name, version - ) + resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) resolved_identifiers_dict = asdict(resolved_identifiers) if not image_config: image_config = ImageConfig.auto_default_image() fast_serialization_settings = None if self.interactive_mode_enabled: - md5_bytes, fast_serialization_settings = self._pickle_and_upload_entity( - entity - ) + md5_bytes, fast_serialization_settings = self._pickle_and_upload_entity(entity) ss = SerializationSettings( image_config=image_config, @@ -2295,17 +2114,13 @@ def execute_local_workflow( logger.info("Registering workflow because it wasn't found in Flyte Admin.") default_inputs = entity.python_interface.default_inputs_as_kwargs if not version and self.interactive_mode_enabled: - version = self._version_from_hash( - md5_bytes, ss, default_inputs, *self._get_image_names(entity) - ) + version = self._version_from_hash(md5_bytes, ss, default_inputs, *self._get_image_names(entity)) self.register_workflow(entity, ss, version=version, options=options) try: flyte_lp = self.fetch_launch_plan(**resolved_identifiers_dict) except FlyteEntityNotExistException: - logger.info( - "Try to register default launch plan because it wasn't found in Flyte Admin!" - ) + logger.info("Try to register default launch plan because it wasn't found in Flyte Admin!") default_lp = LaunchPlan.get_default_launch_plan(self.context, entity) self.register_launch_plan( default_lp, @@ -2369,16 +2184,12 @@ def execute_local_launch_plan( :param execution_cluster_label: Specify label of cluster(s) on which newly created execution should be placed. :return: FlyteWorkflowExecution object """ - resolved_identifiers = self._resolve_identifier_kwargs( - entity, project, domain, name, version - ) + resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) resolved_identifiers_dict = asdict(resolved_identifiers) project = resolved_identifiers.project domain = resolved_identifiers.domain try: - flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan( - **resolved_identifiers_dict - ) + flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan(**resolved_identifiers_dict) except FlyteEntityNotExistException: flyte_launchplan: FlyteLaunchPlan = self.register_launch_plan( entity, @@ -2430,9 +2241,7 @@ def wait( return execution time.sleep(poll_interval.total_seconds()) - raise user_exceptions.FlyteTimeout( - f"Execution {self} did not complete before timeout." - ) + raise user_exceptions.FlyteTimeout(f"Execution {self} did not complete before timeout.") ######################## # Sync Execution State # @@ -2456,9 +2265,7 @@ def sync( :return: Returns the same execution object, but with additional information pulled in. """ if not isinstance(execution, FlyteWorkflowExecution): - raise ValueError( - f"remote.sync should only be called on workflow executions, got {type(execution)}" - ) + raise ValueError(f"remote.sync should only be called on workflow executions, got {type(execution)}") return self.sync_execution(execution, entity_definition, sync_nodes) def sync_execution( @@ -2471,9 +2278,7 @@ def sync_execution( Sync a FlyteWorkflowExecution object with its corresponding remote state. """ if entity_definition is not None: - raise ValueError( - "Entity definition arguments aren't supported when syncing workflow executions" - ) + raise ValueError("Entity definition arguments aren't supported when syncing workflow executions") # Update closure, and then data, because we don't want the execution to finish between when we get the data, # and then for the closure to have is_done to be true. @@ -2483,15 +2288,12 @@ def sync_execution( underlying_node_executions = [] if sync_nodes: underlying_node_executions = [ - FlyteNodeExecution.promote_from_model(n) - for n in iterate_node_executions(self.client, execution.id) + FlyteNodeExecution.promote_from_model(n) for n in iterate_node_executions(self.client, execution.id) ] # This condition is only true for single-task executions if execution.spec.launch_plan.resource_type == ResourceType.TASK: - flyte_entity = self.fetch_task( - lp_id.project, lp_id.domain, lp_id.name, lp_id.version - ) + flyte_entity = self.fetch_task(lp_id.project, lp_id.domain, lp_id.name, lp_id.version) node_interface = flyte_entity.interface if sync_nodes: # Need to construct the mapping. There should've been returned exactly three nodes, a start, @@ -2499,8 +2301,7 @@ def sync_execution( task_node_exec = [ x for x in filter( - lambda x: x.id.node_id != constants.START_NODE_ID - and x.id.node_id != constants.END_NODE_ID, + lambda x: x.id.node_id != constants.START_NODE_ID and x.id.node_id != constants.END_NODE_ID, underlying_node_executions, ) ] @@ -2521,9 +2322,7 @@ def sync_execution( ) # This is the default case, an execution of a normal workflow through a launch plan else: - fetched_lp = self.fetch_launch_plan( - lp_id.project, lp_id.domain, lp_id.name, lp_id.version - ) + fetched_lp = self.fetch_launch_plan(lp_id.project, lp_id.domain, lp_id.name, lp_id.version) node_interface = fetched_lp.flyte_workflow.interface execution._flyte_workflow = fetched_lp.flyte_workflow node_mapping = fetched_lp.flyte_workflow._node_map @@ -2532,13 +2331,9 @@ def sync_execution( if sync_nodes: node_execs = {} for n in underlying_node_executions: - node_execs[n.id.node_id] = self.sync_node_execution( - n, node_mapping - ) # noqa + node_execs[n.id.node_id] = self.sync_node_execution(n, node_mapping) # noqa execution._node_executions = node_execs - return self._assign_inputs_and_outputs( - execution, execution_data, node_interface - ) + return self._assign_inputs_and_outputs(execution, execution_data, node_interface) def sync_node_execution( self, @@ -2566,11 +2361,7 @@ def sync_node_execution( # For single task execution - the metadata spec node id is missing. In these cases, revert to regular node id node_id = execution.metadata.spec_node_id # This case supports single-task execution compiled workflows. - if ( - node_id - and node_id not in node_mapping - and execution.id.node_id in node_mapping - ): + if node_id and node_id not in node_mapping and execution.id.node_id in node_mapping: node_id = execution.id.node_id logger.debug( f"Using node execution ID {node_id} instead of spec node id " @@ -2592,19 +2383,14 @@ def sync_node_execution( raise ValueError(f"Missing node from mapping: {node_id}") # Get the node execution data - node_execution_get_data_response = self.client.get_node_execution_data( - execution.id - ) + node_execution_get_data_response = self.client.get_node_execution_data(execution.id) # Calling a launch plan directly case # If a node ran a launch plan directly (i.e. not through a dynamic task or anything) then # the closure should have a workflow_node_metadata populated with the launched execution id. # The parent node flag should not be populated here # This is the simplest case - if ( - not execution.metadata.is_parent_node - and execution.closure.workflow_node_metadata - ): + if not execution.metadata.is_parent_node and execution.closure.workflow_node_metadata: launched_exec_id = execution.closure.workflow_node_metadata.execution_id # This is a recursive call, basically going through the same process that brought us here in the first # place, but on the launched execution. @@ -2635,30 +2421,21 @@ def sync_node_execution( # If this was a dynamic task, then there should be a CompiledWorkflowClosure inside the # NodeExecutionGetDataResponse if node_execution_get_data_response.dynamic_workflow is not None: - compiled_wf = ( - node_execution_get_data_response.dynamic_workflow.compiled_workflow - ) + compiled_wf = node_execution_get_data_response.dynamic_workflow.compiled_workflow node_launch_plans = {} # TODO: Inspect branch nodes for launch plans - for template in [compiled_wf.primary.template] + [ - swf.template for swf in compiled_wf.sub_workflows - ]: + for template in [compiled_wf.primary.template] + [swf.template for swf in compiled_wf.sub_workflows]: for node in FlyteWorkflow.get_non_system_nodes(template.nodes): if ( node.workflow_node is not None and node.workflow_node.launchplan_ref is not None - and node.workflow_node.launchplan_ref - not in node_launch_plans + and node.workflow_node.launchplan_ref not in node_launch_plans ): - node_launch_plans[node.workflow_node.launchplan_ref] = ( - self.client.get_launch_plan( - node.workflow_node.launchplan_ref - ).spec - ) - - dynamic_flyte_wf = FlyteWorkflow.promote_from_closure( - compiled_wf, node_launch_plans - ) + node_launch_plans[node.workflow_node.launchplan_ref] = self.client.get_launch_plan( + node.workflow_node.launchplan_ref + ).spec + + dynamic_flyte_wf = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) execution._underlying_node_executions = [ self.sync_node_execution( FlyteNodeExecution.promote_from_model(cne), @@ -2667,8 +2444,7 @@ def sync_node_execution( for cne in child_node_executions ] execution._task_executions = [ - node_exes.task_executions - for node_exes in execution.subworkflow_node_executions.values() + node_exes.task_executions for node_exes in execution.subworkflow_node_executions.values() ] execution._interface = dynamic_flyte_wf.interface @@ -2678,9 +2454,7 @@ def sync_node_execution( sub_flyte_workflow = execution._node.flyte_entity sub_node_mapping = {n.id: n for n in sub_flyte_workflow.flyte_nodes} execution._underlying_node_executions = [ - self.sync_node_execution( - FlyteNodeExecution.promote_from_model(cne), sub_node_mapping - ) + self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping) for cne in child_node_executions ] execution._interface = sub_flyte_workflow.interface @@ -2700,26 +2474,18 @@ def sync_node_execution( if t.interface: execution._interface = t.interface else: - logger.error( - f"Fetched map task does not have an interface, skipping i/o {t}" - ) + logger.error(f"Fetched map task does not have an interface, skipping i/o {t}") return execution else: logger.error(f"Array node not over task, skipping i/o {t}") return execution else: - logger.error( - f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}" - ) - raise ValueError( - f"Node execution undeterminable, entity has type {type(execution._node)}" - ) + logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}") + raise ValueError(f"Node execution undeterminable, entity has type {type(execution._node)}") # Handle the case for gate nodes elif execution._node.gate_node is not None: - logger.info( - "Skipping gate node execution for now - gate nodes don't have inputs and outputs filled in" - ) + logger.info("Skipping gate node execution for now - gate nodes don't have inputs and outputs filled in") return execution # This is the plain ol' task execution case @@ -2751,12 +2517,8 @@ def sync_task_execution( execution_data = self.client.get_task_execution_data(execution.id) task_id = execution.id.task_id if entity_definition is None: - entity_definition = self.fetch_task( - task_id.project, task_id.domain, task_id.name, task_id.version - ) - return self._assign_inputs_and_outputs( - execution, execution_data, entity_definition.interface - ) + entity_definition = self.fetch_task(task_id.project, task_id.domain, task_id.name, task_id.version) + return self._assign_inputs_and_outputs(execution, execution_data, entity_definition.interface) ############################# # Terminate Execution State # @@ -2776,28 +2538,20 @@ def terminate(self, execution: FlyteWorkflowExecution, cause: str): def _assign_inputs_and_outputs( self, - execution: typing.Union[ - FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution - ], + execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution], execution_data, interface: TypedInterface, ): """Helper for assigning synced inputs and outputs to an execution object.""" input_literal_map = self._get_input_literal_map(execution_data) - execution._inputs = LiteralsResolver( - input_literal_map.literals, interface.inputs, self.context - ) + execution._inputs = LiteralsResolver(input_literal_map.literals, interface.inputs, self.context) if execution.is_done and not execution.error: output_literal_map = self._get_output_literal_map(execution_data) - execution._outputs = LiteralsResolver( - output_literal_map.literals, interface.outputs, self.context - ) + execution._outputs = LiteralsResolver(output_literal_map.literals, interface.outputs, self.context) return execution - def _get_input_literal_map( - self, execution_data: ExecutionDataResponse - ) -> literal_models.LiteralMap: + def _get_input_literal_map(self, execution_data: ExecutionDataResponse) -> literal_models.LiteralMap: # Inputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. if bool(execution_data.full_inputs.literals): return execution_data.full_inputs @@ -2810,9 +2564,7 @@ def _get_input_literal_map( ) return literal_models.LiteralMap({}) - def _get_output_literal_map( - self, execution_data: ExecutionDataResponse - ) -> literal_models.LiteralMap: + def _get_output_literal_map(self, execution_data: ExecutionDataResponse) -> literal_models.LiteralMap: # Outputs are returned inline unless they are too big, in which case a url blob pointing to them is returned. if bool(execution_data.full_outputs.literals): return execution_data.full_outputs @@ -2857,15 +2609,11 @@ def generate_console_url( Generate a Flyteconsole URL for the given Flyte remote endpoint. This will automatically determine if this is an execution or an entity and change the type automatically """ - if isinstance( - entity, (FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution) - ): + if isinstance(entity, (FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution)): return f"{self.generate_console_http_domain()}/console/projects/{entity.id.project}/domains/{entity.id.domain}/executions/{entity.id.name}" # noqa if not isinstance(entity, (FlyteWorkflow, FlyteTask, FlyteLaunchPlan)): - raise ValueError( - f"Only remote entities can be looked at in the console, got type {type(entity)}" - ) + raise ValueError(f"Only remote entities can be looked at in the console, got type {type(entity)}") return self.generate_url_from_id(id=entity.id) def generate_url_from_id(self, id: Identifier): @@ -2922,9 +2670,7 @@ def launch_backfill( :return: In case of dry-run, return WorkflowBase, else if no_execute return FlyteWorkflow else in the default case return a FlyteWorkflowExecution """ - lp = self.fetch_launch_plan( - project=project, domain=domain, name=launchplan, version=launchplan_version - ) + lp = self.fetch_launch_plan(project=project, domain=domain, name=launchplan, version=launchplan_version) wf, start, end = create_backfill_workflow( start_date=from_date, end_date=to_date, @@ -2933,17 +2679,13 @@ def launch_backfill( failure_policy=failure_policy, ) if dry_run: - logger.warning( - "Dry Run enabled. Workflow will not be registered and or executed." - ) + logger.warning("Dry Run enabled. Workflow will not be registered and or executed.") return wf unique_fingerprint = f"{start}-{end}-{launchplan}-{launchplan_version}" h = hashlib.md5() h.update(unique_fingerprint.encode("utf-8")) - unique_fingerprint_encoded = base64.urlsafe_b64encode(h.digest()).decode( - "ascii" - ) + unique_fingerprint_encoded = base64.urlsafe_b64encode(h.digest()).decode("ascii") if not version: version = unique_fingerprint_encoded ss = SerializationSettings( @@ -3002,9 +2744,7 @@ def download( download_literal(self.file_access, "data", data, download_to) else: if not recursive: - raise click.UsageError( - "Please specify --recursive to download all variables in a literal map." - ) + raise click.UsageError("Please specify --recursive to download all variables in a literal map.") if isinstance(data, LiteralsResolver): lm = data.literals else: @@ -3030,14 +2770,10 @@ def _pickle_and_upload_entity( raise ValueError( "The size of the task to pickled exceeds the limit of 150MB. Please reduce the size of the task." ) - logger.debug( - f"Uploading Pickled representation of Workflow `{entity.name}` to remote storage..." - ) + logger.debug(f"Uploading Pickled representation of Workflow `{entity.name}` to remote storage...") md5_bytes, native_url = self.upload_file(dest) - return md5_bytes, FastSerializationSettings( - enabled=True, distribution_location=native_url, destination_dir="." - ) + return md5_bytes, FastSerializationSettings(enabled=True, distribution_location=native_url, destination_dir=".") @classmethod def for_endpoint( diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 4df39329c7..3ac6b879f5 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -2,6 +2,7 @@ import pathlib import shutil import subprocess +import sys import tempfile import typing import uuid @@ -707,7 +708,8 @@ def w() -> int: return t2(a=t1()) target_dict = _get_pickled_target_dict(w) - assert len(target_dict) == 2 + assert len(target_dict) == 3 + assert target_dict["metadata"]["python_version"] == f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" assert t1.name in target_dict assert t2.name in target_dict assert target_dict[t1.name] == t1 @@ -723,7 +725,8 @@ def w() -> int: return map_task(partial(t1, y=2))(x=[1, 2, 3]) target_dict = _get_pickled_target_dict(w) - assert len(target_dict) == 1 + assert len(target_dict) == 2 + assert target_dict["metadata"]["python_version"] == f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" assert t1.name in target_dict assert target_dict[t1.name] == t1