Skip to content

Commit

Permalink
Use files argument in deps methods
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed May 3, 2024
1 parent cc1f4d9 commit a06c508
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 37 deletions.
43 changes: 21 additions & 22 deletions audb/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def check() -> bool:
for table in deps.tables:
if not os.path.exists(os.path.join(db_root, table)):
return False
for media in deps.media:
if not deps.removed(media):
removed = deps.removed(deps.media)
for n, media in enumerate(deps.media):
if not removed[n]:
path = os.path.join(db_root, media)
path = flavor.destination(path)
if not os.path.exists(path):
Expand Down Expand Up @@ -244,7 +245,7 @@ def _get_attachments_from_cache(
flavor,
verbose,
)
missing_attachments = [deps.archive(path) for path in missing_paths]
missing_attachments = deps.archive(missing_paths)
db_root_tmp = database_tmp_root(db_root)

def job(cache_root: str, file: str):
Expand Down Expand Up @@ -410,21 +411,18 @@ def _get_media_from_backend(
):
r"""Load media from backend."""
# figure out archives
archives = set()
archive_names = set()
for file in media:
archive_name = deps.archive(file)
archive_version = deps.version(file)
archives.add((archive_name, archive_version))
archive_names.add(archive_name)
names = deps.archive(media)
versions = deps.version(media)
archive_names = set(names)
archives = set([(name, version) for name, version in zip(names, versions)])
# collect all files that will be extracted,
# if we have more files than archives
if len(deps.files) > len(deps.archives):
files = list()
for file in deps.media:
archive = deps.archive(file)
if archive in archive_names:
files.append(file)
files = [
file
for file, archive in zip(deps.media, deps.archive(deps.media))
if archive in archive_names
]
media = files

# create folder tree to avoid race condition
Expand All @@ -449,22 +447,23 @@ def job(archive: str, version: str):
version,
tmp_root=db_root_tmp,
)
for file in files:
if flavor is not None:
bit_depth = deps.bit_depth(files)
channels = deps.channels(files)
sampling_rate = deps.sampling_rate(files)
for n, file in enumerate(files):
if os.name == "nt": # pragma: no cover
file = file.replace(os.sep, "/")
if flavor is not None:
bit_depth = deps.bit_depth(file)
channels = deps.channels(file)
sampling_rate = deps.sampling_rate(file)
src_path = os.path.join(db_root_tmp, file)
file = flavor.destination(file)
dst_path = os.path.join(db_root_tmp, file)
flavor(
src_path,
dst_path,
src_bit_depth=bit_depth,
src_channels=channels,
src_sampling_rate=sampling_rate,
src_bit_depth=bit_depth[n],
src_channels=channels[n],
src_sampling_rate=sampling_rate[n],
)
if src_path != dst_path:
os.remove(src_path)
Expand Down
21 changes: 13 additions & 8 deletions audb/core/load_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@ def _find_media(
media = []

def job(file: str):
if not deps.removed(file):
full_file = os.path.join(db_root, file)
if not os.path.exists(full_file):
media.append(file)
full_file = os.path.join(db_root, file)
if not os.path.exists(full_file):
media.append(file)

files = [
file for file, removed in zip(db.files, deps.removed(db.files)) if not removed
]
audeer.run_tasks(
job,
params=[([file], {}) for file in db.files],
params=[([file], {}) for file in files],
num_workers=num_workers,
progress_bar=verbose,
task_description="Find media",
Expand Down Expand Up @@ -165,9 +167,12 @@ def _get_media(
utils.mkdir_tree(media, db_root_tmp)

# figure out archives
archives = set()
for file in media:
archives.add((deps.archive(file), deps.version(file)))
archives = set(
[
(name, version)
for name, version in zip(deps.archive(media), deps.version(media))
]
)

def job(archive: str, version: str):
archive = backend.join(
Expand Down
17 changes: 10 additions & 7 deletions audb/core/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,12 @@ def _find_media(
verbose: bool,
) -> typing.Set[str]:
r"""Find archives with new, altered or removed media and update 'deps'."""
media_archives = set()
db_media = set(db.files)

# release dependencies to removed media
# and select according archives for upload
removed_files = set(deps.media) - db_media
for file in removed_files:
media_archives.add(deps.archive(file))
media_archives = deps.archive(list(removed_files))
deps._drop(removed_files)

# limit to relevant media
Expand Down Expand Up @@ -242,11 +240,16 @@ def job(file):
deps._add_media(add_media)

# select archives with new or altered files for upload
for file in deps.media:
if not deps.removed(file) and deps.version(file) == version:
media_archives.add(deps.archive(file))
media_archives += [
archive
for archive, _version, removed in zip(
deps.archive(deps.media),
deps.version(deps.media),
deps.removed(deps.media),
)
]

return media_archives
return set(media_archives)


def _find_tables(
Expand Down

0 comments on commit a06c508

Please sign in to comment.