Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

windows: handle pending control actions #268

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package service_test

import (
"fmt"
"os"
"testing"
"time"
Expand All @@ -22,22 +23,27 @@ func TestRunInterrupt(t *testing.T) {
t.Fatalf("New err: %s", err)
}

retChan := make(chan error)
go func() {
if err = s.Run(); err != nil {
retChan <- fmt.Errorf("Run() err: %w", err)
}
}()
go func() {
<-time.After(1 * time.Second)
interruptProcess(t)
}()

go func() {
for i := 0; i < 25 && p.numStopped == 0; i++ {
<-time.After(200 * time.Millisecond)
}
if p.numStopped == 0 {
t.Fatal("Run() hasn't been stopped")
retChan <- fmt.Errorf("Run() hasn't been stopped")
}
retChan <- nil
}()

if err = s.Run(); err != nil {
t.Fatalf("Run() err: %s", err)
if err = <-retChan; err != nil {
t.Fatal(err)
}
}

Expand Down
226 changes: 159 additions & 67 deletions service_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package service

import (
"errors"
"fmt"
"os"
"os/signal"
Expand Down Expand Up @@ -39,6 +40,11 @@ const (
errnoServiceDoesNotExist syscall.Errno = 1060
)

var (
errAlreadyRunning = errors.New("service already running")
errAlreadyStopped = errors.New("service already stopped")
)

type windowsService struct {
i Interface
*Config
Expand Down Expand Up @@ -149,11 +155,11 @@ func (l WindowsLogger) NInfof(eventID uint32, format string, a ...interface{}) e
var interactive = false

func init() {
var err error
interactive, err = svc.IsAnInteractiveSession()
isService, err := svc.IsWindowsService()
if err != nil {
panic(err)
}
interactive = !isService
}

func (ws *windowsService) String() string {
Expand All @@ -178,48 +184,63 @@ func (ws *windowsService) getError() error {
return ws.stopStartErr
}

func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, exitCode uint32) {
var err error
defer func() {
if err != nil {
ssec = true
ws.setError(err)
}
}()

// Signal that we're starting.
changes <- svc.Status{State: svc.StartPending}

if err := ws.i.Start(ws); err != nil {
ws.setError(err)
return true, 1
// Perform the actual start.
if initErr := ws.i.Start(ws); initErr != nil {
err = initErr
exitCode = 1
return
}

// Signal that we're ready.
changes <- svc.Status{
State: svc.Running,
Accepts: svc.AcceptStop | svc.AcceptShutdown,
}

changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
// Expect service change requests.
var stopMethod func(s Service) error
loop:
for {
c := <-r
for c := range r {
switch c.Cmd {
case svc.Interrogate:
changes <- c.CurrentStatus
case svc.Stop:
changes <- svc.Status{State: svc.StopPending}
if err := ws.i.Stop(ws); err != nil {
ws.setError(err)
return true, 2
}
break loop
case svc.Shutdown:
changes <- svc.Status{State: svc.StopPending}
var err error
if wsShutdown, ok := ws.i.(Shutdowner); ok {
err = wsShutdown.Shutdown(ws)
} else {
err = ws.i.Stop(ws)
}
if err != nil {
ws.setError(err)
return true, 2
if shutdowner, ok := ws.i.(Shutdowner); ok {
stopMethod = shutdowner.Shutdown
break loop
}
fallthrough
case svc.Stop:
stopMethod = ws.i.Stop
break loop
default:
continue loop
continue
}
}

return false, 0
// We were requested to stop,
// change state and proceed to do so.
changes <- svc.Status{State: svc.StopPending}
if stopErr := stopMethod(ws); stopErr != nil {
err = stopErr
exitCode = 2
return
}

// Calling function will set our state to Stopped.
return
}

func (ws *windowsService) Install() error {
Expand Down Expand Up @@ -308,19 +329,21 @@ func (ws *windowsService) Uninstall() error {
return err
}
defer m.Disconnect()

s, err := m.OpenService(ws.Name)
if err != nil {
return fmt.Errorf("service %s is not installed", ws.Name)
}
defer s.Close()
err = s.Delete()
if err != nil {

if err := s.Delete(); err != nil {
return err
}
err = eventlog.Remove(ws.Name)
if err != nil {

if err := eventlog.Remove(ws.Name); err != nil {
return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
}

return nil
}

Expand All @@ -346,7 +369,7 @@ func (ws *windowsService) Run() error {
return err
}

sigChan := make(chan os.Signal)
sigChan := make(chan os.Signal, 1)

signal.Notify(sigChan, os.Interrupt)

Expand Down Expand Up @@ -408,26 +431,28 @@ func (ws *windowsService) Start() error {
return err
}
defer s.Close()
return s.Start()
}

func (ws *windowsService) Stop() error {
m, err := mgr.Connect()
status, err := s.Query()
if err != nil {
return err
}
defer m.Disconnect()

s, err := m.OpenService(ws.Name)
if err != nil {
return err
switch status.State {
default:
err = errAlreadyRunning
case svc.StopPending:
err = waitForStateChange(s, status, svc.Stopped)
case svc.Stopped:
if startErr := s.Start(); startErr != nil {
return startErr
}
err = waitForStateChange(s, status, svc.Running)
}
defer s.Close()

return ws.stopWait(s)
return err
}

func (ws *windowsService) Restart() error {
func (ws *windowsService) Stop() error {
m, err := mgr.Connect()
if err != nil {
return err
Expand All @@ -440,37 +465,104 @@ func (ws *windowsService) Restart() error {
}
defer s.Close()

err = ws.stopWait(s)
status, err := s.Query()
if err != nil {
return err
}

return s.Start()
}

func (ws *windowsService) stopWait(s *mgr.Service) error {
// First stop the service. Then wait for the service to
// actually stop before starting it.
status, err := s.Control(svc.Stop)
if err != nil {
return err
switch status.State {
case svc.Stopped:
err = errAlreadyStopped
case svc.StopPending:
err = waitForStateChange(s, status, svc.Stopped)
default:
if _, stopErr := s.Control(svc.Stop); stopErr != nil {
return stopErr
}
err = waitForStateChange(s, status, svc.Stopped)
}

timeDuration := time.Millisecond * 50

timeout := time.After(getStopTimeout() + (timeDuration * 2))
tick := time.NewTicker(timeDuration)
defer tick.Stop()
return err
}

for status.State != svc.Stopped {
func (ws *windowsService) Restart() error {
if stopErr := ws.Stop(); stopErr != nil {
return stopErr
}
return ws.Start()
}

// statusInterval retreives a (bounded) duration from the status,
// or provides a default.
func statusInterval(status svc.Status) time.Duration {
// MSDN:
// "Do not wait longer than the wait hint. A good interval is
// one-tenth of the wait hint but not less than 1 second
// and not more than 10 seconds."
const (
lower = time.Second
upper = time.Second * 10
)

waitDuration := (time.Duration(status.WaitHint) * time.Millisecond) / 10
if waitDuration < lower {
waitDuration = lower
} else if waitDuration > upper {
waitDuration = upper
}
return waitDuration
}

// waitForStateChange polls the service until its state matches the desiredState,
// and error is encountered, or we timeout.
func waitForStateChange(s *mgr.Service, currentStatus svc.Status, desiredState svc.State) error {
const defaultAttempts = 10
var (
initialInterval = statusInterval(currentStatus)
queryTicker = time.NewTicker(initialInterval)
queryTimer *time.Timer
)
// If the service is providing hints,
// use them, otherwise use a default timeout.
if currentStatus.CheckPoint != 0 {
queryTimer = time.NewTimer(initialInterval)
} else {
queryTimer = time.NewTimer(initialInterval * defaultAttempts)
}
defer func() {
queryTicker.Stop()
queryTimer.Stop()
}()

var (
currentState = currentStatus.State
lastCheckpoint uint32
)
for currentState != desiredState {
select {
case <-tick.C:
status, err = s.Query()
if err != nil {
return err
case <-queryTicker.C:
currentStatus, queryErr := s.Query()
if queryErr != nil {
return queryErr
}

currentState = currentStatus.State
if currentState == desiredState {
return nil
}

if currentStatus.CheckPoint > lastCheckpoint {
// Service progressed,
// give it more time to complete.
if !queryTimer.Stop() {
<-queryTimer.C
}
queryTimer.Reset(statusInterval(currentStatus))
}
case <-timeout:
break
lastCheckpoint = currentStatus.CheckPoint
case <-queryTimer.C:
return fmt.Errorf("service did not enter desired state (%v) before we timed out",
desiredState)
}
}
return nil
Expand Down