Merge pull request #4918 from robberphex/callback

support LifeCycleCallback
This commit is contained in:
jianzhiyao 2022-04-29 11:55:21 +08:00 committed by GitHub
commit 69c17fafbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 194 additions and 54 deletions

View File

@ -5,6 +5,7 @@
- [make `PatternLogFormatter` handling the arguments](https://github.com/beego/beego/pull/4914/files) - [make `PatternLogFormatter` handling the arguments](https://github.com/beego/beego/pull/4914/files)
- [Add httplib OpenTelemetry Filter](https://github.com/beego/beego/pull/4888, https://github.com/beego/beego/pull/4915) - [Add httplib OpenTelemetry Filter](https://github.com/beego/beego/pull/4888, https://github.com/beego/beego/pull/4915)
- [Support NewBeegoRequestWithCtx in httplib](https://github.com/beego/beego/pull/4895) - [Support NewBeegoRequestWithCtx in httplib](https://github.com/beego/beego/pull/4895)
- [Support lifecycle callback](https://github.com/beego/beego/pull/4918)
# v2.0.2 # v2.0.2
See v2.0.2-beta.1 See v2.0.2-beta.1

View File

@ -105,8 +105,17 @@ func init() {
} }
} }
// ServerOption configures how we set up the connection.
type ServerOption func(*Server)
func WithShutdownCallback(shutdownCallback func()) ServerOption {
return func(srv *Server) {
srv.shutdownCallbacks = append(srv.shutdownCallbacks, shutdownCallback)
}
}
// NewServer returns a new graceServer. // NewServer returns a new graceServer.
func NewServer(addr string, handler http.Handler) (srv *Server) { func NewServer(addr string, handler http.Handler, opts ...ServerOption) (srv *Server) {
regLock.Lock() regLock.Lock()
defer regLock.Unlock() defer regLock.Unlock()
@ -148,6 +157,10 @@ func NewServer(addr string, handler http.Handler) (srv *Server) {
Handler: handler, Handler: handler,
} }
for _, opt := range opts {
opt(srv)
}
runningServersOrder = append(runningServersOrder, addr) runningServersOrder = append(runningServersOrder, addr)
runningServers[addr] = srv runningServers[addr] = srv
return srv return srv

View File

@ -20,31 +20,41 @@ import (
// Server embedded http.Server // Server embedded http.Server
type Server struct { type Server struct {
*http.Server *http.Server
ln net.Listener ln net.Listener
SignalHooks map[int]map[os.Signal][]func() SignalHooks map[int]map[os.Signal][]func()
sigChan chan os.Signal sigChan chan os.Signal
isChild bool isChild bool
state uint8 state uint8
Network string Network string
terminalChan chan error terminalChan chan error
shutdownCallbacks []func()
} }
// Serve accepts incoming connections on the Listener l // Serve accepts incoming connections on the Listener l
// and creates a new service goroutine for each. // and creates a new service goroutine for each.
// The service goroutines read requests and then call srv.Handler to reply to them. // The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *Server) Serve() (err error) { func (srv *Server) Serve() (err error) {
return srv.internalServe(srv.ln)
}
func (srv *Server) ServeWithListener(ln net.Listener) (err error) {
go srv.handleSignals()
return srv.internalServe(ln)
}
func (srv *Server) internalServe(ln net.Listener) (err error) {
srv.state = StateRunning srv.state = StateRunning
defer func() { srv.state = StateTerminate }() defer func() { srv.state = StateTerminate }()
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
// immediately return ErrServerClosed. Make sure the program doesn't exit // immediately return ErrServerClosed. Make sure the program doesn't exit
// and waits instead for Shutdown to return. // and waits instead for Shutdown to return.
if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { if err = srv.Server.Serve(ln); err != nil && err != http.ErrServerClosed {
log.Println(syscall.Getpid(), "Server.Serve() error:", err) log.Println(syscall.Getpid(), "Server.Serve() error:", err)
return err return err
} }
log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") log.Println(syscall.Getpid(), ln.Addr(), "Listener closed.")
// wait for Shutdown to return // wait for Shutdown to return
if shutdownErr := <-srv.terminalChan; shutdownErr != nil { if shutdownErr := <-srv.terminalChan; shutdownErr != nil {
return shutdownErr return shutdownErr
@ -95,6 +105,15 @@ func (srv *Server) ListenAndServe() (err error) {
// //
// If srv.Addr is blank, ":https" is used. // If srv.Addr is blank, ":https" is used.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
ln, err := srv.ListenTLS(certFile, keyFile)
if err != nil {
return err
}
return srv.ServeTLS(ln)
}
func (srv *Server) ListenTLS(certFile string, keyFile string) (net.Listener, error) {
addr := srv.Addr addr := srv.Addr
if addr == "" { if addr == "" {
addr = ":https" addr = ":https"
@ -108,20 +127,35 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
} }
srv.TLSConfig.Certificates = make([]tls.Certificate, 1) srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
return return nil, err
} }
srv.TLSConfig.Certificates[0] = cert
go srv.handleSignals() go srv.handleSignals()
ln, err := srv.getListener(addr) ln, err := srv.getListener(addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return nil, err
}
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
return tlsListener, nil
}
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming mutual TLS connections.
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
ln, err := srv.ListenMutualTLS(certFile, keyFile, trustFile)
if err != nil {
return err return err
} }
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
return srv.ServeTLS(ln)
}
func (srv *Server) ServeTLS(ln net.Listener) error {
if srv.isChild { if srv.isChild {
process, err := os.FindProcess(os.Getppid()) process, err := os.FindProcess(os.Getppid())
if err != nil { if err != nil {
@ -134,13 +168,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
} }
} }
go srv.handleSignals()
log.Println(os.Getpid(), srv.Addr) log.Println(os.Getpid(), srv.Addr)
return srv.Serve() return srv.internalServe(ln)
} }
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls func (srv *Server) ListenMutualTLS(certFile string, keyFile string, trustFile string) (net.Listener, error) {
// Serve to handle requests on incoming mutual TLS connections.
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
addr := srv.Addr addr := srv.Addr
if addr == "" { if addr == "" {
addr = ":https" addr = ":https"
@ -154,16 +187,17 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
} }
srv.TLSConfig.Certificates = make([]tls.Certificate, 1) srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
return return nil, err
} }
srv.TLSConfig.Certificates[0] = cert
srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
pool := x509.NewCertPool() pool := x509.NewCertPool()
data, err := ioutil.ReadFile(trustFile) data, err := ioutil.ReadFile(trustFile)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err return nil, err
} }
pool.AppendCertsFromPEM(data) pool.AppendCertsFromPEM(data)
srv.TLSConfig.ClientCAs = pool srv.TLSConfig.ClientCAs = pool
@ -173,24 +207,10 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
ln, err := srv.getListener(addr) ln, err := srv.getListener(addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err return nil, err
} }
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
return tlsListener, nil
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Signal(syscall.SIGTERM)
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
} }
// getListener either opens a new socket to listen on, or takes the acceptor socket // getListener either opens a new socket to listen on, or takes the acceptor socket
@ -292,6 +312,9 @@ func (srv *Server) shutdown() {
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel() defer cancel()
} }
for _, shutdownCallback := range srv.shutdownCallbacks {
shutdownCallback()
}
srv.terminalChan <- srv.Server.Shutdown(ctx) srv.terminalChan <- srv.Server.Shutdown(ctx)
} }

