Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race issue #23

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
- uses: actions/setup-go@v4
with:
go-version-file: ./go.mod
cache: true
Expand Down
61 changes: 54 additions & 7 deletions xhttp/serve.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
// Package xhttp implements http helpers.
//
package xhttp

import (
"context"
"errors"
"log"
"net"
"net/http"
"sync"
"sync/atomic"
"time"

"oss.terrastruct.com/util-go/xcontext"
Expand All @@ -22,23 +26,66 @@ func NewServer(log *log.Logger, h http.Handler) *http.Server {
}
}

type safeServer struct {
*http.Server
running int32
mu sync.Mutex
}

func newSafeServer(s *http.Server) *safeServer {
return &safeServer{
Server: s,
}
}

func (s *safeServer) ListenAndServe(l net.Listener) error {
s.mu.Lock()
defer s.mu.Unlock()

if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errors.New("server is already running")
}
defer atomic.StoreInt32(&s.running, 0)

return s.Serve(l)
}

func (s *safeServer) Shutdown(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()

if atomic.LoadInt32(&s.running) == 0 {
return nil
}

return s.Server.Shutdown(ctx)
}

func Serve(ctx context.Context, shutdownTimeout time.Duration, s *http.Server, l net.Listener) error {
s.BaseContext = func(net.Listener) context.Context {
return ctx
}

done := make(chan error, 1)
ss := newSafeServer(s)

serverClosed := make(chan struct{})
var serverError error
go func() {
done <- s.Serve(l)
serverError = ss.ListenAndServe(l)
close(serverClosed)
}()

select {
case err := <-done:
return err
case <-serverClosed:
return serverError
case <-ctx.Done():
ctx = xcontext.WithoutCancel(ctx)
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
shutdownCtx, cancel := context.WithTimeout(xcontext.WithoutCancel(ctx), shutdownTimeout)
defer cancel()
return s.Shutdown(ctx)
err := ss.Shutdown(shutdownCtx)
<-serverClosed // Wait for server to exit
if err != nil {
return err
}
return serverError
}
}
33 changes: 33 additions & 0 deletions xmain/xmain.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,39 @@ func (ms *State) WritePath(fp string, p []byte) error {
return os.WriteFile(fp, p, 0644)
}

func (ms *State) AtomicWritePath(fp string, p []byte) error {
if fp == "-" {
return ms.WritePath(fp, p)
}

dir := filepath.Dir(fp)
base := filepath.Base(fp)
tempFile, err := os.CreateTemp(dir, "tmp-"+base+"-")
if err != nil {
return err
}
defer func() {
// Clean up temporary file if it still exists and there's an error
if err != nil {
os.Remove(tempFile.Name())
}
}()

if _, err = tempFile.Write(p); err != nil {
return err
}

if err = tempFile.Close(); err != nil {
return err
}

if err = os.Rename(tempFile.Name(), fp); err != nil {
return err
}

return nil
}

// AbsPath joins the PWD with fp to give the absolute path to fp.
func (ms *State) AbsPath(fp string) string {
if fp == "-" || filepath.IsAbs(fp) {
Expand Down
Loading