diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 1466c351ac..8883f42637 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -294,11 +294,31 @@ def name(self) -> str: def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: _, entity_name, *_ = loader_args import gzip + import sys import cloudpickle - with gzip.open(PICKLE_FILE_PATH, "r") as f: - entity_dict = cloudpickle.load(f) + try: + with gzip.open(PICKLE_FILE_PATH, "r") as f: + entity_dict = cloudpickle.load(f) + except TypeError: + raise RuntimeError( + "The Python version is smaller than the version used to create the pickle file. " + f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. " + "Please try using the same Python version to create the pickle file or use another " + "container image with a matching version." + ) + + pickled_version = entity_dict["metadata"]["python_version"].split(".") + if sys.version_info.major != int(pickled_version[0]) or sys.version_info.minor != int(pickled_version[1]): + raise RuntimeError( + "The Python version used to create the pickle file is different from the current Python version. " + f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}. " + f"Python version used to create the pickle file: {entity_dict['metadata']['python_version']}. " + "Please try using the same Python version to create the pickle file or use another " + "container image with a matching version." + ) + return entity_dict[entity_name] def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index 4aba363f3e..65ef77abc1 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -43,7 +43,10 @@ def outputs(self) -> Optional[LiteralsResolver]: "Please wait until the execution has completed before requesting the outputs." ) if self.error: - raise user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") + raise user_exceptions.FlyteAssertion( + "Outputs could not be found because the execution ended in failure. Error message: " + f"{self.error.message}" + ) return self._outputs diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index e17943f80b..eef202bd74 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -35,7 +35,13 @@ from flytekit import ImageSpec from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions -from flytekit.configuration import Config, DataConfig, FastSerializationSettings, ImageConfig, SerializationSettings +from flytekit.configuration import ( + Config, + DataConfig, + FastSerializationSettings, + ImageConfig, + SerializationSettings, +) from flytekit.configuration.file import ConfigFile from flytekit.constants import CopyFileDetection from flytekit.core import constants, utils @@ -56,7 +62,12 @@ from flytekit.core.task import ReferenceTask from flytekit.core.tracker import extract_task_module from flytekit.core.type_engine import LiteralsResolver, TypeEngine -from flytekit.core.workflow import PythonFunctionWorkflow, ReferenceWorkflow, WorkflowBase, WorkflowFailurePolicy +from flytekit.core.workflow import ( + PythonFunctionWorkflow, + ReferenceWorkflow, + WorkflowBase, + WorkflowFailurePolicy, +) from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.user import ( FlyteAssertion, @@ -77,7 +88,12 @@ from flytekit.models.common import NamedEntityIdentifier from flytekit.models.core import identifier as id_models from flytekit.models.core import workflow as workflow_model -from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier +from flytekit.models.core.identifier import ( + Identifier, + ResourceType, + SignalIdentifier, + WorkflowExecutionIdentifier, +) from flytekit.models.core.workflow import BranchNode, Node, NodeMetadata from flytekit.models.execution import ( ClusterAssignment, @@ -92,15 +108,30 @@ from flytekit.models.matchable_resource import ExecutionClusterLabel from flytekit.remote.backfill import create_backfill_workflow from flytekit.remote.data import download_literal -from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow -from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution +from flytekit.remote.entities import ( + FlyteLaunchPlan, + FlyteNode, + FlyteTask, + FlyteTaskNode, + FlyteWorkflow, +) +from flytekit.remote.executions import ( + FlyteNodeExecution, + FlyteTaskExecution, + FlyteWorkflowExecution, +) from flytekit.remote.interface import TypedInterface from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity from flytekit.remote.remote_fs import get_flyte_fs from flytekit.tools.fast_registration import FastPackageOptions, fast_package from flytekit.tools.interactive import ipython_check -from flytekit.tools.script_mode import _find_project_root, compress_scripts, get_all_modules, hash_file +from flytekit.tools.script_mode import ( + _find_project_root, + compress_scripts, + get_all_modules, + hash_file, +) from flytekit.tools.translator import ( FlyteControlPlaneEntity, FlyteLocalEntity, @@ -163,7 +194,7 @@ def _get_entity_identifier( project, domain, name, - version if version is not None else _get_latest_version(list_entities_method, project, domain, name), + (version if version is not None else _get_latest_version(list_entities_method, project, domain, name)), ) @@ -206,8 +237,14 @@ def _get_pickled_target_dict(root_entity: typing.Union[WorkflowBase, PythonTask] :param root_entity: The entity to get the pickled target for. :return: The pickled target dictionary. """ + import sys + queue = [root_entity] - pickled_target_dict = {} + pickled_target_dict = { + "metadata": { + "python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + } + } while queue: entity = queue.pop() if isinstance(entity, PythonFunctionTask): @@ -373,7 +410,11 @@ def remote_context(self): ) def fetch_task_lazy( - self, project: str = None, domain: str = None, name: str = None, version: str = None + self, + project: str = None, + domain: str = None, + name: str = None, + version: str = None, ) -> LazyEntity: """ Similar to fetch_task, just that it returns a LazyEntity, which will fetch the workflow lazily. @@ -386,7 +427,13 @@ def _fetch(): return LazyEntity(name=name, getter=_fetch) - def fetch_task(self, project: str = None, domain: str = None, name: str = None, version: str = None) -> FlyteTask: + def fetch_task( + self, + project: str = None, + domain: str = None, + name: str = None, + version: str = None, + ) -> FlyteTask: """Fetch a task entity from flyte admin. :param project: fetch entity from this project. If None, uses the default_project attribute. @@ -413,7 +460,11 @@ def fetch_task(self, project: str = None, domain: str = None, name: str = None, return flyte_task def fetch_workflow_lazy( - self, project: str = None, domain: str = None, name: str = None, version: str = None + self, + project: str = None, + domain: str = None, + name: str = None, + version: str = None, ) -> LazyEntity[FlyteWorkflow]: """ Similar to fetch_workflow, just that it returns a LazyEntity, which will fetch the workflow lazily. @@ -427,7 +478,11 @@ def _fetch(): return LazyEntity(name=name, getter=_fetch) def fetch_workflow( - self, project: str = None, domain: str = None, name: str = None, version: str = None + self, + project: str = None, + domain: str = None, + name: str = None, + version: str = None, ) -> FlyteWorkflow: """ Fetch a workflow entity from flyte admin. @@ -457,7 +512,8 @@ def fetch_workflow( node_launch_plans = {} def find_launch_plan( - lp_ref: id_models, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] + lp_ref: id_models, + node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec], ) -> None: if lp_ref not in node_launch_plans: admin_launch_plan = self.client.get_launch_plan(lp_ref) @@ -471,10 +527,12 @@ def find_launch_plan( # Inspect conditional branch nodes for launch plans def get_launch_plan_from_branch( - branch_node: BranchNode, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] + branch_node: BranchNode, + node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec], ) -> None: def get_launch_plan_from_then_node( - child_then_node: Node, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec] + child_then_node: Node, + node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec], ) -> None: # then_node could have nested branch_node or be a normal then_node if child_then_node.branch_node: @@ -534,7 +592,11 @@ def fetch_active_launchplan( return None def fetch_launch_plan( - self, project: str = None, domain: str = None, name: str = None, version: str = None + self, + project: str = None, + domain: str = None, + name: str = None, + version: str = None, ) -> FlyteLaunchPlan: """Fetch a launchplan entity from flyte admin. @@ -602,9 +664,15 @@ def list_signals( :param filters: Optional list of filters """ wf_exec_id = WorkflowExecutionIdentifier( - project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + project=project or self.default_project, + domain=domain or self.default_domain, + name=execution_name, + ) + req = SignalListRequest( + workflow_execution_id=wf_exec_id.to_flyte_idl(), + limit=limit, + filters=filters, ) - req = SignalListRequest(workflow_execution_id=wf_exec_id.to_flyte_idl(), limit=limit, filters=filters) resp = self.client.list_signals(req) s = resp.signals return s @@ -632,7 +700,9 @@ def set_signal( is not a Literal """ wf_exec_id = WorkflowExecutionIdentifier( - project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + project=project or self.default_project, + domain=domain or self.default_domain, + name=execution_name, ) if isinstance(value, Literal): logger.debug(f"Using provided {value} as existing Literal value") @@ -644,7 +714,10 @@ def set_signal( lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt) logger.debug(f"Converted {value} to literal {lit} using literal type {lt}") - req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl()) + req = SignalSetRequest( + id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), + value=lit.to_flyte_idl(), + ) # Response is empty currently, nothing to give back to the user. self.client.set_signal(req) @@ -782,7 +855,9 @@ def raw_register( # Let us also create a default launch-plan, ideally the default launchplan should be added # to the orderedDict, but we do not. self.file_access._get_upload_signed_url_fn = functools.partial( - self.client.get_upload_signed_url, project=settings.project, domain=settings.domain + self.client.get_upload_signed_url, + project=settings.project, + domain=settings.domain, ) default_lp = LaunchPlan.get_default_launch_plan(self.context, og_entity) lp_entity = get_serializable_launch_plan( @@ -843,7 +918,13 @@ async def _serialize_and_register( tasks.append( loop.run_in_executor( None, - functools.partial(self.raw_register, cp_entity, serialization_settings, version, og_entity=entity), + functools.partial( + self.raw_register, + cp_entity, + serialization_settings, + version, + og_entity=entity, + ), ) ) @@ -895,7 +976,12 @@ def register_task( domain=self.default_domain, ) - ident = run_sync(self._serialize_and_register, entity=entity, settings=serialization_settings, version=version) + ident = run_sync( + self._serialize_and_register, + entity=entity, + settings=serialization_settings, + version=version, + ) ft = self.fetch_task( ident.project, @@ -934,7 +1020,12 @@ def register_workflow( ) ident = run_sync( - self._serialize_and_register, entity, serialization_settings, version, options, default_launch_plan + self._serialize_and_register, + entity, + serialization_settings, + version, + options, + default_launch_plan, ) fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version) @@ -980,7 +1071,7 @@ def fast_register_workflow( return self.register_script( entity, - image_config=serialization_settings.image_config if serialization_settings else None, + image_config=(serialization_settings.image_config if serialization_settings else None), project=serialization_settings.project if serialization_settings else None, domain=serialization_settings.domain if serialization_settings else None, version=version, @@ -1045,15 +1136,20 @@ def upload_file( local_file_path = str(to_upload) content_length = os.stat(local_file_path).st_size with open(local_file_path, "+rb") as local_file: - headers = {"Content-Length": str(content_length), "Content-MD5": encoded_md5} + headers = { + "Content-Length": str(content_length), + "Content-MD5": encoded_md5, + } headers.update(extra_headers) rsp = requests.put( upload_location.signed_url, data=local_file, # NOTE: We pass the file object directly to stream our upload. headers=headers, - verify=False - if self._config.platform.insecure_skip_verify is True - else self._config.platform.ca_cert_file_path, + verify=( + False + if self._config.platform.insecure_skip_verify is True + else self._config.platform.ca_cert_file_path + ), ) # Check both HTTP 201 and 200, because some storage backends (e.g. Azure) return 201 instead of 200. @@ -1171,9 +1267,15 @@ def register_script( ) else: archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) - compress_scripts(source_path, str(archive_fname), get_all_modules(source_path, module_name)) + compress_scripts( + source_path, + str(archive_fname), + get_all_modules(source_path, module_name), + ) md5_bytes, upload_native_url = self.upload_file( - archive_fname, project or self.default_project, domain or self.default_domain + archive_fname, + project or self.default_project, + domain or self.default_domain, ) serialization_settings = SerializationSettings( @@ -1199,7 +1301,10 @@ def register_script( # but we don't have to use it when registering with the Flyte backend. # For that add the hash of the compilation settings to hash of file version = self._version_from_hash( - md5_bytes, serialization_settings, default_inputs, *self._get_image_names(entity) + md5_bytes, + serialization_settings, + default_inputs, + *self._get_image_names(entity), ) if isinstance(entity, PythonTask): @@ -1316,7 +1421,8 @@ def _execute( for k, v in inputs.items(): if input_flyte_type_map.get(k) is None: raise user_exceptions.FlyteValueException( - k, f"The {entity.__class__.__name__} doesn't have this input key." + k, + f"The {entity.__class__.__name__} doesn't have this input key.", ) if isinstance(v, Literal): lit = v @@ -1368,10 +1474,10 @@ def _execute( security_context=options.security_context, envs=common_models.Envs(envs) if envs else None, tags=tags, - cluster_assignment=ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None, - execution_cluster_label=ExecutionClusterLabel(execution_cluster_label) - if execution_cluster_label - else None, + cluster_assignment=(ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None), + execution_cluster_label=( + ExecutionClusterLabel(execution_cluster_label) if execution_cluster_label else None + ), ), literal_inputs, ) @@ -1381,7 +1487,9 @@ def _execute( f"Assuming this is the same execution, returning!" ) exec_id = WorkflowExecutionIdentifier( - project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + project=project or self.default_project, + domain=domain or self.default_domain, + name=execution_name, ) execution = FlyteWorkflowExecution.promote_from_model(self.client.get_execution(exec_id)) @@ -1416,7 +1524,14 @@ def _resolve_identifier_kwargs( def execute( self, - entity: typing.Union[FlyteTask, FlyteLaunchPlan, FlyteWorkflow, PythonTask, WorkflowBase, LaunchPlan], + entity: typing.Union[ + FlyteTask, + FlyteLaunchPlan, + FlyteWorkflow, + PythonTask, + WorkflowBase, + LaunchPlan, + ], inputs: typing.Dict[str, typing.Any], project: str = None, domain: str = None, @@ -2280,7 +2395,9 @@ def sync_node_execution( # This is a recursive call, basically going through the same process that brought us here in the first # place, but on the launched execution. launched_exec = self.fetch_execution( - project=launched_exec_id.project, domain=launched_exec_id.domain, name=launched_exec_id.name + project=launched_exec_id.project, + domain=launched_exec_id.domain, + name=launched_exec_id.name, ) self.sync_execution(launched_exec) if launched_exec.is_done: @@ -2320,7 +2437,10 @@ def sync_node_execution( dynamic_flyte_wf = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) execution._underlying_node_executions = [ - self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), dynamic_flyte_wf._node_map) + self.sync_node_execution( + FlyteNodeExecution.promote_from_model(cne), + dynamic_flyte_wf._node_map, + ) for cne in child_node_executions ] execution._task_executions = [ @@ -2372,7 +2492,8 @@ def sync_node_execution( else: execution._task_executions = [ self.sync_task_execution( - FlyteTaskExecution.promote_from_model(t), node_mapping[node_id].task_node.flyte_task + FlyteTaskExecution.promote_from_model(t), + node_mapping[node_id].task_node.flyte_task, ) for t in iterate_task_executions(self.client, execution.id) ] @@ -2387,7 +2508,9 @@ def sync_node_execution( return execution def sync_task_execution( - self, execution: FlyteTaskExecution, entity_definition: typing.Optional[FlyteTask] = None + self, + execution: FlyteTaskExecution, + entity_definition: typing.Optional[FlyteTask] = None, ) -> FlyteTaskExecution: """Sync a FlyteTaskExecution object with its corresponding remote state.""" execution._closure = self.client.get_task_execution(execution.id).closure @@ -2474,7 +2597,12 @@ def generate_console_http_domain(self) -> str: def generate_console_url( self, entity: typing.Union[ - FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflow, FlyteTask, FlyteLaunchPlan + FlyteWorkflowExecution, + FlyteNodeExecution, + FlyteTaskExecution, + FlyteWorkflow, + FlyteTask, + FlyteLaunchPlan, ], ): """ @@ -2544,7 +2672,11 @@ def launch_backfill( """ lp = self.fetch_launch_plan(project=project, domain=domain, name=launchplan, version=launchplan_version) wf, start, end = create_backfill_workflow( - start_date=from_date, end_date=to_date, for_lp=lp, parallel=parallel, failure_policy=failure_policy + start_date=from_date, + end_date=to_date, + for_lp=lp, + parallel=parallel, + failure_policy=failure_policy, ) if dry_run: logger.warning("Dry Run enabled. Workflow will not be registered and or executed.") @@ -2589,7 +2721,10 @@ def activate_launchplan(self, ident: Identifier): self.client.update_launch_plan(id=ident, state=LaunchPlanState.ACTIVE) def download( - self, data: typing.Union[LiteralsResolver, Literal, LiteralMap], download_to: str, recursive: bool = True + self, + data: typing.Union[LiteralsResolver, Literal, LiteralMap], + download_to: str, + recursive: bool = True, ): """ Download the data to the specified location. If the data is a LiteralsResolver, LiteralMap and if recursive is diff --git a/tests/flytekit/unit/core/test_resolver.py b/tests/flytekit/unit/core/test_resolver.py index 116b1251ae..5005adfd76 100644 --- a/tests/flytekit/unit/core/test_resolver.py +++ b/tests/flytekit/unit/core/test_resolver.py @@ -4,6 +4,7 @@ import cloudpickle import mock import pytest +import sys import flytekit.configuration from flytekit.configuration import Image, ImageConfig @@ -123,10 +124,28 @@ def t1(a: str, b: str) -> str: assert c.loader_args(None, t1) == ["entity-name", "tests.flytekit.unit.core.test_resolver.t1"] - pickled_dict = {"tests.flytekit.unit.core.test_resolver.t1": t1} + pickled_dict = { + "tests.flytekit.unit.core.test_resolver.t1": t1, + "metadata": { + "python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + } + } custom_pickled_object = cloudpickle.dumps(pickled_dict) mock_gzip_open.return_value.read.return_value = custom_pickled_object mock_cloudpickle.return_value = pickled_dict t = c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"]) assert t == t1 + + mismatched_pickled_dict = { + "tests.flytekit.unit.core.test_resolver.t1": t1, + "metadata": { + "python_version": f"{sys.version_info.major}.{sys.version_info.minor - 1}.{sys.version_info.micro}", + } + } + mismatched_custom_pickled_object = cloudpickle.dumps(mismatched_pickled_dict) + mock_gzip_open.return_value.read.return_value = mismatched_custom_pickled_object + mock_cloudpickle.return_value = mismatched_pickled_dict + + with pytest.raises(RuntimeError): + c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"]) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 4df39329c7..3ac6b879f5 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -2,6 +2,7 @@ import pathlib import shutil import subprocess +import sys import tempfile import typing import uuid @@ -707,7 +708,8 @@ def w() -> int: return t2(a=t1()) target_dict = _get_pickled_target_dict(w) - assert len(target_dict) == 2 + assert len(target_dict) == 3 + assert target_dict["metadata"]["python_version"] == f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" assert t1.name in target_dict assert t2.name in target_dict assert target_dict[t1.name] == t1 @@ -723,7 +725,8 @@ def w() -> int: return map_task(partial(t1, y=2))(x=[1, 2, 3]) target_dict = _get_pickled_target_dict(w) - assert len(target_dict) == 1 + assert len(target_dict) == 2 + assert target_dict["metadata"]["python_version"] == f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" assert t1.name in target_dict assert target_dict[t1.name] == t1