View File

@ -132,6 +132,10 @@ func (c *ControllerInfo) GetPattern() string {
return c.pattern return c.pattern
} }
func (c *ControllerInfo) GetMethod() map[string]string {
return c.methods
}
func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption {
return func(c *ControllerInfo) { return func(c *ControllerInfo) {
c.methods = parseMappingMethods(ctrlInterface, mappingMethod) c.methods = parseMappingMethods(ctrlInterface, mappingMethod)
@ -1315,6 +1319,32 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo
return return
} }
// GetAllControllerInfo get all ControllerInfo
func (p *ControllerRegister) GetAllControllerInfo() (routerInfos []*ControllerInfo) {
for _, webTree := range p.routers {
composeControllerInfos(webTree, &routerInfos)
}
return
}
func composeControllerInfos(tree *Tree, routerInfos *[]*ControllerInfo) {
if tree.fixrouters != nil {
for _, subTree := range tree.fixrouters {
composeControllerInfos(subTree, routerInfos)
}
}
if tree.wildcard != nil {
composeControllerInfos(tree.wildcard, routerInfos)
}
if tree.leaves != nil {
for _, l := range tree.leaves {
if c, ok := l.runObject.(*ControllerInfo); ok {
*routerInfos = append(*routerInfos, c)
}
}
}
}
func toURL(params map[string]string) string { func toURL(params map[string]string) string {
if len(params) == 0 { if len(params) == 0 {
return "" return ""

View File

@ -19,6 +19,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"sort"
"strings" "strings"
"testing" "testing"
@ -1097,3 +1099,32 @@ func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) {
handler := NewControllerRegister() handler := NewControllerRegister()
handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer) handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer)
} }
func TestGetAllControllerInfo(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/level1", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"))
handler.Add("/level1/level2", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"))
handler.Add("/:name1", &TestController{}, WithRouterMethods(&TestController{}, "post:Post"))
var actualPatterns []string
var actualMethods []string
for _, controllerInfo := range handler.GetAllControllerInfo() {
actualPatterns = append(actualPatterns, controllerInfo.GetPattern())
for _, httpMethod := range controllerInfo.GetMethod() {
actualMethods = append(actualMethods, httpMethod)
}
}
sort.Strings(actualPatterns)
expectedPatterns := []string{"/level1", "/level1/level2", "/:name1"}
sort.Strings(expectedPatterns)
if !reflect.DeepEqual(actualPatterns, expectedPatterns) {
t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedPatterns, actualPatterns)
}
sort.Strings(actualMethods)
expectedMethods := []string{"Get", "Get", "Post"}
sort.Strings(expectedMethods)
if !reflect.DeepEqual(actualMethods, expectedMethods) {
t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedMethods, actualMethods)
}
}

