-
Notifications
You must be signed in to change notification settings - Fork 4
/
weights.py
133 lines (109 loc) · 4.44 KB
/
weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from collections import deque
import hashlib
import os
import shutil
import subprocess
import time
class WeightsDownloadCache:
def __init__(
self, min_disk_free: int = 10 * (2**30), base_dir: str = "/src/weights-cache"
):
"""
WeightsDownloadCache is meant to track and download weights files as fast
as possible, while ensuring there's enough disk space.
It tries to keep the most recently used weights files in the cache, so
ensure you call ensure() on the weights each time you use them.
It will not re-download weights files that are already in the cache.
:param min_disk_free: Minimum disk space required to start download, in bytes.
:param base_dir: The base directory to store weights files.
"""
self.min_disk_free = min_disk_free
self.base_dir = base_dir
self._hits = 0
self._misses = 0
# Least Recently Used (LRU) cache for paths
self.lru_paths = deque()
if not os.path.exists(base_dir):
os.makedirs(base_dir)
def _remove_least_recent(self) -> None:
"""
Remove the least recently used weights file from the cache and disk.
"""
oldest = self.lru_paths.popleft()
self._rm_disk(oldest)
def cache_info(self) -> str:
"""
Get cache information.
:return: Cache information.
"""
return f"CacheInfo(hits={self._hits}, misses={self._misses}, base_dir='{self.base_dir}', currsize={len(self.lru_paths)})"
def _rm_disk(self, path: str) -> None:
"""
Remove a weights file or directory from disk.
:param path: Path to remove.
"""
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
def _has_enough_space(self) -> bool:
"""
Check if there's enough disk space.
:return: True if there's more than min_disk_free free, False otherwise.
"""
disk_usage = shutil.disk_usage(self.base_dir)
print(f"Free disk space: {disk_usage.free}")
return disk_usage.free >= self.min_disk_free
def ensure(self, url: str) -> str:
"""
Ensure weights file is in the cache and return its path.
This also updates the LRU cache to mark the weights as recently used.
:param url: URL to download weights file from, if not in cache.
:return: Path to weights.
"""
path = self.weights_path(url)
if path in self.lru_paths:
# here we remove to re-add to the end of the LRU (marking it as recently used)
self._hits += 1
self.lru_paths.remove(path)
print("weights already in cache")
else:
self._misses += 1
print("weights not in cache")
self.download_weights(url, path)
self.lru_paths.append(path) # Add file to end of cache
return path
def weights_path(self, url: str) -> str:
"""
Generate path to store a weights file based hash of the URL.
:param url: URL to download weights file from.
:return: Path to store weights file.
"""
hashed_url = hashlib.sha256(url.encode()).hexdigest()
short_hash = hashed_url[:16] # Use the first 16 characters of the hash
return os.path.join(self.base_dir, short_hash)
def download_weights(self, url: str, dest: str) -> None:
"""
Download weights file from a URL, ensuring there's enough disk space.
:param url: URL to download weights file from.
:param dest: Path to store weights file.
"""
print("Ensuring enough disk space...")
while not self._has_enough_space() and len(self.lru_paths) > 0:
self._remove_least_recent()
print(f"Downloading weights: {url}")
st = time.time()
# maybe retry with the real url if this doesn't work
url = url.replace(
"replicate.delivery/pbxt", "replicate-files.object.lga1.coreweave.com"
)
try:
print(f"downloading {url}")
output = subprocess.check_output(["pget", "-x", url, dest], close_fds=True)
print(output)
except subprocess.CalledProcessError as e:
# If download fails, clean up and re-raise exception
print(e.output)
self._rm_disk(dest)
raise e
print(f"Downloaded weights in {time.time() - st} seconds")