Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async/data persistence #2829

Merged
merged 16 commits into from
Oct 22, 2024
Merged
104 changes: 76 additions & 28 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

"""

import asyncio
import io
import os
import pathlib
Expand All @@ -29,6 +30,7 @@

import fsspec
from decorator import decorator
from fsspec.asyn import AsyncFileSystem
from fsspec.utils import get_protocol
from typing_extensions import Unpack

Expand All @@ -40,6 +42,7 @@
from flytekit.exceptions.user import FlyteAssertion, FlyteDataNotFoundException
from flytekit.interfaces.random import random
from flytekit.loggers import logger
from flytekit.utils.asyn import loop_manager

# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198
# for key and secret
Expand Down Expand Up @@ -208,8 +211,17 @@ def get_filesystem(
storage_options = get_fsspec_storage_options(
protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs
)
kwargs.update(storage_options)

return fsspec.filesystem(protocol, **storage_options)
return fsspec.filesystem(protocol, **kwargs)

async def get_async_filesystem_for_path(
self, path: str = "", anonymous: bool = False, **kwargs
) -> Union[AsyncFileSystem, fsspec.AbstractFileSystem]:
protocol = get_protocol(path)
loop = asyncio.get_running_loop()

return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, loop=loop, **kwargs)

def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem:
protocol = get_protocol(path)
Expand Down Expand Up @@ -282,8 +294,8 @@ def exists(self, path: str) -> bool:
raise oe

@retry_request
def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
file_system = self.get_filesystem_for_path(from_path)
async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
file_system = await self.get_async_filesystem_for_path(from_path)
if recursive:
from_path, to_path = self.recursive_paths(from_path, to_path)
try:
Expand All @@ -294,23 +306,33 @@ def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True
)
logger.info(f"Getting {from_path} to {to_path}")
dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(file_system, AsyncFileSystem):
dst = await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
else:
dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(dst, (str, pathlib.Path)):
return dst
return to_path
except OSError as oe:
logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}")
if not file_system.exists(from_path):
raise FlyteDataNotFoundException(from_path)
file_system = self.get_filesystem(get_protocol(from_path), anonymous=True)
file_system = self.get_filesystem(get_protocol(from_path), anonymous=True, asynchronous=True)
if file_system is not None:
logger.debug(f"Attempting anonymous get with {file_system}")
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(file_system, AsyncFileSystem):
return await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
else:
return file_system.get(from_path, to_path, recursive=recursive, **kwargs)
raise oe

@retry_request
def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
file_system = self.get_filesystem_for_path(to_path)
async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
"""
More of an internal function to be called by put_data and put_raw_data
This does not need a separate sync function.
"""
file_system = await self.get_async_filesystem_for_path(to_path)
from_path = self.strip_file_header(from_path)
if recursive:
# Only check this for the local filesystem
Expand All @@ -327,13 +349,16 @@ def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
if "metadata" not in kwargs:
kwargs["metadata"] = {}
kwargs["metadata"].update(self._execution_metadata)
dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(file_system, AsyncFileSystem):
dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212
else:
dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(dst, (str, pathlib.Path)):
return dst
else:
return to_path

def put_raw_data(
async def async_put_raw_data(
self,
lpath: Uploadable,
upload_prefix: Optional[str] = None,
Expand Down Expand Up @@ -364,7 +389,7 @@ def put_raw_data(
:param read_chunk_size_bytes: If lpath is a buffer, this is the chunk size to read from it
:param encoding: If lpath is a io.StringIO, this is the encoding to use to encode it to binary.
:param skip_raw_data_prefix: If True, the raw data prefix will not be prepended to the upload_prefix
:param kwargs: Additional kwargs are passed into the the fsspec put() call or the open() call
:param kwargs: Additional kwargs are passed into the fsspec put() call or the open() call
:return: Returns the final path data was written to.
"""
# First figure out what the destination path should be, then call put.
Expand All @@ -388,42 +413,60 @@ def put_raw_data(
raise FlyteAssertion(f"File {from_path} is a symlink, can't upload")
if p.is_dir():
logger.debug(f"Detected directory {from_path}, using recursive put")
r = self.put(from_path, to_path, recursive=True, **kwargs)
r = await self._put(from_path, to_path, recursive=True, **kwargs)
else:
logger.debug(f"Detected file {from_path}, call put non-recursive")
r = self.put(from_path, to_path, **kwargs)
r = await self._put(from_path, to_path, **kwargs)
return r or to_path

# raw bytes
if isinstance(lpath, bytes):
fs = self.get_filesystem_for_path(to_path)
with fs.open(to_path, "wb", **kwargs) as s:
s.write(lpath)
fs = await self.get_async_filesystem_for_path(to_path)
if isinstance(fs, AsyncFileSystem):
async with fs.open_async(to_path, "wb", **kwargs) as s:
s.write(lpath)
else:
with fs.open(to_path, "wb", **kwargs) as s:
s.write(lpath)

return to_path

# If lpath is a buffered reader of some kind
if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO):
if not lpath.readable():
raise FlyteAssertion("Buffered reader must be readable")
fs = self.get_filesystem_for_path(to_path)
fs = await self.get_async_filesystem_for_path(to_path)
lpath.seek(0)
with fs.open(to_path, "wb", **kwargs) as s:
while data := lpath.read(read_chunk_size_bytes):
s.write(data)
if isinstance(fs, AsyncFileSystem):
async with fs.open_async(to_path, "wb", **kwargs) as s:
while data := lpath.read(read_chunk_size_bytes):
s.write(data)
else:
with fs.open(to_path, "wb", **kwargs) as s:
while data := lpath.read(read_chunk_size_bytes):
s.write(data)
return to_path

