From 8953b4bd549077dfc4acc15678bded8e3855ccc1 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Else Date: Sat, 19 Jun 2021 23:18:56 +0200 Subject: [PATCH] Allow to cancel graceful shutdown --- example/example.go | 3 ++ shutdown.go | 74 +++++++++++++++++++++++++++++++++++----------- shutdown_test.go | 34 +++++++++++++++++++++ 3 files changed, 94 insertions(+), 17 deletions(-) diff --git a/example/example.go b/example/example.go index e4fce24..659b2a6 100644 --- a/example/example.go +++ b/example/example.go @@ -2,6 +2,7 @@ package main import ( "log" + "time" goshutdowner "git.jlel.se/jlelse/go-shutdowner" ) @@ -13,10 +14,12 @@ func main() { // Add a function to execute on shutdown sd.Add(func() { + time.Sleep(5 * time.Second) log.Println("Shutdown") }) log.Println("Started") + log.Println("Print CTRL + C once to gracefully shutdown, twice to cancel execution") // CTRL + C or otherwise interrupt the program diff --git a/shutdown.go b/shutdown.go index 31a2d38..2d02836 100644 --- a/shutdown.go +++ b/shutdown.go @@ -1,6 +1,7 @@ package goshutdowner import ( + "context" "os" "os/signal" "sync" @@ -14,30 +15,62 @@ import ( // log.Println("Shutting down") // }) type Shutdowner struct { - initialized bool - quit chan os.Signal - funcs []ShutdownFunc - wg sync.WaitGroup - mutex sync.RWMutex + initialized bool + shutdown bool + quit chan os.Signal + funcs []ShutdownFunc + wg sync.WaitGroup + mutex sync.Mutex + cancelContext context.Context + cancelFunc context.CancelFunc } type ShutdownFunc func() +// Internal method +func (f ShutdownFunc) execute(c context.Context) { + done := false + // Execute ShutdownFunc in goroutine and set done = true + go func() { + f() + done = true + }() + for { + // Check if context canceled or ShutdownFunc finished + select { + case <-c.Done(): + // Context canceled, return + return + default: + if done { + // ShutdownFunc finished, return + return + } + // Otherwise continue + } + } +} + // Internal method func (s *Shutdowner) init() { + // Check if already initialized if s.initialized { return } - s.quit = make(chan os.Signal, 1) + // Initialize cancel context and signal channel + s.cancelContext, s.cancelFunc = context.WithCancel(context.Background()) + s.quit = make(chan os.Signal, 2) signal.Notify(s.quit, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, // e.g. Docker stop ) go func() { - <-s.quit - s.Shutdown() + for range s.quit { + s.Shutdown() + } }() + // Finished s.initialized = true } @@ -45,23 +78,30 @@ func (s *Shutdowner) init() { func (s *Shutdowner) Add(f ShutdownFunc) { s.init() s.mutex.Lock() + defer s.mutex.Unlock() s.wg.Add(1) s.funcs = append(s.funcs, f) - s.mutex.Unlock() } // Trigger shutdown directly func (s *Shutdowner) Shutdown() { s.init() - s.mutex.RLock() - for _, f := range s.funcs { - go func(f func()) { - defer s.wg.Done() - f() - }(f) + s.mutex.Lock() + defer s.mutex.Unlock() + if !s.shutdown { + // First time shutdown is called + for _, f := range s.funcs { + go func(f ShutdownFunc) { + defer s.wg.Done() + f.execute(s.cancelContext) + }(f) + } + s.shutdown = true + } else { + // Second time shutdown is called + // Cancel graceful shutdown + s.cancelFunc() } - s.mutex.RUnlock() - s.wg.Wait() } // Wait till all functions finished diff --git a/shutdown_test.go b/shutdown_test.go index 5614a93..d1201c8 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -3,6 +3,7 @@ package goshutdowner import ( "os" "testing" + "time" ) func Test_shutdowner(t *testing.T) { @@ -30,6 +31,7 @@ func Test_shutdowner(t *testing.T) { var s Shutdowner var testBool1 bool s.Add(func() { + time.Sleep(1 * time.Second) testBool1 = true }) s.quit <- os.Interrupt @@ -38,4 +40,36 @@ func Test_shutdowner(t *testing.T) { t.Fail() } }) + t.Run("Cancel shutdown", func(t *testing.T) { + var s Shutdowner + var testBool1 bool + s.Add(func() { + time.Sleep(10 * time.Second) + testBool1 = true + }) + go func() { + time.Sleep(1 * time.Second) + s.Shutdown() + }() + s.ShutdownAndWait() + if testBool1 == true { + t.Fail() + } + }) + t.Run("Cancel shutdown using signal", func(t *testing.T) { + var s Shutdowner + var testBool1 bool + s.Add(func() { + time.Sleep(10 * time.Second) + testBool1 = true + }) + go func() { + time.Sleep(1 * time.Second) + s.quit <- os.Interrupt + }() + s.ShutdownAndWait() + if testBool1 == true { + t.Fail() + } + }) }