Skip to content

Commit

Permalink
Merge pull request #853 from planetscale/dbussink/fix-shell-signal-ha…
Browse files Browse the repository at this point in the history
…ndling

Fix the MySQL shell signal handling
  • Loading branch information
dbussink authored Apr 18, 2024
2 parents 3db79dc + b77a1db commit e37d329
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 10 deletions.
19 changes: 17 additions & 2 deletions cmd/pscale/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,23 @@ func main() {
}

func realMain() int {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

return cmd.Execute(ctx, version, commit, date)
sigc := make(chan os.Signal, 1)
signals := []os.Signal{os.Interrupt}
signal.Notify(sigc, signals...)
defer func() {
signal.Stop(sigc)
cancel()
}()
go func() {
select {
case <-sigc:
cancel()
case <-ctx.Done():
}
}()

return cmd.Execute(ctx, sigc, signals, version, commit, date)
}
8 changes: 4 additions & 4 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ var rootCmd = &cobra.Command{

// Execute executes the command and returns the exit status of the finished
// command.
func Execute(ctx context.Context, ver, commit, buildDate string) int {
func Execute(ctx context.Context, sigc chan os.Signal, signals []os.Signal, ver, commit, buildDate string) int {
var format printer.Format
var debug bool

Expand Down Expand Up @@ -106,7 +106,7 @@ func Execute(ctx context.Context, ver, commit, buildDate string) int {
}()
}

err := runCmd(ctx, ver, commit, buildDate, &format, &debug)
err := runCmd(ctx, ver, commit, buildDate, &format, &debug, sigc, signals)
if err == nil {
return 0
}
Expand All @@ -130,7 +130,7 @@ func Execute(ctx context.Context, ver, commit, buildDate string) int {

// runCmd adds all child commands to the root command, sets flags
// appropriately, and runs the root command.
func runCmd(ctx context.Context, ver, commit, buildDate string, format *printer.Format, debug *bool) error {
func runCmd(ctx context.Context, ver, commit, buildDate string, format *printer.Format, debug *bool, sigc chan os.Signal, signals []os.Signal) error {
cobra.OnInitialize(initConfig)

rootCmd.PersistentFlags().StringVar(&cfgFile, "config",
Expand Down Expand Up @@ -220,7 +220,7 @@ func runCmd(ctx context.Context, ver, commit, buildDate string, format *printer.
rootCmd.AddCommand(password.PasswordCmd(ch))
rootCmd.AddCommand(ping.PingCmd(ch))
rootCmd.AddCommand(region.RegionCmd(ch))
rootCmd.AddCommand(shell.ShellCmd(ch))
rootCmd.AddCommand(shell.ShellCmd(ch, sigc, signals...))
rootCmd.AddCommand(signup.SignupCmd(ch))
rootCmd.AddCommand(token.TokenCmd(ch))
rootCmd.AddCommand(version.VersionCmd(ch, ver, commit, buildDate))
Expand Down
53 changes: 49 additions & 4 deletions internal/cmd/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"os"
"os/signal"
"path/filepath"
"time"

Expand All @@ -22,7 +23,7 @@ import (
"github.com/planetscale/cli/internal/proxyutil"
)

func ShellCmd(ch *cmdutil.Helper) *cobra.Command {
func ShellCmd(ch *cmdutil.Helper, sigc chan os.Signal, signals ...os.Signal) *cobra.Command {
var flags struct {
localAddr string
remoteAddr string
Expand Down Expand Up @@ -212,7 +213,7 @@ second argument:
}()

go func() {
errCh <- m.Run(ctx, mysqlArgs...)
errCh <- m.Run(ctx, sigc, signals, mysqlArgs...)
}()

go func() {
Expand Down Expand Up @@ -255,7 +256,7 @@ type mysql struct {
}

// Run runs the `mysql` client with the given arguments.
func (m *mysql) Run(ctx context.Context, args ...string) error {
func (m *mysql) Run(ctx context.Context, sigc chan os.Signal, signals []os.Signal, args ...string) error {
c := exec.CommandContext(ctx, m.mysqlPath, args...)
if m.dir != "" {
c.Dir = m.dir
Expand All @@ -271,7 +272,51 @@ func (m *mysql) Run(ctx context.Context, args ...string) error {
c.Stderr = os.Stderr
c.Stdin = os.Stdin

return c.Run()
// Set up a new channel for signals received while MySQL is active.
// This is registered for all signals so we forward them all to MySQL,
// so we behave as much as possible like a regular MySQL.
// When we exit this function, we stop the custom signal receiver.
msig := make(chan os.Signal, 1)
signal.Notify(msig)
defer signal.Stop(msig)

// We stop handling signals for our default setup from the CLI. This
// is needed, so we stop handling for example the default os.Interrupt
// that stops the shell and we forward it to MySQL.
// When we exit from this function, we restore the signals as they were.
signal.Stop(sigc)
defer signal.Notify(sigc, signals...)

err := c.Start()
if err != nil {
return err
}

wait := make(chan error, 1)
go func() {
wait <- c.Wait()
close(wait)
}()

for {
select {
case sig := <-msig:
if err := c.Process.Signal(sig); err != nil {
// If we failed to send a signal to the process, just in case
// it's still alive, make sure we kill it.
_ = c.Process.Signal(os.Kill)
return err
}
case err := <-wait:
if err != nil {
// If we failed to wait for the process, just in case
// we send a hard kill to ensure the MySQL subprocess
// gets killed.
c.Process.Signal(os.Kill)
}
return err
}
}
}

func formatMySQLBranch(database string, branch *ps.DatabaseBranch) string {
Expand Down

0 comments on commit e37d329

Please sign in to comment.