diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3d8c526..5182a8e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 + - uses: actions/setup-go@v4 with: go-version-file: ./go.mod cache: true diff --git a/xhttp/serve.go b/xhttp/serve.go index 93ccc25..639600f 100644 --- a/xhttp/serve.go +++ b/xhttp/serve.go @@ -1,11 +1,15 @@ // Package xhttp implements http helpers. +// package xhttp import ( "context" + "errors" "log" "net" "net/http" + "sync" + "sync/atomic" "time" "oss.terrastruct.com/util-go/xcontext" @@ -22,23 +26,66 @@ func NewServer(log *log.Logger, h http.Handler) *http.Server { } } +type safeServer struct { + *http.Server + running int32 + mu sync.Mutex +} + +func newSafeServer(s *http.Server) *safeServer { + return &safeServer{ + Server: s, + } +} + +func (s *safeServer) ListenAndServe(l net.Listener) error { + s.mu.Lock() + defer s.mu.Unlock() + + if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { + return errors.New("server is already running") + } + defer atomic.StoreInt32(&s.running, 0) + + return s.Serve(l) +} + +func (s *safeServer) Shutdown(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if atomic.LoadInt32(&s.running) == 0 { + return nil + } + + return s.Server.Shutdown(ctx) +} + func Serve(ctx context.Context, shutdownTimeout time.Duration, s *http.Server, l net.Listener) error { s.BaseContext = func(net.Listener) context.Context { return ctx } - done := make(chan error, 1) + ss := newSafeServer(s) + + serverClosed := make(chan struct{}) + var serverError error go func() { - done <- s.Serve(l) + serverError = ss.ListenAndServe(l) + close(serverClosed) }() select { - case err := <-done: - return err + case <-serverClosed: + return serverError case <-ctx.Done(): - ctx = xcontext.WithoutCancel(ctx) - ctx, cancel := context.WithTimeout(ctx, shutdownTimeout) + shutdownCtx, cancel := context.WithTimeout(xcontext.WithoutCancel(ctx), shutdownTimeout) defer cancel() - return s.Shutdown(ctx) + err := ss.Shutdown(shutdownCtx) + <-serverClosed // Wait for server to exit + if err != nil { + return err + } + return serverError } } diff --git a/xmain/xmain.go b/xmain/xmain.go index 9a08041..9714764 100644 --- a/xmain/xmain.go +++ b/xmain/xmain.go @@ -182,6 +182,39 @@ func (ms *State) WritePath(fp string, p []byte) error { return os.WriteFile(fp, p, 0644) } +func (ms *State) AtomicWritePath(fp string, p []byte) error { + if fp == "-" { + return ms.WritePath(fp, p) + } + + dir := filepath.Dir(fp) + base := filepath.Base(fp) + tempFile, err := os.CreateTemp(dir, "tmp-"+base+"-") + if err != nil { + return err + } + defer func() { + // Clean up temporary file if it still exists and there's an error + if err != nil { + os.Remove(tempFile.Name()) + } + }() + + if _, err = tempFile.Write(p); err != nil { + return err + } + + if err = tempFile.Close(); err != nil { + return err + } + + if err = os.Rename(tempFile.Name(), fp); err != nil { + return err + } + + return nil +} + // AbsPath joins the PWD with fp to give the absolute path to fp. func (ms *State) AbsPath(fp string) string { if fp == "-" || filepath.IsAbs(fp) {