523 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			523 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2014 beego Author. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //      http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| // Description: http://grisha.org/blog/2014/06/03/graceful-restart-in-golang/
 | |
| //
 | |
| // Usage:
 | |
| //
 | |
| // import(
 | |
| //   "log"
 | |
| //	 "net/http"
 | |
| //	 "os"
 | |
| //
 | |
| //   "github.com/astaxie/beego/grace"
 | |
| // )
 | |
| //
 | |
| //  func handler(w http.ResponseWriter, r *http.Request) {
 | |
| //	  w.Write([]byte("WORLD!"))
 | |
| //  }
 | |
| //
 | |
| //  func main() {
 | |
| //      mux := http.NewServeMux()
 | |
| //      mux.HandleFunc("/hello", handler)
 | |
| //
 | |
| //	    err := grace.ListenAndServe("localhost:8080", mux1)
 | |
| //      if err != nil {
 | |
| //		   log.Println(err)
 | |
| //	    }
 | |
| //      log.Println("Server on 8080 stopped")
 | |
| //	     os.Exit(0)
 | |
| //    }
 | |
| package grace
 | |
| 
 | |
| import (
 | |
| 	"crypto/tls"
 | |
| 	"flag"
 | |
| 	"fmt"
 | |
| 	"log"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"os/exec"
 | |
| 	"os/signal"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"syscall"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	PRE_SIGNAL = iota
 | |
| 	POST_SIGNAL
 | |
| 
 | |
| 	STATE_INIT
 | |
| 	STATE_RUNNING
 | |
| 	STATE_SHUTTING_DOWN
 | |
| 	STATE_TERMINATE
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	regLock              *sync.Mutex
 | |
| 	runningServers       map[string]*graceServer
 | |
| 	runningServersOrder  []string
 | |
| 	socketPtrOffsetMap   map[string]uint
 | |
| 	runningServersForked bool
 | |
| 
 | |
| 	DefaultReadTimeOut    time.Duration
 | |
| 	DefaultWriteTimeOut   time.Duration
 | |
| 	DefaultMaxHeaderBytes int
 | |
| 	DefaultTimeout        time.Duration
 | |
| 
 | |
| 	isChild     bool
 | |
| 	socketOrder string
 | |
| )
 | |
| 
 | |
| func init() {
 | |
| 	regLock = &sync.Mutex{}
 | |
| 	flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
 | |
| 	flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
 | |
| 	runningServers = make(map[string]*graceServer)
 | |
| 	runningServersOrder = []string{}
 | |
| 	socketPtrOffsetMap = make(map[string]uint)
 | |
| 
 | |
| 	DefaultMaxHeaderBytes = 0
 | |
| 
 | |
| 	// after a restart the parent will finish ongoing requests before
 | |
| 	// shutting down. set to a negative value to disable
 | |
| 	DefaultTimeout = 60 * time.Second
 | |
| }
 | |
| 
 | |
| type graceServer struct {
 | |
| 	http.Server
 | |
| 	GraceListener    net.Listener
 | |
| 	SignalHooks      map[int]map[os.Signal][]func()
 | |
| 	tlsInnerListener *graceListener
 | |
| 	wg               sync.WaitGroup
 | |
| 	sigChan          chan os.Signal
 | |
| 	isChild          bool
 | |
| 	state            uint8
 | |
| 	Network          string
 | |
| }
 | |
| 
 | |
| // NewServer returns an intialized graceServer. Calling Serve on it will
 | |
| // actually "start" the server.
 | |
