diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index d5ee4d2134a..99bd347942a 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -24,6 +24,10 @@ Available fields and semantics: # # Ref: https://docs.skypilot.co/en/latest/examples/managed-jobs.html#customizing-job-controller-resources jobs: + # Bucket to store managed jobs mount files and tmp files. Bucket must already exist. + # Optional. If not set, SkyPilot will create a new bucket for each managed job launch. + # Supports s3://, gs://, https://.blob.core.windows.net/, r2://, cos:/// + bucket: s3://my-bucket/ controller: resources: # same spec as 'resources' in a task YAML cloud: gcp diff --git a/sky/data/mounting_utils.py b/sky/data/mounting_utils.py index b713a1f1cc5..d2a95a3c20b 100644 --- a/sky/data/mounting_utils.py +++ b/sky/data/mounting_utils.py @@ -31,12 +31,19 @@ def get_s3_mount_install_cmd() -> str: return install_cmd -def get_s3_mount_cmd(bucket_name: str, mount_path: str) -> str: +# pylint: disable=invalid-name +def get_s3_mount_cmd(bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount an S3 bucket using goofys.""" + if _bucket_sub_path is None: + _bucket_sub_path = '' + else: + _bucket_sub_path = f':{_bucket_sub_path}' mount_cmd = ('goofys -o allow_other ' f'--stat-cache-ttl {_STAT_CACHE_TTL} ' f'--type-cache-ttl {_TYPE_CACHE_TTL} ' - f'{bucket_name} {mount_path}') + f'{bucket_name}{_bucket_sub_path} {mount_path}') return mount_cmd @@ -50,15 +57,20 @@ def get_gcs_mount_install_cmd() -> str: return install_cmd -def get_gcs_mount_cmd(bucket_name: str, mount_path: str) -> str: +# pylint: disable=invalid-name +def get_gcs_mount_cmd(bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount a GCS bucket using gcsfuse.""" - + bucket_sub_path_arg = f'--only-dir {_bucket_sub_path} '\ + if _bucket_sub_path else '' mount_cmd = ('gcsfuse -o allow_other ' '--implicit-dirs ' f'--stat-cache-capacity {_STAT_CACHE_CAPACITY} ' f'--stat-cache-ttl {_STAT_CACHE_TTL} ' f'--type-cache-ttl {_TYPE_CACHE_TTL} ' f'--rename-dir-limit {_RENAME_DIR_LIMIT} ' + f'{bucket_sub_path_arg}' f'{bucket_name} {mount_path}') return mount_cmd @@ -79,10 +91,12 @@ def get_az_mount_install_cmd() -> str: return install_cmd +# pylint: disable=invalid-name def get_az_mount_cmd(container_name: str, storage_account_name: str, mount_path: str, - storage_account_key: Optional[str] = None) -> str: + storage_account_key: Optional[str] = None, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount an AZ Container using blobfuse2. Args: @@ -91,6 +105,7 @@ def get_az_mount_cmd(container_name: str, belongs to. mount_path: Path where the container will be mounting. storage_account_key: Access key for the given storage account. + _bucket_sub_path: Sub path of the mounting container. Returns: str: Command used to mount AZ container with blobfuse2. @@ -107,25 +122,38 @@ def get_az_mount_cmd(container_name: str, cache_path = _BLOBFUSE_CACHE_DIR.format( storage_account_name=storage_account_name, container_name=container_name) + if _bucket_sub_path is None: + bucket_sub_path_arg = '' + else: + bucket_sub_path_arg = f'--subdirectory={_bucket_sub_path}/ ' mount_cmd = (f'AZURE_STORAGE_ACCOUNT={storage_account_name} ' f'{key_env_var} ' f'blobfuse2 {mount_path} --allow-other --no-symlinks ' '-o umask=022 -o default_permissions ' f'--tmp-path {cache_path} ' + f'{bucket_sub_path_arg}' f'--container-name {container_name}') return mount_cmd -def get_r2_mount_cmd(r2_credentials_path: str, r2_profile_name: str, - endpoint_url: str, bucket_name: str, - mount_path: str) -> str: +# pylint: disable=invalid-name +def get_r2_mount_cmd(r2_credentials_path: str, + r2_profile_name: str, + endpoint_url: str, + bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to install R2 mount utility goofys.""" + if _bucket_sub_path is None: + _bucket_sub_path = '' + else: + _bucket_sub_path = f':{_bucket_sub_path}' mount_cmd = (f'AWS_SHARED_CREDENTIALS_FILE={r2_credentials_path} ' f'AWS_PROFILE={r2_profile_name} goofys -o allow_other ' f'--stat-cache-ttl {_STAT_CACHE_TTL} ' f'--type-cache-ttl {_TYPE_CACHE_TTL} ' f'--endpoint {endpoint_url} ' - f'{bucket_name} {mount_path}') + f'{bucket_name}{_bucket_sub_path} {mount_path}') return mount_cmd @@ -137,9 +165,12 @@ def get_cos_mount_install_cmd() -> str: return install_cmd -def get_cos_mount_cmd(rclone_config_data: str, rclone_config_path: str, - bucket_rclone_profile: str, bucket_name: str, - mount_path: str) -> str: +def get_cos_mount_cmd(rclone_config_data: str, + rclone_config_path: str, + bucket_rclone_profile: str, + bucket_name: str, + mount_path: str, + _bucket_sub_path: Optional[str] = None) -> str: """Returns a command to mount an IBM COS bucket using rclone.""" # creates a fusermount soft link on older (<22) Ubuntu systems for # rclone's mount utility. @@ -151,10 +182,14 @@ def get_cos_mount_cmd(rclone_config_data: str, rclone_config_path: str, 'mkdir -p ~/.config/rclone/ && ' f'echo "{rclone_config_data}" >> ' f'{rclone_config_path}') + if _bucket_sub_path is None: + sub_path_arg = f'{bucket_name}/{_bucket_sub_path}' + else: + sub_path_arg = f'/{bucket_name}' # --daemon will keep the mounting process running in the background. mount_cmd = (f'{configure_rclone_profile} && ' 'rclone mount ' - f'{bucket_rclone_profile}:{bucket_name} {mount_path} ' + f'{bucket_rclone_profile}:{sub_path_arg} {mount_path} ' '--daemon') return mount_cmd @@ -252,7 +287,7 @@ def get_mounting_script( script = textwrap.dedent(f""" #!/usr/bin/env bash set -e - + {command_runner.ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD} MOUNT_PATH={mount_path} diff --git a/sky/data/storage.py b/sky/data/storage.py index 188c97b9545..018cb2797ca 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -200,6 +200,45 @@ def get_endpoint_url(cls, store: 'AbstractStore', path: str) -> str: bucket_endpoint_url = f'{store_type.store_prefix()}{path}' return bucket_endpoint_url + @classmethod + def get_fields_from_store_url( + cls, store_url: str + ) -> Tuple['StoreType', Type['AbstractStore'], str, str, Optional[str], + Optional[str]]: + """Returns the store type, store class, bucket name, and sub path from + a store URL, and the storage account name and region if applicable. + + Args: + store_url: str; The store URL. + """ + # The full path from the user config of IBM COS contains the region, + # and Azure Blob Storage contains the storage account name, we need to + # pass these information to the store constructor. + storage_account_name = None + region = None + for store_type in StoreType: + if store_url.startswith(store_type.store_prefix()): + if store_type == StoreType.AZURE: + storage_account_name, bucket_name, sub_path = \ + data_utils.split_az_path(store_url) + store_cls: Type['AbstractStore'] = AzureBlobStore + elif store_type == StoreType.IBM: + bucket_name, sub_path, region = data_utils.split_cos_path( + store_url) + store_cls = IBMCosStore + elif store_type == StoreType.R2: + bucket_name, sub_path = data_utils.split_r2_path(store_url) + store_cls = R2Store + elif store_type == StoreType.GCS: + bucket_name, sub_path = data_utils.split_gcs_path(store_url) + store_cls = GcsStore + elif store_type == StoreType.S3: + bucket_name, sub_path = data_utils.split_s3_path(store_url) + store_cls = S3Store + return store_type, store_cls,bucket_name, \ + sub_path, storage_account_name, region + raise ValueError(f'Unknown store URL: {store_url}') + class StorageMode(enum.Enum): MOUNT = 'MOUNT' @@ -226,25 +265,29 @@ def __init__(self, name: str, source: Optional[SourceType], region: Optional[str] = None, - is_sky_managed: Optional[bool] = None): + is_sky_managed: Optional[bool] = None, + _bucket_sub_path: Optional[str] = None): self.name = name self.source = source self.region = region self.is_sky_managed = is_sky_managed + self._bucket_sub_path = _bucket_sub_path def __repr__(self): return (f'StoreMetadata(' f'\n\tname={self.name},' f'\n\tsource={self.source},' f'\n\tregion={self.region},' - f'\n\tis_sky_managed={self.is_sky_managed})') + f'\n\tis_sky_managed={self.is_sky_managed},' + f'\n\t_bucket_sub_path={self._bucket_sub_path})') def __init__(self, name: str, source: Optional[SourceType], region: Optional[str] = None, is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): # pylint: disable=invalid-name """Initialize AbstractStore Args: @@ -258,7 +301,11 @@ def __init__(self, there. This is set to false when the Storage object is created not for direct use, e.g. for 'sky storage delete', or the storage is being re-used, e.g., for `sky start` on a stopped cluster. - + _bucket_sub_path: str; The prefix of the bucket directory to be + created in the store, e.g. if _bucket_sub_path=my-dir, the files + will be uploaded to s3:///my-dir/. + This only works if source is a local directory. + # TODO(zpoint): Add support for non-local source. Raises: StorageBucketCreateError: If bucket creation fails StorageBucketGetError: If fetching existing bucket fails @@ -269,10 +316,29 @@ def __init__(self, self.region = region self.is_sky_managed = is_sky_managed self.sync_on_reconstruction = sync_on_reconstruction + + # To avoid mypy error + self._bucket_sub_path: Optional[str] = None + # Trigger the setter to strip any leading/trailing slashes. + self.bucket_sub_path = _bucket_sub_path # Whether sky is responsible for the lifecycle of the Store. self._validate() self.initialize() + @property + def bucket_sub_path(self) -> Optional[str]: + """Get the bucket_sub_path.""" + return self._bucket_sub_path + + @bucket_sub_path.setter + # pylint: disable=invalid-name + def bucket_sub_path(self, bucket_sub_path: Optional[str]) -> None: + """Set the bucket_sub_path, stripping any leading/trailing slashes.""" + if bucket_sub_path is not None: + self._bucket_sub_path = bucket_sub_path.strip('/') + else: + self._bucket_sub_path = None + @classmethod def from_metadata(cls, metadata: StoreMetadata, **override_args): """Create a Store from a StoreMetadata object. @@ -280,19 +346,26 @@ def from_metadata(cls, metadata: StoreMetadata, **override_args): Used when reconstructing Storage and Store objects from global_user_state. """ - return cls(name=override_args.get('name', metadata.name), - source=override_args.get('source', metadata.source), - region=override_args.get('region', metadata.region), - is_sky_managed=override_args.get('is_sky_managed', - metadata.is_sky_managed), - sync_on_reconstruction=override_args.get( - 'sync_on_reconstruction', True)) + return cls( + name=override_args.get('name', metadata.name), + source=override_args.get('source', metadata.source), + region=override_args.get('region', metadata.region), + is_sky_managed=override_args.get('is_sky_managed', + metadata.is_sky_managed), + sync_on_reconstruction=override_args.get('sync_on_reconstruction', + True), + # backward compatibility + _bucket_sub_path=override_args.get( + '_bucket_sub_path', + metadata._bucket_sub_path # pylint: disable=protected-access + ) if hasattr(metadata, '_bucket_sub_path') else None) def get_metadata(self) -> StoreMetadata: return self.StoreMetadata(name=self.name, source=self.source, region=self.region, - is_sky_managed=self.is_sky_managed) + is_sky_managed=self.is_sky_managed, + _bucket_sub_path=self._bucket_sub_path) def initialize(self): """Initializes the Store object on the cloud. @@ -320,7 +393,11 @@ def upload(self) -> None: raise NotImplementedError def delete(self) -> None: - """Removes the Storage object from the cloud.""" + """Removes the Storage from the cloud.""" + raise NotImplementedError + + def _delete_sub_path(self) -> None: + """Removes objects from the sub path in the bucket.""" raise NotImplementedError def get_handle(self) -> StorageHandle: @@ -464,13 +541,19 @@ def remove_store(self, store: AbstractStore) -> None: if storetype in self.sky_stores: del self.sky_stores[storetype] - def __init__(self, - name: Optional[str] = None, - source: Optional[SourceType] = None, - stores: Optional[Dict[StoreType, AbstractStore]] = None, - persistent: Optional[bool] = True, - mode: StorageMode = StorageMode.MOUNT, - sync_on_reconstruction: bool = True) -> None: + def __init__( + self, + name: Optional[str] = None, + source: Optional[SourceType] = None, + stores: Optional[Dict[StoreType, AbstractStore]] = None, + persistent: Optional[bool] = True, + mode: StorageMode = StorageMode.MOUNT, + sync_on_reconstruction: bool = True, + # pylint: disable=invalid-name + _is_sky_managed: Optional[bool] = None, + # pylint: disable=invalid-name + _bucket_sub_path: Optional[str] = None + ) -> None: """Initializes a Storage object. Three fields are required: the name of the storage, the source @@ -508,6 +591,18 @@ def __init__(self, there. This is set to false when the Storage object is created not for direct use, e.g. for 'sky storage delete', or the storage is being re-used, e.g., for `sky start` on a stopped cluster. + _is_sky_managed: Optional[bool]; Indicates if the storage is managed + by Sky. Without this argument, the controller's behavior differs + from the local machine. For example, if a bucket does not exist: + Local Machine (is_sky_managed=True) → + Controller (is_sky_managed=False). + With this argument, the controller aligns with the local machine, + ensuring it retains the is_sky_managed information from the YAML. + During teardown, if is_sky_managed is True, the controller should + delete the bucket. Otherwise, it might mistakenly delete only the + sub-path, assuming is_sky_managed is False. + _bucket_sub_path: Optional[str]; The subdirectory to use for the + storage object. """ self.name: str self.source = source @@ -515,6 +610,8 @@ def __init__(self, self.mode = mode assert mode in StorageMode self.sync_on_reconstruction = sync_on_reconstruction + self._is_sky_managed = _is_sky_managed + self._bucket_sub_path = _bucket_sub_path # TODO(romilb, zhwu): This is a workaround to support storage deletion # for spot. Once sky storage supports forced management for external @@ -577,6 +674,12 @@ def __init__(self, elif self.source.startswith('oci://'): self.add_store(StoreType.OCI) + def get_bucket_sub_path_prefix(self, blob_path: str) -> str: + """Adds the bucket sub path prefix to the blob path.""" + if self._bucket_sub_path is not None: + return f'{blob_path}/{self._bucket_sub_path}' + return blob_path + @staticmethod def _validate_source( source: SourceType, mode: StorageMode, @@ -787,34 +890,40 @@ def _add_store_from_metadata( store = S3Store.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.GCS: store = GcsStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.AZURE: assert isinstance(s_metadata, AzureBlobStore.AzureBlobStoreMetadata) store = AzureBlobStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.R2: store = R2Store.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.IBM: store = IBMCosStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) elif s_type == StoreType.OCI: store = OciStore.from_metadata( s_metadata, source=self.source, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + _bucket_sub_path=self._bucket_sub_path) else: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Unknown store type: {s_type}') @@ -834,7 +943,6 @@ def _add_store_from_metadata( 'to be reconstructed while the corresponding ' 'bucket was externally deleted.') continue - self._add_store(store, is_reconstructed=True) @classmethod @@ -890,6 +998,7 @@ def add_store(self, f'storage account {storage_account_name!r}.') else: logger.info(f'Storage type {store_type} already exists.') + return self.stores[store_type] store_cls: Type[AbstractStore] @@ -909,21 +1018,24 @@ def add_store(self, with ux_utils.print_exception_no_traceback(): raise exceptions.StorageSpecError( f'{store_type} not supported as a Store.') - - # Initialize store object and get/create bucket try: store = store_cls( name=self.name, source=self.source, region=region, - sync_on_reconstruction=self.sync_on_reconstruction) + sync_on_reconstruction=self.sync_on_reconstruction, + is_sky_managed=self._is_sky_managed, + _bucket_sub_path=self._bucket_sub_path) except exceptions.StorageBucketCreateError: # Creation failed, so this must be sky managed store. Add failure # to state. logger.error(f'Could not create {store_type} store ' f'with name {self.name}.') - global_user_state.set_storage_status(self.name, - StorageStatus.INIT_FAILED) + try: + global_user_state.set_storage_status(self.name, + StorageStatus.INIT_FAILED) + except ValueError as e: + logger.error(f'Error setting storage status: {e}') raise except exceptions.StorageBucketGetError: # Bucket get failed, so this is not sky managed. Do not update state @@ -1039,12 +1151,15 @@ def warn_for_git_dir(source: str): def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': common_utils.validate_schema(config, schemas.get_storage_schema(), 'Invalid storage YAML: ') - name = config.pop('name', None) source = config.pop('source', None) store = config.pop('store', None) mode_str = config.pop('mode', None) force_delete = config.pop('_force_delete', None) + # pylint: disable=invalid-name + _is_sky_managed = config.pop('_is_sky_managed', None) + # pylint: disable=invalid-name + _bucket_sub_path = config.pop('_bucket_sub_path', None) if force_delete is None: force_delete = False @@ -1064,7 +1179,9 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': storage_obj = cls(name=name, source=source, persistent=persistent, - mode=mode) + mode=mode, + _is_sky_managed=_is_sky_managed, + _bucket_sub_path=_bucket_sub_path) if store is not None: storage_obj.add_store(StoreType(store.upper())) @@ -1072,7 +1189,7 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': storage_obj.force_delete = force_delete return storage_obj - def to_yaml_config(self) -> Dict[str, str]: + def to_yaml_config(self) -> Dict[str, Any]: config = {} def add_if_not_none(key: str, value: Optional[Any]): @@ -1088,13 +1205,18 @@ def add_if_not_none(key: str, value: Optional[Any]): add_if_not_none('source', self.source) stores = None + is_sky_managed = self._is_sky_managed if self.stores: stores = ','.join([store.value for store in self.stores]) + is_sky_managed = list(self.stores.values())[0].is_sky_managed add_if_not_none('store', stores) + add_if_not_none('_is_sky_managed', is_sky_managed) add_if_not_none('persistent', self.persistent) add_if_not_none('mode', self.mode.value) if self.force_delete: config['_force_delete'] = True + if self._bucket_sub_path is not None: + config['_bucket_sub_path'] = self._bucket_sub_path return config @@ -1116,7 +1238,8 @@ def __init__(self, source: str, region: Optional[str] = _DEFAULT_REGION, is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: bool = True): + sync_on_reconstruction: bool = True, + _bucket_sub_path: Optional[str] = None): self.client: 'boto3.client.Client' self.bucket: 'StorageHandle' # TODO(romilb): This is purely a stopgap fix for @@ -1129,7 +1252,7 @@ def __init__(self, f'{self._DEFAULT_REGION} for bucket {name!r}.') region = self._DEFAULT_REGION super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) def _validate(self): if self.source is not None and isinstance(self.source, str): @@ -1293,6 +1416,9 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_s3_bucket(self.name) if deleted_by_skypilot: msg_str = f'Deleted S3 bucket {self.name}.' @@ -1302,6 +1428,19 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + deleted_by_skypilot = self._delete_s3_bucket_sub_path( + self.name, self._bucket_sub_path) + if deleted_by_skypilot: + msg_str = f'Removed objects from S3 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + else: + msg_str = f'Failed to remove objects from S3 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + def get_handle(self) -> StorageHandle: return aws.resource('s3').Bucket(self.name) @@ -1332,9 +1471,11 @@ def get_file_sync_command(base_dir_path, file_names): for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ('aws s3 sync --no-follow-symlinks --exclude="*" ' f'{includes} {base_dir_path} ' - f's3://{self.name}') + f's3://{self.name}{sub_path}') return sync_command def get_dir_sync_command(src_dir_path, dest_dir_name): @@ -1346,9 +1487,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): for file_name in excluded_list ]) src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'aws s3 sync --no-follow-symlinks {excludes} ' f'{src_dir_path} ' - f's3://{self.name}/{dest_dir_name}') + f's3://{self.name}{sub_path}/{dest_dir_name}') return sync_command # Generate message for upload @@ -1466,7 +1609,8 @@ def mount_command(self, mount_path: str) -> str: """ install_cmd = mounting_utils.get_s3_mount_install_cmd() mount_cmd = mounting_utils.get_s3_mount_cmd(self.bucket.name, - mount_path) + mount_path, + self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -1516,6 +1660,27 @@ def _create_s3_bucket(self, ) from e return aws.resource('s3').Bucket(bucket_name) + def _execute_s3_remove_command(self, command: str, bucket_name: str, + hint_operating: str, + hint_failed: str) -> bool: + try: + with rich_utils.safe_status( + ux_utils.spinner_message(hint_operating)): + subprocess.check_output(command.split(' '), + stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + if 'NoSuchBucket' in e.output.decode('utf-8'): + logger.debug( + _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( + bucket_name=bucket_name)) + return False + else: + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'{hint_failed}' + f'Detailed error: {e.output}') + return True + def _delete_s3_bucket(self, bucket_name: str) -> bool: """Deletes S3 bucket, including all objects in bucket @@ -1533,29 +1698,28 @@ def _delete_s3_bucket(self, bucket_name: str) -> bool: # The fastest way to delete is to run `aws s3 rb --force`, # which removes the bucket by force. remove_command = f'aws s3 rb s3://{bucket_name} --force' - try: - with rich_utils.safe_status( - ux_utils.spinner_message( - f'Deleting S3 bucket [green]{bucket_name}')): - subprocess.check_output(remove_command.split(' '), - stderr=subprocess.STDOUT) - except subprocess.CalledProcessError as e: - if 'NoSuchBucket' in e.output.decode('utf-8'): - logger.debug( - _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( - bucket_name=bucket_name)) - return False - else: - with ux_utils.print_exception_no_traceback(): - raise exceptions.StorageBucketDeleteError( - f'Failed to delete S3 bucket {bucket_name}.' - f'Detailed error: {e.output}') + success = self._execute_s3_remove_command( + remove_command, bucket_name, + f'Deleting S3 bucket [green]{bucket_name}[/]', + f'Failed to delete S3 bucket {bucket_name}.') + if not success: + return False # Wait until bucket deletion propagates on AWS servers while data_utils.verify_s3_bucket(bucket_name): time.sleep(0.1) return True + def _delete_s3_bucket_sub_path(self, bucket_name: str, + sub_path: str) -> bool: + """Deletes the sub path from the bucket.""" + remove_command = f'aws s3 rm s3://{bucket_name}/{sub_path}/ --recursive' + return self._execute_s3_remove_command( + remove_command, bucket_name, f'Removing objects from S3 bucket ' + f'[green]{bucket_name}/{sub_path}[/]', + f'Failed to remove objects from S3 bucket {bucket_name}/{sub_path}.' + ) + class GcsStore(AbstractStore): """GcsStore inherits from Storage Object and represents the backend @@ -1569,11 +1733,12 @@ def __init__(self, source: str, region: Optional[str] = 'us-central1', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): self.client: 'storage.Client' self.bucket: StorageHandle super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) def _validate(self): if self.source is not None and isinstance(self.source, str): @@ -1736,6 +1901,9 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_gcs_bucket(self.name) if deleted_by_skypilot: msg_str = f'Deleted GCS bucket {self.name}.' @@ -1745,6 +1913,19 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + deleted_by_skypilot = self._delete_gcs_bucket(self.name, + self._bucket_sub_path) + if deleted_by_skypilot: + msg_str = f'Deleted objects in GCS bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + else: + msg_str = f'GCS bucket {self.name} may have ' \ + 'been deleted externally.' + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + def get_handle(self) -> StorageHandle: return self.client.get_bucket(self.name) @@ -1818,9 +1999,11 @@ def get_file_sync_command(base_dir_path, file_names): sync_format = '|'.join(file_names) gsutil_alias, alias_gen = data_utils.get_gsutil_command() base_dir_path = shlex.quote(base_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -x \'^(?!{sync_format}$).*\' ' - f'{base_dir_path} gs://{self.name}') + f'{base_dir_path} gs://{self.name}{sub_path}') return sync_command def get_dir_sync_command(src_dir_path, dest_dir_name): @@ -1830,9 +2013,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): excludes = '|'.join(excluded_list) gsutil_alias, alias_gen = data_utils.get_gsutil_command() src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = (f'{alias_gen}; {gsutil_alias} ' f'rsync -e -r -x \'({excludes})\' {src_dir_path} ' - f'gs://{self.name}/{dest_dir_name}') + f'gs://{self.name}{sub_path}/{dest_dir_name}') return sync_command # Generate message for upload @@ -1937,7 +2122,8 @@ def mount_command(self, mount_path: str) -> str: """ install_cmd = mounting_utils.get_gcs_mount_install_cmd() mount_cmd = mounting_utils.get_gcs_mount_cmd(self.bucket.name, - mount_path) + mount_path, + self._bucket_sub_path) version_check_cmd = ( f'gcsfuse --version | grep -q {mounting_utils.GCSFUSE_VERSION}') return mounting_utils.get_mounting_command(mount_path, install_cmd, @@ -1977,19 +2163,33 @@ def _create_gcs_bucket(self, f'{new_bucket.storage_class}{colorama.Style.RESET_ALL}') return new_bucket - def _delete_gcs_bucket(self, bucket_name: str) -> bool: - """Deletes GCS bucket, including all objects in bucket + def _delete_gcs_bucket( + self, + bucket_name: str, + # pylint: disable=invalid-name + _bucket_sub_path: Optional[str] = None + ) -> bool: + """Deletes objects in GCS bucket Args: bucket_name: str; Name of bucket + _bucket_sub_path: str; Sub path in the bucket, if provided only + objects in the sub path will be deleted, else the whole bucket will + be deleted Returns: bool; True if bucket was deleted, False if it was deleted externally. """ - + if _bucket_sub_path is not None: + command_suffix = f'/{_bucket_sub_path}' + hint_text = 'objects in ' + else: + command_suffix = '' + hint_text = '' with rich_utils.safe_status( ux_utils.spinner_message( - f'Deleting GCS bucket [green]{bucket_name}')): + f'Deleting {hint_text}GCS bucket ' + f'[green]{bucket_name}{command_suffix}[/]')): try: self.client.get_bucket(bucket_name) except gcp.forbidden_exception() as e: @@ -2007,8 +2207,9 @@ def _delete_gcs_bucket(self, bucket_name: str) -> bool: return False try: gsutil_alias, alias_gen = data_utils.get_gsutil_command() - remove_obj_command = (f'{alias_gen};{gsutil_alias} ' - f'rm -r gs://{bucket_name}') + remove_obj_command = ( + f'{alias_gen};{gsutil_alias} ' + f'rm -r gs://{bucket_name}{command_suffix}') subprocess.check_output(remove_obj_command, stderr=subprocess.STDOUT, shell=True, @@ -2017,7 +2218,8 @@ def _delete_gcs_bucket(self, bucket_name: str) -> bool: except subprocess.CalledProcessError as e: with ux_utils.print_exception_no_traceback(): raise exceptions.StorageBucketDeleteError( - f'Failed to delete GCS bucket {bucket_name}.' + f'Failed to delete {hint_text}GCS bucket ' + f'{bucket_name}{command_suffix}.' f'Detailed error: {e.output}') @@ -2069,7 +2271,8 @@ def __init__(self, storage_account_name: str = '', region: Optional[str] = 'eastus', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: bool = True): + sync_on_reconstruction: bool = True, + _bucket_sub_path: Optional[str] = None): self.storage_client: 'storage.Client' self.resource_client: 'storage.Client' self.container_name: str @@ -2081,7 +2284,7 @@ def __init__(self, if region is None: region = 'eastus' super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) @classmethod def from_metadata(cls, metadata: AbstractStore.StoreMetadata, @@ -2231,6 +2434,17 @@ def initialize(self): """ self.storage_client = data_utils.create_az_client('storage') self.resource_client = data_utils.create_az_client('resource') + self._update_storage_account_name_and_resource() + + self.container_name, is_new_bucket = self._get_bucket() + if self.is_sky_managed is None: + # If is_sky_managed is not specified, then this is a new storage + # object (i.e., did not exist in global_user_state) and we should + # set the is_sky_managed property. + # If is_sky_managed is specified, then we take no action. + self.is_sky_managed = is_new_bucket + + def _update_storage_account_name_and_resource(self): self.storage_account_name, self.resource_group_name = ( self._get_storage_account_and_resource_group()) @@ -2241,13 +2455,13 @@ def initialize(self): self.storage_account_name, self.resource_group_name, self.storage_client, self.resource_client) - self.container_name, is_new_bucket = self._get_bucket() - if self.is_sky_managed is None: - # If is_sky_managed is not specified, then this is a new storage - # object (i.e., did not exist in global_user_state) and we should - # set the is_sky_managed property. - # If is_sky_managed is specified, then we take no action. - self.is_sky_managed = is_new_bucket + def update_storage_attributes(self, **kwargs: Dict[str, Any]): + assert 'storage_account_name' in kwargs, ( + 'only storage_account_name supported') + assert isinstance(kwargs['storage_account_name'], + str), ('storage_account_name must be a string') + self.storage_account_name = kwargs['storage_account_name'] + self._update_storage_account_name_and_resource() @staticmethod def get_default_storage_account_name(region: Optional[str]) -> str: @@ -2518,6 +2732,9 @@ def upload(self): def delete(self) -> None: """Deletes the storage.""" + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_az_bucket(self.name) if deleted_by_skypilot: msg_str = (f'Deleted AZ Container {self.name!r} under storage ' @@ -2528,6 +2745,32 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + try: + container_url = data_utils.AZURE_CONTAINER_URL.format( + storage_account_name=self.storage_account_name, + container_name=self.name) + container_client = data_utils.create_az_client( + client_type='container', + container_url=container_url, + storage_account_name=self.storage_account_name, + resource_group_name=self.resource_group_name) + # List and delete blobs in the specified directory + blobs = container_client.list_blobs( + name_starts_with=self._bucket_sub_path + '/') + for blob in blobs: + container_client.delete_blob(blob.name) + logger.info( + f'Deleted objects from sub path {self._bucket_sub_path} ' + f'in container {self.name}.') + except Exception as e: # pylint: disable=broad-except + logger.error( + f'Failed to delete objects from sub path ' + f'{self._bucket_sub_path} in container {self.name}. ' + f'Details: {common_utils.format_exception(e, use_bracket=True)}' + ) + def get_handle(self) -> StorageHandle: """Returns the Storage Handle object.""" return self.storage_client.blob_containers.get( @@ -2554,13 +2797,15 @@ def get_file_sync_command(base_dir_path, file_names) -> str: includes_list = ';'.join(file_names) includes = f'--include-pattern "{includes_list}"' base_dir_path = shlex.quote(base_dir_path) + container_path = (f'{self.container_name}/{self._bucket_sub_path}' + if self._bucket_sub_path else self.container_name) sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' f'{includes} ' '--delete-destination false ' f'--source {base_dir_path} ' - f'--container {self.container_name}') + f'--container {container_path}') return sync_command def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: @@ -2571,8 +2816,11 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: [file_name.rstrip('*') for file_name in excluded_list]) excludes = f'--exclude-path "{excludes_list}"' src_dir_path = shlex.quote(src_dir_path) - container_path = (f'{self.container_name}/{dest_dir_name}' - if dest_dir_name else self.container_name) + container_path = (f'{self.container_name}/{self._bucket_sub_path}' + if self._bucket_sub_path else + f'{self.container_name}') + if dest_dir_name: + container_path = f'{container_path}/{dest_dir_name}' sync_command = (f'az storage blob sync ' f'--account-name {self.storage_account_name} ' f'--account-key {self.storage_account_key} ' @@ -2695,6 +2943,7 @@ def _get_bucket(self) -> Tuple[str, bool]: f'{self.storage_account_name!r}.' 'Details: ' f'{common_utils.format_exception(e, use_bracket=True)}') + # If the container cannot be found in both private and public settings, # the container is to be created by Sky. However, creation is skipped # if Store object is being reconstructed for deletion or re-mount with @@ -2725,7 +2974,8 @@ def mount_command(self, mount_path: str) -> str: mount_cmd = mounting_utils.get_az_mount_cmd(self.container_name, self.storage_account_name, mount_path, - self.storage_account_key) + self.storage_account_key, + self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -2824,11 +3074,12 @@ def __init__(self, source: str, region: Optional[str] = 'auto', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): self.client: 'boto3.client.Client' self.bucket: 'StorageHandle' super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) def _validate(self): if self.source is not None and isinstance(self.source, str): @@ -2933,6 +3184,9 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + deleted_by_skypilot = self._delete_r2_bucket(self.name) if deleted_by_skypilot: msg_str = f'Deleted R2 bucket {self.name}.' @@ -2942,6 +3196,19 @@ def delete(self) -> None: logger.info(f'{colorama.Fore.GREEN}{msg_str}' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + deleted_by_skypilot = self._delete_r2_bucket_sub_path( + self.name, self._bucket_sub_path) + if deleted_by_skypilot: + msg_str = f'Removed objects from R2 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + else: + msg_str = f'Failed to remove objects from R2 bucket ' \ + f'{self.name}/{self._bucket_sub_path}.' + logger.info(f'{colorama.Fore.GREEN}{msg_str}' + f'{colorama.Style.RESET_ALL}') + def get_handle(self) -> StorageHandle: return cloudflare.resource('s3').Bucket(self.name) @@ -2973,11 +3240,13 @@ def get_file_sync_command(base_dir_path, file_names): ]) endpoint_url = cloudflare.create_endpoint() base_dir_path = shlex.quote(base_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' 'aws s3 sync --no-follow-symlinks --exclude="*" ' f'{includes} {base_dir_path} ' - f's3://{self.name} ' + f's3://{self.name}{sub_path} ' f'--endpoint {endpoint_url} ' f'--profile={cloudflare.R2_PROFILE_NAME}') return sync_command @@ -2992,11 +3261,13 @@ def get_dir_sync_command(src_dir_path, dest_dir_name): ]) endpoint_url = cloudflare.create_endpoint() src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ('AWS_SHARED_CREDENTIALS_FILE=' f'{cloudflare.R2_CREDENTIALS_PATH} ' f'aws s3 sync --no-follow-symlinks {excludes} ' f'{src_dir_path} ' - f's3://{self.name}/{dest_dir_name} ' + f's3://{self.name}{sub_path}/{dest_dir_name} ' f'--endpoint {endpoint_url} ' f'--profile={cloudflare.R2_PROFILE_NAME}') return sync_command @@ -3127,11 +3398,9 @@ def mount_command(self, mount_path: str) -> str: endpoint_url = cloudflare.create_endpoint() r2_credential_path = cloudflare.R2_CREDENTIALS_PATH r2_profile_name = cloudflare.R2_PROFILE_NAME - mount_cmd = mounting_utils.get_r2_mount_cmd(r2_credential_path, - r2_profile_name, - endpoint_url, - self.bucket.name, - mount_path) + mount_cmd = mounting_utils.get_r2_mount_cmd( + r2_credential_path, r2_profile_name, endpoint_url, self.bucket.name, + mount_path, self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -3164,6 +3433,43 @@ def _create_r2_bucket(self, f'{self.name} but failed.') from e return cloudflare.resource('s3').Bucket(bucket_name) + def _execute_r2_remove_command(self, command: str, bucket_name: str, + hint_operating: str, + hint_failed: str) -> bool: + try: + with rich_utils.safe_status( + ux_utils.spinner_message(hint_operating)): + subprocess.check_output(command.split(' '), + stderr=subprocess.STDOUT, + shell=True) + except subprocess.CalledProcessError as e: + if 'NoSuchBucket' in e.output.decode('utf-8'): + logger.debug( + _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( + bucket_name=bucket_name)) + return False + else: + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketDeleteError( + f'{hint_failed}' + f'Detailed error: {e.output}') + return True + + def _delete_r2_bucket_sub_path(self, bucket_name: str, + sub_path: str) -> bool: + """Deletes the sub path from the bucket.""" + endpoint_url = cloudflare.create_endpoint() + remove_command = ( + f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} ' + f'aws s3 rm s3://{bucket_name}/{sub_path}/ --recursive ' + f'--endpoint {endpoint_url} ' + f'--profile={cloudflare.R2_PROFILE_NAME}') + return self._execute_r2_remove_command( + remove_command, bucket_name, + f'Removing objects from R2 bucket {bucket_name}/{sub_path}', + f'Failed to remove objects from R2 bucket {bucket_name}/{sub_path}.' + ) + def _delete_r2_bucket(self, bucket_name: str) -> bool: """Deletes R2 bucket, including all objects in bucket @@ -3186,24 +3492,12 @@ def _delete_r2_bucket(self, bucket_name: str) -> bool: f'aws s3 rb s3://{bucket_name} --force ' f'--endpoint {endpoint_url} ' f'--profile={cloudflare.R2_PROFILE_NAME}') - try: - with rich_utils.safe_status( - ux_utils.spinner_message( - f'Deleting R2 bucket {bucket_name}')): - subprocess.check_output(remove_command, - stderr=subprocess.STDOUT, - shell=True) - except subprocess.CalledProcessError as e: - if 'NoSuchBucket' in e.output.decode('utf-8'): - logger.debug( - _BUCKET_EXTERNALLY_DELETED_DEBUG_MESSAGE.format( - bucket_name=bucket_name)) - return False - else: - with ux_utils.print_exception_no_traceback(): - raise exceptions.StorageBucketDeleteError( - f'Failed to delete R2 bucket {bucket_name}.' - f'Detailed error: {e.output}') + + success = self._execute_r2_remove_command( + remove_command, bucket_name, f'Deleting R2 bucket {bucket_name}', + f'Failed to delete R2 bucket {bucket_name}.') + if not success: + return False # Wait until bucket deletion propagates on AWS servers while data_utils.verify_r2_bucket(bucket_name): @@ -3222,11 +3516,12 @@ def __init__(self, source: str, region: Optional[str] = 'us-east', is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: bool = True): + sync_on_reconstruction: bool = True, + _bucket_sub_path: Optional[str] = None): self.client: 'storage.Client' self.bucket: 'StorageHandle' super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) self.bucket_rclone_profile = \ Rclone.generate_rclone_bucket_profile_name( self.name, Rclone.RcloneClouds.IBM) @@ -3371,10 +3666,22 @@ def upload(self): f'Upload failed for store {self.name}') from e def delete(self) -> None: + if self._bucket_sub_path is not None and not self.is_sky_managed: + return self._delete_sub_path() + self._delete_cos_bucket() logger.info(f'{colorama.Fore.GREEN}Deleted COS bucket {self.name}.' f'{colorama.Style.RESET_ALL}') + def _delete_sub_path(self) -> None: + assert self._bucket_sub_path is not None, 'bucket_sub_path is not set' + bucket = self.s3_resource.Bucket(self.name) + try: + self._delete_cos_bucket_objects(bucket, self._bucket_sub_path + '/') + except ibm.ibm_botocore.exceptions.ClientError as e: + if e.__class__.__name__ == 'NoSuchBucket': + logger.debug('bucket already removed') + def get_handle(self) -> StorageHandle: return self.s3_resource.Bucket(self.name) @@ -3415,10 +3722,13 @@ def get_dir_sync_command(src_dir_path, dest_dir_name) -> str: # .git directory is excluded from the sync # wrapping src_dir_path with "" to support path with spaces src_dir_path = shlex.quote(src_dir_path) + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') sync_command = ( 'rclone copy --exclude ".git/*" ' f'{src_dir_path} ' - f'{self.bucket_rclone_profile}:{self.name}/{dest_dir_name}') + f'{self.bucket_rclone_profile}:{self.name}{sub_path}' + f'/{dest_dir_name}') return sync_command def get_file_sync_command(base_dir_path, file_names) -> str: @@ -3444,9 +3754,12 @@ def get_file_sync_command(base_dir_path, file_names) -> str: for file_name in file_names ]) base_dir_path = shlex.quote(base_dir_path) - sync_command = ('rclone copy ' - f'{includes} {base_dir_path} ' - f'{self.bucket_rclone_profile}:{self.name}') + sub_path = (f'/{self._bucket_sub_path}' + if self._bucket_sub_path else '') + sync_command = ( + 'rclone copy ' + f'{includes} {base_dir_path} ' + f'{self.bucket_rclone_profile}:{self.name}{sub_path}') return sync_command # Generate message for upload @@ -3531,6 +3844,7 @@ def _get_bucket(self) -> Tuple[StorageHandle, bool]: Rclone.RcloneClouds.IBM, self.region, # type: ignore ) + if not bucket_region and self.sync_on_reconstruction: # bucket doesn't exist return self._create_cos_bucket(self.name, self.region), True @@ -3577,7 +3891,8 @@ def mount_command(self, mount_path: str) -> str: Rclone.RCLONE_CONFIG_PATH, self.bucket_rclone_profile, self.bucket.name, - mount_path) + mount_path, + self._bucket_sub_path) return mounting_utils.get_mounting_command(mount_path, install_cmd, mount_cmd) @@ -3615,15 +3930,27 @@ def _create_cos_bucket(self, return self.bucket - def _delete_cos_bucket(self): - bucket = self.s3_resource.Bucket(self.name) - try: - bucket_versioning = self.s3_resource.BucketVersioning(self.name) - if bucket_versioning.status == 'Enabled': + def _delete_cos_bucket_objects(self, + bucket: Any, + prefix: Optional[str] = None): + bucket_versioning = self.s3_resource.BucketVersioning(bucket.name) + if bucket_versioning.status == 'Enabled': + if prefix is not None: + res = list( + bucket.object_versions.filter(Prefix=prefix).delete()) + else: res = list(bucket.object_versions.delete()) + else: + if prefix is not None: + res = list(bucket.objects.filter(Prefix=prefix).delete()) else: res = list(bucket.objects.delete()) - logger.debug(f'Deleted bucket\'s content:\n{res}') + logger.debug(f'Deleted bucket\'s content:\n{res}, prefix: {prefix}') + + def _delete_cos_bucket(self): + bucket = self.s3_resource.Bucket(self.name) + try: + self._delete_cos_bucket_objects(bucket) bucket.delete() bucket.wait_until_not_exists() except ibm.ibm_botocore.exceptions.ClientError as e: @@ -3644,7 +3971,8 @@ def __init__(self, source: str, region: Optional[str] = None, is_sky_managed: Optional[bool] = None, - sync_on_reconstruction: Optional[bool] = True): + sync_on_reconstruction: Optional[bool] = True, + _bucket_sub_path: Optional[str] = None): self.client: Any self.bucket: StorageHandle self.oci_config_file: str @@ -3656,7 +3984,8 @@ def __init__(self, region = oci.get_oci_config()['region'] super().__init__(name, source, region, is_sky_managed, - sync_on_reconstruction) + sync_on_reconstruction, _bucket_sub_path) + # TODO(zpoint): add _bucket_sub_path to the sync/mount/delete commands def _validate(self): if self.source is not None and isinstance(self.source, str): diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 0b2a5b08e1b..96651eddc39 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -268,12 +268,16 @@ # Used for translate local file mounts to cloud storage. Please refer to # sky/execution.py::_maybe_translate_local_file_mounts_and_sync_up for # more details. -WORKDIR_BUCKET_NAME = 'skypilot-workdir-{username}-{id}' -FILE_MOUNTS_BUCKET_NAME = 'skypilot-filemounts-folder-{username}-{id}' -FILE_MOUNTS_FILE_ONLY_BUCKET_NAME = 'skypilot-filemounts-files-{username}-{id}' +FILE_MOUNTS_BUCKET_NAME = 'skypilot-filemounts-{username}-{id}' FILE_MOUNTS_LOCAL_TMP_DIR = 'skypilot-filemounts-files-{id}' FILE_MOUNTS_REMOTE_TMP_DIR = '/tmp/sky-{}-filemounts-files' +# Used when an managed jobs are created and +# files are synced up to the cloud. +FILE_MOUNTS_WORKDIR_SUBPATH = 'job-{run_id}/workdir' +FILE_MOUNTS_SUBPATH = 'job-{run_id}/local-file-mounts/{i}' +FILE_MOUNTS_TMP_SUBPATH = 'job-{run_id}/tmp-files' + # The default idle timeout for SkyPilot controllers. This include spot # controller and sky serve controller. # TODO(tian): Refactor to controller_utils. Current blocker: circular import. diff --git a/sky/task.py b/sky/task.py index edd2fd211a3..bbf6d59b2ae 100644 --- a/sky/task.py +++ b/sky/task.py @@ -948,12 +948,22 @@ def _get_preferred_store( store_type = storage_lib.StoreType.from_cloud(storage_cloud_str) return store_type, storage_region - def sync_storage_mounts(self) -> None: + def sync_storage_mounts(self, force_sync: bool = False) -> None: """(INTERNAL) Eagerly syncs storage mounts to cloud storage. After syncing up, COPY-mode storage mounts are translated into regular file_mounts of the form ``{ /remote/path: {s3,gs,..}:// }``. + + Args: + force_sync: If True, forces the synchronization of storage mounts. + If the store object is added via storage.add_store(), + the sync will happen automatically via add_store. + However, if it is passed via the construction function + of storage, it is usually because the user passed an + intermediate bucket name in the config and we need to + construct from the user config. In this case, set + force_sync to True. """ for storage in self.storage_mounts.values(): if not storage.stores: @@ -961,6 +971,8 @@ def sync_storage_mounts(self) -> None: self.storage_plans[storage] = store_type storage.add_store(store_type, store_region) else: + if force_sync: + storage.sync_all_stores() # We will download the first store that is added to remote. self.storage_plans[storage] = list(storage.stores.keys())[0] @@ -977,6 +989,7 @@ def sync_storage_mounts(self) -> None: else: assert storage.name is not None, storage blob_path = 's3://' + storage.name + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -987,6 +1000,7 @@ def sync_storage_mounts(self) -> None: else: assert storage.name is not None, storage blob_path = 'gs://' + storage.name + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -1005,6 +1019,7 @@ def sync_storage_mounts(self) -> None: blob_path = data_utils.AZURE_CONTAINER_URL.format( storage_account_name=storage_account_name, container_name=storage.name) + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -1015,6 +1030,7 @@ def sync_storage_mounts(self) -> None: blob_path = storage.source else: blob_path = 'r2://' + storage.name + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({ mnt_path: blob_path, }) @@ -1030,6 +1046,7 @@ def sync_storage_mounts(self) -> None: cos_region = data_utils.Rclone.get_region_from_rclone( storage.name, data_utils.Rclone.RcloneClouds.IBM) blob_path = f'cos://{cos_region}/{storage.name}' + blob_path = storage.get_bucket_sub_path_prefix(blob_path) self.update_file_mounts({mnt_path: blob_path}) elif store_type is storage_lib.StoreType.OCI: if storage.source is not None and not isinstance( diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 0166a16ff16..39623085bbb 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -649,10 +649,27 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', still sync up any storage mounts with local source paths (which do not undergo translation). """ + # ================================================================ # Translate the workdir and local file mounts to cloud file mounts. # ================================================================ + def _sub_path_join(sub_path: Optional[str], path: str) -> str: + if sub_path is None: + return path + return os.path.join(sub_path, path).strip('/') + + def assert_no_bucket_creation(store: storage_lib.AbstractStore) -> None: + if store.is_sky_managed: + # Bucket was created, this should not happen since use configured + # the bucket and we assumed it already exists. + store.delete() + with ux_utils.print_exception_no_traceback(): + raise exceptions.StorageBucketCreateError( + f'Jobs bucket {store.name!r} does not exist. ' + 'Please check jobs.bucket configuration in ' + 'your SkyPilot config.') + run_id = common_utils.get_usage_run_id()[:8] original_file_mounts = task.file_mounts if task.file_mounts else {} original_storage_mounts = task.storage_mounts if task.storage_mounts else {} @@ -679,11 +696,27 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', ux_utils.spinner_message( f'Translating {msg} to SkyPilot Storage...')) + # Get the bucket name for the workdir and file mounts, + # we store all these files in same bucket from config. + bucket_wth_prefix = skypilot_config.get_nested(('jobs', 'bucket'), None) + store_kwargs: Dict[str, Any] = {} + if bucket_wth_prefix is None: + store_type = store_cls = sub_path = None + storage_account_name = region = None + bucket_name = constants.FILE_MOUNTS_BUCKET_NAME.format( + username=common_utils.get_cleaned_username(), id=run_id) + else: + store_type, store_cls, bucket_name, sub_path, storage_account_name, \ + region = storage_lib.StoreType.get_fields_from_store_url( + bucket_wth_prefix) + if storage_account_name is not None: + store_kwargs['storage_account_name'] = storage_account_name + if region is not None: + store_kwargs['region'] = region + # Step 1: Translate the workdir to SkyPilot storage. new_storage_mounts = {} if task.workdir is not None: - bucket_name = constants.WORKDIR_BUCKET_NAME.format( - username=common_utils.get_cleaned_username(), id=run_id) workdir = task.workdir task.workdir = None if (constants.SKY_REMOTE_WORKDIR in original_file_mounts or @@ -691,14 +724,28 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', raise ValueError( f'Cannot mount {constants.SKY_REMOTE_WORKDIR} as both the ' 'workdir and file_mounts contains it as the target.') - new_storage_mounts[ - constants. - SKY_REMOTE_WORKDIR] = storage_lib.Storage.from_yaml_config({ - 'name': bucket_name, - 'source': workdir, - 'persistent': False, - 'mode': 'COPY', - }) + bucket_sub_path = _sub_path_join( + sub_path, + constants.FILE_MOUNTS_WORKDIR_SUBPATH.format(run_id=run_id)) + stores = None + if store_type is not None: + assert store_cls is not None + with sky_logging.silent(): + stores = { + store_type: store_cls(name=bucket_name, + source=workdir, + _bucket_sub_path=bucket_sub_path, + **store_kwargs) + } + assert_no_bucket_creation(stores[store_type]) + + storage_obj = storage_lib.Storage(name=bucket_name, + source=workdir, + persistent=False, + mode=storage_lib.StorageMode.COPY, + stores=stores, + _bucket_sub_path=bucket_sub_path) + new_storage_mounts[constants.SKY_REMOTE_WORKDIR] = storage_obj # Check of the existence of the workdir in file_mounts is done in # the task construction. logger.info(f' {colorama.Style.DIM}Workdir: {workdir!r} ' @@ -716,27 +763,37 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', if os.path.isfile(os.path.abspath(os.path.expanduser(src))): copy_mounts_with_file_in_src[dst] = src continue - bucket_name = constants.FILE_MOUNTS_BUCKET_NAME.format( - username=common_utils.get_cleaned_username(), - id=f'{run_id}-{i}', - ) - new_storage_mounts[dst] = storage_lib.Storage.from_yaml_config({ - 'name': bucket_name, - 'source': src, - 'persistent': False, - 'mode': 'COPY', - }) + bucket_sub_path = _sub_path_join( + sub_path, constants.FILE_MOUNTS_SUBPATH.format(i=i, run_id=run_id)) + stores = None + if store_type is not None: + assert store_cls is not None + with sky_logging.silent(): + store = store_cls(name=bucket_name, + source=src, + _bucket_sub_path=bucket_sub_path, + **store_kwargs) + + stores = {store_type: store} + assert_no_bucket_creation(stores[store_type]) + storage_obj = storage_lib.Storage(name=bucket_name, + source=src, + persistent=False, + mode=storage_lib.StorageMode.COPY, + stores=stores, + _bucket_sub_path=bucket_sub_path) + new_storage_mounts[dst] = storage_obj logger.info(f' {colorama.Style.DIM}Folder : {src!r} ' f'-> storage: {bucket_name!r}.{colorama.Style.RESET_ALL}') # Step 3: Translate local file mounts with file in src to SkyPilot storage. # Hard link the files in src to a temporary directory, and upload folder. + file_mounts_tmp_subpath = _sub_path_join( + sub_path, constants.FILE_MOUNTS_TMP_SUBPATH.format(run_id=run_id)) local_fm_path = os.path.join( tempfile.gettempdir(), constants.FILE_MOUNTS_LOCAL_TMP_DIR.format(id=run_id)) os.makedirs(local_fm_path, exist_ok=True) - file_bucket_name = constants.FILE_MOUNTS_FILE_ONLY_BUCKET_NAME.format( - username=common_utils.get_cleaned_username(), id=run_id) file_mount_remote_tmp_dir = constants.FILE_MOUNTS_REMOTE_TMP_DIR.format( path) if copy_mounts_with_file_in_src: @@ -745,14 +802,27 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', src_to_file_id[src] = i os.link(os.path.abspath(os.path.expanduser(src)), os.path.join(local_fm_path, f'file-{i}')) - - new_storage_mounts[ - file_mount_remote_tmp_dir] = storage_lib.Storage.from_yaml_config({ - 'name': file_bucket_name, - 'source': local_fm_path, - 'persistent': False, - 'mode': 'MOUNT', - }) + stores = None + if store_type is not None: + assert store_cls is not None + with sky_logging.silent(): + stores = { + store_type: store_cls( + name=bucket_name, + source=local_fm_path, + _bucket_sub_path=file_mounts_tmp_subpath, + **store_kwargs) + } + assert_no_bucket_creation(stores[store_type]) + storage_obj = storage_lib.Storage( + name=bucket_name, + source=local_fm_path, + persistent=False, + mode=storage_lib.StorageMode.MOUNT, + stores=stores, + _bucket_sub_path=file_mounts_tmp_subpath) + + new_storage_mounts[file_mount_remote_tmp_dir] = storage_obj if file_mount_remote_tmp_dir in original_storage_mounts: with ux_utils.print_exception_no_traceback(): raise ValueError( @@ -762,8 +832,9 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', sources = list(src_to_file_id.keys()) sources_str = '\n '.join(sources) logger.info(f' {colorama.Style.DIM}Files (listed below) ' - f' -> storage: {file_bucket_name}:' + f' -> storage: {bucket_name}:' f'\n {sources_str}{colorama.Style.RESET_ALL}') + rich_utils.force_update_status( ux_utils.spinner_message('Uploading translated local files/folders')) task.update_storage_mounts(new_storage_mounts) @@ -779,7 +850,7 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', ux_utils.spinner_message('Uploading local sources to storage[/] ' '[dim]View storages: sky storage ls')) try: - task.sync_storage_mounts() + task.sync_storage_mounts(force_sync=bucket_wth_prefix is not None) except (ValueError, exceptions.NoCloudAccessError) as e: if 'No enabled cloud for storage' in str(e) or isinstance( e, exceptions.NoCloudAccessError): @@ -809,10 +880,11 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', # file_mount_remote_tmp_dir will only exist when there are files in # the src for copy mounts. storage_obj = task.storage_mounts[file_mount_remote_tmp_dir] - store_type = list(storage_obj.stores.keys())[0] - store_object = storage_obj.stores[store_type] + curr_store_type = list(storage_obj.stores.keys())[0] + store_object = storage_obj.stores[curr_store_type] bucket_url = storage_lib.StoreType.get_endpoint_url( - store_object, file_bucket_name) + store_object, bucket_name) + bucket_url += f'/{file_mounts_tmp_subpath}' for dst, src in copy_mounts_with_file_in_src.items(): file_id = src_to_file_id[src] new_file_mounts[dst] = bucket_url + f'/file-{file_id}' @@ -829,8 +901,8 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', store_types = list(storage_obj.stores.keys()) assert len(store_types) == 1, ( 'We only support one store type for now.', storage_obj.stores) - store_type = store_types[0] - store_object = storage_obj.stores[store_type] + curr_store_type = store_types[0] + store_object = storage_obj.stores[curr_store_type] storage_obj.source = storage_lib.StoreType.get_endpoint_url( store_object, storage_obj.name) storage_obj.force_delete = True @@ -847,8 +919,8 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', store_types = list(storage_obj.stores.keys()) assert len(store_types) == 1, ( 'We only support one store type for now.', storage_obj.stores) - store_type = store_types[0] - store_object = storage_obj.stores[store_type] + curr_store_type = store_types[0] + store_object = storage_obj.stores[curr_store_type] source = storage_lib.StoreType.get_endpoint_url( store_object, storage_obj.name) new_storage = storage_lib.Storage.from_yaml_config({ diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 851e77a57fc..a424ae074b9 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -299,6 +299,12 @@ def get_storage_schema(): mode.value for mode in storage.StorageMode ] }, + '_is_sky_managed': { + 'type': 'boolean', + }, + '_bucket_sub_path': { + 'type': 'string', + }, '_force_delete': { 'type': 'boolean', } @@ -721,6 +727,11 @@ def get_config_schema(): 'resources': resources_schema, } }, + 'bucket': { + 'type': 'string', + 'pattern': '^(https|s3|gs|r2|cos)://.+', + 'required': [], + } } } cloud_configs = { diff --git a/tests/smoke_tests/test_managed_job.py b/tests/smoke_tests/test_managed_job.py index c8ef5c1a502..f39dba6f47e 100644 --- a/tests/smoke_tests/test_managed_job.py +++ b/tests/smoke_tests/test_managed_job.py @@ -23,6 +23,7 @@ # > pytest tests/smoke_tests/test_managed_job.py --generic-cloud aws import pathlib +import re import tempfile import time @@ -742,14 +743,70 @@ def test_managed_jobs_storage(generic_cloud: str): # Check if file was written to the mounted output bucket output_check_cmd ], - (f'sky jobs cancel -y -n {name}', - f'; sky storage delete {output_storage_name} || true'), + (f'sky jobs cancel -y -n {name}' + f'; sky storage delete {output_storage_name} -y || true'), # Increase timeout since sky jobs queue -r can be blocked by other spot tests. timeout=20 * 60, ) smoke_tests_utils.run_one_test(test) +@pytest.mark.aws +def test_managed_jobs_intermediate_storage(generic_cloud: str): + """Test storage with managed job""" + name = smoke_tests_utils.get_cluster_name() + yaml_str = pathlib.Path( + 'examples/managed_job_with_storage.yaml').read_text() + timestamp = int(time.time()) + storage_name = f'sky-test-{timestamp}' + output_storage_name = f'sky-test-output-{timestamp}' + + yaml_str_user_config = pathlib.Path( + 'tests/test_yamls/use_intermediate_bucket_config.yaml').read_text() + intermediate_storage_name = f'intermediate-smoke-test-{timestamp}' + + yaml_str = yaml_str.replace('sky-workdir-zhwu', storage_name) + yaml_str = yaml_str.replace('sky-output-bucket', output_storage_name) + yaml_str_user_config = re.sub(r'bucket-jobs-[\w\d]+', + intermediate_storage_name, + yaml_str_user_config) + + with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f_user_config: + f_user_config.write(yaml_str_user_config) + f_user_config.flush() + user_config_path = f_user_config.name + with tempfile.NamedTemporaryFile(suffix='.yaml', mode='w') as f_task: + f_task.write(yaml_str) + f_task.flush() + file_path = f_task.name + + test = smoke_tests_utils.Test( + 'managed_jobs_intermediate_storage', + [ + *smoke_tests_utils.STORAGE_SETUP_COMMANDS, + # Verify command fails with correct error - run only once + f'err=$(sky jobs launch -n {name} --cloud {generic_cloud} {file_path} -y 2>&1); ret=$?; echo "$err" ; [ $ret -eq 0 ] || ! echo "$err" | grep "StorageBucketCreateError: Jobs bucket \'{intermediate_storage_name}\' does not exist. Please check jobs.bucket configuration in your SkyPilot config." > /dev/null && exit 1 || exit 0', + f'aws s3api create-bucket --bucket {intermediate_storage_name}', + f'sky jobs launch -n {name} --cloud {generic_cloud} {file_path} -y', + # fail because the bucket does not exist + smoke_tests_utils. + get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.SUCCEEDED], + timeout=60 + smoke_tests_utils.BUMP_UP_SECONDS), + # check intermediate bucket exists, it won't be deletd if its user specific + f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{intermediate_storage_name}\')].Name" --output text | wc -l) -eq 1 ]', + ], + (f'sky jobs cancel -y -n {name}' + f'; aws s3 rb s3://{intermediate_storage_name} --force' + f'; sky storage delete {output_storage_name} -y || true'), + env={'SKYPILOT_CONFIG': user_config_path}, + # Increase timeout since sky jobs queue -r can be blocked by other spot tests. + timeout=20 * 60, + ) + smoke_tests_utils.run_one_test(test) + + # ---------- Testing spot TPU ---------- @pytest.mark.gcp @pytest.mark.managed_jobs diff --git a/tests/smoke_tests/test_mount_and_storage.py b/tests/smoke_tests/test_mount_and_storage.py index aa61282aa11..89a849ad090 100644 --- a/tests/smoke_tests/test_mount_and_storage.py +++ b/tests/smoke_tests/test_mount_and_storage.py @@ -19,6 +19,7 @@ # Change cloud for generic tests to aws # > pytest tests/smoke_tests/test_mount_and_storage.py --generic-cloud aws +import json import os import pathlib import shlex @@ -37,6 +38,7 @@ import sky from sky import global_user_state from sky import skypilot_config +from sky.adaptors import azure from sky.adaptors import cloudflare from sky.adaptors import ibm from sky.data import data_utils @@ -629,21 +631,69 @@ def cli_delete_cmd(store_type, bucket_name, Rclone.RcloneClouds.IBM) return f'rclone purge {bucket_rclone_profile}:{bucket_name} && rclone config delete {bucket_rclone_profile}' + @classmethod + def list_all_files(cls, store_type, bucket_name): + cmd = cls.cli_ls_cmd(store_type, bucket_name, recursive=True) + if store_type == storage_lib.StoreType.GCS: + try: + out = subprocess.check_output(cmd, + shell=True, + stderr=subprocess.PIPE) + files = [line[5:] for line in out.decode('utf-8').splitlines()] + except subprocess.CalledProcessError as e: + error_output = e.stderr.decode('utf-8') + if "One or more URLs matched no objects" in error_output: + files = [] + else: + raise + elif store_type == storage_lib.StoreType.AZURE: + out = subprocess.check_output(cmd, shell=True) + try: + blobs = json.loads(out.decode('utf-8')) + files = [blob['name'] for blob in blobs] + except json.JSONDecodeError: + files = [] + elif store_type == storage_lib.StoreType.IBM: + # rclone ls format: " 1234 path/to/file" + out = subprocess.check_output(cmd, shell=True) + files = [] + for line in out.decode('utf-8').splitlines(): + # Skip empty lines + if not line.strip(): + continue + # Split by whitespace and get the file path (last column) + parts = line.strip().split( + None, 1) # Split into max 2 parts (size and path) + if len(parts) == 2: + files.append(parts[1]) + else: + out = subprocess.check_output(cmd, shell=True) + files = [ + line.split()[-1] for line in out.decode('utf-8').splitlines() + ] + return files + @staticmethod - def cli_ls_cmd(store_type, bucket_name, suffix=''): + def cli_ls_cmd(store_type, bucket_name, suffix='', recursive=False): if store_type == storage_lib.StoreType.S3: if suffix: url = f's3://{bucket_name}/{suffix}' else: url = f's3://{bucket_name}' - return f'aws s3 ls {url}' + cmd = f'aws s3 ls {url}' + if recursive: + cmd += ' --recursive' + return cmd if store_type == storage_lib.StoreType.GCS: if suffix: url = f'gs://{bucket_name}/{suffix}' else: url = f'gs://{bucket_name}' + if recursive: + url = f'"{url}/**"' return f'gsutil ls {url}' if store_type == storage_lib.StoreType.AZURE: + # azure isrecursive by default default_region = 'eastus' config_storage_account = skypilot_config.get_nested( ('azure', 'storage_account'), None) @@ -665,8 +715,10 @@ def cli_ls_cmd(store_type, bucket_name, suffix=''): url = f's3://{bucket_name}/{suffix}' else: url = f's3://{bucket_name}' - return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls {url} --endpoint {endpoint_url} --profile=r2' + recursive_flag = '--recursive' if recursive else '' + return f'AWS_SHARED_CREDENTIALS_FILE={cloudflare.R2_CREDENTIALS_PATH} aws s3 ls {url} --endpoint {endpoint_url} --profile=r2 {recursive_flag}' if store_type == storage_lib.StoreType.IBM: + # rclone ls is recursive by default bucket_rclone_profile = Rclone.generate_rclone_bucket_profile_name( bucket_name, Rclone.RcloneClouds.IBM) return f'rclone ls {bucket_rclone_profile}:{bucket_name}/{suffix}' @@ -764,6 +816,12 @@ def tmp_source(self, tmp_path): circle_link.symlink_to(tmp_dir, target_is_directory=True) yield str(tmp_dir) + @pytest.fixture + def tmp_sub_path(self): + tmp_dir1 = uuid.uuid4().hex[:8] + tmp_dir2 = uuid.uuid4().hex[:8] + yield "/".join([tmp_dir1, tmp_dir2]) + @staticmethod def generate_bucket_name(): # Creates a temporary bucket name @@ -783,13 +841,15 @@ def yield_storage_object( stores: Optional[Dict[storage_lib.StoreType, storage_lib.AbstractStore]] = None, persistent: Optional[bool] = True, - mode: storage_lib.StorageMode = storage_lib.StorageMode.MOUNT): + mode: storage_lib.StorageMode = storage_lib.StorageMode.MOUNT, + _bucket_sub_path: Optional[str] = None): # Creates a temporary storage object. Stores must be added in the test. storage_obj = storage_lib.Storage(name=name, source=source, stores=stores, persistent=persistent, - mode=mode) + mode=mode, + _bucket_sub_path=_bucket_sub_path) yield storage_obj handle = global_user_state.get_handle_from_storage_name( storage_obj.name) @@ -856,6 +916,15 @@ def tmp_local_storage_obj(self, tmp_bucket_name, tmp_source): yield from self.yield_storage_object(name=tmp_bucket_name, source=tmp_source) + @pytest.fixture + def tmp_local_storage_obj_with_sub_path(self, tmp_bucket_name, tmp_source, + tmp_sub_path): + # Creates a temporary storage object with sub. Stores must be added in the test. + list_source = [tmp_source, tmp_source + '/tmp-file'] + yield from self.yield_storage_object(name=tmp_bucket_name, + source=list_source, + _bucket_sub_path=tmp_sub_path) + @pytest.fixture def tmp_local_list_storage_obj(self, tmp_bucket_name, tmp_source): # Creates a temp storage object which uses a list of paths as source. @@ -1014,6 +1083,59 @@ def test_new_bucket_creation_and_deletion(self, tmp_local_storage_obj, out = subprocess.check_output(['sky', 'storage', 'ls']) assert tmp_local_storage_obj.name not in out.decode('utf-8') + @pytest.mark.no_fluidstack + @pytest.mark.parametrize('store_type', [ + pytest.param(storage_lib.StoreType.S3, marks=pytest.mark.aws), + pytest.param(storage_lib.StoreType.GCS, marks=pytest.mark.gcp), + pytest.param(storage_lib.StoreType.AZURE, marks=pytest.mark.azure), + pytest.param(storage_lib.StoreType.IBM, marks=pytest.mark.ibm), + pytest.param(storage_lib.StoreType.R2, marks=pytest.mark.cloudflare) + ]) + def test_bucket_sub_path(self, tmp_local_storage_obj_with_sub_path, + store_type): + # Creates a new bucket with a local source, uploads files to it + # and deletes it. + tmp_local_storage_obj_with_sub_path.add_store(store_type) + + # Check files under bucket and filter by prefix + files = self.list_all_files(store_type, + tmp_local_storage_obj_with_sub_path.name) + assert len(files) > 0 + if store_type == storage_lib.StoreType.GCS: + assert all([ + file.startswith( + tmp_local_storage_obj_with_sub_path.name + '/' + + tmp_local_storage_obj_with_sub_path._bucket_sub_path) + for file in files + ]) + else: + assert all([ + file.startswith( + tmp_local_storage_obj_with_sub_path._bucket_sub_path) + for file in files + ]) + + # Check bucket is empty, all files under sub directory should be deleted + store = tmp_local_storage_obj_with_sub_path.stores[store_type] + store.is_sky_managed = False + if store_type == storage_lib.StoreType.AZURE: + azure.assign_storage_account_iam_role( + storage_account_name=store.storage_account_name, + resource_group_name=store.resource_group_name) + store.delete() + files = self.list_all_files(store_type, + tmp_local_storage_obj_with_sub_path.name) + assert len(files) == 0 + + # Now, delete the entire bucket + store.is_sky_managed = True + tmp_local_storage_obj_with_sub_path.delete() + + # Run sky storage ls to check if storage object is deleted + out = subprocess.check_output(['sky', 'storage', 'ls']) + assert tmp_local_storage_obj_with_sub_path.name not in out.decode( + 'utf-8') + @pytest.mark.no_fluidstack @pytest.mark.xdist_group('multiple_bucket_deletion') @pytest.mark.parametrize('store_type', [ diff --git a/tests/test_yamls/intermediate_bucket.yaml b/tests/test_yamls/intermediate_bucket.yaml new file mode 100644 index 00000000000..fe9aafd0675 --- /dev/null +++ b/tests/test_yamls/intermediate_bucket.yaml @@ -0,0 +1,21 @@ +name: intermediate-bucket + +file_mounts: + /setup.py: ./setup.py + /sky: . + /train-00001-of-01024: gs://cloud-tpu-test-datasets/fake_imagenet/train-00001-of-01024 + +workdir: . + + +setup: | + echo "running setup" + +run: | + echo "listing workdir" + ls . + echo "listing file_mounts" + ls /setup.py + ls /sky + ls /train-00001-of-01024 + echo "task run finish" diff --git a/tests/test_yamls/use_intermediate_bucket_config.yaml b/tests/test_yamls/use_intermediate_bucket_config.yaml new file mode 100644 index 00000000000..cdfb5fbabc1 --- /dev/null +++ b/tests/test_yamls/use_intermediate_bucket_config.yaml @@ -0,0 +1,2 @@ +jobs: + bucket: "s3://bucket-jobs-s3"