package goshutdowner import ( "context" "os" "os/signal" "sync" "syscall" ) // Simple struct, use like a sync.Mutex // // var s goshutdowner.Shutdowner // s.Add(func() { // log.Println("Shutting down") // }) type Shutdowner struct { initializer sync.Once shutdown bool quit chan os.Signal funcs []ShutdownFunc wg sync.WaitGroup mutex sync.Mutex cancelContext context.Context cancelFunc context.CancelFunc } type ShutdownFunc func() // Internal method func (f ShutdownFunc) execute(c context.Context) { done := make(chan interface{}) // Execute ShutdownFunc in goroutine and close channel when finished go func() { f() close(done) }() // Check if context canceled or ShutdownFunc finished select { case <-c.Done(): // Context canceled, return return case <-done: // ShutdownFunc finished, return return } } // Internal method func (s *Shutdowner) init() { s.initializer.Do(func() { s.mutex.Lock() defer s.mutex.Unlock() // Initialize cancel context and signal channel s.cancelContext, s.cancelFunc = context.WithCancel(context.Background()) s.quit = make(chan os.Signal, 2) signal.Notify(s.quit, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, // e.g. Docker stop ) go func() { for range s.quit { s.Shutdown() } }() }) } // Add a func, that should be called with s.Shutdown() or when receiving a shutdown signal func (s *Shutdowner) Add(f ShutdownFunc) { s.init() s.mutex.Lock() defer s.mutex.Unlock() s.wg.Add(1) s.funcs = append(s.funcs, f) } // Trigger shutdown directly func (s *Shutdowner) Shutdown() { s.init() s.mutex.Lock() defer s.mutex.Unlock() if !s.shutdown { // First time shutdown is called for _, f := range s.funcs { go func(f ShutdownFunc) { defer s.wg.Done() f.execute(s.cancelContext) }(f) } s.shutdown = true } else { // Second time shutdown is called // Cancel graceful shutdown s.cancelFunc() } } // Wait till all functions finished func (s *Shutdowner) Wait() { s.init() s.wg.Wait() } // Shutdown and wait till shutdown finished. Shorthand for: // // s.Shutdown() // s.Wait() func (s *Shutdowner) ShutdownAndWait() { s.Shutdown() s.Wait() }