diff --git a/go/pkg/diskcache/BUILD.bazel b/go/pkg/diskcache/BUILD.bazel index f0a365b8..84832ba9 100644 --- a/go/pkg/diskcache/BUILD.bazel +++ b/go/pkg/diskcache/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "@com_github_bazelbuild_remote_apis//build/bazel/remote/execution/v2:remote_execution_go_proto", "@com_github_golang_glog//:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_x_sync//errgroup:go_default_library", ], ) diff --git a/go/pkg/diskcache/diskcache.go b/go/pkg/diskcache/diskcache.go index 196865f8..23e3ea19 100644 --- a/go/pkg/diskcache/diskcache.go +++ b/go/pkg/diskcache/diskcache.go @@ -16,6 +16,8 @@ import ( "time" "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + "golang.org/x/sync/errgroup" + log "github.com/golang/glog" ) @@ -100,7 +102,7 @@ type DiskCache struct { testGcTicks chan uint64 } -func New(ctx context.Context, root string, maxCapacityBytes uint64) *DiskCache { +func New(ctx context.Context, root string, maxCapacityBytes uint64) (*DiskCache, error) { res := &DiskCache{ root: root, maxCapacityBytes: maxCapacityBytes, @@ -112,21 +114,25 @@ func New(ctx context.Context, root string, maxCapacityBytes uint64) *DiskCache { shutdown: make(chan bool), } heap.Init(res.queue) - _ = os.MkdirAll(root, os.ModePerm) + if err := os.MkdirAll(root, os.ModePerm); err != nil { + return nil, err + } // We use Git's directory/file naming structure as inspiration: // https://git-scm.com/book/en/v2/Git-Internals-Git-Objects#:~:text=The%20subdirectory%20is%20named%20with%20the%20first%202%20characters%20of%20the%20SHA%2D1%2C%20and%20the%20filename%20is%20the%20remaining%2038%20characters. - var wg sync.WaitGroup - wg.Add(256) + eg, eCtx := errgroup.WithContext(ctx) for i := 0; i < 256; i++ { prefixDir := filepath.Join(root, fmt.Sprintf("%02x", i)) - go func() { - defer wg.Done() - _ = os.MkdirAll(prefixDir, os.ModePerm) - _ = filepath.WalkDir(prefixDir, func(path string, d fs.DirEntry, err error) error { + eg.Go(func() error { + if eCtx.Err() != nil { + return eCtx.Err() + } + if err := os.MkdirAll(prefixDir, os.ModePerm); err != nil { + return err + } + return filepath.WalkDir(prefixDir, func(path string, d fs.DirEntry, err error) error { // We log and continue on all errors, because cache read errors are not critical. if err != nil { - log.Errorf("Error reading cache directory: %v", err) - return nil + return fmt.Errorf("error reading cache directory: %v", err) } if d.IsDir() { return nil @@ -134,13 +140,11 @@ func New(ctx context.Context, root string, maxCapacityBytes uint64) *DiskCache { subdir := filepath.Base(filepath.Dir(path)) k, err := res.getKeyFromFileName(subdir + d.Name()) if err != nil { - log.Errorf("Error parsing cached file name %s: %v", path, err) - return nil + return fmt.Errorf("error parsing cached file name %s: %v", path, err) } - atime, err := GetLastAccessTime(path) + atime, err := getLastAccessTime(path) if err != nil { - log.Errorf("Error getting last accessed time of %s: %v", path, err) - return nil + return fmt.Errorf("error getting last accessed time of %s: %v", path, err) } it := &qitem{ key: k, @@ -148,8 +152,7 @@ func New(ctx context.Context, root string, maxCapacityBytes uint64) *DiskCache { } size, err := res.getItemSize(k) if err != nil { - log.Errorf("Error getting file size of %s: %v", path, err) - return nil + return fmt.Errorf("error getting file size of %s: %v", path, err) } res.store.Store(k, it) atomic.AddInt64(&res.sizeBytes, size) @@ -158,11 +161,13 @@ func New(ctx context.Context, root string, maxCapacityBytes uint64) *DiskCache { res.mu.Unlock() return nil }) - }() + }) + } + if err := eg.Wait(); err != nil { + return nil, err } - wg.Wait() go res.gc() - return res + return res, nil } func (d *DiskCache) getItemSize(k key) (int64, error) { @@ -284,18 +289,13 @@ func copyFile(src, dst string, size int64) error { return err } defer out.Close() - _, err = io.Copy(out, in) + n, err := io.Copy(out, in) if err != nil { return err } - // Required sanity check: sometimes the copy pretends to succeed, but doesn't, if - // the file is being concurrently deleted. - dstInfo, err := os.Stat(dst) - if err != nil { - return err - } - if dstInfo.Size() != size { - return fmt.Errorf("copy of %s to %s failed: src/dst size mismatch: wanted %d, got %d", src, dst, size, dstInfo.Size()) + // Required sanity check: if the file is being concurrently deleted, we may not always copy everything. + if n != size { + return fmt.Errorf("copy of %s to %s failed: src/dst size mismatch: wanted %d, got %d", src, dst, size, n) } return nil } @@ -309,15 +309,23 @@ func (d *DiskCache) LoadCas(dg digest.Digest, path string) bool { } it := iUntyped.(*qitem) it.mu.RLock() - if err := copyFile(d.getPath(k), path, dg.Size); err != nil { + err := copyFile(d.getPath(k), path, dg.Size) + it.mu.RUnlock() + if err != nil { // It is not possible to prevent a race with GC; hence, we return false on copy errors. - it.mu.RUnlock() return false } - it.mu.RUnlock() d.mu.Lock() d.queue.Bump(it) d.mu.Unlock() return true } + +func getLastAccessTime(path string) (time.Time, error) { + info, err := os.Stat(path) + if err != nil { + return time.Time{}, err + } + return FileInfoToAccessTime(info), nil +} diff --git a/go/pkg/diskcache/diskcache_test.go b/go/pkg/diskcache/diskcache_test.go index 5ac93c07..3a704488 100644 --- a/go/pkg/diskcache/diskcache_test.go +++ b/go/pkg/diskcache/diskcache_test.go @@ -41,7 +41,10 @@ func TestStoreLoadCasPerm(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { root := t.TempDir() - d := New(context.Background(), filepath.Join(root, "cache"), 20) + d, err := New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } defer d.Shutdown() fname, _ := testutil.CreateFile(t, tc.executable, "12345") srcInfo, err := os.Stat(fname) @@ -79,7 +82,10 @@ func TestStoreLoadCasPerm(t *testing.T) { func TestLoadCasNotFound(t *testing.T) { root := t.TempDir() - d := New(context.Background(), filepath.Join(root, "cache"), 20) + d, err := New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } defer d.Shutdown() newName := filepath.Join(root, "new") dg := digest.NewFromBlob([]byte("bla")) @@ -90,7 +96,10 @@ func TestLoadCasNotFound(t *testing.T) { func TestGcOldestCas(t *testing.T) { root := t.TempDir() - d := New(context.Background(), filepath.Join(root, "cache"), 20) + d, err := New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } defer d.Shutdown() d.testGcTicks = make(chan uint64, 1) for i := 0; i < 5; i++ { @@ -123,11 +132,11 @@ func TestGcOldestCas(t *testing.T) { func isSystemLastAccessTimeAccurate(t *testing.T) bool { t.Helper() fname, _ := testutil.CreateFile(t, false, "foo") - lat, _ := GetLastAccessTime(fname) + lat, _ := getLastAccessTime(fname) if _, err := os.ReadFile(fname); err != nil { t.Fatalf("%v", err) } - newLat, _ := GetLastAccessTime(fname) + newLat, _ := getLastAccessTime(fname) return lat.Before(newLat) } @@ -140,7 +149,10 @@ func TestInitFromExistingCas(t *testing.T) { return } root := t.TempDir() - d := New(context.Background(), filepath.Join(root, "cache"), 20) + d, err := New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } for i := 0; i < 4; i++ { fname, _ := testutil.CreateFile(t, false, fmt.Sprintf("aaa %d", i)) dg, err := digest.NewFromFile(fname) @@ -159,7 +171,10 @@ func TestInitFromExistingCas(t *testing.T) { d.Shutdown() // Re-initialize from existing files. - d = New(context.Background(), filepath.Join(root, "cache"), 20) + d, err = New(context.Background(), filepath.Join(root, "cache"), 20) + if err != nil { + t.Errorf("New: %v", err) + } defer d.Shutdown() d.testGcTicks = make(chan uint64, 1) @@ -169,7 +184,7 @@ func TestInitFromExistingCas(t *testing.T) { t.Errorf("expected %s to be cached", dg) } fname, _ := testutil.CreateFile(t, false, "aaa 4") - dg, err := digest.NewFromFile(fname) + dg, err = digest.NewFromFile(fname) if err != nil { t.Fatalf("digest.NewFromFile failed: %v", err) } @@ -198,7 +213,10 @@ func TestThreadSafetyCas(t *testing.T) { nFiles := 10 attempts := 5000 // All blobs are size 5 exactly. We will have half the byte capacity we need. - d := New(context.Background(), filepath.Join(root, "cache"), uint64(nFiles*5)/2) + d, err := New(context.Background(), filepath.Join(root, "cache"), uint64(nFiles*5)/2) + if err != nil { + t.Errorf("New: %v", err) + } d.testGcTicks = make(chan uint64, attempts) defer d.Shutdown() var files []string diff --git a/go/pkg/diskcache/sys_darwin.go b/go/pkg/diskcache/sys_darwin.go index 076416af..1d51d7cd 100644 --- a/go/pkg/diskcache/sys_darwin.go +++ b/go/pkg/diskcache/sys_darwin.go @@ -2,15 +2,11 @@ package diskcache import ( - "os" + "io/fs" "syscall" "time" ) -func GetLastAccessTime(path string) (time.Time, error) { - info, err := os.Stat(path) - if err != nil { - return time.Time{}, err - } - return time.Unix(info.Sys().(*syscall.Stat_t).Atimespec.Unix()), nil +func FileInfoToAccessTime(info fs.FileInfo) time.Time { + return time.Unix(info.Sys().(*syscall.Stat_t).Atimespec.Unix()) } diff --git a/go/pkg/diskcache/sys_linux.go b/go/pkg/diskcache/sys_linux.go index 1c79836c..7414d4b5 100644 --- a/go/pkg/diskcache/sys_linux.go +++ b/go/pkg/diskcache/sys_linux.go @@ -2,15 +2,11 @@ package diskcache import ( - "os" + "io/fs" "syscall" "time" ) -func GetLastAccessTime(path string) (time.Time, error) { - info, err := os.Stat(path) - if err != nil { - return time.Time{}, err - } - return time.Unix(info.Sys().(*syscall.Stat_t).Atim.Unix()), nil +func FileInfoToAccessTime(info fs.FileInfo) time.Time { + return time.Unix(info.Sys().(*syscall.Stat_t).Atim.Unix()) } diff --git a/go/pkg/diskcache/sys_windows.go b/go/pkg/diskcache/sys_windows.go index 92f35d4a..319a7988 100644 --- a/go/pkg/diskcache/sys_windows.go +++ b/go/pkg/diskcache/sys_windows.go @@ -2,17 +2,13 @@ package diskcache import ( - "os" + "io/fs" "syscall" "time" ) // This will return correct values only if `fsutil behavior set disablelastaccess 0` is set. // Tracking of last access time is disabled by default on Windows. -func GetLastAccessTime(path string) (time.Time, error) { - info, err := os.Stat(path) - if err != nil { - return time.Time{}, err - } - return time.Unix(0, info.Sys().(*syscall.Win32FileAttributeData).LastAccessTime.Nanoseconds()), nil +func FileInfoToAccessTime(info fs.FileInfo) time.Time { + return time.Unix(0, info.Sys().(*syscall.Win32FileAttributeData).LastAccessTime.Nanoseconds()) }