1
Fork 0

Allow to cancel graceful shutdown

This commit is contained in:
Jan-Lukas Else 2021-06-19 23:18:56 +02:00
parent 61dfc5d02b
commit 8953b4bd54
3 changed files with 94 additions and 17 deletions

View File

@ -2,6 +2,7 @@ package main
import ( import (
"log" "log"
"time"
goshutdowner "git.jlel.se/jlelse/go-shutdowner" goshutdowner "git.jlel.se/jlelse/go-shutdowner"
) )
@ -13,10 +14,12 @@ func main() {
// Add a function to execute on shutdown // Add a function to execute on shutdown
sd.Add(func() { sd.Add(func() {
time.Sleep(5 * time.Second)
log.Println("Shutdown") log.Println("Shutdown")
}) })
log.Println("Started") log.Println("Started")
log.Println("Print CTRL + C once to gracefully shutdown, twice to cancel execution")
// CTRL + C or otherwise interrupt the program // CTRL + C or otherwise interrupt the program

View File

@ -1,6 +1,7 @@
package goshutdowner package goshutdowner
import ( import (
"context"
"os" "os"
"os/signal" "os/signal"
"sync" "sync"
@ -14,30 +15,62 @@ import (
// log.Println("Shutting down") // log.Println("Shutting down")
// }) // })
type Shutdowner struct { type Shutdowner struct {
initialized bool initialized bool
quit chan os.Signal shutdown bool
funcs []ShutdownFunc quit chan os.Signal
wg sync.WaitGroup funcs []ShutdownFunc
mutex sync.RWMutex wg sync.WaitGroup
mutex sync.Mutex
cancelContext context.Context
cancelFunc context.CancelFunc
} }
type ShutdownFunc func() type ShutdownFunc func()
// Internal method
func (f ShutdownFunc) execute(c context.Context) {
done := false
// Execute ShutdownFunc in goroutine and set done = true
go func() {
f()
done = true
}()
for {
// Check if context canceled or ShutdownFunc finished
select {
case <-c.Done():
// Context canceled, return
return
default:
if done {
// ShutdownFunc finished, return
return
}
// Otherwise continue
}
}
}
// Internal method // Internal method
func (s *Shutdowner) init() { func (s *Shutdowner) init() {
// Check if already initialized
if s.initialized { if s.initialized {
return return
} }
s.quit = make(chan os.Signal, 1) // Initialize cancel context and signal channel
s.cancelContext, s.cancelFunc = context.WithCancel(context.Background())
s.quit = make(chan os.Signal, 2)
signal.Notify(s.quit, signal.Notify(s.quit,
os.Interrupt, os.Interrupt,
syscall.SIGINT, syscall.SIGINT,
syscall.SIGTERM, // e.g. Docker stop syscall.SIGTERM, // e.g. Docker stop
) )
go func() { go func() {
<-s.quit for range s.quit {
s.Shutdown() s.Shutdown()
}
}() }()
// Finished
s.initialized = true s.initialized = true
} }
@ -45,23 +78,30 @@ func (s *Shutdowner) init() {
func (s *Shutdowner) Add(f ShutdownFunc) { func (s *Shutdowner) Add(f ShutdownFunc) {
s.init() s.init()
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock()
s.wg.Add(1) s.wg.Add(1)
s.funcs = append(s.funcs, f) s.funcs = append(s.funcs, f)
s.mutex.Unlock()
} }
// Trigger shutdown directly // Trigger shutdown directly
func (s *Shutdowner) Shutdown() { func (s *Shutdowner) Shutdown() {
s.init() s.init()
s.mutex.RLock() s.mutex.Lock()
for _, f := range s.funcs { defer s.mutex.Unlock()
go func(f func()) { if !s.shutdown {
defer s.wg.Done() // First time shutdown is called
f() for _, f := range s.funcs {
}(f) go func(f ShutdownFunc) {
defer s.wg.Done()
f.execute(s.cancelContext)
}(f)
}
s.shutdown = true
} else {
// Second time shutdown is called
// Cancel graceful shutdown
s.cancelFunc()
} }
s.mutex.RUnlock()
s.wg.Wait()
} }
// Wait till all functions finished // Wait till all functions finished

View File

@ -3,6 +3,7 @@ package goshutdowner
import ( import (
"os" "os"
"testing" "testing"
"time"
) )
func Test_shutdowner(t *testing.T) { func Test_shutdowner(t *testing.T) {
@ -30,6 +31,7 @@ func Test_shutdowner(t *testing.T) {
var s Shutdowner var s Shutdowner
var testBool1 bool var testBool1 bool
s.Add(func() { s.Add(func() {
time.Sleep(1 * time.Second)
testBool1 = true testBool1 = true
}) })
s.quit <- os.Interrupt s.quit <- os.Interrupt
@ -38,4 +40,36 @@ func Test_shutdowner(t *testing.T) {
t.Fail() t.Fail()
} }
}) })
t.Run("Cancel shutdown", func(t *testing.T) {
var s Shutdowner
var testBool1 bool
s.Add(func() {
time.Sleep(10 * time.Second)
testBool1 = true
})
go func() {
time.Sleep(1 * time.Second)
s.Shutdown()
}()
s.ShutdownAndWait()
if testBool1 == true {
t.Fail()
}
})
t.Run("Cancel shutdown using signal", func(t *testing.T) {
var s Shutdowner
var testBool1 bool
s.Add(func() {
time.Sleep(10 * time.Second)
testBool1 = true
})
go func() {
time.Sleep(1 * time.Second)
s.quit <- os.Interrupt
}()
s.ShutdownAndWait()
if testBool1 == true {
t.Fail()
}
})
} }