From ef11168987a96293b65e188813cd55c65e469def Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 11 Oct 2024 17:54:16 -0700 Subject: [PATCH 01/15] eod Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 55 ++++++++++++++----- flytekit/extend/backend/base_agent.py | 2 +- flytekit/extras/pytorch/checkpoint.py | 2 +- flytekit/extras/pytorch/native.py | 2 +- flytekit/extras/sklearn/native.py | 2 +- flytekit/extras/tensorflow/model.py | 2 +- flytekit/extras/tensorflow/record.py | 4 +- flytekit/types/directory/types.py | 4 +- flytekit/types/file/file.py | 4 +- flytekit/types/iterator/json_iterator.py | 2 +- flytekit/types/numpy/ndarray.py | 2 +- flytekit/types/pickle/pickle.py | 2 +- flytekit/types/schema/types.py | 6 +- flytekit/types/schema/types_pandas.py | 4 +- .../types/structured/structured_dataset.py | 2 +- .../flytekitplugins/modin/schema.py | 2 +- .../flytekitplugins/onnxpytorch/schema.py | 2 +- .../flytekitplugins/onnxscikitlearn/schema.py | 2 +- .../flytekitplugins/onnxtensorflow/schema.py | 2 +- .../flytekitplugins/pandera/schema.py | 2 +- .../flytekitplugins/polars/sd_transformers.py | 4 +- .../flytekitplugins/vaex/sd_transformers.py | 2 +- .../flytekitplugins/whylogs/schema.py | 2 +- tests/flytekit/unit/core/test_data.py | 17 +++--- .../unit/core/test_data_persistence.py | 14 ++--- tests/flytekit/unit/core/test_flyte_file.py | 2 +- tests/flytekit/unit/remote/test_fs_remote.py | 2 +- 27 files changed, 87 insertions(+), 61 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index cdd07afba7..4ceff28824 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -26,6 +26,7 @@ from time import sleep from typing import Any, Dict, Optional, Union, cast from uuid import UUID +from fsspec.asyn import AsyncFileSystem import fsspec from decorator import decorator @@ -39,6 +40,7 @@ from flytekit.exceptions.system import FlyteDownloadDataException, FlyteUploadDataException from flytekit.exceptions.user import FlyteAssertion, FlyteDataNotFoundException from flytekit.interfaces.random import random +from flytekit.utils.asyn import loop_manager from flytekit.loggers import logger # Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 @@ -211,6 +213,10 @@ def get_filesystem( return fsspec.filesystem(protocol, **storage_options) + def get_async_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> AsyncFileSystem: + protocol = get_protocol(path) + return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, **kwargs) + def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) return self.get_filesystem(protocol, anonymous=anonymous, path=path, **kwargs) @@ -309,8 +315,12 @@ def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): raise oe @retry_request - def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): - file_system = self.get_filesystem_for_path(to_path) + async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): + """ + More of an internal function to be called by put_data and put_raw_data + This does not need a separate sync function. + """ + file_system = self.get_async_filesystem_for_path(to_path) from_path = self.strip_file_header(from_path) if recursive: # Only check this for the local filesystem @@ -327,13 +337,16 @@ def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): if "metadata" not in kwargs: kwargs["metadata"] = {} kwargs["metadata"].update(self._execution_metadata) - dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs) + """ + Need to check here for async fs or sync + """ + dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 if isinstance(dst, (str, pathlib.Path)): return dst else: return to_path - def put_raw_data( + async def async_put_raw_data( self, lpath: Uploadable, upload_prefix: Optional[str] = None, @@ -364,10 +377,16 @@ def put_raw_data( :param read_chunk_size_bytes: If lpath is a buffer, this is the chunk size to read from it :param encoding: If lpath is a io.StringIO, this is the encoding to use to encode it to binary. :param skip_raw_data_prefix: If True, the raw data prefix will not be prepended to the upload_prefix - :param kwargs: Additional kwargs are passed into the the fsspec put() call or the open() call + :param kwargs: Additional kwargs are passed into the fsspec put() call or the open() call :return: Returns the final path data was written to. """ # First figure out what the destination path should be, then call put. + """ + Make this function work with both sync and async filesystems, + update maybe delete the get async file system function + + Do this first and then make the local file system async + """ upload_prefix = self.get_random_string() if upload_prefix is None else upload_prefix to_path = self.join(self.raw_output_prefix, upload_prefix) if not skip_raw_data_prefix else upload_prefix if file_name: @@ -388,16 +407,16 @@ def put_raw_data( raise FlyteAssertion(f"File {from_path} is a symlink, can't upload") if p.is_dir(): logger.debug(f"Detected directory {from_path}, using recursive put") - r = self.put(from_path, to_path, recursive=True, **kwargs) + r = await self._put(from_path, to_path, recursive=True, **kwargs) else: logger.debug(f"Detected file {from_path}, call put non-recursive") - r = self.put(from_path, to_path, **kwargs) + r = await self._put(from_path, to_path, **kwargs) return r or to_path # raw bytes if isinstance(lpath, bytes): - fs = self.get_filesystem_for_path(to_path) - with fs.open(to_path, "wb", **kwargs) as s: + fs = self.get_async_filesystem_for_path(to_path) + async with fs.open_async(to_path, "wb", **kwargs) as s: s.write(lpath) return to_path @@ -405,9 +424,9 @@ def put_raw_data( if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = self.get_filesystem_for_path(to_path) + fs = self.get_async_filesystem_for_path(to_path) lpath.seek(0) - with fs.open(to_path, "wb", **kwargs) as s: + async with fs.open_async(to_path, "wb", **kwargs) as s: while data := lpath.read(read_chunk_size_bytes): s.write(data) return to_path @@ -415,9 +434,9 @@ def put_raw_data( if isinstance(lpath, io.StringIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = self.get_filesystem_for_path(to_path) + fs = self.get_async_filesystem_for_path(to_path) lpath.seek(0) - with fs.open(to_path, "wb", **kwargs) as s: + async with fs.open_async(to_path, "wb", **kwargs) as s: while data_str := lpath.read(read_chunk_size_bytes): s.write(data_str.encode(encoding)) return to_path @@ -567,7 +586,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False f"Original exception: {str(ex)}" ) - def put_data( + async def async_put_data( self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False, **kwargs ) -> str: """ @@ -578,10 +597,13 @@ def put_data( :param remote_path: :param is_multipart: """ + """ + write a test to confirm that a local path that's a folder is using async + """ try: local_path = str(local_path) with timeit(f"Upload data to {remote_path}"): - put_result = self.put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs) + put_result = await self._put(cast(str, local_path), remote_path, recursive=is_multipart, **kwargs) # This is an unfortunate workaround to ensure that we return the correct path for the remote location # Callers of this put_data function in flytekit have been changed to assign the remote path to the # output @@ -595,6 +617,9 @@ def put_data( f"Original exception: {str(ex)}" ) from ex + # Public synchronous version of async_put_data + put_data = loop_manager.synced(async_put_data) + flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-") default_local_file_access_provider = FileAccessProvider( diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 2f973e94f0..eb476bc983 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -368,7 +368,7 @@ async def _create( literal_map = await TypeEngine._dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) path = ctx.file_access.get_random_local_path() utils.write_proto_to_file(literal_map.to_flyte_idl(), path) - ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb") + await ctx.file_access.async_put_data(path, f"{output_prefix}/inputs.pb") task_template = render_task_template(task_template, output_prefix) else: literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) diff --git a/flytekit/extras/pytorch/checkpoint.py b/flytekit/extras/pytorch/checkpoint.py index dfb21f5932..942795fd54 100644 --- a/flytekit/extras/pytorch/checkpoint.py +++ b/flytekit/extras/pytorch/checkpoint.py @@ -98,7 +98,7 @@ def to_literal( # save checkpoint to a file torch.save(to_save, local_path) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value( diff --git a/flytekit/extras/pytorch/native.py b/flytekit/extras/pytorch/native.py index 4afce9aa4b..dab6803f3c 100644 --- a/flytekit/extras/pytorch/native.py +++ b/flytekit/extras/pytorch/native.py @@ -44,7 +44,7 @@ def to_literal( # save pytorch tensor/module to a file torch.save(python_val, local_path) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: diff --git a/flytekit/extras/sklearn/native.py b/flytekit/extras/sklearn/native.py index 37426fdfa4..568e32558b 100644 --- a/flytekit/extras/sklearn/native.py +++ b/flytekit/extras/sklearn/native.py @@ -42,7 +42,7 @@ def to_literal( # save sklearn estimator to a file joblib.dump(python_val, local_path) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: diff --git a/flytekit/extras/tensorflow/model.py b/flytekit/extras/tensorflow/model.py index 2978fe1d69..e5d08188d5 100644 --- a/flytekit/extras/tensorflow/model.py +++ b/flytekit/extras/tensorflow/model.py @@ -44,7 +44,7 @@ def to_literal( # save model in SavedModel format tf.keras.models.save_model(python_val, local_path) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value( diff --git a/flytekit/extras/tensorflow/record.py b/flytekit/extras/tensorflow/record.py index 3e86b6b2ee..ff072f3af2 100644 --- a/flytekit/extras/tensorflow/record.py +++ b/flytekit/extras/tensorflow/record.py @@ -84,7 +84,7 @@ def to_literal( local_path = os.path.join(local_dir, "0000.tfrecord") with tf.io.TFRecordWriter(local_path) as writer: writer.write(python_val.SerializeToString()) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value( @@ -150,7 +150,7 @@ def to_literal( local_path = f"{local_dir}/part_{i}.tfrecord" with tf.io.TFRecordWriter(local_path) as writer: writer.write(val.SerializeToString()) - remote_path = ctx.file_access.put_raw_data(local_dir) + remote_path = ctx.file_access.async_put_raw_data(local_dir) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value( diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 518525914d..7eea572f82 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -444,7 +444,7 @@ def assert_type(self, t: typing.Type[FlyteDirectory], v: typing.Union[FlyteDirec def get_literal_type(self, t: typing.Type[FlyteDirectory]) -> LiteralType: return _type_models.LiteralType(blob=self._blob_type(format=FlyteDirToMultipartBlobTransformer.get_format(t))) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: FlyteDirectory, @@ -499,7 +499,7 @@ def to_literal( remote_directory = ctx.file_access.get_random_remote_directory() if not pathlib.Path(source_path).is_dir(): raise FlyteAssertion("Expected a directory. {} is not a directory".format(source_path)) - ctx.file_access.put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size) + await ctx.file_access.async_put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory))) # If not uploading, then we can only take the original source path as the uri. diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 602f5bc12e..be1d92c2d4 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -511,9 +511,9 @@ async def async_to_literal( if should_upload: headers = self.get_additional_headers(source_path) if remote_path is not None: - remote_path = ctx.file_access.put_data(source_path, remote_path, is_multipart=False, **headers) + remote_path = ctx.file_access.async_put_data(source_path, remote_path, is_multipart=False, **headers) else: - remote_path = ctx.file_access.put_raw_data(source_path, **headers) + remote_path = ctx.file_access.async_put_raw_data(source_path, **headers) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path))))) # If not uploading, then we can only take the original source path as the uri. else: diff --git a/flytekit/types/iterator/json_iterator.py b/flytekit/types/iterator/json_iterator.py index d8ed2ce570..3852d74c9f 100644 --- a/flytekit/types/iterator/json_iterator.py +++ b/flytekit/types/iterator/json_iterator.py @@ -83,7 +83,7 @@ def to_literal( ) ) - return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=ctx.file_access.put_raw_data(uri)))) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=ctx.file_access.async_put_raw_data(uri)))) def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[Iterator[JSON]] diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index 1ca25bde11..0df6321b80 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -84,7 +84,7 @@ def to_literal( arr=python_val, allow_pickle=metadata.get("allow_pickle", False), ) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[np.ndarray]) -> np.ndarray: diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index d26ede7b1b..698b1a8fa9 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -60,7 +60,7 @@ def to_pickle(cls, ctx: FlyteContext, python_val: typing.Any) -> str: with open(uri, "w+b") as outfile: cloudpickle.dump(python_val, outfile) - return ctx.file_access.put_raw_data(uri) + return ctx.file_access.async_put_raw_data(uri) @classmethod def from_pickle(cls, uri: str) -> typing.Any: diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 5cf8308b03..634bb17528 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -406,7 +406,7 @@ def assert_type(self, t: Type[FlyteSchema], v: typing.Any): def get_literal_type(self, t: Type[FlyteSchema]) -> LiteralType: return LiteralType(schema=self._get_schema_type(t)) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: FlyteSchema, python_type: Type[FlyteSchema], expected: LiteralType ) -> Literal: if isinstance(python_val, FlyteSchema): @@ -421,7 +421,7 @@ def to_literal( # This means the local path is empty. Don't try to overwrite the remote data logger.debug(f"Skipping upload for {python_val} because it was never downloaded.") else: - remote_path = ctx.file_access.put_data(python_val.local_path, remote_path, is_multipart=True) + remote_path = await ctx.file_access.async_put_data(python_val.local_path, remote_path, is_multipart=True) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type)))) remote_path = ctx.file_access.join(ctx.file_access.raw_output_prefix, ctx.file_access.get_random_string()) @@ -438,7 +438,7 @@ def to_literal( writer = schema.open(type(python_val)) writer.write(python_val) if not h.handles_remote_io: - schema.remote_path = ctx.file_access.put_data(schema.local_path, schema.remote_path, is_multipart=True) + schema.remote_path = await ctx.file_access.async_put_data(schema.local_path, schema.remote_path, is_multipart=True) return Literal(scalar=Scalar(schema=Schema(schema.remote_path, self._get_schema_type(python_type)))) def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index a7ade2fe46..38789eaebf 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -91,7 +91,7 @@ def _get_schema_type() -> SchemaType: def get_literal_type(self, t: Type[pandas.DataFrame]) -> LiteralType: return LiteralType(schema=self._get_schema_type()) - def to_literal( + async def to_literal( self, ctx: FlyteContext, python_val: pandas.DataFrame, @@ -105,7 +105,7 @@ def to_literal( ctx.file_access.raw_output_prefix, ctx.file_access.get_random_string(), ) - remote_path = ctx.file_access.put_data(local_dir, remote_path, is_multipart=True) + remote_path = await ctx.file_access.async_put_data(local_dir, remote_path, is_multipart=True) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type()))) def to_python_value( diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 75b20fe08c..81a7ebb549 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -647,7 +647,7 @@ def to_literal( if not uri: raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}") if not ctx.file_access.is_remote(uri): - uri = ctx.file_access.put_raw_data(uri) + uri = ctx.file_access.async_put_raw_data(uri) sd_model = literals.StructuredDataset( uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type=sdt), diff --git a/plugins/flytekit-modin/flytekitplugins/modin/schema.py b/plugins/flytekit-modin/flytekitplugins/modin/schema.py index 0504c38746..5d93ec6208 100644 --- a/plugins/flytekit-modin/flytekitplugins/modin/schema.py +++ b/plugins/flytekit-modin/flytekitplugins/modin/schema.py @@ -89,7 +89,7 @@ def to_literal( ctx.file_access.raw_output_prefix, ctx.file_access.get_random_string(), ) - remote_path = ctx.file_access.put_data(local_dir, remote_path, is_multipart=True) + remote_path = await ctx.file_access.async_put_data(local_dir, remote_path, is_multipart=True) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type()))) def to_python_value( diff --git a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py index 78793b84d3..db82c7c278 100644 --- a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py +++ b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py @@ -100,7 +100,7 @@ def to_literal( if config: local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = ctx.file_access.async_put_raw_data(local_path) else: raise TypeTransformerFailedError(f"{python_type}'s config is None") diff --git a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py index ea85c0b6fb..8963aa1580 100644 --- a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py +++ b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py @@ -119,7 +119,7 @@ def to_literal( if config: local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = ctx.file_access.async_put_raw_data(local_path) else: raise TypeTransformerFailedError(f"{python_type}'s config is None") diff --git a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py index 2e7c6cc579..6bfcb60067 100644 --- a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py +++ b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py @@ -90,7 +90,7 @@ def to_literal( if config: local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) - remote_path = ctx.file_access.put_raw_data(local_path) + remote_path = ctx.file_access.async_put_raw_data(local_path) else: raise TypeTransformerFailedError(f"{python_type}'s config is None") diff --git a/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py b/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py index 6fe833d836..9261b37938 100644 --- a/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py +++ b/plugins/flytekit-pandera/flytekitplugins/pandera/schema.py @@ -73,7 +73,7 @@ def to_literal( local_dir=local_dir, cols=self._get_col_dtypes(python_type), fmt=SchemaFormat.PARQUET ) w.write(self._pandera_schema(python_type)(python_val)) - remote_path = ctx.file_access.put_raw_data(local_dir) + remote_path = ctx.file_access.async_put_raw_data(local_dir) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type)))) else: raise AssertionError( diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 474901544d..ed71913e25 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -75,7 +75,7 @@ def encode( output_uri = structured_dataset.uri else: remote_fn = "00000" # 00000 is our default unnamed parquet filename - output_uri = ctx.file_access.put_raw_data(output_bytes, file_name=remote_fn) + output_uri = ctx.file_access.async_put_raw_data(output_bytes, file_name=remote_fn) return literals.StructuredDataset(uri=output_uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -137,7 +137,7 @@ def encode( output_bytes = io.BytesIO() remote_fn = "00000" # 00000 is our default unnamed parquet filename _write_method(output_bytes) - output_uri = ctx.file_access.put_raw_data(output_bytes, file_name=remote_fn) + output_uri = ctx.file_access.async_put_raw_data(output_bytes, file_name=remote_fn) return literals.StructuredDataset(uri=output_uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) diff --git a/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py b/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py index 28e6537aa6..fea73fdbf6 100644 --- a/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py +++ b/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py @@ -30,7 +30,7 @@ def encode( local_dir = ctx.file_access.get_random_local_directory() local_path = os.path.join(local_dir, f"{0:05}") df.export_parquet(local_path) - path = ctx.file_access.put_raw_data(local_dir) + path = ctx.file_access.async_put_raw_data(local_dir) return literals.StructuredDataset( uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type), diff --git a/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py b/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py index 82b1d3b616..a9aca16945 100644 --- a/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py +++ b/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py @@ -30,7 +30,7 @@ def to_literal( ) -> Literal: local_dir = ctx.file_access.get_random_local_path() python_val.write(local_dir) - remote_path = ctx.file_access.put_raw_data(local_dir) + remote_path = ctx.file_access.async_put_raw_data(local_dir) return Literal(scalar=Scalar(blob=Blob(uri=remote_path, metadata=BlobMetadata(type=self._TYPE_INFO)))) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[DatasetProfileView]) -> T: diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 2de6e8c196..0a9fa7d080 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -102,16 +102,17 @@ def source_folder(): def test_local_raw_fsspec(source_folder): # Test copying using raw fsspec local filesystem, should not create a nested folder with tempfile.TemporaryDirectory() as dest_tmpdir: - local.put(source_folder, dest_tmpdir, recursive=True) + local._put(source_folder, dest_tmpdir, recursive=True) new_temp_dir_2 = tempfile.mkdtemp() new_temp_dir_2 = os.path.join(new_temp_dir_2, "doesnotexist") - local.put(source_folder, new_temp_dir_2, recursive=True) + local._put(source_folder, new_temp_dir_2, recursive=True) files = local.find(new_temp_dir_2) assert len(files) == 2 -def test_local_provider(source_folder): +@pytest.mark.asyncio +async def test_local_provider(source_folder): # Test that behavior putting from a local dir to a local remote dir is the same whether or not the local # dest folder exists. dc = Config.for_sandbox().data_config @@ -119,14 +120,14 @@ def test_local_provider(source_folder): provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=dest_tmpdir, data_config=dc) r = provider.get_random_string() doesnotexist = provider.join(provider.raw_output_prefix, r) - provider.put_data(source_folder, doesnotexist, is_multipart=True) + await provider.async_put_data(source_folder, doesnotexist, is_multipart=True) files = provider.raw_output_fs.find(doesnotexist) assert len(files) == 2 r = provider.get_random_string() exists = provider.join(provider.raw_output_prefix, r) provider.raw_output_fs.mkdir(exists) - provider.put_data(source_folder, exists, is_multipart=True) + await provider.async_put_data(source_folder, exists, is_multipart=True) files = provider.raw_output_fs.find(exists) assert len(files) == 2 @@ -161,7 +162,7 @@ async def _lsdir( fsspec.register_implementation("test", MockAsyncFileSystem) ctx = FlyteContextManager.current_context() - dst = ctx.file_access.put(local_path, remote_path) + dst = ctx.file_access._put(local_path, remote_path) assert dst == remote_path dst = ctx.file_access.get(remote_path, local_path) assert dst == local_path @@ -175,7 +176,7 @@ def test_s3_provider(source_folder): local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc ) doesnotexist = provider.join(provider.raw_output_prefix, provider.get_random_string()) - provider.put_data(source_folder, doesnotexist, is_multipart=True) + provider.async_put_data(source_folder, doesnotexist, is_multipart=True) fs = provider.get_filesystem_for_path(doesnotexist) files = fs.find(doesnotexist) assert len(files) == 2 @@ -378,7 +379,7 @@ def test_crawl_s3(source_folder): local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc ) s3_random_target = provider.join(provider.raw_output_prefix, provider.get_random_string()) - provider.put_data(source_folder, s3_random_target, is_multipart=True) + provider.async_put_data(source_folder, s3_random_target, is_multipart=True) ctx = FlyteContextManager.current_context() expected = {f"{s3_random_target}/original.txt", f"{s3_random_target}/nested/more.txt"} diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 5063e484d2..771cdff073 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -63,17 +63,17 @@ def test_write_folder_put_raw(mock_uuid_class): df.to_parquet(bio2, engine="pyarrow") # Write foo/a.txt by specifying the upload prefix and a file name - fs.put_raw_data(sio, upload_prefix="foo", file_name="a.txt") + fs.async_put_raw_data(sio, upload_prefix="foo", file_name="a.txt") # Write bar/00000 by specifying the folder in the filename - fs.put_raw_data(bio, file_name="bar/00000") + fs.async_put_raw_data(bio, file_name="bar/00000") # Write pd.parquet and baz by specifying an empty string upload prefix - fs.put_raw_data(bio2, upload_prefix="", file_name="pd.parquet") - fs.put_raw_data(bio, upload_prefix="", file_name="baz/00000") + fs.async_put_raw_data(bio2, upload_prefix="", file_name="pd.parquet") + fs.async_put_raw_data(bio, upload_prefix="", file_name="baz/00000") # Write sio again with known folder but random file name - fs.put_raw_data(sio, upload_prefix="baz") + fs.async_put_raw_data(sio, upload_prefix="baz") paths = [str(p) for p in pathlib.Path(raw).rglob("*")] assert len(paths) == 9 @@ -107,7 +107,7 @@ def test_write_large_put_raw(): sio.seek(0) # Write foo/a.txt by specifying the upload prefix and a file name - fs.put_raw_data(sio, upload_prefix="foo", file_name="a.txt", block_size=5, read_chunk_size_bytes=1) + fs.async_put_raw_data(sio, upload_prefix="foo", file_name="a.txt", block_size=5, read_chunk_size_bytes=1) output_file = os.path.join(raw, "foo", "a.txt") with open(output_file, "rb") as f: assert f.read() == arbitrary_text.encode("utf-8") @@ -130,7 +130,7 @@ def test_write_known_location(): # Write foo/a.txt by specifying the upload prefix and a file name known_dest_dir = tempfile.mkdtemp() set_path = fs.join(known_dest_dir, "a.txt") - output_path = fs.put_raw_data(sio, upload_prefix=known_dest_dir, file_name="a.txt", skip_raw_data_prefix=True) + output_path = fs.async_put_raw_data(sio, upload_prefix=known_dest_dir, file_name="a.txt", skip_raw_data_prefix=True) assert output_path == set_path with open(output_path, "rb") as f: assert f.read() == arbitrary_text.encode("utf-8") diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 7e09e918ae..352984ca37 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -650,7 +650,7 @@ def write_this_file_to_s3() -> FlyteFile: ctx = FlyteContextManager.current_context() r = ctx.file_access.get_random_string() dest = ctx.file_access.join(ctx.file_access.raw_output_prefix, r) - ctx.file_access.put(__file__, dest) + ctx.file_access._put(__file__, dest) return FlyteFile(path=dest) @task diff --git a/tests/flytekit/unit/remote/test_fs_remote.py b/tests/flytekit/unit/remote/test_fs_remote.py index efc6e94e8b..5c635376b4 100644 --- a/tests/flytekit/unit/remote/test_fs_remote.py +++ b/tests/flytekit/unit/remote/test_fs_remote.py @@ -109,7 +109,7 @@ def test_remote_upload_with_data_persistence(sandbox_remote): f.write("asdf") f.flush() # Test uploading a file and folder. - res = fp.put(f.name, "flyte://data/", recursive=True) + res = fp._put(f.name, "flyte://data/", recursive=True) # Unlike using the RemoteFS directly, the trailing slash is automatically added by data persistence, # not sure why but preserving the behavior for now. only_file = pathlib.Path(f.name).name From 603f10877efa1f1cb186e095ecba3bc49409d9c3 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 16 Oct 2024 18:05:16 -0700 Subject: [PATCH 02/15] test type engine Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 19 +++++++++++-------- flytekit/core/type_engine.py | 8 +++++--- flytekit/types/directory/types.py | 10 ++++++---- flytekit/types/file/file.py | 6 ++++-- flytekit/types/pickle/pickle.py | 16 +++++++++------- flytekit/types/schema/types.py | 12 ++++++++---- .../flytekitplugins/modin/schema.py | 2 +- tests/flytekit/unit/core/test_data.py | 15 ++++++++------- tests/flytekit/unit/core/test_type_engine.py | 6 +++--- 9 files changed, 55 insertions(+), 39 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 4ceff28824..1350d128ae 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -26,10 +26,10 @@ from time import sleep from typing import Any, Dict, Optional, Union, cast from uuid import UUID -from fsspec.asyn import AsyncFileSystem import fsspec from decorator import decorator +from fsspec.asyn import AsyncFileSystem from fsspec.utils import get_protocol from typing_extensions import Unpack @@ -40,8 +40,8 @@ from flytekit.exceptions.system import FlyteDownloadDataException, FlyteUploadDataException from flytekit.exceptions.user import FlyteAssertion, FlyteDataNotFoundException from flytekit.interfaces.random import random -from flytekit.utils.asyn import loop_manager from flytekit.loggers import logger +from flytekit.utils.asyn import loop_manager # Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 # for key and secret @@ -288,7 +288,7 @@ def exists(self, path: str) -> bool: raise oe @retry_request - def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): + async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): file_system = self.get_filesystem_for_path(from_path) if recursive: from_path, to_path = self.recursive_paths(from_path, to_path) @@ -340,7 +340,10 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw """ Need to check here for async fs or sync """ - dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 + if isinstance(file_system, AsyncFileSystem): + dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 + else: + dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs) if isinstance(dst, (str, pathlib.Path)): return dst else: @@ -382,9 +385,7 @@ async def async_put_raw_data( """ # First figure out what the destination path should be, then call put. """ - Make this function work with both sync and async filesystems, update maybe delete the get async file system function - Do this first and then make the local file system async """ upload_prefix = self.get_random_string() if upload_prefix is None else upload_prefix @@ -568,7 +569,7 @@ def upload_directory(self, local_path: str, remote_path: str, **kwargs): """ return self.put_data(local_path, remote_path, is_multipart=True, **kwargs) - def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs): + async def async_get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs): """ :param remote_path: :param local_path: @@ -577,7 +578,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False try: pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) with timeit(f"Download data to local from {remote_path}"): - self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs) + await self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs) except FlyteDataNotFoundException: raise except Exception as ex: @@ -586,6 +587,8 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False f"Original exception: {str(ex)}" ) + get_data = loop_manager.synced(async_get_data) + async def async_put_data( self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False, **kwargs ) -> str: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 0c6541a860..d857a0d9c8 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1938,7 +1938,9 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple: return None, None @staticmethod - def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool) -> Literal: + async def dict_to_binary_literal( + ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool + ) -> Literal: """ Converts a Python dictionary to a Flyte-specific ``Literal`` using MessagePack encoding. Falls back to Pickle if encoding fails and `allow_pickle` is True. @@ -1952,7 +1954,7 @@ def dict_to_binary_literal(ctx: FlyteContext, v: dict, python_type: Type[dict], return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) except TypeError as e: if allow_pickle: - remote_path = FlytePickle.to_pickle(ctx, v) + remote_path = await FlytePickle.to_pickle(ctx, v) return Literal( scalar=Scalar( generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct()) @@ -2010,7 +2012,7 @@ async def async_to_literal( allow_pickle, base_type = DictTransformer.is_pickle(python_type) if expected and expected.simple and expected.simple == SimpleType.STRUCT: - return self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle) + return await self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle) lit_map = {} for k, v in python_val.items(): diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 7eea572f82..52249e2977 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -19,7 +19,7 @@ from flytekit import BlobType from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_batch_size +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size from flytekit.exceptions.user import FlyteAssertion from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types @@ -407,7 +407,7 @@ def __str__(self): return str(self.path) -class FlyteDirToMultipartBlobTransformer(TypeTransformer[FlyteDirectory]): +class FlyteDirToMultipartBlobTransformer(AsyncTypeTransformer[FlyteDirectory]): """ This transformer handles conversion between the Python native FlyteDirectory class defined above, and the Flyte IDL literal/type of Multipart Blob. Please see the FlyteDirectory comments for additional information. @@ -499,7 +499,9 @@ async def async_to_literal( remote_directory = ctx.file_access.get_random_remote_directory() if not pathlib.Path(source_path).is_dir(): raise FlyteAssertion("Expected a directory. {} is not a directory".format(source_path)) - await ctx.file_access.async_put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size) + await ctx.file_access.async_put_data( + source_path, remote_directory, is_multipart=True, batch_size=batch_size + ) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory))) # If not uploading, then we can only take the original source path as the uri. @@ -535,7 +537,7 @@ def from_binary_idl( else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory] ) -> FlyteDirectory: if lv.scalar.binary: diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index be1d92c2d4..e952cc0f15 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -511,9 +511,11 @@ async def async_to_literal( if should_upload: headers = self.get_additional_headers(source_path) if remote_path is not None: - remote_path = ctx.file_access.async_put_data(source_path, remote_path, is_multipart=False, **headers) + remote_path = await ctx.file_access.async_put_data( + source_path, remote_path, is_multipart=False, **headers + ) else: - remote_path = ctx.file_access.async_put_raw_data(source_path, **headers) + remote_path = await ctx.file_access.async_put_raw_data(source_path, **headers) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path))))) # If not uploading, then we can only take the original source path as the uri. else: diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 698b1a8fa9..a3aa93662a 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -5,7 +5,7 @@ import cloudpickle from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType @@ -52,7 +52,7 @@ def python_type(cls) -> typing.Type: return _SpecificFormatClass @classmethod - def to_pickle(cls, ctx: FlyteContext, python_val: typing.Any) -> str: + async def to_pickle(cls, ctx: FlyteContext, python_val: typing.Any) -> str: local_dir = ctx.file_access.get_random_local_directory() os.makedirs(local_dir, exist_ok=True) local_path = ctx.file_access.get_random_local_path() @@ -60,7 +60,7 @@ def to_pickle(cls, ctx: FlyteContext, python_val: typing.Any) -> str: with open(uri, "w+b") as outfile: cloudpickle.dump(python_val, outfile) - return ctx.file_access.async_put_raw_data(uri) + return await ctx.file_access.async_put_raw_data(uri) @classmethod def from_pickle(cls, uri: str) -> typing.Any: @@ -76,7 +76,7 @@ def from_pickle(cls, uri: str) -> typing.Any: return data -class FlytePickleTransformer(TypeTransformer[FlytePickle]): +class FlytePickleTransformer(AsyncTypeTransformer[FlytePickle]): PYTHON_PICKLE_FORMAT = "PythonPickle" def __init__(self): @@ -86,11 +86,13 @@ def assert_type(self, t: Type[T], v: T): # Every type can serialize to pickle, so we don't need to check the type here. ... - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: uri = lv.scalar.blob.uri return FlytePickle.from_pickle(uri) - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + async def async_to_literal( + self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType + ) -> Literal: if python_val is None: raise AssertionError("Cannot pickle None Value.") meta = BlobMetadata( @@ -98,7 +100,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) ) - remote_path = FlytePickle.to_pickle(ctx, python_val) + remote_path = await FlytePickle.to_pickle(ctx, python_val) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickle[typing.Any]]: diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 634bb17528..9bdafb1d52 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -17,7 +17,7 @@ from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError from flytekit.loggers import logger from flytekit.models.literals import Binary, Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType @@ -373,7 +373,7 @@ def _get_numpy_type_mappings() -> typing.Dict[Type, SchemaType.SchemaColumn.Sche return {} -class FlyteSchemaTransformer(TypeTransformer[FlyteSchema]): +class FlyteSchemaTransformer(AsyncTypeTransformer[FlyteSchema]): _SUPPORTED_TYPES: typing.Dict[Type, SchemaType.SchemaColumn.SchemaColumnType] = { float: SchemaType.SchemaColumn.SchemaColumnType.FLOAT, int: SchemaType.SchemaColumn.SchemaColumnType.INTEGER, @@ -421,7 +421,9 @@ async def async_to_literal( # This means the local path is empty. Don't try to overwrite the remote data logger.debug(f"Skipping upload for {python_val} because it was never downloaded.") else: - remote_path = await ctx.file_access.async_put_data(python_val.local_path, remote_path, is_multipart=True) + remote_path = await ctx.file_access.async_put_data( + python_val.local_path, remote_path, is_multipart=True + ) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type(python_type)))) remote_path = ctx.file_access.join(ctx.file_access.raw_output_prefix, ctx.file_access.get_random_string()) @@ -438,7 +440,9 @@ async def async_to_literal( writer = schema.open(type(python_val)) writer.write(python_val) if not h.handles_remote_io: - schema.remote_path = await ctx.file_access.async_put_data(schema.local_path, schema.remote_path, is_multipart=True) + schema.remote_path = await ctx.file_access.async_put_data( + schema.local_path, schema.remote_path, is_multipart=True + ) return Literal(scalar=Scalar(schema=Schema(schema.remote_path, self._get_schema_type(python_type)))) def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: diff --git a/plugins/flytekit-modin/flytekitplugins/modin/schema.py b/plugins/flytekit-modin/flytekitplugins/modin/schema.py index 5d93ec6208..22f511ea78 100644 --- a/plugins/flytekit-modin/flytekitplugins/modin/schema.py +++ b/plugins/flytekit-modin/flytekitplugins/modin/schema.py @@ -89,7 +89,7 @@ def to_literal( ctx.file_access.raw_output_prefix, ctx.file_access.get_random_string(), ) - remote_path = await ctx.file_access.async_put_data(local_dir, remote_path, is_multipart=True) + remote_path = ctx.file_access.async_put_data(local_dir, remote_path, is_multipart=True) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type()))) def to_python_value( diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 0a9fa7d080..b0f1a1200b 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -102,11 +102,11 @@ def source_folder(): def test_local_raw_fsspec(source_folder): # Test copying using raw fsspec local filesystem, should not create a nested folder with tempfile.TemporaryDirectory() as dest_tmpdir: - local._put(source_folder, dest_tmpdir, recursive=True) + local.put(source_folder, dest_tmpdir, recursive=True) new_temp_dir_2 = tempfile.mkdtemp() new_temp_dir_2 = os.path.join(new_temp_dir_2, "doesnotexist") - local._put(source_folder, new_temp_dir_2, recursive=True) + local.put(source_folder, new_temp_dir_2, recursive=True) files = local.find(new_temp_dir_2) assert len(files) == 2 @@ -132,7 +132,8 @@ async def test_local_provider(source_folder): assert len(files) == 2 -def test_async_file_system(): +@pytest.mark.asyncio +async def test_async_file_system(): remote_path = "test:///tmp/test.py" local_path = "test.py" @@ -162,9 +163,9 @@ async def _lsdir( fsspec.register_implementation("test", MockAsyncFileSystem) ctx = FlyteContextManager.current_context() - dst = ctx.file_access._put(local_path, remote_path) + dst = await ctx.file_access._put(local_path, remote_path) assert dst == remote_path - dst = ctx.file_access.get(remote_path, local_path) + dst = await ctx.file_access.get(remote_path, local_path) assert dst == local_path @@ -176,7 +177,7 @@ def test_s3_provider(source_folder): local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc ) doesnotexist = provider.join(provider.raw_output_prefix, provider.get_random_string()) - provider.async_put_data(source_folder, doesnotexist, is_multipart=True) + provider.put_data(source_folder, doesnotexist, is_multipart=True) fs = provider.get_filesystem_for_path(doesnotexist) files = fs.find(doesnotexist) assert len(files) == 2 @@ -379,7 +380,7 @@ def test_crawl_s3(source_folder): local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc ) s3_random_target = provider.join(provider.raw_output_prefix, provider.get_random_string()) - provider.async_put_data(source_folder, s3_random_target, is_multipart=True) + provider.put_data(source_folder, s3_random_target, is_multipart=True) ctx = FlyteContextManager.current_context() expected = {f"{s3_random_target}/original.txt", f"{s3_random_target}/nested/more.txt"} diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 5cf00150a9..093cb52d5c 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -924,7 +924,7 @@ class TestStructD(DataClassJsonMixin): assert ot == o -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") def test_dataclass_with_postponed_annotation(mock_put_data): remote_path = "s3://tmp/file" mock_put_data.return_value = remote_path @@ -950,7 +950,7 @@ class Data: assert dict_obj["f"]["path"] == remote_path -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") def test_optional_flytefile_in_dataclass(mock_upload_dir): mock_upload_dir.return_value = True @@ -1037,7 +1037,7 @@ class TestFileStruct(DataClassJsonMixin): assert o.i_prime == A(a=99) -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.async_put_data") def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): @dataclass class A_optional_flytefile(DataClassJSONMixin): From ac8115ed2f6211a6826feed6c34714675d9d74f3 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 17 Oct 2024 12:02:30 -0700 Subject: [PATCH 03/15] unit tests Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 33 ++++++++++++++----- flytekit/types/iterator/json_iterator.py | 11 ++++--- flytekit/types/numpy/ndarray.py | 14 ++++---- flytekit/types/schema/types.py | 4 ++- .../types/structured/structured_dataset.py | 10 +++--- .../unit/core/test_data_persistence.py | 17 +++++----- 6 files changed, 56 insertions(+), 33 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 1350d128ae..80307fe60d 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -417,8 +417,13 @@ async def async_put_raw_data( # raw bytes if isinstance(lpath, bytes): fs = self.get_async_filesystem_for_path(to_path) - async with fs.open_async(to_path, "wb", **kwargs) as s: - s.write(lpath) + if isinstance(fs, AsyncFileSystem): + async with fs.open_async(to_path, "wb", **kwargs) as s: + s.write(lpath) + else: + with fs.open(to_path, "wb", **kwargs) as s: + s.write(lpath) + return to_path # If lpath is a buffered reader of some kind @@ -427,9 +432,14 @@ async def async_put_raw_data( raise FlyteAssertion("Buffered reader must be readable") fs = self.get_async_filesystem_for_path(to_path) lpath.seek(0) - async with fs.open_async(to_path, "wb", **kwargs) as s: - while data := lpath.read(read_chunk_size_bytes): - s.write(data) + if isinstance(fs, AsyncFileSystem): + async with fs.open_async(to_path, "wb", **kwargs) as s: + while data := lpath.read(read_chunk_size_bytes): + s.write(data) + else: + with fs.open(to_path, "wb", **kwargs) as s: + while data := lpath.read(read_chunk_size_bytes): + s.write(data) return to_path if isinstance(lpath, io.StringIO): @@ -437,13 +447,20 @@ async def async_put_raw_data( raise FlyteAssertion("Buffered reader must be readable") fs = self.get_async_filesystem_for_path(to_path) lpath.seek(0) - async with fs.open_async(to_path, "wb", **kwargs) as s: - while data_str := lpath.read(read_chunk_size_bytes): - s.write(data_str.encode(encoding)) + if isinstance(fs, AsyncFileSystem): + async with fs.open_async(to_path, "wb", **kwargs) as s: + while data_str := lpath.read(read_chunk_size_bytes): + s.write(data_str.encode(encoding)) + else: + with fs.open(to_path, "wb", **kwargs) as s: + while data_str := lpath.read(read_chunk_size_bytes): + s.write(data_str.encode(encoding)) return to_path raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}") + put_raw_data = loop_manager.synced(async_put_raw_data) + @staticmethod def get_random_string() -> str: return UUID(int=random.getrandbits(128)).hex diff --git a/flytekit/types/iterator/json_iterator.py b/flytekit/types/iterator/json_iterator.py index 3852d74c9f..bbf0853568 100644 --- a/flytekit/types/iterator/json_iterator.py +++ b/flytekit/types/iterator/json_iterator.py @@ -6,8 +6,8 @@ from flytekit import FlyteContext, Literal, LiteralType from flytekit.core.type_engine import ( + AsyncTypeTransformer, TypeEngine, - TypeTransformer, TypeTransformerFailedError, ) from flytekit.models.core import types as _core_types @@ -34,7 +34,7 @@ def __next__(self): raise StopIteration("File handler is exhausted") -class JSONIteratorTransformer(TypeTransformer[Iterator[JSON]]): +class JSONIteratorTransformer(AsyncTypeTransformer[Iterator[JSON]]): """ A JSON iterator that handles conversion between an iterator/generator and a JSONL file. """ @@ -54,7 +54,7 @@ def get_literal_type(self, t: Type[Iterator[JSON]]) -> LiteralType: metadata={"format": self.JSON_ITERATOR_METADATA}, ) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: Iterator[JSON], @@ -83,9 +83,10 @@ def to_literal( ) ) - return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=ctx.file_access.async_put_raw_data(uri)))) + uri = await ctx.file_access.async_put_raw_data(uri) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=uri))) - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[Iterator[JSON]] ) -> JSONIterator: try: diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index 0df6321b80..91c9dc3019 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -9,8 +9,8 @@ from flytekit.core.context_manager import FlyteContext from flytekit.core.hash import HashMethod from flytekit.core.type_engine import ( + AsyncTypeTransformer, TypeEngine, - TypeTransformer, TypeTransformerFailedError, ) from flytekit.models.core import types as _core_types @@ -41,7 +41,7 @@ def extract_metadata(t: Type[np.ndarray]) -> Tuple[Type[np.ndarray], Dict[str, b return t, metadata -class NumpyArrayTransformer(TypeTransformer[np.ndarray]): +class NumpyArrayTransformer(AsyncTypeTransformer[np.ndarray]): """ TypeTransformer that supports np.ndarray as a native type. """ @@ -59,7 +59,7 @@ def get_literal_type(self, t: Type[np.ndarray]) -> LiteralType: ) ) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: np.ndarray, @@ -84,10 +84,12 @@ def to_literal( arr=python_val, allow_pickle=metadata.get("allow_pickle", False), ) - remote_path = ctx.file_access.async_put_raw_data(local_path) + remote_path = await ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[np.ndarray]) -> np.ndarray: + async def async_to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[np.ndarray] + ) -> np.ndarray: try: uri = lv.scalar.blob.uri except AttributeError: @@ -96,7 +98,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: expected_python_type, metadata = extract_metadata(expected_python_type) local_path = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(uri, local_path, is_multipart=False) + await ctx.file_access.async_get_data(uri, local_path, is_multipart=False) # load numpy array from a file return np.load( diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 9bdafb1d52..28a2c542ef 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -462,7 +462,9 @@ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[ else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[FlyteSchema]) -> FlyteSchema: + async def async_to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[FlyteSchema] + ) -> FlyteSchema: # Handle dataclass attribute access if lv.scalar and lv.scalar.binary: return self.from_binary_idl(lv.scalar.binary, expected_python_type) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 81a7ebb549..d5acda3d6f 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -19,7 +19,7 @@ from flytekit import lazy_module from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError from flytekit.deck.renderer import Renderable from flytekit.loggers import developer_logger, logger from flytekit.models import literals @@ -399,7 +399,7 @@ def get_supported_types(): class DuplicateHandlerError(ValueError): ... -class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): +class StructuredDatasetTransformerEngine(AsyncTypeTransformer[StructuredDataset]): """ Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of @@ -594,7 +594,7 @@ def register_for_protocol( def assert_type(self, t: Type[StructuredDataset], v: typing.Any): return - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: Union[StructuredDataset, typing.Any], @@ -647,7 +647,7 @@ def to_literal( if not uri: raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}") if not ctx.file_access.is_remote(uri): - uri = ctx.file_access.async_put_raw_data(uri) + uri = await ctx.file_access.async_put_raw_data(uri) sd_model = literals.StructuredDataset( uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type=sdt), @@ -745,7 +745,7 @@ def from_binary_idl( else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset ) -> T | StructuredDataset: """ diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 771cdff073..face41de0c 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -63,17 +63,17 @@ def test_write_folder_put_raw(mock_uuid_class): df.to_parquet(bio2, engine="pyarrow") # Write foo/a.txt by specifying the upload prefix and a file name - fs.async_put_raw_data(sio, upload_prefix="foo", file_name="a.txt") + fs.put_raw_data(sio, upload_prefix="foo", file_name="a.txt") # Write bar/00000 by specifying the folder in the filename - fs.async_put_raw_data(bio, file_name="bar/00000") + fs.put_raw_data(bio, file_name="bar/00000") # Write pd.parquet and baz by specifying an empty string upload prefix - fs.async_put_raw_data(bio2, upload_prefix="", file_name="pd.parquet") - fs.async_put_raw_data(bio, upload_prefix="", file_name="baz/00000") + fs.put_raw_data(bio2, upload_prefix="", file_name="pd.parquet") + fs.put_raw_data(bio, upload_prefix="", file_name="baz/00000") # Write sio again with known folder but random file name - fs.async_put_raw_data(sio, upload_prefix="baz") + fs.put_raw_data(sio, upload_prefix="baz") paths = [str(p) for p in pathlib.Path(raw).rglob("*")] assert len(paths) == 9 @@ -92,7 +92,8 @@ def test_write_folder_put_raw(mock_uuid_class): assert sorted(paths) == sorted(expected) -def test_write_large_put_raw(): +@pytest.mark.asyncio +async def test_write_large_put_raw(): """ Test that writes a large'ish file setting block size and read size. """ @@ -107,7 +108,7 @@ def test_write_large_put_raw(): sio.seek(0) # Write foo/a.txt by specifying the upload prefix and a file name - fs.async_put_raw_data(sio, upload_prefix="foo", file_name="a.txt", block_size=5, read_chunk_size_bytes=1) + await fs.async_put_raw_data(sio, upload_prefix="foo", file_name="a.txt", block_size=5, read_chunk_size_bytes=1) output_file = os.path.join(raw, "foo", "a.txt") with open(output_file, "rb") as f: assert f.read() == arbitrary_text.encode("utf-8") @@ -130,7 +131,7 @@ def test_write_known_location(): # Write foo/a.txt by specifying the upload prefix and a file name known_dest_dir = tempfile.mkdtemp() set_path = fs.join(known_dest_dir, "a.txt") - output_path = fs.async_put_raw_data(sio, upload_prefix=known_dest_dir, file_name="a.txt", skip_raw_data_prefix=True) + output_path = fs.put_raw_data(sio, upload_prefix=known_dest_dir, file_name="a.txt", skip_raw_data_prefix=True) assert output_path == set_path with open(output_path, "rb") as f: assert f.read() == arbitrary_text.encode("utf-8") From 22603f6a948d6e2e59d7d59d89d6ad4f92e6eae2 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 17 Oct 2024 12:20:59 -0700 Subject: [PATCH 04/15] revert some put raw data Signed-off-by: Yee Hing Tong --- plugins/flytekit-modin/flytekitplugins/modin/schema.py | 2 +- .../flytekitplugins/onnxpytorch/schema.py | 2 +- .../flytekitplugins/onnxscikitlearn/schema.py | 2 +- .../flytekitplugins/onnxtensorflow/schema.py | 2 +- .../flytekit-polars/flytekitplugins/polars/sd_transformers.py | 4 ++-- plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py | 2 +- plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/plugins/flytekit-modin/flytekitplugins/modin/schema.py b/plugins/flytekit-modin/flytekitplugins/modin/schema.py index 22f511ea78..0504c38746 100644 --- a/plugins/flytekit-modin/flytekitplugins/modin/schema.py +++ b/plugins/flytekit-modin/flytekitplugins/modin/schema.py @@ -89,7 +89,7 @@ def to_literal( ctx.file_access.raw_output_prefix, ctx.file_access.get_random_string(), ) - remote_path = ctx.file_access.async_put_data(local_dir, remote_path, is_multipart=True) + remote_path = ctx.file_access.put_data(local_dir, remote_path, is_multipart=True) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type()))) def to_python_value( diff --git a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py index db82c7c278..78793b84d3 100644 --- a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py +++ b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py @@ -100,7 +100,7 @@ def to_literal( if config: local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) - remote_path = ctx.file_access.async_put_raw_data(local_path) + remote_path = ctx.file_access.put_raw_data(local_path) else: raise TypeTransformerFailedError(f"{python_type}'s config is None") diff --git a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py index 8963aa1580..ea85c0b6fb 100644 --- a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py +++ b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py @@ -119,7 +119,7 @@ def to_literal( if config: local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) - remote_path = ctx.file_access.async_put_raw_data(local_path) + remote_path = ctx.file_access.put_raw_data(local_path) else: raise TypeTransformerFailedError(f"{python_type}'s config is None") diff --git a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py index 6bfcb60067..2e7c6cc579 100644 --- a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py +++ b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py @@ -90,7 +90,7 @@ def to_literal( if config: local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) - remote_path = ctx.file_access.async_put_raw_data(local_path) + remote_path = ctx.file_access.put_raw_data(local_path) else: raise TypeTransformerFailedError(f"{python_type}'s config is None") diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index ed71913e25..474901544d 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -75,7 +75,7 @@ def encode( output_uri = structured_dataset.uri else: remote_fn = "00000" # 00000 is our default unnamed parquet filename - output_uri = ctx.file_access.async_put_raw_data(output_bytes, file_name=remote_fn) + output_uri = ctx.file_access.put_raw_data(output_bytes, file_name=remote_fn) return literals.StructuredDataset(uri=output_uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -137,7 +137,7 @@ def encode( output_bytes = io.BytesIO() remote_fn = "00000" # 00000 is our default unnamed parquet filename _write_method(output_bytes) - output_uri = ctx.file_access.async_put_raw_data(output_bytes, file_name=remote_fn) + output_uri = ctx.file_access.put_raw_data(output_bytes, file_name=remote_fn) return literals.StructuredDataset(uri=output_uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) diff --git a/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py b/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py index fea73fdbf6..28e6537aa6 100644 --- a/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py +++ b/plugins/flytekit-vaex/flytekitplugins/vaex/sd_transformers.py @@ -30,7 +30,7 @@ def encode( local_dir = ctx.file_access.get_random_local_directory() local_path = os.path.join(local_dir, f"{0:05}") df.export_parquet(local_path) - path = ctx.file_access.async_put_raw_data(local_dir) + path = ctx.file_access.put_raw_data(local_dir) return literals.StructuredDataset( uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type), diff --git a/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py b/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py index a9aca16945..82b1d3b616 100644 --- a/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py +++ b/plugins/flytekit-whylogs/flytekitplugins/whylogs/schema.py @@ -30,7 +30,7 @@ def to_literal( ) -> Literal: local_dir = ctx.file_access.get_random_local_path() python_val.write(local_dir) - remote_path = ctx.file_access.async_put_raw_data(local_dir) + remote_path = ctx.file_access.put_raw_data(local_dir) return Literal(scalar=Scalar(blob=Blob(uri=remote_path, metadata=BlobMetadata(type=self._TYPE_INFO)))) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[DatasetProfileView]) -> T: From 97811458801d186ed8573b7003c3caf0e8c6eeb5 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 17 Oct 2024 12:40:22 -0700 Subject: [PATCH 05/15] remove more async put data Signed-off-by: Yee Hing Tong --- flytekit/extras/pytorch/checkpoint.py | 2 +- flytekit/extras/pytorch/native.py | 12 +++++----- flytekit/extras/sklearn/native.py | 2 +- flytekit/extras/tensorflow/model.py | 12 +++++----- flytekit/extras/tensorflow/record.py | 4 ++-- .../extras/pytorch/test_transformations.py | 24 +++++++++---------- 6 files changed, 28 insertions(+), 28 deletions(-) diff --git a/flytekit/extras/pytorch/checkpoint.py b/flytekit/extras/pytorch/checkpoint.py index 942795fd54..dfb21f5932 100644 --- a/flytekit/extras/pytorch/checkpoint.py +++ b/flytekit/extras/pytorch/checkpoint.py @@ -98,7 +98,7 @@ def to_literal( # save checkpoint to a file torch.save(to_save, local_path) - remote_path = ctx.file_access.async_put_raw_data(local_path) + remote_path = ctx.file_access.put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value( diff --git a/flytekit/extras/pytorch/native.py b/flytekit/extras/pytorch/native.py index dab6803f3c..dd895bc993 100644 --- a/flytekit/extras/pytorch/native.py +++ b/flytekit/extras/pytorch/native.py @@ -4,7 +4,7 @@ import torch from flytekit.core.context_manager import FlyteContext -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType @@ -12,7 +12,7 @@ T = TypeVar("T") -class PyTorchTypeTransformer(TypeTransformer[T]): +class PyTorchTypeTransformer(AsyncTypeTransformer[T]): def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( @@ -21,7 +21,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: ) ) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: T, @@ -44,17 +44,17 @@ def to_literal( # save pytorch tensor/module to a file torch.save(python_val, local_path) - remote_path = ctx.file_access.async_put_raw_data(local_path) + remote_path = await ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: try: uri = lv.scalar.blob.uri except AttributeError: TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") local_path = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(uri, local_path, is_multipart=False) + await ctx.file_access.async_get_data(uri, local_path, is_multipart=False) # cpu <-> gpu conversion if torch.cuda.is_available(): diff --git a/flytekit/extras/sklearn/native.py b/flytekit/extras/sklearn/native.py index 568e32558b..37426fdfa4 100644 --- a/flytekit/extras/sklearn/native.py +++ b/flytekit/extras/sklearn/native.py @@ -42,7 +42,7 @@ def to_literal( # save sklearn estimator to a file joblib.dump(python_val, local_path) - remote_path = ctx.file_access.async_put_raw_data(local_path) + remote_path = ctx.file_access.put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: diff --git a/flytekit/extras/tensorflow/model.py b/flytekit/extras/tensorflow/model.py index e5d08188d5..b9fbf24d4b 100644 --- a/flytekit/extras/tensorflow/model.py +++ b/flytekit/extras/tensorflow/model.py @@ -4,13 +4,13 @@ import tensorflow as tf from flytekit.core.context_manager import FlyteContext -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]): +class TensorFlowModelTransformer(AsyncTypeTransformer[tf.keras.Model]): TENSORFLOW_FORMAT = "TensorFlowModel" def __init__(self): @@ -24,7 +24,7 @@ def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType: ) ) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: tf.keras.Model, @@ -44,10 +44,10 @@ def to_literal( # save model in SavedModel format tf.keras.models.save_model(python_val, local_path) - remote_path = ctx.file_access.async_put_raw_data(local_path) + remote_path = await ctx.file_access.async_put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model] ) -> tf.keras.Model: try: @@ -56,7 +56,7 @@ def to_python_value( TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") local_path = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(uri, local_path, is_multipart=True) + await ctx.file_access.async_get_data(uri, local_path, is_multipart=True) # load model return tf.keras.models.load_model(local_path) diff --git a/flytekit/extras/tensorflow/record.py b/flytekit/extras/tensorflow/record.py index ff072f3af2..3e86b6b2ee 100644 --- a/flytekit/extras/tensorflow/record.py +++ b/flytekit/extras/tensorflow/record.py @@ -84,7 +84,7 @@ def to_literal( local_path = os.path.join(local_dir, "0000.tfrecord") with tf.io.TFRecordWriter(local_path) as writer: writer.write(python_val.SerializeToString()) - remote_path = ctx.file_access.async_put_raw_data(local_path) + remote_path = ctx.file_access.put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value( @@ -150,7 +150,7 @@ def to_literal( local_path = f"{local_dir}/part_{i}.tfrecord" with tf.io.TFRecordWriter(local_path) as writer: writer.write(val.SerializeToString()) - remote_path = ctx.file_access.async_put_raw_data(local_dir) + remote_path = ctx.file_access.put_raw_data(local_dir) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) def to_python_value( diff --git a/tests/flytekit/unit/extras/pytorch/test_transformations.py b/tests/flytekit/unit/extras/pytorch/test_transformations.py index 4bc81bb3a0..b0db1232f2 100644 --- a/tests/flytekit/unit/extras/pytorch/test_transformations.py +++ b/tests/flytekit/unit/extras/pytorch/test_transformations.py @@ -49,18 +49,18 @@ def test_get_literal_type(transformer, python_type, format): @pytest.mark.parametrize( "transformer,python_type,format,python_val", [ - ( - PyTorchTensorTransformer(), - torch.Tensor, - PyTorchTensorTransformer.PYTORCH_FORMAT, - torch.tensor([[1, 2], [3, 4]]), - ), - ( - PyTorchModuleTransformer(), - torch.nn.Module, - PyTorchModuleTransformer.PYTORCH_FORMAT, - torch.nn.Linear(2, 2), - ), + # ( + # PyTorchTensorTransformer(), + # torch.Tensor, + # PyTorchTensorTransformer.PYTORCH_FORMAT, + # torch.tensor([[1, 2], [3, 4]]), + # ), + # ( + # PyTorchModuleTransformer(), + # torch.nn.Module, + # PyTorchModuleTransformer.PYTORCH_FORMAT, + # torch.nn.Linear(2, 2), + # ), ( PyTorchCheckpointTransformer(), PyTorchCheckpoint, From f87ed1ebfa624567e079d35041269928ecd408be Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 17 Oct 2024 15:00:32 -0700 Subject: [PATCH 06/15] forgot to get async Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 80307fe60d..4031cb52b1 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -215,6 +215,7 @@ def get_filesystem( def get_async_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> AsyncFileSystem: protocol = get_protocol(path) + return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, **kwargs) def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: @@ -289,7 +290,7 @@ def exists(self, path: str) -> bool: @retry_request async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): - file_system = self.get_filesystem_for_path(from_path) + file_system = self.get_async_filesystem_for_path(from_path) if recursive: from_path, to_path = self.recursive_paths(from_path, to_path) try: From 4dda739356617457e306fc60bdfe1bb1528c1d92 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 17 Oct 2024 17:16:29 -0700 Subject: [PATCH 07/15] pass along kwargs, add loop to create filesystem, add private download function Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 24 ++++++--- flytekit/types/file/file.py | 3 ++ tests/flytekit/unit/core/test_data.py | 75 +++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 4031cb52b1..e0e5ff7ef4 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -18,6 +18,7 @@ """ +import asyncio import io import os import pathlib @@ -210,13 +211,15 @@ def get_filesystem( storage_options = get_fsspec_storage_options( protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs ) + kwargs.update(storage_options) - return fsspec.filesystem(protocol, **storage_options) + return fsspec.filesystem(protocol, **kwargs) def get_async_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> AsyncFileSystem: protocol = get_protocol(path) + loop = asyncio.get_event_loop() - return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, **kwargs) + return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, loop=loop, **kwargs) def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) @@ -301,7 +304,10 @@ async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwa self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True ) logger.info(f"Getting {from_path} to {to_path}") - dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs) + if isinstance(file_system, AsyncFileSystem): + dst = await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 + else: + dst = file_system.get(from_path, to_path, recursive=recursive, **kwargs) if isinstance(dst, (str, pathlib.Path)): return dst return to_path @@ -309,10 +315,13 @@ async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwa logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") if not file_system.exists(from_path): raise FlyteDataNotFoundException(from_path) - file_system = self.get_filesystem(get_protocol(from_path), anonymous=True) + file_system = self.get_filesystem(get_protocol(from_path), anonymous=True, asynchronous=True) if file_system is not None: logger.debug(f"Attempting anonymous get with {file_system}") - return file_system.get(from_path, to_path, recursive=recursive, **kwargs) + if isinstance(file_system, AsyncFileSystem): + return await file_system._get(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 + else: + return file_system.get(from_path, to_path, recursive=recursive, **kwargs) raise oe @retry_request @@ -605,7 +614,10 @@ async def async_get_data(self, remote_path: str, local_path: str, is_multipart: f"Original exception: {str(ex)}" ) - get_data = loop_manager.synced(async_get_data) + # get_data = loop_manager.synced(async_get_data) + + def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs): + loop_manager.run_sync(self.async_get_data, remote_path, local_path, is_multipart, **kwargs) async def async_put_data( self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False, **kwargs diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index e952cc0f15..bf08a1f535 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -309,6 +309,9 @@ def remote_source(self) -> str: def download(self) -> str: return self.__fspath__() + async def _download(self) -> str: + return self.__fspath__() + @contextmanager def open( self, diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index b0f1a1200b..bb731919aa 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -4,6 +4,7 @@ import tempfile from uuid import UUID import typing +import asyncio import fsspec import mock import pytest @@ -17,6 +18,7 @@ from flytekit.types.file import FlyteFile from flytekit.utils.asyn import loop_manager from flytekit.models.literals import Literal +from flytekit.utils.asyn import run_sync local = fsspec.filesystem("file") root = os.path.abspath(os.sep) @@ -494,3 +496,76 @@ def test_async_local_copy_to_s3(): print(f"Time taken: {end_time - start_time}") print(f"Wall time taken: {end_wall_time - start_wall_time}") print(f"Process time taken: {end_process_time - start_process_time}") + + +async def download_files(ffs: typing.List[FlyteFile]): + futures = [asyncio.create_task(ff._download()) for ff in ffs] + return await asyncio.gather(*futures, return_exceptions=True) + + +@pytest.mark.sandbox_test +def test_async_download_from_s3(): + import time + import datetime + + f1 = "/Users/ytong/go/src/github.com/unionai/debugyt/user/ytong/src/yt_dbg/fr/rand.file" + f2 = "/Users/ytong/go/src/github.com/unionai/debugyt/user/ytong/src/yt_dbg/fr/rand2.file" + f3 = "/Users/ytong/go/src/github.com/unionai/debugyt/user/ytong/src/yt_dbg/fr/rand3.file" + + ff1 = FlyteFile(path=f1) + ff2 = FlyteFile(path=f2) + ff3 = FlyteFile(path=f3) + ff = [ff1, ff2, ff3] + + ctx = FlyteContextManager.current_context() + dc = Config.for_sandbox().data_config + random_folder = UUID(int=random.getrandbits(64)).hex + raw_output = f"s3://my-s3-bucket/testing/upload_test/{random_folder}" + print(f"Uploading to {raw_output}") + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + + with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: + lit = TypeEngine.to_literal(ctx, ff, typing.List[FlyteFile], TypeEngine.to_literal_type(typing.List[FlyteFile])) + print(f"Literal is {lit}") + python_list = TypeEngine.to_python_value(ctx, lit, typing.List[FlyteFile]) + + print(f"Serial File list: {python_list}") + + start_time = datetime.datetime.now(datetime.timezone.utc) + start_wall_time = time.perf_counter() + start_process_time = time.process_time() + + for local_file in python_list: + print(f"Downloading {local_file.remote_source} to {local_file.path}") + local_file.download() + + end_time = datetime.datetime.now(datetime.timezone.utc) + end_wall_time = time.perf_counter() + end_process_time = time.process_time() + + print(f"Time taken (serial download): {end_time - start_time}") + print(f"Wall time taken (serial download): {end_wall_time - start_wall_time}") + print(f"Process time taken (serial download): {end_process_time - start_process_time}") + + with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: + lit = TypeEngine.to_literal(ctx, ff, typing.List[FlyteFile], TypeEngine.to_literal_type(typing.List[FlyteFile])) + print(f"Literal is {lit}") + python_list = TypeEngine.to_python_value(ctx, lit, typing.List[FlyteFile]) + + print(f"Async file list: {python_list}") + + start_time = datetime.datetime.now(datetime.timezone.utc) + start_wall_time = time.perf_counter() + start_process_time = time.process_time() + + res = run_sync(download_files, python_list) + print(f"Result is: {res}") + + end_time = datetime.datetime.now(datetime.timezone.utc) + end_wall_time = time.perf_counter() + end_process_time = time.process_time() + + print(f"Time taken (async): {end_time - start_time}") + print(f"Wall time taken (async): {end_wall_time - start_wall_time}") + print(f"Process time taken (async): {end_process_time - start_process_time}") + From 0ca6052b4e9b280dd0f8b0934a1e43e124bf93cf Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 18 Oct 2024 16:15:56 -0700 Subject: [PATCH 08/15] update comments, add test and make schema async Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 18 +++----------- flytekit/types/pickle/pickle.py | 4 ++++ flytekit/types/schema/types_pandas.py | 10 ++++---- .../unit/core/test_data_persistence.py | 17 +++++++++++++ .../extras/pytorch/test_transformations.py | 24 +++++++++---------- 5 files changed, 41 insertions(+), 32 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index e0e5ff7ef4..1b0be6b2bd 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -347,9 +347,6 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw if "metadata" not in kwargs: kwargs["metadata"] = {} kwargs["metadata"].update(self._execution_metadata) - """ - Need to check here for async fs or sync - """ if isinstance(file_system, AsyncFileSystem): dst = await file_system._put(from_path, to_path, recursive=recursive, **kwargs) # pylint: disable=W0212 else: @@ -394,10 +391,6 @@ async def async_put_raw_data( :return: Returns the final path data was written to. """ # First figure out what the destination path should be, then call put. - """ - update maybe delete the get async file system function - Do this first and then make the local file system async - """ upload_prefix = self.get_random_string() if upload_prefix is None else upload_prefix to_path = self.join(self.raw_output_prefix, upload_prefix) if not skip_raw_data_prefix else upload_prefix if file_name: @@ -469,6 +462,7 @@ async def async_put_raw_data( raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}") + # Public synchronous version put_raw_data = loop_manager.synced(async_put_raw_data) @staticmethod @@ -614,10 +608,7 @@ async def async_get_data(self, remote_path: str, local_path: str, is_multipart: f"Original exception: {str(ex)}" ) - # get_data = loop_manager.synced(async_get_data) - - def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs): - loop_manager.run_sync(self.async_get_data, remote_path, local_path, is_multipart, **kwargs) + get_data = loop_manager.synced(async_get_data) async def async_put_data( self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False, **kwargs @@ -630,9 +621,6 @@ async def async_put_data( :param remote_path: :param is_multipart: """ - """ - write a test to confirm that a local path that's a folder is using async - """ try: local_path = str(local_path) with timeit(f"Upload data to {remote_path}"): @@ -650,7 +638,7 @@ async def async_put_data( f"Original exception: {str(ex)}" ) from ex - # Public synchronous version of async_put_data + # Public synchronous version put_data = loop_manager.synced(async_put_data) diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index a3aa93662a..20d732a83e 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -62,6 +62,10 @@ async def to_pickle(cls, ctx: FlyteContext, python_val: typing.Any) -> str: return await ctx.file_access.async_put_raw_data(uri) + """ + this also needs to be updated, or both not updated + """ + @classmethod def from_pickle(cls, uri: str) -> typing.Any: ctx = FlyteContextManager.current_context() diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index 38789eaebf..bff3572cfe 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -5,7 +5,7 @@ import pandas from flytekit import FlyteContext -from flytekit.core.type_engine import T, TypeEngine, TypeTransformer +from flytekit.core.type_engine import AsyncTypeTransformer, T, TypeEngine from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType from flytekit.types.schema import LocalIOSchemaReader, LocalIOSchemaWriter, SchemaEngine, SchemaFormat, SchemaHandler @@ -75,7 +75,7 @@ def _write(self, df: T, path: os.PathLike, **kwargs): return self._parquet_engine.write(df, to_file=path, **kwargs) -class PandasDataFrameTransformer(TypeTransformer[pandas.DataFrame]): +class PandasDataFrameTransformer(AsyncTypeTransformer[pandas.DataFrame]): """ Transforms a pd.DataFrame to Schema without column types. """ @@ -91,7 +91,7 @@ def _get_schema_type() -> SchemaType: def get_literal_type(self, t: Type[pandas.DataFrame]) -> LiteralType: return LiteralType(schema=self._get_schema_type()) - async def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: pandas.DataFrame, @@ -108,13 +108,13 @@ async def to_literal( remote_path = await ctx.file_access.async_put_data(local_dir, remote_path, is_multipart=True) return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type()))) - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[pandas.DataFrame] ) -> pandas.DataFrame: if not (lv and lv.scalar and lv.scalar.schema): return pandas.DataFrame() local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(lv.scalar.schema.uri, local_dir, is_multipart=True) + await ctx.file_access.async_get_data(lv.scalar.schema.uri, local_dir, is_multipart=True) r = PandasSchemaReader(local_dir=local_dir, cols=None, fmt=SchemaFormat.PARQUET) return r.all() diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index face41de0c..d992ed1fa5 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,5 +1,6 @@ import io import os +import fsspec import pathlib import random import string @@ -11,6 +12,7 @@ from azure.identity import ClientSecretCredential, DefaultAzureCredential from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.local_fsspec import FlyteLocalFileSystem def test_get_manual_random_remote_path(): @@ -190,3 +192,18 @@ def test_initialise_azure_file_provider_with_default_credential(): fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") assert fp.get_filesystem().account_name == "accountname" assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential) + + +def test_get_file_system(): + # Test that custom args are not swallowed by get_filesystem + + class MockFileSystem(FlyteLocalFileSystem): + def __init__(self, *args, **kwargs): + assert "test_arg" in kwargs + del kwargs["test_arg"] + super().__init__(*args, **kwargs) + + fsspec.register_implementation("testgetfs", MockFileSystem) + + fp = FileAccessProvider("/tmp", "s3://my-bucket") + fp.get_filesystem("testgetfs", test_arg="test_arg") diff --git a/tests/flytekit/unit/extras/pytorch/test_transformations.py b/tests/flytekit/unit/extras/pytorch/test_transformations.py index b0db1232f2..4bc81bb3a0 100644 --- a/tests/flytekit/unit/extras/pytorch/test_transformations.py +++ b/tests/flytekit/unit/extras/pytorch/test_transformations.py @@ -49,18 +49,18 @@ def test_get_literal_type(transformer, python_type, format): @pytest.mark.parametrize( "transformer,python_type,format,python_val", [ - # ( - # PyTorchTensorTransformer(), - # torch.Tensor, - # PyTorchTensorTransformer.PYTORCH_FORMAT, - # torch.tensor([[1, 2], [3, 4]]), - # ), - # ( - # PyTorchModuleTransformer(), - # torch.nn.Module, - # PyTorchModuleTransformer.PYTORCH_FORMAT, - # torch.nn.Linear(2, 2), - # ), + ( + PyTorchTensorTransformer(), + torch.Tensor, + PyTorchTensorTransformer.PYTORCH_FORMAT, + torch.tensor([[1, 2], [3, 4]]), + ), + ( + PyTorchModuleTransformer(), + torch.nn.Module, + PyTorchModuleTransformer.PYTORCH_FORMAT, + torch.nn.Linear(2, 2), + ), ( PyTorchCheckpointTransformer(), PyTorchCheckpoint, From c158c910719170df300226eed7a5985572e194ea Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 18 Oct 2024 16:17:38 -0700 Subject: [PATCH 09/15] make FlytePickle async, not sure Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 2 +- flytekit/types/pickle/pickle.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 34fe483c74..e4fdc50ca8 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -2064,7 +2064,7 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p from flytekit.types.pickle import FlytePickle uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file") - return FlytePickle.from_pickle(uri) + return await FlytePickle.from_pickle(uri) try: return json.loads(_json_format.MessageToJson(lv.scalar.generic)) diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 20d732a83e..7b4c99cae6 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -62,18 +62,14 @@ async def to_pickle(cls, ctx: FlyteContext, python_val: typing.Any) -> str: return await ctx.file_access.async_put_raw_data(uri) - """ - this also needs to be updated, or both not updated - """ - @classmethod - def from_pickle(cls, uri: str) -> typing.Any: + async def from_pickle(cls, uri: str) -> typing.Any: ctx = FlyteContextManager.current_context() # Deserialize the pickle, and return data in the pickle, # and download pickle file to local first if file is not in the local file systems. if ctx.file_access.is_remote(uri): local_path = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(uri, local_path, False) + await ctx.file_access.async_get_data(uri, local_path, False) uri = local_path with open(uri, "rb") as infile: data = cloudpickle.load(infile) @@ -92,7 +88,7 @@ def assert_type(self, t: Type[T], v: T): async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: uri = lv.scalar.blob.uri - return FlytePickle.from_pickle(uri) + return await FlytePickle.from_pickle(uri) async def async_to_literal( self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType From 1aaa94f229affd5c1364b511ee3bfb6bc381ed54 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 18 Oct 2024 16:18:17 -0700 Subject: [PATCH 10/15] lint Signed-off-by: Yee Hing Tong --- tests/flytekit/unit/core/test_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index bb731919aa..42e74f453c 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -568,4 +568,3 @@ def test_async_download_from_s3(): print(f"Time taken (async): {end_time - start_time}") print(f"Wall time taken (async): {end_wall_time - start_wall_time}") print(f"Process time taken (async): {end_process_time - start_process_time}") - From 0d9e298033f3a34626e38c299fb8cc8bb71dac96 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 18 Oct 2024 16:26:24 -0700 Subject: [PATCH 11/15] oops Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 1b0be6b2bd..9c1838defc 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -217,7 +217,7 @@ def get_filesystem( def get_async_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> AsyncFileSystem: protocol = get_protocol(path) - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() return self.get_filesystem(protocol, anonymous=anonymous, path=path, asynchronous=True, loop=loop, **kwargs) From 7641697ba72c63c76b5a88de01c1b6d01f2155e2 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 21 Oct 2024 17:39:02 -0700 Subject: [PATCH 12/15] make getting the filesystem to be async and await it Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 9c1838defc..751ffb8b27 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -215,7 +215,9 @@ def get_filesystem( return fsspec.filesystem(protocol, **kwargs) - def get_async_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> AsyncFileSystem: + async def get_async_filesystem_for_path( + self, path: str = "", anonymous: bool = False, **kwargs + ) -> Union[AsyncFileSystem, fsspec.AbstractFileSystem]: protocol = get_protocol(path) loop = asyncio.get_running_loop() @@ -293,7 +295,7 @@ def exists(self, path: str) -> bool: @retry_request async def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): - file_system = self.get_async_filesystem_for_path(from_path) + file_system = await self.get_async_filesystem_for_path(from_path) if recursive: from_path, to_path = self.recursive_paths(from_path, to_path) try: @@ -330,7 +332,7 @@ async def _put(self, from_path: str, to_path: str, recursive: bool = False, **kw More of an internal function to be called by put_data and put_raw_data This does not need a separate sync function. """ - file_system = self.get_async_filesystem_for_path(to_path) + file_system = await self.get_async_filesystem_for_path(to_path) from_path = self.strip_file_header(from_path) if recursive: # Only check this for the local filesystem @@ -419,7 +421,7 @@ async def async_put_raw_data( # raw bytes if isinstance(lpath, bytes): - fs = self.get_async_filesystem_for_path(to_path) + fs = await self.get_async_filesystem_for_path(to_path) if isinstance(fs, AsyncFileSystem): async with fs.open_async(to_path, "wb", **kwargs) as s: s.write(lpath) @@ -433,7 +435,7 @@ async def async_put_raw_data( if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = self.get_async_filesystem_for_path(to_path) + fs = await self.get_async_filesystem_for_path(to_path) lpath.seek(0) if isinstance(fs, AsyncFileSystem): async with fs.open_async(to_path, "wb", **kwargs) as s: @@ -448,7 +450,7 @@ async def async_put_raw_data( if isinstance(lpath, io.StringIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = self.get_async_filesystem_for_path(to_path) + fs = await self.get_async_filesystem_for_path(to_path) lpath.seek(0) if isinstance(fs, AsyncFileSystem): async with fs.open_async(to_path, "wb", **kwargs) as s: From 540a8d22c0d3d568e682c70bb3163bec6fd50da5 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 22 Oct 2024 09:47:44 -0700 Subject: [PATCH 13/15] try reverting pytorch change Signed-off-by: Yee Hing Tong --- flytekit/extras/pytorch/native.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flytekit/extras/pytorch/native.py b/flytekit/extras/pytorch/native.py index dd895bc993..4afce9aa4b 100644 --- a/flytekit/extras/pytorch/native.py +++ b/flytekit/extras/pytorch/native.py @@ -4,7 +4,7 @@ import torch from flytekit.core.context_manager import FlyteContext -from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType @@ -12,7 +12,7 @@ T = TypeVar("T") -class PyTorchTypeTransformer(AsyncTypeTransformer[T]): +class PyTorchTypeTransformer(TypeTransformer[T]): def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( @@ -21,7 +21,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: ) ) - async def async_to_literal( + def to_literal( self, ctx: FlyteContext, python_val: T, @@ -44,17 +44,17 @@ async def async_to_literal( # save pytorch tensor/module to a file torch.save(python_val, local_path) - remote_path = await ctx.file_access.async_put_raw_data(local_path) + remote_path = ctx.file_access.put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) - async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: try: uri = lv.scalar.blob.uri except AttributeError: TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") local_path = ctx.file_access.get_random_local_path() - await ctx.file_access.async_get_data(uri, local_path, is_multipart=False) + ctx.file_access.get_data(uri, local_path, is_multipart=False) # cpu <-> gpu conversion if torch.cuda.is_available(): From 1b9a74eca2d202eb3490ca940156b09533c908da Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 22 Oct 2024 14:52:21 -0700 Subject: [PATCH 14/15] runners need to be process specific Signed-off-by: Yee Hing Tong --- flytekit/utils/asyn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/utils/asyn.py b/flytekit/utils/asyn.py index c447db052f..d1edb67436 100644 --- a/flytekit/utils/asyn.py +++ b/flytekit/utils/asyn.py @@ -82,7 +82,7 @@ def run_sync(self, coro_func: Callable[..., Awaitable[T]], *args, **kwargs) -> T """ This should be called from synchronous functions to run an async function. """ - name = threading.current_thread().name + name = threading.current_thread().name + f"PID:{os.getpid()}" coro = coro_func(*args, **kwargs) if name not in self._runner_map: if len(self._runner_map) > 500: From 4a8c936ca20d9bfe41e68b7690413019cd078126 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 22 Oct 2024 14:56:40 -0700 Subject: [PATCH 15/15] add test Signed-off-by: Yee Hing Tong --- tests/flytekit/unit/utils/test_asyn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/flytekit/unit/utils/test_asyn.py b/tests/flytekit/unit/utils/test_asyn.py index db74ac6f53..b8ce75b2a7 100644 --- a/tests/flytekit/unit/utils/test_asyn.py +++ b/tests/flytekit/unit/utils/test_asyn.py @@ -1,3 +1,4 @@ +import os import threading import pytest import asyncio @@ -116,3 +117,8 @@ def test_recursive_calling(): main_ctx.vals["depth"] = 0 assert res == "world" sync_function(6, 6) + + # Check to make sure that the names of the runners have the PID in them. This make the loop manager work with + # things like pytorch elastic. + for k in loop_manager._runner_map.keys(): + assert str(os.getpid()) in k