Merge branch 'filter-sort' of https://gitclone.com/github.com/flycash/beego into session-filter

This commit is contained in:
Anker Jam 2021-01-08 21:50:45 +08:00
commit 4491894c00
19 changed files with 305 additions and 82 deletions

32
.github/workflows/golangci-lint.yml vendored Normal file
View File

@ -0,0 +1,32 @@
name: golangci-lint
on:
push:
tags:
- v*
branches:
- master
- main
pull_request:
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
# Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version.
version: v1.29
# Optional: working directory, useful for monorepos
# working-directory: ./
# Optional: golangci-lint command line arguments.
args: --timeout=5m --print-issued-lines=true --print-linter-name=true --uniq-by-line=true
# Optional: show only new issues if it's a pull request. The default value is `false`.
only-new-issues: true
# Optional: if set to true then the action will use pre-installed Go
# skip-go-installation: true

View File

@ -74,7 +74,7 @@ func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare {
// beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
return (*App)(web.Router(rootpath, c, web.SetRouterMethods(c, mappingMethods...))) return (*App)(web.Router(rootpath, c, mappingMethods...))
} }
// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful // UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful

View File

@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister {
// Add("/api",&RestController{},"get,post:ApiFunc" // Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
(*web.ControllerRegister)(p).Add(pattern, c, web.SetRouterMethods(c, mappingMethods...)) (*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...))
} }
// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller

View File

@ -301,6 +301,28 @@ func TestAddFilter(t *testing.T) {
assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains)) assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains))
} }
func TestFilterChainOrder(t *testing.T) {
req := Get("http://beego.me")
req.AddFilters(func(next Filter) Filter {
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return NewHttpResponseWithJsonBody("first"), nil
}
})
req.AddFilters(func(next Filter) Filter {
return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) {
return NewHttpResponseWithJsonBody("second"), nil
}
})
resp, err := req.DoRequestWithCtx(context.Background())
assert.Nil(t, err)
data := make([]byte, 5)
_, _ = resp.Body.Read(data)
assert.Equal(t, "first", string(data))
}
func TestHead(t *testing.T) { func TestHead(t *testing.T) {
req := Head("http://beego.me") req := Head("http://beego.me")
assert.NotNil(t, req) assert.NotNil(t, req)

View File

@ -948,9 +948,10 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
errTyp := true unregister := true
one := true one := true
isPtr := true isPtr := true
name := ""
if val.Kind() == reflect.Ptr { if val.Kind() == reflect.Ptr {
fn := "" fn := ""
@ -963,19 +964,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
case reflect.Struct: case reflect.Struct:
isPtr = false isPtr = false
fn = getFullName(typ) fn = getFullName(typ)
name = getTableName(reflect.New(typ))
} }
} else { } else {
fn = getFullName(ind.Type()) fn = getFullName(ind.Type())
name = getTableName(ind)
} }
errTyp = fn != mi.fullName unregister = fn != mi.fullName
} }
if errTyp { if unregister {
if one { RegisterModel(container)
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName))
} else {
panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName))
}
} }
rlimit := qs.limit rlimit := qs.limit
@ -1040,6 +1039,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
if qs.distinct { if qs.distinct {
sqlSelect += " DISTINCT" sqlSelect += " DISTINCT"
} }
if qs.aggregate != "" {
sels = qs.aggregate
}
query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s",
sqlSelect, sels, Q, mi.table, Q, sqlSelect, sels, Q, mi.table, Q,
specifyIndexes, join, where, groupBy, orderBy, limit) specifyIndexes, join, where, groupBy, orderBy, limit)
@ -1064,16 +1066,20 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
} }
} }
defer rs.Close()
slice := ind
if unregister {
mi, _ = modelCache.get(name)
tCols = mi.fields.dbcols
colsNum = len(tCols)
}
refs := make([]interface{}, colsNum) refs := make([]interface{}, colsNum)
for i := range refs { for i := range refs {
var ref interface{} var ref interface{}
refs[i] = &ref refs[i] = &ref
} }
defer rs.Close()
slice := ind
var cnt int64 var cnt int64
for rs.Next() { for rs.Next() {
if one && cnt == 0 || !one { if one && cnt == 0 || !one {

View File

@ -332,10 +332,6 @@ end:
// register register models to model cache // register register models to model cache
func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) {
if mc.done {
err = fmt.Errorf("register must be run before BootStrap")
return
}
for _, model := range models { for _, model := range models {
val := reflect.ValueOf(model) val := reflect.ValueOf(model)
@ -352,7 +348,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)
return return
} }
if val.Elem().Kind() == reflect.Slice {
val = reflect.New(val.Elem().Type().Elem())
}
table := getTableName(val) table := getTableName(val)
if prefixOrSuffixStr != "" { if prefixOrSuffixStr != "" {
@ -371,8 +369,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
} }
if _, ok := mc.get(table); ok { if _, ok := mc.get(table); ok {
err = fmt.Errorf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table) return nil
return
} }
mi := newModelInfo(val) mi := newModelInfo(val)
@ -389,12 +386,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m
} }
} }
} }
if mi.fields.pk == nil {
err = fmt.Errorf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
return
}
} }
mi.table = table mi.table = table

