add callback

This commit is contained in:
luyanbo 2022-04-21 09:22:20 +08:00
parent 684d4e030b
commit 01880adad1
7 changed files with 194 additions and 54 deletions

View File

@ -5,6 +5,7 @@
- [make `PatternLogFormatter` handling the arguments](https://github.com/beego/beego/pull/4914/files)
- [Add httplib OpenTelemetry Filter](https://github.com/beego/beego/pull/4888)
- [Support NewBeegoRequestWithCtx in httplib](https://github.com/beego/beego/pull/4895)
- [Support lifecycle callback](https://github.com/beego/beego/pull/4918)
# v2.0.2
See v2.0.2-beta.1

View File

@ -105,8 +105,17 @@ func init() {
}
}
// ServerOption configures how we set up the connection.
type ServerOption func(*Server)
func WithShutdownCallback(shutdownCallback func()) ServerOption {
return func(srv *Server) {
srv.shutdownCallbacks = append(srv.shutdownCallbacks, shutdownCallback)
}
}
// NewServer returns a new graceServer.
func NewServer(addr string, handler http.Handler) (srv *Server) {
func NewServer(addr string, handler http.Handler, opts ...ServerOption) (srv *Server) {
regLock.Lock()
defer regLock.Unlock()
@ -148,6 +157,10 @@ func NewServer(addr string, handler http.Handler) (srv *Server) {
Handler: handler,
}
for _, opt := range opts {
opt(srv)
}
runningServersOrder = append(runningServersOrder, addr)
runningServers[addr] = srv
return srv

View File

@ -20,31 +20,41 @@ import (
// Server embedded http.Server
type Server struct {
*http.Server
ln net.Listener
SignalHooks map[int]map[os.Signal][]func()
sigChan chan os.Signal
isChild bool
state uint8
Network string
terminalChan chan error
ln net.Listener
SignalHooks map[int]map[os.Signal][]func()
sigChan chan os.Signal
isChild bool
state uint8
Network string
terminalChan chan error
shutdownCallbacks []func()
}
// Serve accepts incoming connections on the Listener l
// and creates a new service goroutine for each.
// The service goroutines read requests and then call srv.Handler to reply to them.
func (srv *Server) Serve() (err error) {
return srv.internalServe(srv.ln)
}
func (srv *Server) ServeWithListener(ln net.Listener) (err error) {
go srv.handleSignals()
return srv.internalServe(ln)
}
func (srv *Server) internalServe(ln net.Listener) (err error) {
srv.state = StateRunning
defer func() { srv.state = StateTerminate }()
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
// immediately return ErrServerClosed. Make sure the program doesn't exit
// and waits instead for Shutdown to return.
if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed {
if err = srv.Server.Serve(ln); err != nil && err != http.ErrServerClosed {
log.Println(syscall.Getpid(), "Server.Serve() error:", err)
return err
}
log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
log.Println(syscall.Getpid(), ln.Addr(), "Listener closed.")
// wait for Shutdown to return
if shutdownErr := <-srv.terminalChan; shutdownErr != nil {
return shutdownErr
@ -95,6 +105,15 @@ func (srv *Server) ListenAndServe() (err error) {
//
// If srv.Addr is blank, ":https" is used.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
ln, err := srv.ListenTLS(certFile, keyFile)
if err != nil {
return err
}
return srv.ServeTLS(ln)
}
func (srv *Server) ListenTLS(certFile string, keyFile string) (net.Listener, error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
@ -108,20 +127,35 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
}
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
return nil, err
}
srv.TLSConfig.Certificates[0] = cert
go srv.handleSignals()
ln, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return nil, err
}
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
return tlsListener, nil
}
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming mutual TLS connections.
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
ln, err := srv.ListenMutualTLS(certFile, keyFile, trustFile)
if err != nil {
return err
}
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
return srv.ServeTLS(ln)
}
func (srv *Server) ServeTLS(ln net.Listener) error {
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
@ -134,13 +168,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
}
}
go srv.handleSignals()
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
return srv.internalServe(ln)
}
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming mutual TLS connections.
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
func (srv *Server) ListenMutualTLS(certFile string, keyFile string, trustFile string) (net.Listener, error) {
addr := srv.Addr
if addr == "" {
addr = ":https"
@ -154,16 +187,17 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
}
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return
return nil, err
}
srv.TLSConfig.Certificates[0] = cert
srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
pool := x509.NewCertPool()
data, err := ioutil.ReadFile(trustFile)
if err != nil {
log.Println(err)
return err
return nil, err
}
pool.AppendCertsFromPEM(data)
srv.TLSConfig.ClientCAs = pool
@ -173,24 +207,10 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
ln, err := srv.getListener(addr)
if err != nil {
log.Println(err)
return err
return nil, err
}
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
if srv.isChild {
process, err := os.FindProcess(os.Getppid())
if err != nil {
log.Println(err)
return err
}
err = process.Signal(syscall.SIGTERM)
if err != nil {
return err
}
}
log.Println(os.Getpid(), srv.Addr)
return srv.Serve()
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
return tlsListener, nil
}
// getListener either opens a new socket to listen on, or takes the acceptor socket
@ -292,6 +312,9 @@ func (srv *Server) shutdown() {
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
}
for _, shutdownCallback := range srv.shutdownCallbacks {
shutdownCallback()
}
srv.terminalChan <- srv.Server.Shutdown(ctx)
}

View File

@ -132,6 +132,10 @@ func (c *ControllerInfo) GetPattern() string {
return c.pattern
}
func (c *ControllerInfo) GetMethod() map[string]string {
return c.methods
}
func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption {
return func(c *ControllerInfo) {
c.methods = parseMappingMethods(ctrlInterface, mappingMethod)
@ -1315,6 +1319,32 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo
return
}
// GetAllControllerInfo get all ControllerInfo
func (p *ControllerRegister) GetAllControllerInfo() (routerInfos []*ControllerInfo) {
for _, webTree := range p.routers {
composeControllerInfos(webTree, &routerInfos)
}
return
}
func composeControllerInfos(tree *Tree, routerInfos *[]*ControllerInfo) {
if tree.fixrouters != nil {
for _, subTree := range tree.fixrouters {
composeControllerInfos(subTree, routerInfos)
}
}
if tree.wildcard != nil {
composeControllerInfos(tree.wildcard, routerInfos)
}
if tree.leaves != nil {
for _, l := range tree.leaves {
if c, ok := l.runObject.(*ControllerInfo); ok {
*routerInfos = append(*routerInfos, c)
}
}
}
}
func toURL(params map[string]string) string {
if len(params) == 0 {
return ""

View File

@ -19,6 +19,8 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"sort"
"strings"
"testing"
@ -1097,3 +1099,32 @@ func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) {
handler := NewControllerRegister()
handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer)
}
func TestGetAllControllerInfo(t *testing.T) {
handler := NewControllerRegister()
handler.Add("/level1", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"))
handler.Add("/level1/level2", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"))
handler.Add("/:name1", &TestController{}, WithRouterMethods(&TestController{}, "post:Post"))
var actualPatterns []string
var actualMethods []string
for _, controllerInfo := range handler.GetAllControllerInfo() {
actualPatterns = append(actualPatterns, controllerInfo.GetPattern())
for _, httpMethod := range controllerInfo.GetMethod() {
actualMethods = append(actualMethods, httpMethod)
}
}
sort.Strings(actualPatterns)
expectedPatterns := []string{"/level1", "/level1/level2", "/:name1"}
sort.Strings(expectedPatterns)
if !reflect.DeepEqual(actualPatterns, expectedPatterns) {
t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedPatterns, actualPatterns)
}
sort.Strings(actualMethods)
expectedMethods := []string{"Get", "Get", "Post"}
sort.Strings(expectedMethods)
if !reflect.DeepEqual(actualMethods, expectedMethods) {
t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedMethods, actualMethods)
}
}

View File

@ -49,9 +49,10 @@ func init() {
// HttpServer defines beego application with a new PatternServeMux.
type HttpServer struct {
Handlers *ControllerRegister
Server *http.Server
Cfg *Config
Handlers *ControllerRegister
Server *http.Server
Cfg *Config
LifeCycleCallbacks []LifeCycleCallback
}
// NewHttpSever returns a new beego application.
@ -76,6 +77,13 @@ func NewHttpServerWithCfg(cfg *Config) *HttpServer {
// MiddleWare function for http.Handler
type MiddleWare func(http.Handler) http.Handler
// LifeCycleCallback configures callback.
// Developer can implement this interface to add custom logic to server lifecycle
type LifeCycleCallback interface {
AfterStart(app *HttpServer)
BeforeShutdown(app *HttpServer)
}
// Run beego application.
func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
initBeforeHTTPRun()
@ -98,12 +106,18 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
// run cgi server
if app.Cfg.Listen.EnableFcgi {
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
lifeCycleCallback.AfterStart(app)
}
if app.Cfg.Listen.EnableStdIo {
if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
logs.Info("Use FCGI via standard I/O")
} else {
logs.Critical("Cannot use FCGI via standard I/O", err)
}
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
lifeCycleCallback.BeforeShutdown(app)
}
return
}
if app.Cfg.Listen.HTTPPort == 0 {
@ -116,11 +130,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
l, err = net.Listen("tcp", addr)
}
if err != nil {
logs.Critical("Listen: ", err)
logs.Critical("Listen for Fcgi: ", err)
}
if err = fcgi.Serve(l, app.Handlers); err != nil {
logs.Critical("fcgi.Serve: ", err)
}
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
lifeCycleCallback.BeforeShutdown(app)
}
return
}
@ -137,6 +154,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
// run graceful mode
if app.Cfg.Listen.Graceful {
var opts []grace.ServerOption
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
lifeCycleCallbackDup := lifeCycleCallback
opts = append(opts, grace.WithShutdownCallback(func() {
lifeCycleCallbackDup.BeforeShutdown(app)
}))
}
httpsAddr := app.Cfg.Listen.HTTPSAddr
app.Server.Addr = httpsAddr
if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS {
@ -146,15 +171,16 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort)
app.Server.Addr = httpsAddr
}
server := grace.NewServer(httpsAddr, app.Server.Handler)
server := grace.NewServer(httpsAddr, app.Server.Handler, opts...)
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
var ln net.Listener
if app.Cfg.Listen.EnableMutualHTTPS {
if err := server.ListenAndServeMutualTLS(app.Cfg.Listen.HTTPSCertFile,
if ln, err = server.ListenMutualTLS(app.Cfg.Listen.HTTPSCertFile,
app.Cfg.Listen.HTTPSKeyFile,
app.Cfg.Listen.TrustCaFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
logs.Critical("ListenMutualTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
return
}
} else {
if app.Cfg.Listen.AutoTLS {
@ -166,24 +192,40 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", ""
}
if err := server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
if ln, err = server.ListenTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil {
logs.Critical("ListenTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
return
}
}
for _, callback := range app.LifeCycleCallbacks {
callback.AfterStart(app)
}
if err = server.ServeTLS(ln); err != nil {
logs.Critical("ServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
}
endRunning <- true
}()
}
if app.Cfg.Listen.EnableHTTP {
go func() {
server := grace.NewServer(addr, app.Server.Handler)
server := grace.NewServer(addr, app.Server.Handler, opts...)
server.Server.ReadTimeout = app.Server.ReadTimeout
server.Server.WriteTimeout = app.Server.WriteTimeout
if app.Cfg.Listen.ListenTCP4 {
server.Network = "tcp4"
}
if err := server.ListenAndServe(); err != nil {
logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
ln, err := net.Listen(server.Network, app.Server.Addr)
if err != nil {
logs.Critical("Listen for HTTP[graceful mode]: ", err)
endRunning <- true
return
}
for _, callback := range app.LifeCycleCallbacks {
callback.AfterStart(app)
}
if err := server.ServeWithListener(ln); err != nil {
logs.Critical("ServeWithListener: ", err, fmt.Sprintf("%d", os.Getpid()))
time.Sleep(100 * time.Microsecond)
}
endRunning <- true
@ -239,13 +281,13 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
if app.Cfg.Listen.ListenTCP4 {
ln, err := net.Listen("tcp4", app.Server.Addr)
if err != nil {
logs.Critical("ListenAndServe: ", err)
logs.Critical("Listen for HTTP[normal mode]: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return
}
if err = app.Server.Serve(ln); err != nil {
logs.Critical("ListenAndServe: ", err)
logs.Critical("Serve: ", err)
time.Sleep(100 * time.Microsecond)
endRunning <- true
return

View File

@ -4,4 +4,4 @@ sonar.projectKey=beego_beego
# relative paths to source directories. More details and properties are described
# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/
sonar.sources=.
sonar.exclusions=**/*_test.go
sonar.exclusions=**/*_test.go