Skip to content

Commit

Permalink
Fix: Prevent UnionTransformer type ambiguity in combination with PyTo…
Browse files Browse the repository at this point in the history
…rchTypeTransformer (#2726)

* Fix: Prevent UnionTransformer type ambiguity in combination with PyTorchTypeTransformer

Signed-off-by: Fabio Grätz <[email protected]>

* Add test requested in code review

Signed-off-by: Fabio Grätz <[email protected]>

---------

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
  • Loading branch information
fg91 and Fabio Grätz authored Sep 14, 2024
1 parent e3dc8f9 commit 0b26c92
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
3 changes: 3 additions & 0 deletions flytekit/extras/pytorch/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def to_literal(
python_type: Type[T],
expected: LiteralType,
) -> Literal:
if not isinstance(python_val, torch.Tensor) and not isinstance(python_val, torch.nn.Module):
raise TypeTransformerFailedError("Expected a torch.Tensor or nn.Module")

meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTORCH_FORMAT,
Expand Down
44 changes: 43 additions & 1 deletion tests/flytekit/unit/extras/pytorch/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections import OrderedDict
from typing import Union

import pytest
import torch

import flytekit
from flytekit import task
from flytekit import task, workflow
from flytekit.configuration import Image, ImageConfig
from flytekit.core import context_manager
from flytekit.core.type_engine import TypeTransformerFailedError
from flytekit.extras.pytorch import (
PyTorchCheckpoint,
PyTorchCheckpointTransformer,
Expand All @@ -18,6 +20,7 @@
from flytekit.models.types import LiteralType
from flytekit.tools.translator import get_serializable


default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
project="project",
Expand Down Expand Up @@ -130,3 +133,42 @@ def t1() -> PyTorchCheckpoint:
task_spec.template.interface.outputs["o0"].type.blob.format
is PyTorchCheckpointTransformer.PYTORCH_CHECKPOINT_FORMAT
)


def test_to_literal_unambiguity():
"""Test that the pytorch type transformers raise an error when the input is a list of tensors or modules.
The PyTorchTypeTransformer uses `torch.save` for serialization which is able to serialize a list of tensors
or modules but this causes ambiguity in the Union type transformer as it cannot distinguish whether the
ListTransformer should invoke the PyTorchTypeTransformer for every element in the list or the
PyTorchTypeTransformer for the entire list.
"""
ctx = context_manager.FlyteContext.current_context()

with pytest.raises(TypeTransformerFailedError):
test_inp = torch.tensor([1, 2, 3])
trans = PyTorchTensorTransformer()
trans.to_literal(ctx, [test_inp], torch.Tensor, trans.get_literal_type(torch.Tensor))


with pytest.raises(TypeTransformerFailedError):
model = torch.nn.Linear(2, 2)
trans = PyTorchModuleTransformer()
trans.to_literal(ctx, [model], torch.nn.Module, trans.get_literal_type(torch.nn.Module))


def test_torch_tensor_list_union():
"""Test that a task can return a union of list of tensor and tensor.
See test_to_literal_unambiguity for more details why this failed.
"""

@task
def foo() -> Union[list[torch.Tensor], torch.Tensor]:
return [torch.tensor([1, 2, 3])]

@workflow
def wf():
foo()

wf()

0 comments on commit 0b26c92

Please sign in to comment.