View File

@ -255,6 +255,22 @@ func NewTM() *TM {
return obj return obj
} }
type DeptInfo struct {
ID int `orm:"column(id)"`
Created time.Time `orm:"auto_now_add"`
DeptName string
EmployeeName string
Salary int
}
type UnregisterModel struct {
ID int `orm:"column(id)"`
Created time.Time `orm:"auto_now_add"`
DeptName string
EmployeeName string
Salary int
}
type User struct { type User struct {
ID int `orm:"column(id)"` ID int `orm:"column(id)"`
UserName string `orm:"size(30);unique"` UserName string `orm:"size(30);unique"`

View File

@ -109,6 +109,9 @@ func getTableUnique(val reflect.Value) [][]string {
// get whether the table needs to be created for the database alias // get whether the table needs to be created for the database alias
func isApplicableTableForDB(val reflect.Value, db string) bool { func isApplicableTableForDB(val reflect.Value, db string) bool {
if !val.IsValid() {
return true
}
fun := val.MethodByName("IsApplicableTableForDB") fun := val.MethodByName("IsApplicableTableForDB")
if fun.IsValid() { if fun.IsValid() {
vals := fun.Call([]reflect.Value{reflect.ValueOf(db)}) vals := fun.Call([]reflect.Value{reflect.ValueOf(db)})

View File

@ -79,6 +79,7 @@ type querySet struct {
orm *ormBase orm *ormBase
ctx context.Context ctx context.Context
forContext bool forContext bool
aggregate string
} }
var _ QuerySeter = new(querySet) var _ QuerySeter = new(querySet)
@ -323,3 +324,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter {
o.orm = orm o.orm = orm
return o return o
} }
// aggregate func
func (o querySet) Aggregate(s string) QuerySeter {
o.aggregate = s
return &o
}

View File

@ -205,6 +205,7 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Index)) RegisterModel(new(Index))
RegisterModel(new(StrPk)) RegisterModel(new(StrPk))
RegisterModel(new(TM)) RegisterModel(new(TM))
RegisterModel(new(DeptInfo))
err := RunSyncdb("default", true, Debug) err := RunSyncdb("default", true, Debug)
throwFail(t, err) throwFail(t, err)
@ -232,6 +233,7 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Index)) RegisterModel(new(Index))
RegisterModel(new(StrPk)) RegisterModel(new(StrPk))
RegisterModel(new(TM)) RegisterModel(new(TM))
RegisterModel(new(DeptInfo))
BootStrap() BootStrap()
@ -333,6 +335,73 @@ func TestTM(t *testing.T) {
throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC")) throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC"))
} }
func TestUnregisterModel(t *testing.T) {
data := []*DeptInfo{
{
DeptName: "A",
EmployeeName: "A1",
Salary: 1000,
},
{
DeptName: "A",
EmployeeName: "A2",
Salary: 2000,
},
{
DeptName: "B",
EmployeeName: "B1",
Salary: 2000,
},
{
DeptName: "B",
EmployeeName: "B2",
Salary: 4000,
},
{
DeptName: "B",
EmployeeName: "B3",
Salary: 3000,
},
}
qs := dORM.QueryTable("dept_info")
i, _ := qs.PrepareInsert()
for _, d := range data {
_, err := i.Insert(d)
if err != nil {
throwFail(t, err)
}
}
f := func() {
var res []UnregisterModel
n, err := dORM.QueryTable("dept_info").All(&res)
throwFail(t, err)
throwFail(t, AssertIs(n, 5))
throwFail(t, AssertIs(res[0].EmployeeName, "A1"))
type Sum struct {
DeptName string
Total int
}
var sun []Sum
qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun)
throwFail(t, AssertIs(sun[0].DeptName, "A"))
throwFail(t, AssertIs(sun[0].Total, 3000))
type Max struct {
DeptName string
Max float64
}
var max []Max
qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max)
throwFail(t, AssertIs(max[1].DeptName, "B"))
throwFail(t, AssertIs(max[1].Max, 4000))
}
for i := 0; i < 5; i++ {
f()
}
}
func TestNullDataTypes(t *testing.T) { func TestNullDataTypes(t *testing.T) {
d := DataNull{} d := DataNull{}

View File

@ -405,6 +405,15 @@ type QuerySeter interface {
// Found int // Found int
// } // }
RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error)
// aggregate func.
// for example:
// type result struct {
// DeptName string
// Total int
// }
// var res []result
// o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res)
Aggregate(s string) QuerySeter
} }
// QueryM2Mer model to model query struct // QueryM2Mer model to model query struct