| func NewServer(addr string, handler http.Handler) (srv *graceServer) {
 | |
| 	regLock.Lock()
 | |
| 	defer regLock.Unlock()
 | |
| 	if !flag.Parsed() {
 | |
| 		flag.Parse()
 | |
| 	}
 | |
| 	if len(socketOrder) > 0 {
 | |
| 		for i, addr := range strings.Split(socketOrder, ",") {
 | |
| 			socketPtrOffsetMap[addr] = uint(i)
 | |
| 		}
 | |
| 	} else {
 | |
| 		socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
 | |
| 	}
 | |
| 
 | |
| 	srv = &graceServer{
 | |
| 		wg:      sync.WaitGroup{},
 | |
| 		sigChan: make(chan os.Signal),
 | |
| 		isChild: isChild,
 | |
| 		SignalHooks: map[int]map[os.Signal][]func(){
 | |
| 			PRE_SIGNAL: map[os.Signal][]func(){
 | |
| 				syscall.SIGHUP:  []func(){},
 | |
| 				syscall.SIGUSR1: []func(){},
 | |
| 				syscall.SIGUSR2: []func(){},
 | |
| 				syscall.SIGINT:  []func(){},
 | |
| 				syscall.SIGTERM: []func(){},
 | |
| 				syscall.SIGTSTP: []func(){},
 | |
| 			},
 | |
| 			POST_SIGNAL: map[os.Signal][]func(){
 | |
| 				syscall.SIGHUP:  []func(){},
 | |
| 				syscall.SIGUSR1: []func(){},
 | |
| 				syscall.SIGUSR2: []func(){},
 | |
| 				syscall.SIGINT:  []func(){},
 | |
| 				syscall.SIGTERM: []func(){},
 | |
| 				syscall.SIGTSTP: []func(){},
 | |
| 			},
 | |
| 		},
 | |
| 		state:   STATE_INIT,
 | |
| 		Network: "tcp",
 | |
| 	}
 | |
| 
 | |
| 	srv.Server.Addr = addr
 | |
| 	srv.Server.ReadTimeout = DefaultReadTimeOut
 | |
| 	srv.Server.WriteTimeout = DefaultWriteTimeOut
 | |
| 	srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
 | |
| 	srv.Server.Handler = handler
 | |
| 
 | |
| 	runningServersOrder = append(runningServersOrder, addr)
 | |
| 	runningServers[addr] = srv
 | |
| 
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // ListenAndServe listens on the TCP network address addr
 | |
| // and then calls Serve to handle requests on incoming connections.
 | |
| func ListenAndServe(addr string, handler http.Handler) error {
 | |
| 	server := NewServer(addr, handler)
 | |
| 	return server.ListenAndServe()
 | |
| }
 | |
| 
 | |
| // ListenAndServeTLS listens on the TCP network address addr and then calls
 | |
| // Serve to handle requests on incoming TLS connections.
 | |
| //
 | |
| // Filenames containing a certificate and matching private key for the server must be provided.
 | |
| // If the certificate is signed by a certificate authority,
 | |
| // the certFile should be the concatenation of the server's certificate followed by the CA's certificate.
 | |
| func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
 | |
| 	server := NewServer(addr, handler)
 | |
| 	return server.ListenAndServeTLS(certFile, keyFile)
 | |
| }
 | |
| 
 | |
| // Serve accepts incoming connections on the Listener l,
 | |
| // creating a new service goroutine for each.
 | |
| // The service goroutines read requests and then call srv.Handler to reply to them.
 | |
