From 014d08c025b497c92db48f9c37198c639c67bc26 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Tue, 8 Oct 2024 12:03:22 -0700 Subject: [PATCH] Fix array node map task for offloaded literal (#2772) (#2793) * Fix array node map task for offloaded literal * fix offloaded literal reading in array node * nit * review comments --------- Signed-off-by: pmahindrakar-oss Co-authored-by: Prafulla Mahindrakar --- flytekit/core/array_node_map_task.py | 4 +- flytekit/core/type_engine.py | 18 +++++--- .../unit/core/test_array_node_map_task.py | 43 +++++++++++++++++++ 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 301628915e..94454f417b 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -251,7 +251,9 @@ def _literal_map_to_python_input( inputs_interface = self._run_task.python_interface.inputs for k in self.interface.inputs.keys(): v = literal_map.literals[k] - + # If the input is offloaded, we need to unwrap it + if v.offloaded_metadata: + v = TypeEngine.unwrap_offloaded_literal(ctx, v) if k not in self.bound_inputs: # assert that v.collection is not None if not v.collection or not isinstance(v.collection.literals, list): diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 861909eedd..2d9d21f0ff 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1150,6 +1150,17 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type lv.hash = hash return lv + @classmethod + def unwrap_offloaded_literal(cls, ctx: FlyteContext, lv: Literal) -> Literal: + if not lv.offloaded_metadata: + return lv + + literal_local_file = ctx.file_access.get_random_local_path() + assert lv.offloaded_metadata.uri, "missing offloaded uri" + ctx.file_access.download(lv.offloaded_metadata.uri, literal_local_file) + input_proto = load_proto_from_file(literals_pb2.Literal, literal_local_file) + return Literal.from_flyte_idl(input_proto) + @classmethod def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> typing.Any: """ @@ -1157,12 +1168,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ # Initiate the process of loading the offloaded literal if offloaded_metadata is set if lv.offloaded_metadata: - literal_local_file = ctx.file_access.get_random_local_path() - assert lv.offloaded_metadata.uri, "missing offloaded uri" - ctx.file_access.download(lv.offloaded_metadata.uri, literal_local_file) - input_proto = load_proto_from_file(literals_pb2.Literal, literal_local_file) - lv = Literal.from_flyte_idl(input_proto) - + lv = cls.unwrap_offloaded_literal(ctx, lv) transformer = cls.get_transformer(expected_python_type) return transformer.to_python_value(ctx, lv, expected_python_type) diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index fa964a71ef..fae81d1355 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -17,6 +17,11 @@ from flytekit.core.type_engine import TypeEngine from flytekit.extras.accelerators import GPUAccelerator from flytekit.experimental.eager_function import eager +from flytekit.models.literals import ( + Literal, + LiteralMap, + LiteralOffloadedMetadata, +) from flytekit.tools.translator import get_serializable from flytekit.types.pickle import BatchSize @@ -464,3 +469,41 @@ def wf(): with pytest.raises(AssertionError): wf.compile() + + +def test_load_offloaded_literal(tmp_path, monkeypatch): + @task + def say_hello(name: str) -> str: + return f"hello {name}!" + + ctx = context_manager.FlyteContextManager.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + list_strs = ["a", "b", "c"] + lt = TypeEngine.to_literal_type(typing.List[str]) + to_be_offloaded = TypeEngine.to_literal(ctx, list_strs, typing.List[str], lt) + with open(f"{tmp_path}/literal.pb", "wb") as f: + f.write(to_be_offloaded.to_flyte_idl().SerializeToString()) + + literal = Literal( + offloaded_metadata=LiteralOffloadedMetadata( + uri=f"{tmp_path}/literal.pb", + inferred_type=lt, + ), + ) + + lm = LiteralMap({ + "name": literal + }) + + for index, map_input_str in enumerate(list_strs): + monkeypatch.setenv("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "name") + monkeypatch.setenv("name", str(index)) + t = map_task(say_hello) + res = t.dispatch_execute(ctx, lm) + assert len(res.literals) == 1 + assert res.literals[f"o{0}"].scalar.primitive.string_value == f"hello {map_input_str}!" + monkeypatch.undo()