if isinstance(lpath, io.StringIO):
if not lpath.readable():
raise FlyteAssertion("Buffered reader must be readable")
fs = self.get_filesystem_for_path(to_path)
fs = await self.get_async_filesystem_for_path(to_path)
lpath.seek(0)
with fs.open(to_path, "wb", **kwargs) as s:
while data_str := lpath.read(read_chunk_size_bytes):
s.write(data_str.encode(encoding))
if isinstance(fs, AsyncFileSystem):
async with fs.open_async(to_path, "wb", **kwargs) as s:
while data_str := lpath.read(read_chunk_size_bytes):
s.write(data_str.encode(encoding))
else:
with fs.open(to_path, "wb", **kwargs) as s:
while data_str := lpath.read(read_chunk_size_bytes):
s.write(data_str.encode(encoding))
return to_path

raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}")

# Public synchronous version
put_raw_data = loop_manager.synced(async_put_raw_data)

@staticmethod
def get_random_string() -> str:
return UUID(int=random.getrandbits(128)).hex
Expand Down Expand Up @@ -549,7 +592,7 @@ def upload_directory(self, local_path: str, remote_path: str, **kwargs):
"""
return self.put_data(local_path, remote_path, is_multipart=True, **kwargs)

def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs):
async def async_get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs):
"""
:param remote_path:
:param local_path:
Expand All @@ -558,7 +601,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
try:
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)
with timeit(f"Download data to local from {remote_path}"):
self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs)
await self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs)
except FlyteDataNotFoundException:
raise
except Exception as ex:
Expand All @@ -567,7 +610,9 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False
f"Original exception: {str(ex)}"
)

def put_data(
get_data = loop_manager.synced(async_get_data)

async def async_put_data(
self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False, **kwargs
) -> str:
"""
Expand All @@ -581,7 +626,7 @@ def put_data(
try:
local_path = str(local_path)
with timeit(f"Upload data to {remote_path}"):
put_result = self.put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs)
put_result = await self._put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs)
# This is an unfortunate workaround to ensure that we return the correct path for the remote location
# Callers of this put_data function in flytekit have been changed to assign the remote path to the
# output
Expand All @@ -595,6 +640,9 @@ def put_data(
f"Original exception: {str(ex)}"
) from ex

# Public synchronous version
put_data = loop_manager.synced(async_put_data)


flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-")
default_local_file_access_provider = FileAccessProvider(
Expand Down
10 changes: 6 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,7 +1934,9 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
return None, None

@staticmethod
def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool) -> Literal:
async def dict_to_binary_literal(
ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool
) -> Literal:
"""
Converts a Python dictionary to a Flyte-specific ``Literal`` using MessagePack encoding.
Falls back to Pickle if encoding fails and `allow_pickle` is True.
Expand All @@ -1948,7 +1950,7 @@ def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict],
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
except TypeError as e:
if allow_pickle:
remote_path = FlytePickle.to_pickle(ctx, v)
remote_path = await FlytePickle.to_pickle(ctx, v)
return Literal(
scalar=Scalar(
generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct())
Expand Down Expand Up @@ -2006,7 +2008,7 @@ async def async_to_literal(
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
return self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle)
return await self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle)