| func (srv *graceServer) Serve() (err error) {
 | |
| 	srv.state = STATE_RUNNING
 | |
| 	err = srv.Server.Serve(srv.GraceListener)
 | |
| 	log.Println(syscall.Getpid(), "Waiting for connections to finish...")
 | |
| 	srv.wg.Wait()
 | |
| 	srv.state = STATE_TERMINATE
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
 | |
| // to handle requests on incoming connections. If srv.Addr is blank, ":http" is
 | |
| // used.
 | |
| func (srv *graceServer) ListenAndServe() (err error) {
 | |
| 	addr := srv.Addr
 | |
| 	if addr == "" {
 | |
| 		addr = ":http"
 | |
| 	}
 | |
| 
 | |
| 	go srv.handleSignals()
 | |
| 
 | |
| 	l, err := srv.getListener(addr)
 | |
| 	if err != nil {
 | |
| 		log.Println(err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	srv.GraceListener = newGraceListener(l, srv)
 | |
| 
 | |
| 	if srv.isChild {
 | |
| 		syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
 | |
| 	}
 | |
| 
 | |
| 	log.Println(syscall.Getpid(), srv.Addr)
 | |
| 	return srv.Serve()
 | |
| }
 | |
| 
 | |
| // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
 | |
| // Serve to handle requests on incoming TLS connections.
 | |
| //
 | |
| // Filenames containing a certificate and matching private key for the server must
 | |
| // be provided. If the certificate is signed by a certificate authority, the
 | |
| // certFile should be the concatenation of the server's certificate followed by the
 | |
| // CA's certificate.
 | |
| //
 | |
| // If srv.Addr is blank, ":https" is used.
 | |
| func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) {
 | |
| 	addr := srv.Addr
 | |
| 	if addr == "" {
 | |
| 		addr = ":https"
 | |
| 	}
 | |
| 
 | |
| 	config := &tls.Config{}
 | |
| 	if srv.TLSConfig != nil {
 | |
| 		*config = *srv.TLSConfig
 | |
| 	}
 | |
| 	if config.NextProtos == nil {
 | |
| 		config.NextProtos = []string{"http/1.1"}
 | |
| 	}
 | |
| 
 | |
| 	config.Certificates = make([]tls.Certificate, 1)
 | |
| 	config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	go srv.handleSignals()
 | |
| 
 | |
| 	l, err := srv.getListener(addr)
 | |
| 	if err != nil {
 | |
| 		log.Println(err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	srv.tlsInnerListener = newGraceListener(l, srv)
 | |
| 	srv.GraceListener = tls.NewListener(srv.tlsInnerListener, config)
 | |
| 
 | |
| 	if srv.isChild {
 | |
| 		syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
 | |
| 	}
 | |
| 
 | |
| 	log.Println(syscall.Getpid(), srv.Addr)
 | |
| 	return srv.Serve()
 | |
| }
 | |
| 
 | |
| // getListener either opens a new socket to listen on, or takes the acceptor socket
 | |
| // it got passed when restarted.
 | |
| func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) {
 | |
| 	if srv.isChild {
 | |
| 		var ptrOffset uint = 0
 | |
| 		if len(socketPtrOffsetMap) > 0 {
 | |
| 			ptrOffset = socketPtrOffsetMap[laddr]
 | |
| 			log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
 | |
| 		}
 | |
| 
 | |
| 		f := os.NewFile(uintptr(3+ptrOffset), "")
 | |
| 		l, err = net.FileListener(f)
 | |
| 		if err != nil {
 | |
| 			err = fmt.Errorf("net.FileListener error: %v", err)
 | |
| 			return
 | |
| 		}
 | |
| 	} else {
 | |
| 		l, err = net.Listen(srv.Network, laddr)
 | |
| 		if err != nil {
 | |
| 			err = fmt.Errorf("net.Listen error: %v", err)
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // handleSignals listens for os Signals and calls any hooked in function that the
 | |
| // user had registered with the signal.
 | |
| func (srv *graceServer) handleSignals() {
 | |
| 	var sig os.Signal
 | |
| 
 | |
| 	signal.Notify(
 | |
| 		srv.sigChan,
 | |
| 		syscall.SIGHUP,
 | |
| 		syscall.SIGUSR1,
 | |
| 		syscall.SIGUSR2,
 | |
| 		syscall.SIGINT,
 | |
| 		syscall.SIGTERM,
 | |
| 		syscall.SIGTSTP,
 | |
| 	)
 | |
| 
 | |
| 	pid := syscall.Getpid()
 | |
| 	for {
 | |
| 		sig = <-srv.sigChan
 | |
| 		srv.signalHooks(PRE_SIGNAL, sig)
 | |
| 		switch sig {
 | |
| 		case syscall.SIGHUP:
 | |
| 			log.Println(pid, "Received SIGHUP. forking.")
 | |
| 			err := srv.fork()
 | |
| 			if err != nil {
 | |
| 				log.Println("Fork err:", err)
 | |
| 			}
 | |
| 		case syscall.SIGUSR1:
 | |
| 			log.Println(pid, "Received SIGUSR1.")
 | |
| 		case syscall.SIGUSR2:
 | |
| 			log.Println(pid, "Received SIGUSR2.")
 | |
| 			srv.serverTimeout(0 * time.Second)
 | |
| 		case syscall.SIGINT:
 | |
| 			log.Println(pid, "Received SIGINT.")
 | |
| 			srv.shutdown()
 | |
| 		case syscall.SIGTERM:
 | |
| 			log.Println(pid, "Received SIGTERM.")
 | |
| 			srv.shutdown()
 | |
| 		case syscall.SIGTSTP:
 | |
| 			log.Println(pid, "Received SIGTSTP.")
 | |
| 		default:
 | |
| 			log.Printf("Received %v: nothing i care about...\n", sig)
 | |
| 		}
 | |
| 		srv.signalHooks(POST_SIGNAL, sig)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) {
 | |
| 	if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
 | |
| 		return
 | |
| 	}
 | |
| 	for _, f := range srv.SignalHooks[ppFlag][sig] {
 | |
| 		f()
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // shutdown closes the listener so that no new connections are accepted. it also
 | |
| // starts a goroutine that will hammer (stop all running requests) the server
 | |
| // after DefaultTimeout.
 | |
| func (srv *graceServer) shutdown() {
 | |
| 	if srv.state != STATE_RUNNING {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	srv.state = STATE_SHUTTING_DOWN
 | |
| 	if DefaultTimeout >= 0 {
 | |
| 		go srv.serverTimeout(DefaultTimeout)
 | |
| 	}
 | |
| 	err := srv.GraceListener.Close()
 | |
| 	if err != nil {
 | |
| 		log.Println(syscall.Getpid(), "Listener.Close() error:", err)
 | |
| 	} else {
 | |
| 		log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // hammerTime forces the server to shutdown in a given timeout - whether it
 | |
| // finished outstanding requests or not. if Read/WriteTimeout are not set or the
 | |
| // max header size is very big a connection could hang...
 | |
| //
 | |
| // srv.Serve() will not return until all connections are served. this will
 | |
| // unblock the srv.wg.Wait() in Serve() thus causing ListenAndServe(TLS) to
 | |
| // return.
 | |
| func (srv *graceServer) serverTimeout(d time.Duration) {
 | |
| 	defer func() {
 | |
| 		// we are calling srv.wg.Done() until it panics which means we called
 | |
| 		// Done() when the counter was already at 0 and we're done.
 | |
| 		// (and thus Serve() will return and the parent will exit)
 | |
| 		if r := recover(); r != nil {
 | |
| 			log.Println("WaitGroup at 0", r)
 | |
| 		}
 | |
| 	}()
 | |
| 	if srv.state != STATE_SHUTTING_DOWN {
 | |
| 		return
 | |
| 	}
 | |
| 	time.Sleep(d)
 | |
| 	log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
 | |
| 	for {
 | |
| 		if srv.state == STATE_TERMINATE {
 | |
| 			break
 | |
| 		}
 | |
| 		srv.wg.Done()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (srv *graceServer) fork() (err error) {
 | |
| 	// only one server isntance should fork!
 | |
| 	regLock.Lock()
 | |
| 	defer regLock.Unlock()
 | |
| 	if runningServersForked {
 | |
| 		return
 | |
| 	}
 | |
| 	runningServersForked = true
 | |
| 
 | |
| 	var files = make([]*os.File, len(runningServers))
 | |
| 	var orderArgs = make([]string, len(runningServers))
 | |
| 	// get the accessor socket fds for _all_ server instances
 | |
| 	for _, srvPtr := range runningServers {
 | |
| 		// introspect.PrintTypeDump(srvPtr.EndlessListener)
 | |
| 		switch srvPtr.GraceListener.(type) {
 | |
| 		case *graceListener:
 | |
| 			// normal listener
 | |
| 			files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
 | |
| 		default:
 | |
| 			// tls listener
 | |
| 			files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
 | |
| 		}
 | |
| 		orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
 | |
| 	}
 | |
| 
 | |
| 	log.Println(files)
 | |
| 	path := os.Args[0]
 | |
| 	var args []string
 | |
| 	if len(os.Args) > 1 {
 | |
| 		for _, arg := range os.Args[1:] {
 | |
| 			if arg == "-graceful" {
 | |
| 				break
 | |
| 			}
 | |
| 			args = append(args, arg)
 | |
| 		}
 | |
| 	}
 | |
| 	args = append(args, "-graceful")
 | |
| 	if len(runningServers) > 1 {
 | |
| 		args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
 | |
| 		log.Println(args)
 | |
| 	}
 | |
| 	cmd := exec.Command(path, args...)
 | |
| 	cmd.Stdout = os.Stdout
 | |
| 	cmd.Stderr = os.Stderr
 | |
| 	cmd.ExtraFiles = files
 | |
| 	err = cmd.Start()
 | |
| 	if err != nil {
 | |
| 		log.Fatalf("Restart: Failed to launch, error: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return
 | |
| }
 | |
| 
 | |
| type graceListener struct {
 | |
| 	net.Listener
 | |
| 	stop    chan error
 | |
| 	stopped bool
 | |
| 	server  *graceServer
 | |
| }
 | |
| 
 | |
| func (gl *graceListener) Accept() (c net.Conn, err error) {
 | |
| 	tc, err := gl.Listener.(*net.TCPListener).AcceptTCP()
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	tc.SetKeepAlive(true)                  // see http.tcpKeepAliveListener
 | |
| 	tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
 | |
| 
 | |
| 	c = graceConn{
 | |
| 		Conn:   tc,
 | |
| 		server: gl.server,
 | |
| 	}
 | |
| 
 | |
| 	gl.server.wg.Add(1)
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func newGraceListener(l net.Listener, srv *graceServer) (el *graceListener) {
 | |
| 	el = &graceListener{
 | |
| 		Listener: l,
 | |
| 		stop:     make(chan error),
 | |
| 		server:   srv,
 | |
| 	}
 | |
| 
 | |
| 	// Starting the listener for the stop signal here because Accept blocks on
 | |
| 	// el.Listener.(*net.TCPListener).AcceptTCP()
 | |
| 	// The goroutine will unblock it by closing the listeners fd
 | |
| 	go func() {
 | |
| 		_ = <-el.stop
 | |
| 		el.stopped = true
 | |
| 		el.stop <- el.Listener.Close()
 | |
| 	}()
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (el *graceListener) Close() error {
 | |
| 	if el.stopped {
 | |
| 		return syscall.EINVAL
 | |
| 	}
 | |
| 	el.stop <- nil
 | |
| 	return <-el.stop
 | |
| }
 | |
| 
 | |
| func (el *graceListener) File() *os.File {
 | |
| 	// returns a dup(2) - FD_CLOEXEC flag *not* set
 | |
| 	tl := el.Listener.(*net.TCPListener)
 | |
| 	fl, _ := tl.File()
 | |
| 	return fl
 | |
| }
 | |
| 
 | |
| type graceConn struct {
 | |
| 	net.Conn
 | |
| 	server *graceServer
 | |
| }
 | |
| 
 | |
| func (c graceConn) Close() error {
 | |
| 	c.server.wg.Done()
 | |
| 	return c.Conn.Close()
 | |
| }
 |