View File

@ -112,13 +112,13 @@ func registerAdmin() error {
HttpServer: NewHttpServerWithCfg(BConfig), HttpServer: NewHttpServerWithCfg(BConfig),
} }
// keep in mind that all data should be html escaped to avoid XSS attack // keep in mind that all data should be html escaped to avoid XSS attack
beeAdminApp.Router("/", c, SetRouterMethods(c, "get:AdminIndex")) beeAdminApp.Router("/", c, "get:AdminIndex")
beeAdminApp.Router("/qps", c, SetRouterMethods(c, "get:QpsIndex")) beeAdminApp.Router("/qps", c, "get:QpsIndex")
beeAdminApp.Router("/prof", c, SetRouterMethods(c, "get:ProfIndex")) beeAdminApp.Router("/prof", c, "get:ProfIndex")
beeAdminApp.Router("/healthcheck", c, SetRouterMethods(c, "get:Healthcheck")) beeAdminApp.Router("/healthcheck", c, "get:Healthcheck")
beeAdminApp.Router("/task", c, SetRouterMethods(c, "get:TaskStatus")) beeAdminApp.Router("/task", c, "get:TaskStatus")
beeAdminApp.Router("/listconf", c, SetRouterMethods(c, "get:ListConf")) beeAdminApp.Router("/listconf", c, "get:ListConf")
beeAdminApp.Router("/metrics", c, SetRouterMethods(c, "get:PrometheusMetrics")) beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics")
go beeAdminApp.Run() go beeAdminApp.Run()
} }

View File

