Skip to content

Commit

Permalink
Fix array node map task for offloaded literal (#2772) (#2793)
Browse files Browse the repository at this point in the history
* Fix array node map task for offloaded literal



* fix offloaded literal reading in array node



* nit



* review comments



---------

Signed-off-by: pmahindrakar-oss <[email protected]>
Co-authored-by: Prafulla Mahindrakar <[email protected]>
  • Loading branch information
eapolinario and pmahindrakar-oss authored Oct 8, 2024
1 parent 410b81e commit 014d08c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
4 changes: 3 additions & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ def _literal_map_to_python_input(
inputs_interface = self._run_task.python_interface.inputs
for k in self.interface.inputs.keys():
v = literal_map.literals[k]

# If the input is offloaded, we need to unwrap it
if v.offloaded_metadata:
v = TypeEngine.unwrap_offloaded_literal(ctx, v)
if k not in self.bound_inputs:
# assert that v.collection is not None
if not v.collection or not isinstance(v.collection.literals, list):
Expand Down
18 changes: 12 additions & 6 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,19 +1150,25 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
lv.hash = hash
return lv

@classmethod
def unwrap_offloaded_literal(cls, ctx: FlyteContext, lv: Literal) -> Literal:
if not lv.offloaded_metadata:
return lv

literal_local_file = ctx.file_access.get_random_local_path()
assert lv.offloaded_metadata.uri, "missing offloaded uri"
ctx.file_access.download(lv.offloaded_metadata.uri, literal_local_file)
input_proto = load_proto_from_file(literals_pb2.Literal, literal_local_file)
return Literal.from_flyte_idl(input_proto)

@classmethod
def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> typing.Any:
"""
Converts a Literal value with an expected python type into a python value.
"""
# Initiate the process of loading the offloaded literal if offloaded_metadata is set
if lv.offloaded_metadata:
literal_local_file = ctx.file_access.get_random_local_path()
assert lv.offloaded_metadata.uri, "missing offloaded uri"
ctx.file_access.download(lv.offloaded_metadata.uri, literal_local_file)
input_proto = load_proto_from_file(literals_pb2.Literal, literal_local_file)
lv = Literal.from_flyte_idl(input_proto)

lv = cls.unwrap_offloaded_literal(ctx, lv)
transformer = cls.get_transformer(expected_python_type)
return transformer.to_python_value(ctx, lv, expected_python_type)

Expand Down
43 changes: 43 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.extras.accelerators import GPUAccelerator
from flytekit.experimental.eager_function import eager
from flytekit.models.literals import (
Literal,
LiteralMap,
LiteralOffloadedMetadata,
)
from flytekit.tools.translator import get_serializable
from flytekit.types.pickle import BatchSize

Expand Down Expand Up @@ -464,3 +469,41 @@ def wf():

with pytest.raises(AssertionError):
wf.compile()


def test_load_offloaded_literal(tmp_path, monkeypatch):
@task
def say_hello(name: str) -> str:
return f"hello {name}!"

ctx = context_manager.FlyteContextManager.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
)
) as ctx:
list_strs = ["a", "b", "c"]
lt = TypeEngine.to_literal_type(typing.List[str])
to_be_offloaded = TypeEngine.to_literal(ctx, list_strs, typing.List[str], lt)
with open(f"{tmp_path}/literal.pb", "wb") as f:
f.write(to_be_offloaded.to_flyte_idl().SerializeToString())

literal = Literal(
offloaded_metadata=LiteralOffloadedMetadata(
uri=f"{tmp_path}/literal.pb",
inferred_type=lt,
),
)

lm = LiteralMap({
"name": literal
})

for index, map_input_str in enumerate(list_strs):
monkeypatch.setenv("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "name")
monkeypatch.setenv("name", str(index))
t = map_task(say_hello)
res = t.dispatch_execute(ctx, lm)
assert len(res.literals) == 1
assert res.literals[f"o{0}"].scalar.primitive.string_value == f"hello {map_input_str}!"
monkeypatch.undo()

0 comments on commit 014d08c

Please sign in to comment.