View File

@ -49,9 +49,10 @@ func init() {
// HttpServer defines beego application with a new PatternServeMux. // HttpServer defines beego application with a new PatternServeMux.
type HttpServer struct { type HttpServer struct {
Handlers *ControllerRegister Handlers *ControllerRegister
Server *http.Server Server *http.Server
Cfg *Config Cfg *Config
LifeCycleCallbacks []LifeCycleCallback
} }
// NewHttpSever returns a new beego application. // NewHttpSever returns a new beego application.
@ -76,6 +77,13 @@ func NewHttpServerWithCfg(cfg *Config) *HttpServer {
// MiddleWare function for http.Handler // MiddleWare function for http.Handler
type MiddleWare func(http.Handler) http.Handler type MiddleWare func(http.Handler) http.Handler
// LifeCycleCallback configures callback.
// Developer can implement this interface to add custom logic to server lifecycle
type LifeCycleCallback interface {
AfterStart(app *HttpServer)
BeforeShutdown(app *HttpServer)
}
// Run beego application. // Run beego application.
func (app *HttpServer) Run(addr string, mws ...MiddleWare) { func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
initBeforeHTTPRun() initBeforeHTTPRun()
@ -98,12 +106,18 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
// run cgi server // run cgi server
if app.Cfg.Listen.EnableFcgi { if app.Cfg.Listen.EnableFcgi {
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
lifeCycleCallback.AfterStart(app)
}
if app.Cfg.Listen.EnableStdIo { if app.Cfg.Listen.EnableStdIo {
if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
logs.Info("Use FCGI via standard I/O") logs.Info("Use FCGI via standard I/O")
} else { } else {
logs.Critical("Cannot use FCGI via standard I/O", err) logs.Critical("Cannot use FCGI via standard I/O", err)
} }
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
lifeCycleCallback.BeforeShutdown(app)
}
return return
} }
if app.Cfg.Listen.HTTPPort == 0 { if app.Cfg.Listen.HTTPPort == 0 {
@ -116,11 +130,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
l, err = net.Listen("tcp", addr) l, err = net.Listen("tcp", addr)
} }
if err != nil { if err != nil {
logs.Critical("Listen: ", err) logs.Critical("Listen for Fcgi: ", err)
} }
if err = fcgi.Serve(l, app.Handlers); err != nil { if err = fcgi.Serve(l, app.Handlers); err != nil {
logs.Critical("fcgi.Serve: ", err) logs.Critical("fcgi.Serve: ", err)
} }
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
lifeCycleCallback.BeforeShutdown(app)
}
return return
} }
@ -137,6 +154,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
// run graceful mode // run graceful mode
if app.Cfg.Listen.Graceful { if app.Cfg.Listen.Graceful {
var opts []grace.ServerOption
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
lifeCycleCallbackDup := lifeCycleCallback
opts = append(opts, grace.WithShutdownCallback(func() {
lifeCycleCallbackDup.BeforeShutdown(app)
}))
}
httpsAddr := app.Cfg.Listen.HTTPSAddr httpsAddr := app.Cfg.Listen.HTTPSAddr
app.Server.Addr = httpsAddr app.Server.Addr = httpsAddr
if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS { if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS {
@ -146,15 +171,16 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort) httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort)
app.Server.Addr = httpsAddr app.Server.Addr = httpsAddr
} }
server := grace.NewServer(httpsAddr, app.Server.Handler) server := grace.NewServer(httpsAddr, app.Server.Handler, opts...)
server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout server.Server.WriteTimeout = app.Server.WriteTimeout
var ln net.Listener
if app.Cfg.Listen.EnableMutualHTTPS { if app.Cfg.Listen.EnableMutualHTTPS {
if err := server.ListenAndServeMutualTLS(app.Cfg.Listen.HTTPSCertFile, if ln, err = server.ListenMutualTLS(app.Cfg.Listen.HTTPSCertFile,
app.Cfg.Listen.HTTPSKeyFile, app.Cfg.Listen.HTTPSKeyFile,
app.Cfg.Listen.TrustCaFile); err != nil { app.Cfg.Listen.TrustCaFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("ListenMutualTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) return
} }
} else { } else {
if app.Cfg.Listen.AutoTLS { if app.Cfg.Listen.AutoTLS {
@ -166,24 +192,40 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", "" app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", ""
} }
if err := server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil { if ln, err = server.ListenTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) logs.Critical("ListenTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) return
} }
} }
for _, callback := range app.LifeCycleCallbacks {
callback.AfterStart(app)
}
if err = server.ServeTLS(ln); err != nil {
logs.Critical("ServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
}
endRunning <- true endRunning <- true
}() }()
} }
if app.Cfg.Listen.EnableHTTP { if app.Cfg.Listen.EnableHTTP {
go func() { go func() {
server := grace.NewServer(addr, app.Server.Handler) server := grace.NewServer(addr, app.Server.Handler, opts...)
server.Server.ReadTimeout = app.Server.ReadTimeout server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout server.Server.WriteTimeout = app.Server.WriteTimeout
if app.Cfg.Listen.ListenTCP4 { if app.Cfg.Listen.ListenTCP4 {
server.Network = "tcp4" server.Network = "tcp4"
} }
if err := server.ListenAndServe(); err != nil { ln, err := net.Listen(server.Network, app.Server.Addr)
logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) if err != nil {
logs.Critical("Listen for HTTP[graceful mode]: ", err)
endRunning <- true
return
}
for _, callback := range app.LifeCycleCallbacks {
callback.AfterStart(app)
}
if err := server.ServeWithListener(ln); err != nil {
logs.Critical("ServeWithListener: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
} }
endRunning <- true endRunning <- true
@ -239,13 +281,13 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
if app.Cfg.Listen.ListenTCP4 { if app.Cfg.Listen.ListenTCP4 {
ln, err := net.Listen("tcp4", app.Server.Addr) ln, err := net.Listen("tcp4", app.Server.Addr)
if err != nil { if err != nil {
logs.Critical("ListenAndServe: ", err) logs.Critical("Listen for HTTP[normal mode]: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
return return
} }
if err = app.Server.Serve(ln); err != nil { if err = app.Server.Serve(ln); err != nil {
logs.Critical("ListenAndServe: ", err) logs.Critical("Serve: ", err)
time.Sleep(100 * time.Microsecond) time.Sleep(100 * time.Microsecond)
endRunning <- true endRunning <- true
return return