@ -15,9 +15,12 @@
package web package web
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) {
ns := NewNamespace("/chain") ns := NewNamespace("/chain")
ns.Get("/*", func(ctx *context.Context) { ns.Get("/*", func(ctx *context.Context) {
ctx.Output.Body([]byte("hello")) _ = ctx.Output.Body([]byte("hello"))
}) })
r, _ := http.NewRequest("GET", "/chain/user", nil) r, _ := http.NewRequest("GET", "/chain/user", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
BeeApp.Handlers.Init()
BeeApp.Handlers.ServeHTTP(w, r) BeeApp.Handlers.ServeHTTP(w, r)
assert.Equal(t, "filter-chain", w.Header().Get("filter")) assert.Equal(t, "filter-chain", w.Header().Get("filter"))
} }
func TestControllerRegister_InsertFilterChain_Order(t *testing.T) {
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano()))
time.Sleep(time.Millisecond * 10)
next(ctx)
}
})
InsertFilterChain("/abc", func(next FilterFunc) FilterFunc {
return func(ctx *context.Context) {
ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano()))
time.Sleep(time.Millisecond * 10)
next(ctx)
}
})
r, _ := http.NewRequest("GET", "/abc", nil)
w := httptest.NewRecorder()
BeeApp.Handlers.Init()
BeeApp.Handlers.ServeHTTP(w, r)
first := w.Header().Get("first")
second := w.Header().Get("second")
ft, _ := strconv.ParseInt(first, 10, 64)
st, _ := strconv.ParseInt(second, 10, 64)
assert.True(t, st > ft)
}

View File

@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) {
// setup the handler // setup the handler
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/", &TestFlashController{}, SetRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash"))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
// get the Set-Cookie value // get the Set-Cookie value

View File

@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace {
// Router same as beego.Rourer // Router same as beego.Rourer
// refer: https://godoc.org/github.com/beego/beego/v2#Router // refer: https://godoc.org/github.com/beego/beego/v2#Router
func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace {
n.handlers.Add(rootpath, c, SetRouterMethods(c, mappingMethods...)) n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...))
return n return n
} }

View File

@ -121,24 +121,30 @@ type ControllerInfo struct {
sessionOn bool sessionOn bool
} }
type ControllerOptions func(*ControllerInfo) type ControllerOption func(*ControllerInfo)
func (c *ControllerInfo) GetPattern() string { func (c *ControllerInfo) GetPattern() string {
return c.pattern return c.pattern
} }
func SetRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOptions { func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption {
return func(c *ControllerInfo) { return func(c *ControllerInfo) {
c.methods = parseMappingMethods(ctrlInterface, mappingMethod) c.methods = parseMappingMethods(ctrlInterface, mappingMethod)
} }
} }
func SetRouterSessionOn(sessionOn bool) ControllerOptions { func WithRouterSessionOn(sessionOn bool) ControllerOption {
return func(c *ControllerInfo) { return func(c *ControllerInfo) {
c.sessionOn = sessionOn c.sessionOn = sessionOn
} }
} }
type filterChainConfig struct {
pattern string
chain FilterChain
opts []FilterOpt
}
// ControllerRegister containers registered router rules, controller handlers and filters. // ControllerRegister containers registered router rules, controller handlers and filters.
type ControllerRegister struct { type ControllerRegister struct {
routers map[string]*Tree routers map[string]*Tree
@ -151,6 +157,9 @@ type ControllerRegister struct {
// the filter created by FilterChain // the filter created by FilterChain
chainRoot *FilterRouter chainRoot *FilterRouter
// keep registered chain and build it when serve http
filterChains []filterChainConfig
cfg *Config cfg *Config
} }
@ -171,11 +180,23 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
}, },
}, },
cfg: cfg, cfg: cfg,
filterChains: make([]filterChainConfig, 0, 4),
} }
res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false))
return res return res
} }
// Init will be executed when HttpServer start running
func (p *ControllerRegister) Init() {
for i := len(p.filterChains) - 1; i >= 0 ; i -- {
fc := p.filterChains[i]
root := p.chainRoot
filterFunc := fc.chain(root.filterFunc)
p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...)
p.chainRoot.next = root
}
}
// Add controller handler and pattern rules to ControllerRegister. // Add controller handler and pattern rules to ControllerRegister.
// usage: // usage:
// default methods is the same name as method // default methods is the same name as method
@ -186,7 +207,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister {
// Add("/api/delete",&RestController{},"delete:DeleteFood") // Add("/api/delete",&RestController{},"delete:DeleteFood")
// Add("/api",&RestController{},"get,post:ApiFunc" // Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOptions) { func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) {
p.addWithMethodParams(pattern, c, nil, opts...) p.addWithMethodParams(pattern, c, nil, opts...)
} }
@ -239,7 +260,7 @@ func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) {
} }
} }
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOptions) { func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) {
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type() t := reflect.Indirect(reflectVal).Type()
@ -311,7 +332,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams))
} }
p.addWithMethodParams(a.Router, c, a.MethodParams, SetRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method))
} }
} }
} }
@ -513,12 +534,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter
// } // }
// } // }
func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) {
root := p.chainRoot
filterFunc := chain(root.filterFunc)
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.chainRoot = newFilterRouter(pattern, filterFunc, opts...)
p.chainRoot.next = root
opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive))
p.filterChains = append(p.filterChains, filterChainConfig{
pattern: pattern,
chain: chain,
opts: opts,
})
} }
// add Filter into // add Filter into

