Skip to content

Commit

Permalink
rev#2: address mypy violations
Browse files Browse the repository at this point in the history
  • Loading branch information
matthieu-d4r committed Oct 22, 2024
1 parent 87c01a1 commit 83370c0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
21 changes: 16 additions & 5 deletions s3torchconnector/src/s3torchconnector/dcp/fsdp_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,20 @@
from s3torchconnectorclient._mountpoint_s3_client import S3Exception

logger = logging.getLogger(__name__)
Mode = Literal["wb", "rb"]


class S3FileSystem(FileSystemBase):
class S3FileSystem(FileSystem):
def __init__(self, region: str, s3_client: Optional[S3Client] = None) -> None:
self.path = None
self.path: Union[str, os.PathLike] = ""
self.region = region
self.client = s3_client if s3_client is not None else S3Client(region)
self.checkpoint = S3Checkpoint(region)

@override
@contextmanager
def create_stream(
self, path: Union[str, os.PathLike], mode: Mode
) -> Generator[io.BufferedIOBase, None, None]:
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
"""
Create a stream for reading or writing to S3.
Expand All @@ -48,6 +47,8 @@ def create_stream(
Raises:
ValueError: If the mode is not 'rb' or 'wb'.
"""
path = _path_to_str_or_pathlike(path)

if mode == "wb": # write mode
logger.debug("create_stream writable for %s", path)
with self.checkpoint.writer(path) as stream:
Expand Down Expand Up @@ -108,6 +109,10 @@ def rename(
S3Exception: If there is an error with the S3 client.
"""
logger.debug("rename %s to %s", old_path, new_path)

old_path = _path_to_str_or_pathlike(old_path)
new_path = _path_to_str_or_pathlike(new_path)

bucket_name, old_key = parse_s3_uri(old_path)
_, new_key = parse_s3_uri(new_path)

Expand All @@ -128,6 +133,7 @@ def mkdir(self, path: Union[str, os.PathLike]) -> None:
def exists(self, path: Union[str, os.PathLike]) -> bool:
logger.debug("exists %s", path)

path = _path_to_str_or_pathlike(path)
bucket, key = parse_s3_uri(path)
try:
self.client.head_object(bucket, key)
Expand All @@ -140,6 +146,7 @@ def exists(self, path: Union[str, os.PathLike]) -> bool:
def rm_file(self, path: Union[str, os.PathLike]) -> None:
logger.debug("remove %s", path)

path = _path_to_str_or_pathlike(path)
bucket, key = parse_s3_uri(path)
try:
self.client.delete_object(bucket, key)
Expand Down Expand Up @@ -215,3 +222,7 @@ def __init__(self, region: str, path: Union[str, os.PathLike]) -> None:
@classmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return S3FileSystem.validate_checkpoint_id(checkpoint_id)


def _path_to_str_or_pathlike(path: Union[str, os.PathLike]) -> str:
return path if isinstance(path, str) else str(path)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD

import pickle
from typing import List, Optional

Expand Down Expand Up @@ -54,6 +55,6 @@ def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
return plans

def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: # type: ignore
# TODO: Check expected bucket, prefix etc. in metadata
pass
12 changes: 7 additions & 5 deletions s3torchconnector/src/s3torchconnector/dcp/s3_storage_writer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD

import dataclasses
import pickle
import queue
import threading
import os
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict, Any
import logging


from attr import dataclass
from packaging.metadata import Metadata
from torch.distributed.checkpoint import StorageWriter, SavePlan, SavePlanner
from torch.distributed.checkpoint import Metadata, StorageWriter, SavePlan, SavePlanner
from torch.distributed.checkpoint.filesystem import _split_by_size_and_type
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future

Expand All @@ -37,7 +38,7 @@ def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
# TODO: add implementation
pass

def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: # type: ignore
# TODO: add implementation
pass

Expand Down Expand Up @@ -82,6 +83,7 @@ def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
for i, plan in enumerate(plans)
]
return new_plans
return plans

def write_data(
self, plan: SavePlan, planner: SavePlanner
Expand Down Expand Up @@ -138,7 +140,7 @@ def gen_object_key():
def finish(self, metadata: Metadata, results: List[List[WriteResult]]):
if self.is_coordinator:
# Save metadata from coordinator node
s3_storage_metadata = dict()
s3_storage_metadata: Dict[Union[str, MetadataIndex], Union[Any, str]] = {}
for wr_list in results:
s3_storage_metadata.update(
{wr.index: wr.storage_data for wr in wr_list}
Expand Down

0 comments on commit 83370c0

Please sign in to comment.