diff --git a/CHANGELOG.md b/CHANGELOG.md index 20e8cff7..2cc19889 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - [make `PatternLogFormatter` handling the arguments](https://github.com/beego/beego/pull/4914/files) - [Add httplib OpenTelemetry Filter](https://github.com/beego/beego/pull/4888, https://github.com/beego/beego/pull/4915) - [Support NewBeegoRequestWithCtx in httplib](https://github.com/beego/beego/pull/4895) +- [Support lifecycle callback](https://github.com/beego/beego/pull/4918) # v2.0.2 See v2.0.2-beta.1 diff --git a/server/web/grace/grace.go b/server/web/grace/grace.go index 96ae10ef..bad7e61d 100644 --- a/server/web/grace/grace.go +++ b/server/web/grace/grace.go @@ -105,8 +105,17 @@ func init() { } } +// ServerOption configures how we set up the connection. +type ServerOption func(*Server) + +func WithShutdownCallback(shutdownCallback func()) ServerOption { + return func(srv *Server) { + srv.shutdownCallbacks = append(srv.shutdownCallbacks, shutdownCallback) + } +} + // NewServer returns a new graceServer. -func NewServer(addr string, handler http.Handler) (srv *Server) { +func NewServer(addr string, handler http.Handler, opts ...ServerOption) (srv *Server) { regLock.Lock() defer regLock.Unlock() @@ -148,6 +157,10 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { Handler: handler, } + for _, opt := range opts { + opt(srv) + } + runningServersOrder = append(runningServersOrder, addr) runningServers[addr] = srv return srv diff --git a/server/web/grace/server.go b/server/web/grace/server.go index 5546546d..e9d36a00 100644 --- a/server/web/grace/server.go +++ b/server/web/grace/server.go @@ -20,31 +20,41 @@ import ( // Server embedded http.Server type Server struct { *http.Server - ln net.Listener - SignalHooks map[int]map[os.Signal][]func() - sigChan chan os.Signal - isChild bool - state uint8 - Network string - terminalChan chan error + ln net.Listener + SignalHooks map[int]map[os.Signal][]func() + sigChan chan os.Signal + isChild bool + state uint8 + Network string + terminalChan chan error + shutdownCallbacks []func() } // Serve accepts incoming connections on the Listener l // and creates a new service goroutine for each. // The service goroutines read requests and then call srv.Handler to reply to them. func (srv *Server) Serve() (err error) { + return srv.internalServe(srv.ln) +} + +func (srv *Server) ServeWithListener(ln net.Listener) (err error) { + go srv.handleSignals() + return srv.internalServe(ln) +} + +func (srv *Server) internalServe(ln net.Listener) (err error) { srv.state = StateRunning defer func() { srv.state = StateTerminate }() // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS // immediately return ErrServerClosed. Make sure the program doesn't exit // and waits instead for Shutdown to return. - if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { + if err = srv.Server.Serve(ln); err != nil && err != http.ErrServerClosed { log.Println(syscall.Getpid(), "Server.Serve() error:", err) return err } - log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") + log.Println(syscall.Getpid(), ln.Addr(), "Listener closed.") // wait for Shutdown to return if shutdownErr := <-srv.terminalChan; shutdownErr != nil { return shutdownErr @@ -95,6 +105,15 @@ func (srv *Server) ListenAndServe() (err error) { // // If srv.Addr is blank, ":https" is used. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { + ln, err := srv.ListenTLS(certFile, keyFile) + if err != nil { + return err + } + + return srv.ServeTLS(ln) +} + +func (srv *Server) ListenTLS(certFile string, keyFile string) (net.Listener, error) { addr := srv.Addr if addr == "" { addr = ":https" @@ -108,20 +127,35 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { } srv.TLSConfig.Certificates = make([]tls.Certificate, 1) - srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - return + return nil, err } + srv.TLSConfig.Certificates[0] = cert go srv.handleSignals() ln, err := srv.getListener(addr) if err != nil { log.Println(err) + return nil, err + } + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + return tlsListener, nil +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { + ln, err := srv.ListenMutualTLS(certFile, keyFile, trustFile) + if err != nil { return err } - srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + return srv.ServeTLS(ln) +} + +func (srv *Server) ServeTLS(ln net.Listener) error { if srv.isChild { process, err := os.FindProcess(os.Getppid()) if err != nil { @@ -134,13 +168,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { } } + go srv.handleSignals() log.Println(os.Getpid(), srv.Addr) - return srv.Serve() + return srv.internalServe(ln) } -// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls -// Serve to handle requests on incoming mutual TLS connections. -func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { +func (srv *Server) ListenMutualTLS(certFile string, keyFile string, trustFile string) (net.Listener, error) { addr := srv.Addr if addr == "" { addr = ":https" @@ -154,16 +187,17 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) } srv.TLSConfig.Certificates = make([]tls.Certificate, 1) - srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - return + return nil, err } + srv.TLSConfig.Certificates[0] = cert srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert pool := x509.NewCertPool() data, err := ioutil.ReadFile(trustFile) if err != nil { log.Println(err) - return err + return nil, err } pool.AppendCertsFromPEM(data) srv.TLSConfig.ClientCAs = pool @@ -173,24 +207,10 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) ln, err := srv.getListener(addr) if err != nil { log.Println(err) - return err + return nil, err } - srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) - - if srv.isChild { - process, err := os.FindProcess(os.Getppid()) - if err != nil { - log.Println(err) - return err - } - err = process.Signal(syscall.SIGTERM) - if err != nil { - return err - } - } - - log.Println(os.Getpid(), srv.Addr) - return srv.Serve() + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + return tlsListener, nil } // getListener either opens a new socket to listen on, or takes the acceptor socket @@ -292,6 +312,9 @@ func (srv *Server) shutdown() { ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) defer cancel() } + for _, shutdownCallback := range srv.shutdownCallbacks { + shutdownCallback() + } srv.terminalChan <- srv.Server.Shutdown(ctx) } diff --git a/server/web/router.go b/server/web/router.go index c4400b20..e37da0b7 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -132,6 +132,10 @@ func (c *ControllerInfo) GetPattern() string { return c.pattern } +func (c *ControllerInfo) GetMethod() map[string]string { + return c.methods +} + func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { return func(c *ControllerInfo) { c.methods = parseMappingMethods(ctrlInterface, mappingMethod) @@ -1315,6 +1319,32 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo return } +// GetAllControllerInfo get all ControllerInfo +func (p *ControllerRegister) GetAllControllerInfo() (routerInfos []*ControllerInfo) { + for _, webTree := range p.routers { + composeControllerInfos(webTree, &routerInfos) + } + return +} + +func composeControllerInfos(tree *Tree, routerInfos *[]*ControllerInfo) { + if tree.fixrouters != nil { + for _, subTree := range tree.fixrouters { + composeControllerInfos(subTree, routerInfos) + } + } + if tree.wildcard != nil { + composeControllerInfos(tree.wildcard, routerInfos) + } + if tree.leaves != nil { + for _, l := range tree.leaves { + if c, ok := l.runObject.(*ControllerInfo); ok { + *routerInfos = append(*routerInfos, c) + } + } + } +} + func toURL(params map[string]string) string { if len(params) == 0 { return "" diff --git a/server/web/router_test.go b/server/web/router_test.go index 5009d24e..95e4937d 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -19,6 +19,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" + "sort" "strings" "testing" @@ -1097,3 +1099,32 @@ func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) { handler := NewControllerRegister() handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer) } + +func TestGetAllControllerInfo(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/level1", &TestController{}, WithRouterMethods(&TestController{}, "get:Get")) + handler.Add("/level1/level2", &TestController{}, WithRouterMethods(&TestController{}, "get:Get")) + handler.Add("/:name1", &TestController{}, WithRouterMethods(&TestController{}, "post:Post")) + + var actualPatterns []string + var actualMethods []string + for _, controllerInfo := range handler.GetAllControllerInfo() { + actualPatterns = append(actualPatterns, controllerInfo.GetPattern()) + for _, httpMethod := range controllerInfo.GetMethod() { + actualMethods = append(actualMethods, httpMethod) + } + } + sort.Strings(actualPatterns) + expectedPatterns := []string{"/level1", "/level1/level2", "/:name1"} + sort.Strings(expectedPatterns) + if !reflect.DeepEqual(actualPatterns, expectedPatterns) { + t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedPatterns, actualPatterns) + } + + sort.Strings(actualMethods) + expectedMethods := []string{"Get", "Get", "Post"} + sort.Strings(expectedMethods) + if !reflect.DeepEqual(actualMethods, expectedMethods) { + t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedMethods, actualMethods) + } +} diff --git a/server/web/server.go b/server/web/server.go index ec9b6ef9..0011d455 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -49,9 +49,10 @@ func init() { // HttpServer defines beego application with a new PatternServeMux. type HttpServer struct { - Handlers *ControllerRegister - Server *http.Server - Cfg *Config + Handlers *ControllerRegister + Server *http.Server + Cfg *Config + LifeCycleCallbacks []LifeCycleCallback } // NewHttpSever returns a new beego application. @@ -76,6 +77,13 @@ func NewHttpServerWithCfg(cfg *Config) *HttpServer { // MiddleWare function for http.Handler type MiddleWare func(http.Handler) http.Handler +// LifeCycleCallback configures callback. +// Developer can implement this interface to add custom logic to server lifecycle +type LifeCycleCallback interface { + AfterStart(app *HttpServer) + BeforeShutdown(app *HttpServer) +} + // Run beego application. func (app *HttpServer) Run(addr string, mws ...MiddleWare) { initBeforeHTTPRun() @@ -98,12 +106,18 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { // run cgi server if app.Cfg.Listen.EnableFcgi { + for _, lifeCycleCallback := range app.LifeCycleCallbacks { + lifeCycleCallback.AfterStart(app) + } if app.Cfg.Listen.EnableStdIo { if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O logs.Info("Use FCGI via standard I/O") } else { logs.Critical("Cannot use FCGI via standard I/O", err) } + for _, lifeCycleCallback := range app.LifeCycleCallbacks { + lifeCycleCallback.BeforeShutdown(app) + } return } if app.Cfg.Listen.HTTPPort == 0 { @@ -116,11 +130,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { l, err = net.Listen("tcp", addr) } if err != nil { - logs.Critical("Listen: ", err) + logs.Critical("Listen for Fcgi: ", err) } if err = fcgi.Serve(l, app.Handlers); err != nil { logs.Critical("fcgi.Serve: ", err) } + for _, lifeCycleCallback := range app.LifeCycleCallbacks { + lifeCycleCallback.BeforeShutdown(app) + } return } @@ -137,6 +154,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { // run graceful mode if app.Cfg.Listen.Graceful { + var opts []grace.ServerOption + for _, lifeCycleCallback := range app.LifeCycleCallbacks { + lifeCycleCallbackDup := lifeCycleCallback + opts = append(opts, grace.WithShutdownCallback(func() { + lifeCycleCallbackDup.BeforeShutdown(app) + })) + } + httpsAddr := app.Cfg.Listen.HTTPSAddr app.Server.Addr = httpsAddr if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS { @@ -146,15 +171,16 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort) app.Server.Addr = httpsAddr } - server := grace.NewServer(httpsAddr, app.Server.Handler) + server := grace.NewServer(httpsAddr, app.Server.Handler, opts...) server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.WriteTimeout = app.Server.WriteTimeout + var ln net.Listener if app.Cfg.Listen.EnableMutualHTTPS { - if err := server.ListenAndServeMutualTLS(app.Cfg.Listen.HTTPSCertFile, + if ln, err = server.ListenMutualTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile, app.Cfg.Listen.TrustCaFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) + logs.Critical("ListenMutualTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + return } } else { if app.Cfg.Listen.AutoTLS { @@ -166,24 +192,40 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", "" } - if err := server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) + if ln, err = server.ListenTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + return } } + for _, callback := range app.LifeCycleCallbacks { + callback.AfterStart(app) + } + if err = server.ServeTLS(ln); err != nil { + logs.Critical("ServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } endRunning <- true }() } if app.Cfg.Listen.EnableHTTP { go func() { - server := grace.NewServer(addr, app.Server.Handler) + server := grace.NewServer(addr, app.Server.Handler, opts...) server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.WriteTimeout = app.Server.WriteTimeout if app.Cfg.Listen.ListenTCP4 { server.Network = "tcp4" } - if err := server.ListenAndServe(); err != nil { - logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) + ln, err := net.Listen(server.Network, app.Server.Addr) + if err != nil { + logs.Critical("Listen for HTTP[graceful mode]: ", err) + endRunning <- true + return + } + for _, callback := range app.LifeCycleCallbacks { + callback.AfterStart(app) + } + if err := server.ServeWithListener(ln); err != nil { + logs.Critical("ServeWithListener: ", err, fmt.Sprintf("%d", os.Getpid())) time.Sleep(100 * time.Microsecond) } endRunning <- true @@ -239,13 +281,13 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { if app.Cfg.Listen.ListenTCP4 { ln, err := net.Listen("tcp4", app.Server.Addr) if err != nil { - logs.Critical("ListenAndServe: ", err) + logs.Critical("Listen for HTTP[normal mode]: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true return } if err = app.Server.Serve(ln); err != nil { - logs.Critical("ListenAndServe: ", err) + logs.Critical("Serve: ", err) time.Sleep(100 * time.Microsecond) endRunning <- true return diff --git a/sonar-project.properties b/sonar-project.properties index 1a12fb33..dbd5d6ae 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -4,4 +4,4 @@ sonar.projectKey=beego_beego # relative paths to source directories. More details and properties are described # in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/ sonar.sources=. -sonar.exclusions=**/*_test.go \ No newline at end of file +sonar.exclusions=**/*_test.go