lit_map = {}
for k, v in python_val.items():
Expand Down Expand Up @@ -2062,7 +2064,7 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p
from flytekit.types.pickle import FlytePickle

uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file")
return FlytePickle.from_pickle(uri)
return await FlytePickle.from_pickle(uri)

try:
return json.loads(_json_format.MessageToJson(lv.scalar.generic))
Expand Down
2 changes: 1 addition & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ async def _create(
literal_map = await TypeEngine._dict_to_literal_map(ctx, inputs or {}, self.get_input_types())
path = ctx.file_access.get_random_local_path()
utils.write_proto_to_file(literal_map.to_flyte_idl(), path)
ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb")
await ctx.file_access.async_put_data(path, f"{output_prefix}/inputs.pb")
task_template = render_task_template(task_template, output_prefix)
else:
literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types())
Expand Down
12 changes: 6 additions & 6 deletions flytekit/extras/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import tensorflow as tf

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType


class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]):
class TensorFlowModelTransformer(AsyncTypeTransformer[tf.keras.Model]):
TENSORFLOW_FORMAT = "TensorFlowModel"

def __init__(self):
Expand All @@ -24,7 +24,7 @@ def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType:
)
)

def to_literal(
async def async_to_literal(
self,
ctx: FlyteContext,
python_val: tf.keras.Model,
Expand All @@ -44,10 +44,10 @@ def to_literal(
# save model in SavedModel format
tf.keras.models.save_model(python_val, local_path)

remote_path = ctx.file_access.put_raw_data(local_path)
remote_path = await ctx.file_access.async_put_raw_data(local_path)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(
async def async_to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model]
) -> tf.keras.Model:
try:
Expand All @@ -56,7 +56,7 @@ def to_python_value(
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=True)
await ctx.file_access.async_get_data(uri, local_path, is_multipart=True)

# load model
return tf.keras.models.load_model(local_path)
Expand Down
12 changes: 7 additions & 5 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flytekit import BlobType
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_batch_size
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
Expand Down Expand Up @@ -407,7 +407,7 @@ def __str__(self):
return str(self.path)


class FlyteDirToMultipartBlobTransformer(TypeTransformer[FlyteDirectory]):
class FlyteDirToMultipartBlobTransformer(AsyncTypeTransformer[FlyteDirectory]):
"""
This transformer handles conversion between the Python native FlyteDirectory class defined above, and the Flyte
IDL literal/type of Multipart Blob. Please see the FlyteDirectory comments for additional information.
Expand Down Expand Up @@ -444,7 +444,7 @@ def assert_type(self, t: typing.Type[FlyteDirectory], v: typing.Union[FlyteDirec
def get_literal_type(self, t: typing.Type[FlyteDirectory]) -> LiteralType:
return _type_models.LiteralType(blob=self._blob_type(format=FlyteDirToMultipartBlobTransformer.get_format(t)))

def to_literal(
async def async_to_literal(
self,
ctx: FlyteContext,
python_val: FlyteDirectory,
Expand Down Expand Up @@ -499,7 +499,9 @@ def to_literal(
remote_directory = ctx.file_access.get_random_remote_directory()
if not pathlib.Path(source_path).is_dir():
raise FlyteAssertion("Expected a directory. {} is not a directory".format(source_path))
ctx.file_access.put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size)
await ctx.file_access.async_put_data(
source_path, remote_directory, is_multipart=True, batch_size=batch_size
)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory)))

# If not uploading, then we can only take the original source path as the uri.
Expand Down Expand Up @@ -535,7 +537,7 @@ def from_binary_idl(
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def to_python_value(
async def async_to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory]
) -> FlyteDirectory:
if lv.scalar.binary:
Expand Down
Loading
Loading