add callback
This commit is contained in:
parent
684d4e030b
commit
01880adad1
@ -5,6 +5,7 @@
|
||||
- [make `PatternLogFormatter` handling the arguments](https://github.com/beego/beego/pull/4914/files)
|
||||
- [Add httplib OpenTelemetry Filter](https://github.com/beego/beego/pull/4888)
|
||||
- [Support NewBeegoRequestWithCtx in httplib](https://github.com/beego/beego/pull/4895)
|
||||
- [Support lifecycle callback](https://github.com/beego/beego/pull/4918)
|
||||
|
||||
# v2.0.2
|
||||
See v2.0.2-beta.1
|
||||
|
||||
@ -105,8 +105,17 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// ServerOption configures how we set up the connection.
|
||||
type ServerOption func(*Server)
|
||||
|
||||
func WithShutdownCallback(shutdownCallback func()) ServerOption {
|
||||
return func(srv *Server) {
|
||||
srv.shutdownCallbacks = append(srv.shutdownCallbacks, shutdownCallback)
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer returns a new graceServer.
|
||||
func NewServer(addr string, handler http.Handler) (srv *Server) {
|
||||
func NewServer(addr string, handler http.Handler, opts ...ServerOption) (srv *Server) {
|
||||
regLock.Lock()
|
||||
defer regLock.Unlock()
|
||||
|
||||
@ -148,6 +157,10 @@ func NewServer(addr string, handler http.Handler) (srv *Server) {
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(srv)
|
||||
}
|
||||
|
||||
runningServersOrder = append(runningServersOrder, addr)
|
||||
runningServers[addr] = srv
|
||||
return srv
|
||||
|
||||
@ -20,31 +20,41 @@ import (
|
||||
// Server embedded http.Server
|
||||
type Server struct {
|
||||
*http.Server
|
||||
ln net.Listener
|
||||
SignalHooks map[int]map[os.Signal][]func()
|
||||
sigChan chan os.Signal
|
||||
isChild bool
|
||||
state uint8
|
||||
Network string
|
||||
terminalChan chan error
|
||||
ln net.Listener
|
||||
SignalHooks map[int]map[os.Signal][]func()
|
||||
sigChan chan os.Signal
|
||||
isChild bool
|
||||
state uint8
|
||||
Network string
|
||||
terminalChan chan error
|
||||
shutdownCallbacks []func()
|
||||
}
|
||||
|
||||
// Serve accepts incoming connections on the Listener l
|
||||
// and creates a new service goroutine for each.
|
||||
// The service goroutines read requests and then call srv.Handler to reply to them.
|
||||
func (srv *Server) Serve() (err error) {
|
||||
return srv.internalServe(srv.ln)
|
||||
}
|
||||
|
||||
func (srv *Server) ServeWithListener(ln net.Listener) (err error) {
|
||||
go srv.handleSignals()
|
||||
return srv.internalServe(ln)
|
||||
}
|
||||
|
||||
func (srv *Server) internalServe(ln net.Listener) (err error) {
|
||||
srv.state = StateRunning
|
||||
defer func() { srv.state = StateTerminate }()
|
||||
|
||||
// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS
|
||||
// immediately return ErrServerClosed. Make sure the program doesn't exit
|
||||
// and waits instead for Shutdown to return.
|
||||
if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed {
|
||||
if err = srv.Server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
log.Println(syscall.Getpid(), "Server.Serve() error:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.")
|
||||
log.Println(syscall.Getpid(), ln.Addr(), "Listener closed.")
|
||||
// wait for Shutdown to return
|
||||
if shutdownErr := <-srv.terminalChan; shutdownErr != nil {
|
||||
return shutdownErr
|
||||
@ -95,6 +105,15 @@ func (srv *Server) ListenAndServe() (err error) {
|
||||
//
|
||||
// If srv.Addr is blank, ":https" is used.
|
||||
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
|
||||
ln, err := srv.ListenTLS(certFile, keyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return srv.ServeTLS(ln)
|
||||
}
|
||||
|
||||
func (srv *Server) ListenTLS(certFile string, keyFile string) (net.Listener, error) {
|
||||
addr := srv.Addr
|
||||
if addr == "" {
|
||||
addr = ":https"
|
||||
@ -108,20 +127,35 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
|
||||
}
|
||||
|
||||
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
|
||||
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
srv.TLSConfig.Certificates[0] = cert
|
||||
|
||||
go srv.handleSignals()
|
||||
|
||||
ln, err := srv.getListener(addr)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
|
||||
return tlsListener, nil
|
||||
}
|
||||
|
||||
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
|
||||
// Serve to handle requests on incoming mutual TLS connections.
|
||||
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
|
||||
ln, err := srv.ListenMutualTLS(certFile, keyFile, trustFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
|
||||
|
||||
return srv.ServeTLS(ln)
|
||||
}
|
||||
|
||||
func (srv *Server) ServeTLS(ln net.Listener) error {
|
||||
if srv.isChild {
|
||||
process, err := os.FindProcess(os.Getppid())
|
||||
if err != nil {
|
||||
@ -134,13 +168,12 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
go srv.handleSignals()
|
||||
log.Println(os.Getpid(), srv.Addr)
|
||||
return srv.Serve()
|
||||
return srv.internalServe(ln)
|
||||
}
|
||||
|
||||
// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls
|
||||
// Serve to handle requests on incoming mutual TLS connections.
|
||||
func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) {
|
||||
func (srv *Server) ListenMutualTLS(certFile string, keyFile string, trustFile string) (net.Listener, error) {
|
||||
addr := srv.Addr
|
||||
if addr == "" {
|
||||
addr = ":https"
|
||||
@ -154,16 +187,17 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
|
||||
}
|
||||
|
||||
srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
|
||||
srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
srv.TLSConfig.Certificates[0] = cert
|
||||
srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
pool := x509.NewCertPool()
|
||||
data, err := ioutil.ReadFile(trustFile)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
pool.AppendCertsFromPEM(data)
|
||||
srv.TLSConfig.ClientCAs = pool
|
||||
@ -173,24 +207,10 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string)
|
||||
ln, err := srv.getListener(addr)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
|
||||
|
||||
if srv.isChild {
|
||||
process, err := os.FindProcess(os.Getppid())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
err = process.Signal(syscall.SIGTERM)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Println(os.Getpid(), srv.Addr)
|
||||
return srv.Serve()
|
||||
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)
|
||||
return tlsListener, nil
|
||||
}
|
||||
|
||||
// getListener either opens a new socket to listen on, or takes the acceptor socket
|
||||
@ -292,6 +312,9 @@ func (srv *Server) shutdown() {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
for _, shutdownCallback := range srv.shutdownCallbacks {
|
||||
shutdownCallback()
|
||||
}
|
||||
srv.terminalChan <- srv.Server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
|
||||
@ -132,6 +132,10 @@ func (c *ControllerInfo) GetPattern() string {
|
||||
return c.pattern
|
||||
}
|
||||
|
||||
func (c *ControllerInfo) GetMethod() map[string]string {
|
||||
return c.methods
|
||||
}
|
||||
|
||||
func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption {
|
||||
return func(c *ControllerInfo) {
|
||||
c.methods = parseMappingMethods(ctrlInterface, mappingMethod)
|
||||
@ -1315,6 +1319,32 @@ func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo
|
||||
return
|
||||
}
|
||||
|
||||
// GetAllControllerInfo get all ControllerInfo
|
||||
func (p *ControllerRegister) GetAllControllerInfo() (routerInfos []*ControllerInfo) {
|
||||
for _, webTree := range p.routers {
|
||||
composeControllerInfos(webTree, &routerInfos)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func composeControllerInfos(tree *Tree, routerInfos *[]*ControllerInfo) {
|
||||
if tree.fixrouters != nil {
|
||||
for _, subTree := range tree.fixrouters {
|
||||
composeControllerInfos(subTree, routerInfos)
|
||||
}
|
||||
}
|
||||
if tree.wildcard != nil {
|
||||
composeControllerInfos(tree.wildcard, routerInfos)
|
||||
}
|
||||
if tree.leaves != nil {
|
||||
for _, l := range tree.leaves {
|
||||
if c, ok := l.runObject.(*ControllerInfo); ok {
|
||||
*routerInfos = append(*routerInfos, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toURL(params map[string]string) string {
|
||||
if len(params) == 0 {
|
||||
return ""
|
||||
|
||||
@ -19,6 +19,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -1097,3 +1099,32 @@ func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) {
|
||||
handler := NewControllerRegister()
|
||||
handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer)
|
||||
}
|
||||
|
||||
func TestGetAllControllerInfo(t *testing.T) {
|
||||
handler := NewControllerRegister()
|
||||
handler.Add("/level1", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"))
|
||||
handler.Add("/level1/level2", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"))
|
||||
handler.Add("/:name1", &TestController{}, WithRouterMethods(&TestController{}, "post:Post"))
|
||||
|
||||
var actualPatterns []string
|
||||
var actualMethods []string
|
||||
for _, controllerInfo := range handler.GetAllControllerInfo() {
|
||||
actualPatterns = append(actualPatterns, controllerInfo.GetPattern())
|
||||
for _, httpMethod := range controllerInfo.GetMethod() {
|
||||
actualMethods = append(actualMethods, httpMethod)
|
||||
}
|
||||
}
|
||||
sort.Strings(actualPatterns)
|
||||
expectedPatterns := []string{"/level1", "/level1/level2", "/:name1"}
|
||||
sort.Strings(expectedPatterns)
|
||||
if !reflect.DeepEqual(actualPatterns, expectedPatterns) {
|
||||
t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedPatterns, actualPatterns)
|
||||
}
|
||||
|
||||
sort.Strings(actualMethods)
|
||||
expectedMethods := []string{"Get", "Get", "Post"}
|
||||
sort.Strings(expectedMethods)
|
||||
if !reflect.DeepEqual(actualMethods, expectedMethods) {
|
||||
t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedMethods, actualMethods)
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,9 +49,10 @@ func init() {
|
||||
|
||||
// HttpServer defines beego application with a new PatternServeMux.
|
||||
type HttpServer struct {
|
||||
Handlers *ControllerRegister
|
||||
Server *http.Server
|
||||
Cfg *Config
|
||||
Handlers *ControllerRegister
|
||||
Server *http.Server
|
||||
Cfg *Config
|
||||
LifeCycleCallbacks []LifeCycleCallback
|
||||
}
|
||||
|
||||
// NewHttpSever returns a new beego application.
|
||||
@ -76,6 +77,13 @@ func NewHttpServerWithCfg(cfg *Config) *HttpServer {
|
||||
// MiddleWare function for http.Handler
|
||||
type MiddleWare func(http.Handler) http.Handler
|
||||
|
||||
// LifeCycleCallback configures callback.
|
||||
// Developer can implement this interface to add custom logic to server lifecycle
|
||||
type LifeCycleCallback interface {
|
||||
AfterStart(app *HttpServer)
|
||||
BeforeShutdown(app *HttpServer)
|
||||
}
|
||||
|
||||
// Run beego application.
|
||||
func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
||||
initBeforeHTTPRun()
|
||||
@ -98,12 +106,18 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
||||
|
||||
// run cgi server
|
||||
if app.Cfg.Listen.EnableFcgi {
|
||||
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
|
||||
lifeCycleCallback.AfterStart(app)
|
||||
}
|
||||
if app.Cfg.Listen.EnableStdIo {
|
||||
if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O
|
||||
logs.Info("Use FCGI via standard I/O")
|
||||
} else {
|
||||
logs.Critical("Cannot use FCGI via standard I/O", err)
|
||||
}
|
||||
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
|
||||
lifeCycleCallback.BeforeShutdown(app)
|
||||
}
|
||||
return
|
||||
}
|
||||
if app.Cfg.Listen.HTTPPort == 0 {
|
||||
@ -116,11 +130,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
||||
l, err = net.Listen("tcp", addr)
|
||||
}
|
||||
if err != nil {
|
||||
logs.Critical("Listen: ", err)
|
||||
logs.Critical("Listen for Fcgi: ", err)
|
||||
}
|
||||
if err = fcgi.Serve(l, app.Handlers); err != nil {
|
||||
logs.Critical("fcgi.Serve: ", err)
|
||||
}
|
||||
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
|
||||
lifeCycleCallback.BeforeShutdown(app)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -137,6 +154,14 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
||||
|
||||
// run graceful mode
|
||||
if app.Cfg.Listen.Graceful {
|
||||
var opts []grace.ServerOption
|
||||
for _, lifeCycleCallback := range app.LifeCycleCallbacks {
|
||||
lifeCycleCallbackDup := lifeCycleCallback
|
||||
opts = append(opts, grace.WithShutdownCallback(func() {
|
||||
lifeCycleCallbackDup.BeforeShutdown(app)
|
||||
}))
|
||||
}
|
||||
|
||||
httpsAddr := app.Cfg.Listen.HTTPSAddr
|
||||
app.Server.Addr = httpsAddr
|
||||
if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS {
|
||||
@ -146,15 +171,16 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
||||
httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort)
|
||||
app.Server.Addr = httpsAddr
|
||||
}
|
||||
server := grace.NewServer(httpsAddr, app.Server.Handler)
|
||||
server := grace.NewServer(httpsAddr, app.Server.Handler, opts...)
|
||||
server.Server.ReadTimeout = app.Server.ReadTimeout
|
||||
server.Server.WriteTimeout = app.Server.WriteTimeout
|
||||
var ln net.Listener
|
||||
if app.Cfg.Listen.EnableMutualHTTPS {
|
||||
if err := server.ListenAndServeMutualTLS(app.Cfg.Listen.HTTPSCertFile,
|
||||
if ln, err = server.ListenMutualTLS(app.Cfg.Listen.HTTPSCertFile,
|
||||
app.Cfg.Listen.HTTPSKeyFile,
|
||||
app.Cfg.Listen.TrustCaFile); err != nil {
|
||||
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
logs.Critical("ListenMutualTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if app.Cfg.Listen.AutoTLS {
|
||||
@ -166,24 +192,40 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
||||
app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
|
||||
app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", ""
|
||||
}
|
||||
if err := server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil {
|
||||
logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
if ln, err = server.ListenTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil {
|
||||
logs.Critical("ListenTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, callback := range app.LifeCycleCallbacks {
|
||||
callback.AfterStart(app)
|
||||
}
|
||||
if err = server.ServeTLS(ln); err != nil {
|
||||
logs.Critical("ServeTLS: ", err, fmt.Sprintf("%d", os.Getpid()))
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
}
|
||||
endRunning <- true
|
||||
}()
|
||||
}
|
||||
if app.Cfg.Listen.EnableHTTP {
|
||||
go func() {
|
||||
server := grace.NewServer(addr, app.Server.Handler)
|
||||
server := grace.NewServer(addr, app.Server.Handler, opts...)
|
||||
server.Server.ReadTimeout = app.Server.ReadTimeout
|
||||
server.Server.WriteTimeout = app.Server.WriteTimeout
|
||||
if app.Cfg.Listen.ListenTCP4 {
|
||||
server.Network = "tcp4"
|
||||
}
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid()))
|
||||
ln, err := net.Listen(server.Network, app.Server.Addr)
|
||||
if err != nil {
|
||||
logs.Critical("Listen for HTTP[graceful mode]: ", err)
|
||||
endRunning <- true
|
||||
return
|
||||
}
|
||||
for _, callback := range app.LifeCycleCallbacks {
|
||||
callback.AfterStart(app)
|
||||
}
|
||||
if err := server.ServeWithListener(ln); err != nil {
|
||||
logs.Critical("ServeWithListener: ", err, fmt.Sprintf("%d", os.Getpid()))
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
}
|
||||
endRunning <- true
|
||||
@ -239,13 +281,13 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
|
||||
if app.Cfg.Listen.ListenTCP4 {
|
||||
ln, err := net.Listen("tcp4", app.Server.Addr)
|
||||
if err != nil {
|
||||
logs.Critical("ListenAndServe: ", err)
|
||||
logs.Critical("Listen for HTTP[normal mode]: ", err)
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
endRunning <- true
|
||||
return
|
||||
}
|
||||
if err = app.Server.Serve(ln); err != nil {
|
||||
logs.Critical("ListenAndServe: ", err)
|
||||
logs.Critical("Serve: ", err)
|
||||
time.Sleep(100 * time.Microsecond)
|
||||
endRunning <- true
|
||||
return
|
||||
|
||||
@ -4,4 +4,4 @@ sonar.projectKey=beego_beego
|
||||
# relative paths to source directories. More details and properties are described
|
||||
# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/
|
||||
sonar.sources=.
|
||||
sonar.exclusions=**/*_test.go
|
||||
sonar.exclusions=**/*_test.go
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user