View File

@ -97,7 +97,7 @@ func (jc *JSONController) Get() {
func TestPrefixUrlFor(t *testing.T){ func TestPrefixUrlFor(t *testing.T){
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/my/prefix/list", &PrefixTestController{}, "get:PrefixList") handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList"))
if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` { if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` {
logs.Info(a) logs.Info(a)
@ -111,8 +111,8 @@ func TestPrefixUrlFor(t *testing.T){
func TestUrlFor(t *testing.T) { func TestUrlFor(t *testing.T) {
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
if a := handler.URLFor("TestController.List"); a != "/api/list" { if a := handler.URLFor("TestController.List"); a != "/api/list" {
logs.Info(a) logs.Info(a)
t.Errorf("TestController.List must equal to /api/list") t.Errorf("TestController.List must equal to /api/list")
@ -135,9 +135,9 @@ func TestUrlFor3(t *testing.T) {
func TestUrlFor2(t *testing.T) { func TestUrlFor2(t *testing.T) {
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
handler.Add("/v1/:username/edit", &TestController{}, SetRouterMethods(&TestController{}, "get:GetURL")) handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL"))
handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param"))
handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{})
if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" {
logs.Info(handler.URLFor("TestController.GetURL")) logs.Info(handler.URLFor("TestController.GetURL"))
@ -167,7 +167,7 @@ func TestUserFunc(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List"))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" { if w.Body.String() != "i am list" {
t.Errorf("user define func can't run") t.Errorf("user define func can't run")
@ -257,7 +257,7 @@ func TestRouteOk(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "get:GetParams")) handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams"))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
body := w.Body.String() body := w.Body.String()
if body != "anderson+thomas+kungfu" { if body != "anderson+thomas+kungfu" {
@ -271,7 +271,7 @@ func TestManyRoute(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetManyRouter")) handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter"))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
body := w.Body.String() body := w.Body.String()
@ -288,7 +288,7 @@ func TestEmptyResponse(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/beego-empty.html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetEmptyBody")) handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody"))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if body := w.Body.String(); body != "" { if body := w.Body.String(); body != "" {
@ -783,8 +783,8 @@ func TestRouterSessionSet(t *testing.T) {
r, _ := http.NewRequest("GET", "/user", nil) r, _ := http.NewRequest("GET", "/user", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
SetRouterSessionOn(false)) WithRouterSessionOn(false))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") != "" { if w.Header().Get("Set-Cookie") != "" {
t.Errorf("TestRotuerSessionSet failed") t.Errorf("TestRotuerSessionSet failed")
@ -794,8 +794,8 @@ func TestRouterSessionSet(t *testing.T) {
r, _ = http.NewRequest("GET", "/user", nil) r, _ = http.NewRequest("GET", "/user", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
handler = NewControllerRegister() handler = NewControllerRegister()
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
SetRouterSessionOn(true)) WithRouterSessionOn(true))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") != "" { if w.Header().Get("Set-Cookie") != "" {
t.Errorf("TestRotuerSessionSet failed") t.Errorf("TestRotuerSessionSet failed")
@ -809,8 +809,8 @@ func TestRouterSessionSet(t *testing.T) {
r, _ = http.NewRequest("GET", "/user", nil) r, _ = http.NewRequest("GET", "/user", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
handler = NewControllerRegister() handler = NewControllerRegister()
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
SetRouterSessionOn(false)) WithRouterSessionOn(false))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") != "" { if w.Header().Get("Set-Cookie") != "" {
t.Errorf("TestRotuerSessionSet failed") t.Errorf("TestRotuerSessionSet failed")
@ -820,8 +820,8 @@ func TestRouterSessionSet(t *testing.T) {
r, _ = http.NewRequest("GET", "/user", nil) r, _ = http.NewRequest("GET", "/user", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
handler = NewControllerRegister() handler = NewControllerRegister()
handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"),
SetRouterSessionOn(true)) WithRouterSessionOn(true))
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
if w.Header().Get("Set-Cookie") == "" { if w.Header().Get("Set-Cookie") == "" {
t.Errorf("TestRotuerSessionSet failed") t.Errorf("TestRotuerSessionSet failed")

View File

@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
initBeforeHTTPRun() initBeforeHTTPRun()
// init...
app.initAddr(addr) app.initAddr(addr)
app.Handlers.Init()
addr = app.Cfg.Listen.HTTPAddr addr = app.Cfg.Listen.HTTPAddr
@ -266,8 +268,12 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) {
} }
// Router see HttpServer.Router // Router see HttpServer.Router
func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
return BeeApp.Router(rootpath, c, opts...) return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...))
}
func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
return BeeApp.RouterWithOpts(rootpath, c, opts...)
} }
// Router adds a patterned controller handler to BeeApp. // Router adds a patterned controller handler to BeeApp.
@ -286,7 +292,11 @@ func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *
// beego.Router("/api/create",&RestController{},"post:CreateFood") // beego.Router("/api/create",&RestController{},"post:CreateFood")
// beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/update",&RestController{},"put:UpdateFood")
// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood")
func (app *HttpServer) Router(rootPath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer {
return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...))
}
func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer {
app.Handlers.Add(rootPath, c, opts...) app.Handlers.Add(rootPath, c, opts...)
return app return app
} }

View File

@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
var method = "GET" var method = "GET"
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
// Test original root // Test original root
testHelperFnContentCheck(t, handler, "Test original root", testHelperFnContentCheck(t, handler, "Test original root",
@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) {
// Replace the root path TestPreUnregController action with the action from // Replace the root path TestPreUnregController action with the action from
// TestPostUnregController // TestPostUnregController
handler.Add("/", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot"))
// Test replacement root (expect change) // Test replacement root (expect change)
testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement)
@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
var method = "GET" var method = "GET"
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
// Test original root // Test original root
testHelperFnContentCheck(t, handler, testHelperFnContentCheck(t, handler,
@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) {
// Replace the "level1" path TestPreUnregController action with the action from // Replace the "level1" path TestPreUnregController action with the action from
// TestPostUnregController // TestPostUnregController
handler.Add("/level1", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1"))
// Test replacement root (expect no change from the original) // Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)
@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
var method = "GET" var method = "GET"
handler := NewControllerRegister() handler := NewControllerRegister()
handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot"))
handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1"))
handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2"))
// Test original root // Test original root
testHelperFnContentCheck(t, handler, testHelperFnContentCheck(t, handler,
@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) {
// Replace the "/level1/level2" path TestPreUnregController action with the action from // Replace the "/level1/level2" path TestPreUnregController action with the action from
// TestPostUnregController // TestPostUnregController
handler.Add("/level1/level2", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2"))
// Test replacement root (expect no change from the original) // Test replacement root (expect no change from the original)
testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal)