Merge pull request #4918 from robberphex/callback
support LifeCycleCallback
This commit is contained in:
		
						commit
						69c17fafbb
					
				| @ -5,6 +5,7 @@ | |||||||
| - [make `PatternLogFormatter` handling the arguments](https://github.com/beego/beego/pull/4914/files) | - [make `PatternLogFormatter` handling the arguments](https://github.com/beego/beego/pull/4914/files) | ||||||
| - [Add httplib OpenTelemetry Filter](https://github.com/beego/beego/pull/4888, https://github.com/beego/beego/pull/4915) | - [Add httplib OpenTelemetry Filter](https://github.com/beego/beego/pull/4888, https://github.com/beego/beego/pull/4915) | ||||||
| - [Support NewBeegoRequestWithCtx in httplib](https://github.com/beego/beego/pull/4895) | - [Support NewBeegoRequestWithCtx in httplib](https://github.com/beego/beego/pull/4895) | ||||||
|  | - [Support lifecycle callback](https://github.com/beego/beego/pull/4918) | ||||||
| 
 | 
 | ||||||
| # v2.0.2 | # v2.0.2 | ||||||
| See v2.0.2-beta.1 | See v2.0.2-beta.1 | ||||||
|  | |||||||
| @ -105,8 +105,17 @@ func init() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ServerOption configures how we set up the connection. | ||||||
|  | type ServerOption func(*Server) | ||||||
|  | 
 | ||||||
|  | func WithShutdownCallback(shutdownCallback func()) ServerOption { | ||||||
|  | 	return func(srv *Server) { | ||||||
|  | 		srv.shutdownCallbacks = append(srv.shutdownCallbacks, shutdownCallback) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // NewServer returns a new graceServer. | // NewServer returns a new graceServer. | ||||||
| func NewServer(addr string, handler http.Handler) (srv *Server) { | func NewServer(addr string, handler http.Handler, opts ...ServerOption) (srv *Server) { | ||||||
| 	regLock.Lock() | 	regLock.Lock() | ||||||
| 	defer regLock.Unlock() | 	defer regLock.Unlock() | ||||||
| 
 | 
 | ||||||
| @ -148,6 +157,10 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { | |||||||
| 		Handler:        handler, | 		Handler:        handler, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	for _, opt := range opts { | ||||||
|  | 		opt(srv) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	runningServersOrder = append(runningServersOrder, addr) | 	runningServersOrder = append(runningServersOrder, addr) | ||||||
| 	runningServers[addr] = srv | 	runningServers[addr] = srv | ||||||
| 	return srv | 	return srv | ||||||
|  | |||||||
| @ -27,24 +27,34 @@ type Server struct { | |||||||
| 	state             uint8 | 	state             uint8 | ||||||
| 	Network           string | 	Network           string | ||||||
| 	terminalChan      chan error | 	terminalChan      chan error | ||||||
|  | 	shutdownCallbacks []func() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Serve accepts incoming connections on the Listener l | // Serve accepts incoming connections on the Listener l | ||||||
| // and creates a new service goroutine for each. | // and creates a new service goroutine for each. | ||||||
| // The service goroutines read requests and then call srv.Handler to reply to them. | // The service goroutines read requests and then call srv.Handler to reply to them. | ||||||
| func (srv *Server) Serve() (err error) { | func (srv *Server) Serve() (err error) { | ||||||
|  | 	return srv.internalServe(srv.ln) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (srv *Server) ServeWithListener(ln net.Listener) (err error) { | ||||||
|  | 	go srv.handleSignals() | ||||||
|  | 	return srv.internalServe(ln) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (srv *Server) internalServe(ln net.Listener) (err error) { | ||||||
| 	srv.state = StateRunning | 	srv.state = StateRunning | ||||||
| 	defer func() { srv.state = StateTerminate }() | 	defer func() { srv.state = StateTerminate }() | ||||||
| 
 | 
 | ||||||
| 	// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS | 	// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS | ||||||
| 	// immediately return ErrServerClosed. Make sure the program doesn't exit | 	// immediately return ErrServerClosed. Make sure the program doesn't exit | ||||||
| 	// and waits instead for Shutdown to return. | 	// and waits instead for Shutdown to return. | ||||||
| 	if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { | 	if err = srv.Server.Serve(ln); err != nil && err != http.ErrServerClosed { | ||||||
| 		log.Println(syscall.Getpid(), "Server.Serve() error:", err) | 		log.Println(syscall.Getpid(), "Server.Serve() error:", err) | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") | 	log.Println(syscall.Getpid(), ln.Addr(), "Listener closed.") | ||||||
| 	// wait for Shutdown to return | 	// wait for Shutdown to return | ||||||
| 	if shutdownErr := <-srv.terminalChan; shutdownErr != nil { | 	if shutdownErr := <-srv.terminalChan; shutdownErr != nil { | ||||||
| 		return shutdownErr | 		return shutdownErr | ||||||
| @ -95,6 +105,15 @@ func (srv *Server) ListenAndServe() (err error) { | |||||||
| // | // | ||||||
| // If srv.Addr is blank, ":https" is used. | // If srv.Addr is blank, ":https" is used. | ||||||
| func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { | func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { | ||||||
|  | 	ln, err := srv.ListenTLS(certFile, keyFile) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return srv.ServeTLS(ln) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (srv *Server) ListenTLS(certFile string, keyFile string) (net.Listener, error) { | ||||||
| 	addr := srv.Addr | 	addr := srv.Addr | ||||||
| 	if addr == "" { | 	if addr == "" { | ||||||
| 		addr = ":https" | 		addr = ":https" | ||||||
| @ -108,20 +127,35 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	srv.TLSConfig.Certificates = make([]tls.Certificate, 1) | 	srv.TLSConfig.Certificates = make([]tls.Certificate, 1) | ||||||
| 	srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) | 	cert, err := tls.LoadX509KeyPair(certFile, keyFile) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 	srv.TLSConfig.Certificates[0] = cert | ||||||
| 
 | 
 | ||||||
| 	go srv.handleSignals() | 	go srv.handleSignals() | ||||||
| 
 | 
 | ||||||
| 	ln, err := srv.getListener(addr) | 	ln, err := srv.getListener(addr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Println(err) | 		log.Println(err) | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) | ||||||
|  | 	return tlsListener, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls | ||||||
|  | // Serve to handle requests on incoming mutual TLS connections. | ||||||
|  | func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { | ||||||
|  | 	ln, err := srv.ListenMutualTLS(certFile, keyFile, trustFile) | ||||||
|  | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) |  | ||||||
| 
 | 
 | ||||||
|  | 	return srv.ServeTLS(ln) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (srv *Server) ServeTLS(ln net.Listener) error { | ||||||
| 	if srv.isChild { | 	if srv.isChild { | ||||||
| 		process, err := os.FindProcess(os.Getppid()) | 		process, err := os.FindProcess(os.Getppid()) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @ -134,13 +168,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	go srv.handleSignals() | ||||||
| 	log.Println(os.Getpid(), srv.Addr) | 	log.Println(os.Getpid(), srv.Addr) | ||||||
| 	return srv.Serve() | 	return srv.internalServe(ln) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls | func (srv *Server) ListenMutualTLS(certFile string, keyFile string, trustFile string) (net.Listener, error) { | ||||||
| // Serve to handle requests on incoming mutual TLS connections. |  | ||||||
| func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { |  | ||||||
| 	addr := srv.Addr | 	addr := srv.Addr | ||||||
| 	if addr == "" { | 	if addr == "" { | ||||||
| 		addr = ":https" | 		addr = ":https" | ||||||
| @ -154,16 +187,17 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	srv.TLSConfig.Certificates = make([]tls.Certificate, 1) | 	srv.TLSConfig.Certificates = make([]tls.Certificate, 1) | ||||||
| 	srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) | 	cert, err := tls.LoadX509KeyPair(certFile, keyFile) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 	srv.TLSConfig.Certificates[0] = cert | ||||||
| 	srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert | 	srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert | ||||||
| 	pool := x509.NewCertPool() | 	pool := x509.NewCertPool() | ||||||
| 	data, err := ioutil.ReadFile(trustFile) | 	data, err := ioutil.ReadFile(trustFile) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Println(err) | 		log.Println(err) | ||||||
| 		return err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	pool.AppendCertsFromPEM(data) | 	pool.AppendCertsFromPEM(data) | ||||||
| 	srv.TLSConfig.ClientCAs = pool | 	srv.TLSConfig.ClientCAs = pool | ||||||
| @ -173,24 +207,10 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) | |||||||
| 	ln, err := srv.getListener(addr) | 	ln, err := srv.getListener(addr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Println(err) | 		log.Println(err) | ||||||
| 		return err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) | 	tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) | ||||||
| 
 | 	return tlsListener, nil | ||||||
| 	if srv.isChild { |  | ||||||
| 		process, err := os.FindProcess(os.Getppid()) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Println(err) |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		err = process.Signal(syscall.SIGTERM) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	log.Println(os.Getpid(), srv.Addr) |  | ||||||
| 	return srv.Serve() |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getListener either opens a new socket to listen on, or takes the acceptor socket | // getListener either opens a new socket to listen on, or takes the acceptor socket | ||||||
| @ -292,6 +312,9 @@ func (srv *Server) shutdown() { | |||||||
| 		ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) | 		ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) | ||||||
| 		defer cancel() | 		defer cancel() | ||||||
| 	} | 	} | ||||||
|  | 	for _, shutdownCallback := range srv.shutdownCallbacks { | ||||||
|  | 		shutdownCallback() | ||||||
|  | 	} | ||||||
| 	srv.terminalChan <- srv.Server.Shutdown(ctx) | 	srv.terminalChan <- srv.Server.Shutdown(ctx) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -132,6 +132,10 @@ func (c *ControllerInfo) GetPattern() string { | |||||||
| 	return c.pattern | 	return c.pattern | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (c *ControllerInfo) GetMethod() map[string]string { | ||||||
|  | 	return c.methods | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { | func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { | ||||||
| 	return func(c *ControllerInfo) { | 	return func(c *ControllerInfo) { | ||||||
| 		c.methods = parseMappingMethods(ctrlInterface, mappingMethod) | 		c.methods = parseMappingMethods(ctrlInterface, mappingMethod) | ||||||
| @ -1315,6 +1319,32 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetAllControllerInfo get all ControllerInfo | ||||||
|  | func (p *ControllerRegister) GetAllControllerInfo() (routerInfos []*ControllerInfo) { | ||||||
|  | 	for _, webTree := range p.routers { | ||||||
|  | 		composeControllerInfos(webTree, &routerInfos) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func composeControllerInfos(tree *Tree, routerInfos *[]*ControllerInfo) { | ||||||
|  | 	if tree.fixrouters != nil { | ||||||
|  | 		for _, subTree := range tree.fixrouters { | ||||||
|  | 			composeControllerInfos(subTree, routerInfos) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if tree.wildcard != nil { | ||||||
|  | 		composeControllerInfos(tree.wildcard, routerInfos) | ||||||
|  | 	} | ||||||
|  | 	if tree.leaves != nil { | ||||||
|  | 		for _, l := range tree.leaves { | ||||||
|  | 			if c, ok := l.runObject.(*ControllerInfo); ok { | ||||||
|  | 				*routerInfos = append(*routerInfos, c) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func toURL(params map[string]string) string { | func toURL(params map[string]string) string { | ||||||
| 	if len(params) == 0 { | 	if len(params) == 0 { | ||||||
| 		return "" | 		return "" | ||||||
|  | |||||||
| @ -19,6 +19,8 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | 	"reflect" | ||||||
|  | 	"sort" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| @ -1097,3 +1099,32 @@ func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) { | |||||||
| 	handler := NewControllerRegister() | 	handler := NewControllerRegister() | ||||||
| 	handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer) | 	handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestGetAllControllerInfo(t *testing.T) { | ||||||
|  | 	handler := NewControllerRegister() | ||||||
|  | 	handler.Add("/level1", &TestController{}, WithRouterMethods(&TestController{}, "get:Get")) | ||||||
|  | 	handler.Add("/level1/level2", &TestController{}, WithRouterMethods(&TestController{}, "get:Get")) | ||||||
|  | 	handler.Add("/:name1", &TestController{}, WithRouterMethods(&TestController{}, "post:Post")) | ||||||
|  | 
 | ||||||
|  | 	var actualPatterns []string | ||||||
|  | 	var actualMethods []string | ||||||
|  | 	for _, controllerInfo := range handler.GetAllControllerInfo() { | ||||||
|  | 		actualPatterns = append(actualPatterns, controllerInfo.GetPattern()) | ||||||
|  | 		for _, httpMethod := range controllerInfo.GetMethod() { | ||||||
|  | 			actualMethods = append(actualMethods, httpMethod) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	sort.Strings(actualPatterns) | ||||||
|  | 	expectedPatterns := []string{"/level1", "/level1/level2", "/:name1"} | ||||||
|  | 	sort.Strings(expectedPatterns) | ||||||
|  | 	if !reflect.DeepEqual(actualPatterns, expectedPatterns) { | ||||||
|  | 		t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedPatterns, actualPatterns) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	sort.Strings(actualMethods) | ||||||
|  | 	expectedMethods := []string{"Get", "Get", "Post"} | ||||||
|  | 	sort.Strings(expectedMethods) | ||||||
|  | 	if !reflect.DeepEqual(actualMethods, expectedMethods) { | ||||||
|  | 		t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedMethods, actualMethods) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -52,6 +52,7 @@ type HttpServer struct { | |||||||
| 	Handlers           *ControllerRegister | 	Handlers           *ControllerRegister | ||||||
| 	Server             *http.Server | 	Server             *http.Server | ||||||
| 	Cfg                *Config | 	Cfg                *Config | ||||||
|  | 	LifeCycleCallbacks []LifeCycleCallback | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewHttpSever returns a new beego application. | // NewHttpSever returns a new beego application. | ||||||
| @ -76,6 +77,13 @@ func NewHttpServerWithCfg(cfg *Config) *HttpServer { | |||||||
| // MiddleWare function for http.Handler | // MiddleWare function for http.Handler | ||||||
| type MiddleWare func(http.Handler) http.Handler | type MiddleWare func(http.Handler) http.Handler | ||||||
| 
 | 
 | ||||||
|  | // LifeCycleCallback configures callback. | ||||||
|  | // Developer can implement this interface to add custom logic to server lifecycle | ||||||
|  | type LifeCycleCallback interface { | ||||||
|  | 	AfterStart(app *HttpServer) | ||||||
|  | 	BeforeShutdown(app *HttpServer) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Run beego application. | // Run beego application. | ||||||
| func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | ||||||
| 	initBeforeHTTPRun() | 	initBeforeHTTPRun() | ||||||
| @ -98,12 +106,18 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | |||||||
| 
 | 
 | ||||||
| 	// run cgi server | 	// run cgi server | ||||||
| 	if app.Cfg.Listen.EnableFcgi { | 	if app.Cfg.Listen.EnableFcgi { | ||||||
|  | 		for _, lifeCycleCallback := range app.LifeCycleCallbacks { | ||||||
|  | 			lifeCycleCallback.AfterStart(app) | ||||||
|  | 		} | ||||||
| 		if app.Cfg.Listen.EnableStdIo { | 		if app.Cfg.Listen.EnableStdIo { | ||||||
| 			if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O | 			if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O | ||||||
| 				logs.Info("Use FCGI via standard I/O") | 				logs.Info("Use FCGI via standard I/O") | ||||||
| 			} else { | 			} else { | ||||||
| 				logs.Critical("Cannot use FCGI via standard I/O", err) | 				logs.Critical("Cannot use FCGI via standard I/O", err) | ||||||
| 			} | 			} | ||||||
|  | 			for _, lifeCycleCallback := range app.LifeCycleCallbacks { | ||||||
|  | 				lifeCycleCallback.BeforeShutdown(app) | ||||||
|  | 			} | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		if app.Cfg.Listen.HTTPPort == 0 { | 		if app.Cfg.Listen.HTTPPort == 0 { | ||||||
| @ -116,11 +130,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | |||||||
| 			l, err = net.Listen("tcp", addr) | 			l, err = net.Listen("tcp", addr) | ||||||
| 		} | 		} | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logs.Critical("Listen: ", err) | 			logs.Critical("Listen for Fcgi: ", err) | ||||||
| 		} | 		} | ||||||
| 		if err = fcgi.Serve(l, app.Handlers); err != nil { | 		if err = fcgi.Serve(l, app.Handlers); err != nil { | ||||||
| 			logs.Critical("fcgi.Serve: ", err) | 			logs.Critical("fcgi.Serve: ", err) | ||||||
| 		} | 		} | ||||||
|  | 		for _, lifeCycleCallback := range app.LifeCycleCallbacks { | ||||||
|  | 			lifeCycleCallback.BeforeShutdown(app) | ||||||
|  | 		} | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -137,6 +154,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | |||||||
| 
 | 
 | ||||||
| 	// run graceful mode | 	// run graceful mode | ||||||
| 	if app.Cfg.Listen.Graceful { | 	if app.Cfg.Listen.Graceful { | ||||||
|  | 		var opts []grace.ServerOption | ||||||
|  | 		for _, lifeCycleCallback := range app.LifeCycleCallbacks { | ||||||
|  | 			lifeCycleCallbackDup := lifeCycleCallback | ||||||
|  | 			opts = append(opts, grace.WithShutdownCallback(func() { | ||||||
|  | 				lifeCycleCallbackDup.BeforeShutdown(app) | ||||||
|  | 			})) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		httpsAddr := app.Cfg.Listen.HTTPSAddr | 		httpsAddr := app.Cfg.Listen.HTTPSAddr | ||||||
| 		app.Server.Addr = httpsAddr | 		app.Server.Addr = httpsAddr | ||||||
| 		if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS { | 		if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS { | ||||||
| @ -146,15 +171,16 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | |||||||
| 					httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort) | 					httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort) | ||||||
| 					app.Server.Addr = httpsAddr | 					app.Server.Addr = httpsAddr | ||||||
| 				} | 				} | ||||||
| 				server := grace.NewServer(httpsAddr, app.Server.Handler) | 				server := grace.NewServer(httpsAddr, app.Server.Handler, opts...) | ||||||
| 				server.Server.ReadTimeout = app.Server.ReadTimeout | 				server.Server.ReadTimeout = app.Server.ReadTimeout | ||||||
| 				server.Server.WriteTimeout = app.Server.WriteTimeout | 				server.Server.WriteTimeout = app.Server.WriteTimeout | ||||||
|  | 				var ln net.Listener | ||||||
| 				if app.Cfg.Listen.EnableMutualHTTPS { | 				if app.Cfg.Listen.EnableMutualHTTPS { | ||||||
| 					if err := server.ListenAndServeMutualTLS(app.Cfg.Listen.HTTPSCertFile, | 					if ln, err = server.ListenMutualTLS(app.Cfg.Listen.HTTPSCertFile, | ||||||
| 						app.Cfg.Listen.HTTPSKeyFile, | 						app.Cfg.Listen.HTTPSKeyFile, | ||||||
| 						app.Cfg.Listen.TrustCaFile); err != nil { | 						app.Cfg.Listen.TrustCaFile); err != nil { | ||||||
| 						logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) | 						logs.Critical("ListenMutualTLS: ", err, fmt.Sprintf("%d", os.Getpid())) | ||||||
| 						time.Sleep(100 * time.Microsecond) | 						return | ||||||
| 					} | 					} | ||||||
| 				} else { | 				} else { | ||||||
| 					if app.Cfg.Listen.AutoTLS { | 					if app.Cfg.Listen.AutoTLS { | ||||||
| @ -166,24 +192,40 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | |||||||
| 						app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} | 						app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} | ||||||
| 						app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", "" | 						app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", "" | ||||||
| 					} | 					} | ||||||
| 					if err := server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil { | 					if ln, err = server.ListenTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil { | ||||||
| 						logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) | 						logs.Critical("ListenTLS: ", err, fmt.Sprintf("%d", os.Getpid())) | ||||||
| 						time.Sleep(100 * time.Microsecond) | 						return | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
|  | 				for _, callback := range app.LifeCycleCallbacks { | ||||||
|  | 					callback.AfterStart(app) | ||||||
|  | 				} | ||||||
|  | 				if err = server.ServeTLS(ln); err != nil { | ||||||
|  | 					logs.Critical("ServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) | ||||||
|  | 					time.Sleep(100 * time.Microsecond) | ||||||
|  | 				} | ||||||
| 				endRunning <- true | 				endRunning <- true | ||||||
| 			}() | 			}() | ||||||
| 		} | 		} | ||||||
| 		if app.Cfg.Listen.EnableHTTP { | 		if app.Cfg.Listen.EnableHTTP { | ||||||
| 			go func() { | 			go func() { | ||||||
| 				server := grace.NewServer(addr, app.Server.Handler) | 				server := grace.NewServer(addr, app.Server.Handler, opts...) | ||||||
| 				server.Server.ReadTimeout = app.Server.ReadTimeout | 				server.Server.ReadTimeout = app.Server.ReadTimeout | ||||||
| 				server.Server.WriteTimeout = app.Server.WriteTimeout | 				server.Server.WriteTimeout = app.Server.WriteTimeout | ||||||
| 				if app.Cfg.Listen.ListenTCP4 { | 				if app.Cfg.Listen.ListenTCP4 { | ||||||
| 					server.Network = "tcp4" | 					server.Network = "tcp4" | ||||||
| 				} | 				} | ||||||
| 				if err := server.ListenAndServe(); err != nil { | 				ln, err := net.Listen(server.Network, app.Server.Addr) | ||||||
| 					logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) | 				if err != nil { | ||||||
|  | 					logs.Critical("Listen for HTTP[graceful mode]: ", err) | ||||||
|  | 					endRunning <- true | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				for _, callback := range app.LifeCycleCallbacks { | ||||||
|  | 					callback.AfterStart(app) | ||||||
|  | 				} | ||||||
|  | 				if err := server.ServeWithListener(ln); err != nil { | ||||||
|  | 					logs.Critical("ServeWithListener: ", err, fmt.Sprintf("%d", os.Getpid())) | ||||||
| 					time.Sleep(100 * time.Microsecond) | 					time.Sleep(100 * time.Microsecond) | ||||||
| 				} | 				} | ||||||
| 				endRunning <- true | 				endRunning <- true | ||||||
| @ -239,13 +281,13 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | |||||||
| 			if app.Cfg.Listen.ListenTCP4 { | 			if app.Cfg.Listen.ListenTCP4 { | ||||||
| 				ln, err := net.Listen("tcp4", app.Server.Addr) | 				ln, err := net.Listen("tcp4", app.Server.Addr) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					logs.Critical("ListenAndServe: ", err) | 					logs.Critical("Listen for HTTP[normal mode]: ", err) | ||||||
| 					time.Sleep(100 * time.Microsecond) | 					time.Sleep(100 * time.Microsecond) | ||||||
| 					endRunning <- true | 					endRunning <- true | ||||||
| 					return | 					return | ||||||
| 				} | 				} | ||||||
| 				if err = app.Server.Serve(ln); err != nil { | 				if err = app.Server.Serve(ln); err != nil { | ||||||
| 					logs.Critical("ListenAndServe: ", err) | 					logs.Critical("Serve: ", err) | ||||||
| 					time.Sleep(100 * time.Microsecond) | 					time.Sleep(100 * time.Microsecond) | ||||||
| 					endRunning <- true | 					endRunning <- true | ||||||
| 					return | 					return | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user