Skip to content

Commit

Permalink
Fix syscall path execution
Browse files Browse the repository at this point in the history
Fix archive cleanup if hash is not valid
Limit the archive write bytes
  • Loading branch information
vapopov committed Oct 15, 2024
1 parent a3fa2c3 commit 1eeb647
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions lib/autoupdate/client_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (u *ClientUpdater) Update(ctx context.Context, toolsVersion string) error {
if err != nil {
return trace.Wrap(err)
}
archivePath, err := u.downloadArchive(signalCtx, u.toolsDir, archiveURL, hash)
archivePath, archiveHash, err := u.downloadArchive(signalCtx, u.toolsDir, archiveURL)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -248,6 +248,9 @@ func (u *ClientUpdater) Update(ctx context.Context, toolsVersion string) error {
slog.WarnContext(ctx, "failed to remove archive", "error", err)
}
}()
if archiveHash != hash {
return trace.BadParameter("hash of archive does not match downloaded archive")
}

pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix)
extractDir := filepath.Join(u.toolsDir, pkgName)
Expand Down Expand Up @@ -291,7 +294,7 @@ func (u *ClientUpdater) Exec() (int, error) {
return cmd.ProcessState.ExitCode(), nil
}

if err := syscall.Exec(path, os.Args, env); err != nil {
if err := syscall.Exec(path, append([]string{path}, os.Args[1:]...), env); err != nil {
return 0, trace.Wrap(err)
}

Expand Down Expand Up @@ -324,47 +327,44 @@ func (u *ClientUpdater) downloadHash(ctx context.Context, url string) (string, e
return raw, nil
}

func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string, url string, hash string) (string, error) {
func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string, url string) (string, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", trace.Wrap(err)
return "", "", trace.Wrap(err)
}
resp, err := u.client.Do(req)
if err != nil {
return "", trace.Wrap(err)
return "", "", trace.Wrap(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", trace.BadParameter("bad status when downloading archive: %v", resp.StatusCode)
return "", "", trace.BadParameter("bad status when downloading archive: %v", resp.StatusCode)
}

if resp.ContentLength != -1 {
if err := checkFreeSpace(u.toolsDir, uint64(resp.ContentLength)); err != nil {
return "", trace.Wrap(err)
return "", "", trace.Wrap(err)
}
}

// Caller of this function will remove this file after the atomic swap has
// occurred.
f, err := os.CreateTemp(downloadDir, "tmp-")
if err != nil {
return "", trace.Wrap(err)
return "", "", trace.Wrap(err)
}

h := sha256.New()
pw := &progressWriter{n: 0, limit: resp.ContentLength}
body := io.TeeReader(io.TeeReader(resp.Body, h), pw)

// It is a little inefficient to download the file to disk and then re-load
// it into memory to unarchive later, but this is safer as it allows {tsh,
// tctl} to validate the hash before trying to operate on the archive.
_, err = io.Copy(f, body)
// it into memory to unarchive later, but this is safer as it allows client
// tools to validate the hash before trying to operate on the archive.
_, err = io.CopyN(f, body, resp.ContentLength)
if err != nil {
return "", trace.Wrap(err)
}
if fmt.Sprintf("%x", h.Sum(nil)) != hash {
return "", trace.BadParameter("hash of archive does not match downloaded archive")
return "", "", trace.Wrap(err)
}

return f.Name(), nil
return f.Name(), fmt.Sprintf("%x", h.Sum(nil)), nil
}

0 comments on commit 1eeb647

Please sign in to comment.