Merge branch 'develop' of https://github.com/beego/beego into orm-mock
This commit is contained in:
		
						commit
						89eeedbf9d
					
				
							
								
								
									
										2
									
								
								.github/workflows/changelog.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/changelog.yml
									
									
									
									
										vendored
									
									
								
							| @ -8,7 +8,7 @@ on: | ||||
|   pull_request: | ||||
|     types: [opened, synchronize, reopened, labeled, unlabeled] | ||||
|     branches: | ||||
|       - master | ||||
|       - develop | ||||
| 
 | ||||
| jobs: | ||||
|   changelog: | ||||
|  | ||||
							
								
								
									
										32
									
								
								.github/workflows/golangci-lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								.github/workflows/golangci-lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,32 @@ | ||||
| name: golangci-lint | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - v* | ||||
|     branches: | ||||
|       - master | ||||
|       - main | ||||
|   pull_request: | ||||
| jobs: | ||||
|   golangci: | ||||
|     name: lint | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v2 | ||||
|       - name: golangci-lint | ||||
|         uses: golangci/golangci-lint-action@v2 | ||||
|         with: | ||||
|           # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. | ||||
|           version: v1.29 | ||||
| 
 | ||||
|           # Optional: working directory, useful for monorepos | ||||
| #          working-directory: ./ | ||||
| 
 | ||||
|           # Optional: golangci-lint command line arguments. | ||||
|           args: --timeout=5m --print-issued-lines=true --print-linter-name=true --uniq-by-line=true | ||||
| 
 | ||||
|           # Optional: show only new issues if it's a pull request. The default value is `false`. | ||||
|           only-new-issues: true | ||||
| 
 | ||||
|           # Optional: if set to true then the action will use pre-installed Go | ||||
|           # skip-go-installation: true | ||||
| @ -1,6 +1,11 @@ | ||||
| # developing | ||||
| - Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433) | ||||
| - Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427) | ||||
| - Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398) | ||||
| - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) | ||||
| - Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385) | ||||
| - Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385) | ||||
| - Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) | ||||
| - Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) | ||||
| - Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294) | ||||
| - Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) | ||||
| - Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424) | ||||
|  | ||||
| @ -74,7 +74,7 @@ func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { | ||||
| //  beego.Router("/api/update",&RestController{},"put:UpdateFood") | ||||
| //  beego.Router("/api/delete",&RestController{},"delete:DeleteFood") | ||||
| func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { | ||||
| 	return (*App)(web.Router(rootpath, c, web.SetRouterMethods(c, mappingMethods...))) | ||||
| 	return (*App)(web.Router(rootpath, c, mappingMethods...)) | ||||
| } | ||||
| 
 | ||||
| // UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful | ||||
|  | ||||
| @ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister { | ||||
| //	Add("/api",&RestController{},"get,post:ApiFunc" | ||||
| //	Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") | ||||
| func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { | ||||
| 	(*web.ControllerRegister)(p).Add(pattern, c, web.SetRouterMethods(c, mappingMethods...)) | ||||
| 	(*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...)) | ||||
| } | ||||
| 
 | ||||
| // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller | ||||
|  | ||||
| @ -301,6 +301,28 @@ func TestAddFilter(t *testing.T) { | ||||
| 	assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains)) | ||||
| } | ||||
| 
 | ||||
| func TestFilterChainOrder(t *testing.T) { | ||||
| 	req := Get("http://beego.me") | ||||
| 	req.AddFilters(func(next Filter) Filter { | ||||
| 		return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { | ||||
| 			return NewHttpResponseWithJsonBody("first"), nil | ||||
| 		} | ||||
| 	}) | ||||
| 
 | ||||
| 
 | ||||
| 	req.AddFilters(func(next Filter) Filter { | ||||
| 		return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { | ||||
| 			return NewHttpResponseWithJsonBody("second"), nil | ||||
| 		} | ||||
| 	}) | ||||
| 
 | ||||
| 	resp, err := req.DoRequestWithCtx(context.Background()) | ||||
| 	assert.Nil(t, err) | ||||
| 	data := make([]byte, 5) | ||||
| 	_, _ = resp.Body.Read(data) | ||||
| 	assert.Equal(t, "first", string(data)) | ||||
| } | ||||
| 
 | ||||
| func TestHead(t *testing.T) { | ||||
| 	req := Head("http://beego.me") | ||||
| 	assert.NotNil(t, req) | ||||
|  | ||||
							
								
								
									
										6
									
								
								client/orm/clauses/const.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								client/orm/clauses/const.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | ||||
| package clauses | ||||
| 
 | ||||
| const ( | ||||
| 	ExprSep                = "__" | ||||
| 	ExprDot                = "." | ||||
| ) | ||||
							
								
								
									
										103
									
								
								client/orm/clauses/order_clause/order.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								client/orm/clauses/order_clause/order.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,103 @@ | ||||
| package order_clause | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/beego/beego/v2/client/orm/clauses" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| type Sort int8 | ||||
| 
 | ||||
| const ( | ||||
| 	None       Sort = 0 | ||||
| 	Ascending  Sort = 1 | ||||
| 	Descending Sort = 2 | ||||
| ) | ||||
| 
 | ||||
| type Option func(order *Order) | ||||
| 
 | ||||
| type Order struct { | ||||
| 	column string | ||||
| 	sort   Sort | ||||
| 	isRaw  bool | ||||
| } | ||||
| 
 | ||||
| func Clause(options ...Option) *Order { | ||||
| 	o := &Order{} | ||||
| 	for _, option := range options { | ||||
| 		option(o) | ||||
| 	} | ||||
| 
 | ||||
| 	return o | ||||
| } | ||||
| 
 | ||||
| func (o *Order) GetColumn() string { | ||||
| 	return o.column | ||||
| } | ||||
| 
 | ||||
| func (o *Order) GetSort() Sort { | ||||
| 	return o.sort | ||||
| } | ||||
| 
 | ||||
| func (o *Order) SortString() string { | ||||
| 	switch o.GetSort() { | ||||
| 	case Ascending: | ||||
| 		return "ASC" | ||||
| 	case Descending: | ||||
| 		return "DESC" | ||||
| 	} | ||||
| 
 | ||||
| 	return `` | ||||
| } | ||||
| 
 | ||||
| func (o *Order) IsRaw() bool { | ||||
| 	return o.isRaw | ||||
| } | ||||
| 
 | ||||
| func ParseOrder(expressions ...string) []*Order { | ||||
| 	var orders []*Order | ||||
| 	for _, expression := range expressions { | ||||
| 		sort := Ascending | ||||
| 		column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot) | ||||
| 		if column[0] == '-' { | ||||
| 			sort = Descending | ||||
| 			column = column[1:] | ||||
| 		} | ||||
| 
 | ||||
| 		orders = append(orders, &Order{ | ||||
| 			column: column, | ||||
| 			sort:   sort, | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	return orders | ||||
| } | ||||
| 
 | ||||
| func Column(column string) Option { | ||||
| 	return func(order *Order) { | ||||
| 		order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func sort(sort Sort) Option { | ||||
| 	return func(order *Order) { | ||||
| 		order.sort = sort | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func SortAscending() Option { | ||||
| 	return sort(Ascending) | ||||
| } | ||||
| 
 | ||||
| func SortDescending() Option { | ||||
| 	return sort(Descending) | ||||
| } | ||||
| 
 | ||||
| func SortNone() Option { | ||||
| 	return sort(None) | ||||
| } | ||||
| 
 | ||||
| func Raw() Option { | ||||
| 	return func(order *Order) { | ||||
| 		order.isRaw = true | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										144
									
								
								client/orm/clauses/order_clause/order_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								client/orm/clauses/order_clause/order_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,144 @@ | ||||
| package order_clause | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestClause(t *testing.T) { | ||||
| 	var ( | ||||
| 		column = `a` | ||||
| 	) | ||||
| 
 | ||||
| 	o := Clause( | ||||
| 		Column(column), | ||||
| 	) | ||||
| 
 | ||||
| 	if o.GetColumn() != column { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSortAscending(t *testing.T) { | ||||
| 	o := Clause( | ||||
| 		SortAscending(), | ||||
| 	) | ||||
| 
 | ||||
| 	if o.GetSort() != Ascending { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSortDescending(t *testing.T) { | ||||
| 	o := Clause( | ||||
| 		SortDescending(), | ||||
| 	) | ||||
| 
 | ||||
| 	if o.GetSort() != Descending { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSortNone(t *testing.T) { | ||||
| 	o1 := Clause( | ||||
| 		SortNone(), | ||||
| 	) | ||||
| 
 | ||||
| 	if o1.GetSort() != None { | ||||
| 		t.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	o2 := Clause() | ||||
| 
 | ||||
| 	if o2.GetSort() != None { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRaw(t *testing.T) { | ||||
| 	o1 := Clause() | ||||
| 
 | ||||
| 	if o1.IsRaw() { | ||||
| 		t.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	o2 := Clause( | ||||
| 		Raw(), | ||||
| 	) | ||||
| 
 | ||||
| 	if !o2.IsRaw() { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestColumn(t *testing.T) { | ||||
| 	o1 := Clause( | ||||
| 		Column(`aaa`), | ||||
| 	) | ||||
| 
 | ||||
| 	if o1.GetColumn() != `aaa` { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestParseOrder(t *testing.T) { | ||||
| 	orders := ParseOrder( | ||||
| 		`-user__status`, | ||||
| 		`status`, | ||||
| 		`user__status`, | ||||
| 	) | ||||
| 
 | ||||
| 	t.Log(orders) | ||||
| 
 | ||||
| 	if orders[0].GetSort() != Descending { | ||||
| 		t.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	if orders[0].GetColumn() != `user.status` { | ||||
| 		t.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	if orders[1].GetColumn() != `status` { | ||||
| 		t.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	if orders[1].GetSort() != Ascending { | ||||
| 		t.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	if orders[2].GetColumn() != `user.status` { | ||||
| 		t.Error() | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestOrder_GetColumn(t *testing.T) { | ||||
| 	o := Clause( | ||||
| 		Column(`user__id`), | ||||
| 	) | ||||
| 	if o.GetColumn() != `user.id` { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOrder_GetSort(t *testing.T) { | ||||
| 	o := Clause( | ||||
| 		SortDescending(), | ||||
| 	) | ||||
| 	if o.GetSort() != Descending { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOrder_IsRaw(t *testing.T) { | ||||
| 	o1 := Clause() | ||||
| 	if o1.IsRaw() { | ||||
| 		t.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	o2 := Clause( | ||||
| 		Raw(), | ||||
| 	) | ||||
| 	if !o2.IsRaw() { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| @ -15,6 +15,7 @@ | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| @ -141,6 +142,7 @@ func (d *commandSyncDb) Run() error { | ||||
| 		fmt.Printf("    %s\n", err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	ctx := context.Background() | ||||
| 	for i, mi := range modelCache.allOrdered() { | ||||
| 
 | ||||
| 		if !isApplicableTableForDB(mi.addrField, d.al.Name) { | ||||
| @ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error { | ||||
| 			} | ||||
| 
 | ||||
| 			var fields []*fieldInfo | ||||
| 			columns, err := d.al.DbBaser.GetColumns(db, mi.table) | ||||
| 			columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table) | ||||
| 			if err != nil { | ||||
| 				if d.rtOnError { | ||||
| 					return err | ||||
| @ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error { | ||||
| 			} | ||||
| 
 | ||||
| 			for _, idx := range indexes[mi.table] { | ||||
| 				if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { | ||||
| 				if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) { | ||||
| 					if !d.noInfo { | ||||
| 						fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) | ||||
| 					} | ||||
|  | ||||
							
								
								
									
										146
									
								
								client/orm/db.go
									
									
									
									
									
								
							
							
						
						
									
										146
									
								
								client/orm/db.go
									
									
									
									
									
								
							| @ -15,6 +15,7 @@ | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| @ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val | ||||
| } | ||||
| 
 | ||||
| // create insert sql preparation statement object. | ||||
| func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { | ||||
| func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { | ||||
| 	Q := d.ins.TableQuote() | ||||
| 
 | ||||
| 	dbcols := make([]string, 0, len(mi.fields.dbcols)) | ||||
| @ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, | ||||
| 
 | ||||
| 	d.ins.HasReturningID(mi, &query) | ||||
| 
 | ||||
| 	stmt, err := q.Prepare(query) | ||||
| 	stmt, err := q.PrepareContext(ctx, query) | ||||
| 	return stmt, query, err | ||||
| } | ||||
| 
 | ||||
| // insert struct with prepared statement and given struct reflect value. | ||||
| func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { | ||||
| func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { | ||||
| 	values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| @ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, | ||||
| 		err := row.Scan(&id) | ||||
| 		return id, err | ||||
| 	} | ||||
| 	res, err := stmt.Exec(values...) | ||||
| 	res, err := stmt.ExecContext(ctx, values...) | ||||
| 	if err == nil { | ||||
| 		return res.LastInsertId() | ||||
| 	} | ||||
| @ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, | ||||
| } | ||||
| 
 | ||||
| // query sql ,read records and persist in dbBaser. | ||||
| func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { | ||||
| func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { | ||||
| 	var whereCols []string | ||||
| 	var args []interface{} | ||||
| 
 | ||||
| @ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	row := q.QueryRow(query, args...) | ||||
| 	row := q.QueryRowContext(ctx, query, args...) | ||||
| 	if err := row.Scan(refs...); err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
| 			return ErrNoRows | ||||
| @ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo | ||||
| } | ||||
| 
 | ||||
| // execute insert sql dbQuerier with given struct reflect.Value. | ||||
| func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { | ||||
| func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { | ||||
| 	names := make([]string, 0, len(mi.fields.dbcols)) | ||||
| 	values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	id, err := d.InsertValue(q, mi, false, names, values) | ||||
| 	id, err := d.InsertValue(ctx, q, mi, false, names, values) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	if len(autoFields) > 0 { | ||||
| 		err = d.ins.setval(q, mi, autoFields) | ||||
| 		err = d.ins.setval(ctx, q, mi, autoFields) | ||||
| 	} | ||||
| 	return id, err | ||||
| } | ||||
| 
 | ||||
| // multi-insert sql with given slice struct reflect.Value. | ||||
| func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { | ||||
| func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { | ||||
| 	var ( | ||||
| 		cnt    int64 | ||||
| 		nums   int | ||||
| @ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul | ||||
| 		} | ||||
| 
 | ||||
| 		if i > 1 && i%bulk == 0 || length == i { | ||||
| 			num, err := d.InsertValue(q, mi, true, names, values[:nums]) | ||||
| 			num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums]) | ||||
| 			if err != nil { | ||||
| 				return cnt, err | ||||
| 			} | ||||
| @ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul | ||||
| 
 | ||||
| 	var err error | ||||
| 	if len(autoFields) > 0 { | ||||
| 		err = d.ins.setval(q, mi, autoFields) | ||||
| 		err = d.ins.setval(ctx, q, mi, autoFields) | ||||
| 	} | ||||
| 
 | ||||
| 	return cnt, err | ||||
| @ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul | ||||
| 
 | ||||
| // execute insert sql with given struct and given values. | ||||
| // insert the given values, not the field values in struct. | ||||
| func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { | ||||
| func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { | ||||
| 	Q := d.ins.TableQuote() | ||||
| 
 | ||||
| 	marks := make([]string, len(names)) | ||||
| @ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	if isMulti || !d.ins.HasReturningID(mi, &query) { | ||||
| 		res, err := q.Exec(query, values...) | ||||
| 		res, err := q.ExecContext(ctx, query, values...) | ||||
| 		if err == nil { | ||||
| 			if isMulti { | ||||
| 				return res.RowsAffected() | ||||
| @ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s | ||||
| 		} | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	row := q.QueryRow(query, values...) | ||||
| 	row := q.QueryRowContext(ctx, query, values...) | ||||
| 	var id int64 | ||||
| 	err := row.Scan(&id) | ||||
| 	return id, err | ||||
| @ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s | ||||
| // InsertOrUpdate a row | ||||
| // If your primary key or unique column conflict will update | ||||
| // If no will insert | ||||
| func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { | ||||
| func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { | ||||
| 	args0 := "" | ||||
| 	iouStr := "" | ||||
| 	argsMap := map[string]string{} | ||||
| @ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	if isMulti || !d.ins.HasReturningID(mi, &query) { | ||||
| 		res, err := q.Exec(query, values...) | ||||
| 		res, err := q.ExecContext(ctx, query, values...) | ||||
| 		if err == nil { | ||||
| 			if isMulti { | ||||
| 				return res.RowsAffected() | ||||
| @ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	row := q.QueryRow(query, values...) | ||||
| 	row := q.QueryRowContext(ctx, query, values...) | ||||
| 	var id int64 | ||||
| 	err = row.Scan(&id) | ||||
| 	if err != nil && err.Error() == `pq: syntax error at or near "ON"` { | ||||
| @ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a | ||||
| } | ||||
| 
 | ||||
| // execute update sql dbQuerier with given struct reflect.Value. | ||||
| func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { | ||||
| func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { | ||||
| 	pkName, pkValue, ok := getExistPk(mi, ind) | ||||
| 	if !ok { | ||||
| 		return 0, ErrMissPK | ||||
| @ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	res, err := q.Exec(query, setValues...) | ||||
| 	res, err := q.ExecContext(ctx, query, setValues...) | ||||
| 	if err == nil { | ||||
| 		return res.RowsAffected() | ||||
| 	} | ||||
| @ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. | ||||
| 
 | ||||
| // execute delete sql dbQuerier with given struct reflect.Value. | ||||
| // delete index is pk. | ||||
| func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { | ||||
| func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { | ||||
| 	var whereCols []string | ||||
| 	var args []interface{} | ||||
| 	// if specify cols length > 0, then use it for where condition. | ||||
| @ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. | ||||
| 	query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 	res, err := q.Exec(query, args...) | ||||
| 	res, err := q.ExecContext(ctx, query, args...) | ||||
| 	if err == nil { | ||||
| 		num, err := res.RowsAffected() | ||||
| 		if err != nil { | ||||
| @ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. | ||||
| 					ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) | ||||
| 				} | ||||
| 			} | ||||
| 			err := d.deleteRels(q, mi, args, tz) | ||||
| 			err := d.deleteRels(ctx, q, mi, args, tz) | ||||
| 			if err != nil { | ||||
| 				return num, err | ||||
| 			} | ||||
| @ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. | ||||
| 
 | ||||
| // update table-related record by querySet. | ||||
| // need querySet not struct reflect.Value to update related records. | ||||
| func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { | ||||
| func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { | ||||
| 	columns := make([]string, 0, len(params)) | ||||
| 	values := make([]interface{}, 0, len(params)) | ||||
| 	for col, val := range params { | ||||
| @ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con | ||||
| 	} | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 	var err error | ||||
| 	var res sql.Result | ||||
| 	if qs != nil && qs.forContext { | ||||
| 		res, err = q.ExecContext(qs.ctx, query, values...) | ||||
| 	} else { | ||||
| 		res, err = q.Exec(query, values...) | ||||
| 	} | ||||
| 	res, err := q.ExecContext(ctx, query, values...) | ||||
| 	if err == nil { | ||||
| 		return res.RowsAffected() | ||||
| 	} | ||||
| @ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con | ||||
| 
 | ||||
| // delete related records. | ||||
| // do UpdateBanch or DeleteBanch by condition of tables' relationship. | ||||
| func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { | ||||
| func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { | ||||
| 	for _, fi := range mi.fields.fieldsReverse { | ||||
| 		fi = fi.reverseFieldInfo | ||||
| 		switch fi.onDelete { | ||||
| 		case odCascade: | ||||
| 			cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) | ||||
| 			_, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) | ||||
| 			_, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| @ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * | ||||
| 			if fi.onDelete == odSetDefault { | ||||
| 				params[fi.column] = fi.initial.String() | ||||
| 			} | ||||
| 			_, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) | ||||
| 			_, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| @ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * | ||||
| } | ||||
| 
 | ||||
| // delete table-related records. | ||||
| func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { | ||||
| func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { | ||||
| 	tables := newDbTables(mi, d.ins) | ||||
| 	tables.skipEnd = true | ||||
| 
 | ||||
| @ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	var rs *sql.Rows | ||||
| 	r, err := q.Query(query, args...) | ||||
| 	r, err := q.QueryContext(ctx, query, args...) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| @ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con | ||||
| 	query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 	var res sql.Result | ||||
| 	if qs != nil && qs.forContext { | ||||
| 		res, err = q.ExecContext(qs.ctx, query, args...) | ||||
| 	} else { | ||||
| 		res, err = q.Exec(query, args...) | ||||
| 	} | ||||
| 	res, err := q.ExecContext(ctx, query, args...) | ||||
| 	if err == nil { | ||||
| 		num, err := res.RowsAffected() | ||||
| 		if err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 		if num > 0 { | ||||
| 			err := d.deleteRels(q, mi, args, tz) | ||||
| 			err := d.deleteRels(ctx, q, mi, args, tz) | ||||
| 			if err != nil { | ||||
| 				return num, err | ||||
| 			} | ||||
| @ -943,14 +933,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con | ||||
| } | ||||
| 
 | ||||
| // read related records. | ||||
| func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { | ||||
| func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { | ||||
| 
 | ||||
| 	val := reflect.ValueOf(container) | ||||
| 	ind := reflect.Indirect(val) | ||||
| 
 | ||||
| 	errTyp := true | ||||
| 	unregister := true | ||||
| 	one := true | ||||
| 	isPtr := true | ||||
| 	name := "" | ||||
| 
 | ||||
| 	if val.Kind() == reflect.Ptr { | ||||
| 		fn := "" | ||||
| @ -963,19 +954,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi | ||||
| 			case reflect.Struct: | ||||
| 				isPtr = false | ||||
| 				fn = getFullName(typ) | ||||
| 				name = getTableName(reflect.New(typ)) | ||||
| 			} | ||||
| 		} else { | ||||
| 			fn = getFullName(ind.Type()) | ||||
| 			name = getTableName(ind) | ||||
| 		} | ||||
| 		errTyp = fn != mi.fullName | ||||
| 		unregister = fn != mi.fullName | ||||
| 	} | ||||
| 
 | ||||
| 	if errTyp { | ||||
| 		if one { | ||||
| 			panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) | ||||
| 		} else { | ||||
| 			panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) | ||||
| 		} | ||||
| 	if unregister { | ||||
| 		RegisterModel(container) | ||||
| 	} | ||||
| 
 | ||||
| 	rlimit := qs.limit | ||||
| @ -1040,6 +1029,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi | ||||
| 	if qs.distinct { | ||||
| 		sqlSelect += " DISTINCT" | ||||
| 	} | ||||
| 	if qs.aggregate != "" { | ||||
| 		sels = qs.aggregate | ||||
| 	} | ||||
| 	query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", | ||||
| 		sqlSelect, sels, Q, mi.table, Q, | ||||
| 		specifyIndexes, join, where, groupBy, orderBy, limit) | ||||
| @ -1050,18 +1042,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	var rs *sql.Rows | ||||
| 	var err error | ||||
| 	if qs != nil && qs.forContext { | ||||
| 		rs, err = q.QueryContext(qs.ctx, query, args...) | ||||
| 		if err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 	} else { | ||||
| 		rs, err = q.Query(query, args...) | ||||
| 		if err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 	rs, err := q.QueryContext(ctx, query, args...) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	defer rs.Close() | ||||
| 
 | ||||
| 	slice := ind | ||||
| 	if unregister { | ||||
| 		mi, _ = modelCache.get(name) | ||||
| 		tCols = mi.fields.dbcols | ||||
| 		colsNum = len(tCols) | ||||
| 	} | ||||
| 
 | ||||
| 	refs := make([]interface{}, colsNum) | ||||
| @ -1069,11 +1061,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi | ||||
| 		var ref interface{} | ||||
| 		refs[i] = &ref | ||||
| 	} | ||||
| 
 | ||||
| 	defer rs.Close() | ||||
| 
 | ||||
| 	slice := ind | ||||
| 
 | ||||
| 	var cnt int64 | ||||
| 	for rs.Next() { | ||||
| 		if one && cnt == 0 || !one { | ||||
| @ -1172,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi | ||||
| } | ||||
| 
 | ||||
| // excute count sql and return count result int64. | ||||
| func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { | ||||
| func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { | ||||
| 	tables := newDbTables(mi, d.ins) | ||||
| 	tables.parseRelated(qs.related, qs.relDepth) | ||||
| 
 | ||||
| @ -1194,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	var row *sql.Row | ||||
| 	if qs != nil && qs.forContext { | ||||
| 		row = q.QueryRowContext(qs.ctx, query, args...) | ||||
| 	} else { | ||||
| 		row = q.QueryRow(query, args...) | ||||
| 	} | ||||
| 	row := q.QueryRowContext(ctx, query, args...) | ||||
| 	err = row.Scan(&cnt) | ||||
| 	return | ||||
| } | ||||
| @ -1649,7 +1631,7 @@ setValue: | ||||
| } | ||||
| 
 | ||||
| // query sql, read values , save to *[]ParamList. | ||||
| func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { | ||||
| func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { | ||||
| 
 | ||||
| 	var ( | ||||
| 		maps  []Params | ||||
| @ -1732,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	rs, err := q.Query(query, args...) | ||||
| 	rs, err := q.QueryContext(ctx, query, args...) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| @ -1847,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool { | ||||
| } | ||||
| 
 | ||||
| // sync auto key | ||||
| func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { | ||||
| func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| @ -1892,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { | ||||
| } | ||||
| 
 | ||||
| // get all cloumns in table. | ||||
| func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { | ||||
| func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { | ||||
| 	columns := make(map[string][3]string) | ||||
| 	query := d.ins.ShowColumnsQuery(table) | ||||
| 	rows, err := db.Query(query) | ||||
| 	rows, err := db.QueryContext(ctx, query) | ||||
| 	if err != nil { | ||||
| 		return columns, err | ||||
| 	} | ||||
| @ -1934,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string { | ||||
| } | ||||
| 
 | ||||
| // not implement. | ||||
| func (d *dbBase) IndexExists(dbQuerier, string, string) bool { | ||||
| func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool { | ||||
| 	panic(ErrNotImplement) | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| @ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string { | ||||
| } | ||||
| 
 | ||||
| // execute sql to check index exist. | ||||
| func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| 	row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ | ||||
| func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { | ||||
| 	row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ | ||||
| 		"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) | ||||
| 	var cnt int | ||||
| 	row.Scan(&cnt) | ||||
| @ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool | ||||
| // If your primary key or unique column conflict will update | ||||
| // If no will insert | ||||
| // Add "`" for mysql sql building | ||||
| func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { | ||||
| func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { | ||||
| 	var iouStr string | ||||
| 	argsMap := map[string]string{} | ||||
| 
 | ||||
| @ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	if isMulti || !d.ins.HasReturningID(mi, &query) { | ||||
| 		res, err := q.Exec(query, values...) | ||||
| 		res, err := q.ExecContext(ctx, query, values...) | ||||
| 		if err == nil { | ||||
| 			if isMulti { | ||||
| 				return res.RowsAffected() | ||||
| @ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	row := q.QueryRow(query, values...) | ||||
| 	row := q.QueryRowContext(ctx, query, values...) | ||||
| 	var id int64 | ||||
| 	err = row.Scan(&id) | ||||
| 	return id, err | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 
 | ||||
| @ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string { | ||||
| } | ||||
| 
 | ||||
| // check index is exist | ||||
| func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| 	row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ | ||||
| func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { | ||||
| 	row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ | ||||
| 		"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ | ||||
| 		"AND  USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) | ||||
| 
 | ||||
| @ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde | ||||
| 
 | ||||
| // execute insert sql with given struct and given values. | ||||
| // insert the given values, not the field values in struct. | ||||
| func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { | ||||
| func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { | ||||
| 	Q := d.ins.TableQuote() | ||||
| 
 | ||||
| 	marks := make([]string, len(names)) | ||||
| @ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	if isMulti || !d.ins.HasReturningID(mi, &query) { | ||||
| 		res, err := q.Exec(query, values...) | ||||
| 		res, err := q.ExecContext(ctx, query, values...) | ||||
| 		if err == nil { | ||||
| 			if isMulti { | ||||
| 				return res.RowsAffected() | ||||
| @ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam | ||||
| 		} | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	row := q.QueryRow(query, values...) | ||||
| 	row := q.QueryRowContext(ctx, query, values...) | ||||
| 	var id int64 | ||||
| 	err := row.Scan(&id) | ||||
| 	return id, err | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| ) | ||||
| @ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { | ||||
| } | ||||
| 
 | ||||
| // sync auto key | ||||
| func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { | ||||
| func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { | ||||
| 	if len(autoFields) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| @ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string | ||||
| 			mi.table, name, | ||||
| 			Q, name, Q, | ||||
| 			Q, mi.table, Q) | ||||
| 		if _, err := db.Exec(query); err != nil { | ||||
| 		if _, err := db.ExecContext(ctx, query); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| @ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string { | ||||
| } | ||||
| 
 | ||||
| // check index exist in postgresql. | ||||
| func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { | ||||
| 	query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) | ||||
| 	row := db.QueryRow(query) | ||||
| 	row := db.QueryRowContext(ctx, query) | ||||
| 	var cnt int | ||||
| 	row.Scan(&cnt) | ||||
| 	return cnt > 0 | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| @ -73,11 +74,11 @@ type dbBaseSqlite struct { | ||||
| var _ dbBaser = new(dbBaseSqlite) | ||||
| 
 | ||||
| // override base db read for update behavior as SQlite does not support syntax | ||||
| func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { | ||||
| func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { | ||||
| 	if isForUpdate { | ||||
| 		DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") | ||||
| 	} | ||||
| 	return d.dbBase.Read(q, mi, ind, tz, cols, false) | ||||
| 	return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false) | ||||
| } | ||||
| 
 | ||||
| // get sqlite operator. | ||||
| @ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string { | ||||
| } | ||||
| 
 | ||||
| // get columns in sqlite. | ||||
| func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { | ||||
| func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { | ||||
| 	query := d.ins.ShowColumnsQuery(table) | ||||
| 	rows, err := db.Query(query) | ||||
| 	rows, err := db.QueryContext(ctx, query) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { | ||||
| } | ||||
| 
 | ||||
| // check index exist in sqlite. | ||||
| func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { | ||||
| 	query := fmt.Sprintf("PRAGMA index_list('%s')", table) | ||||
| 	rows, err := db.Query(query) | ||||
| 	rows, err := db.QueryContext(ctx, query) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
|  | ||||
| @ -16,6 +16,8 @@ package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/beego/beego/v2/client/orm/clauses" | ||||
| 	"github.com/beego/beego/v2/client/orm/clauses/order_clause" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| @ -421,7 +423,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { | ||||
| } | ||||
| 
 | ||||
| // generate order sql. | ||||
| func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { | ||||
| func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) { | ||||
| 	if len(orders) == 0 { | ||||
| 		return | ||||
| 	} | ||||
| @ -430,19 +432,25 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { | ||||
| 
 | ||||
| 	orderSqls := make([]string, 0, len(orders)) | ||||
| 	for _, order := range orders { | ||||
| 		asc := "ASC" | ||||
| 		if order[0] == '-' { | ||||
| 			asc = "DESC" | ||||
| 			order = order[1:] | ||||
| 		} | ||||
| 		exprs := strings.Split(order, ExprSep) | ||||
| 		column := order.GetColumn() | ||||
| 		clause := strings.Split(column, clauses.ExprDot) | ||||
| 
 | ||||
| 		index, _, fi, suc := t.parseExprs(t.mi, exprs) | ||||
| 		if !suc { | ||||
| 			panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) | ||||
| 		} | ||||
| 		if order.IsRaw() { | ||||
| 			if len(clause) == 2 { | ||||
| 				orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString())) | ||||
| 			} else if len(clause) == 1 { | ||||
| 				orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString())) | ||||
| 			} else { | ||||
| 				panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) | ||||
| 			} | ||||
| 		} else { | ||||
| 			index, _, fi, suc := t.parseExprs(t.mi, clause) | ||||
| 			if !suc { | ||||
| 				panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) | ||||
| 			} | ||||
| 
 | ||||
| 		orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) | ||||
| 			orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString())) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| ) | ||||
| 
 | ||||
| @ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string { | ||||
| } | ||||
| 
 | ||||
| // execute sql to check index exist. | ||||
| func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| 	row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ | ||||
| func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { | ||||
| 	row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ | ||||
| 		"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) | ||||
| 	var cnt int | ||||
| 	row.Scan(&cnt) | ||||
|  | ||||
| @ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // NOTE: this method is deprecated, context parameter will not take effect. | ||||
| func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { | ||||
| 	return nil | ||||
| } | ||||
| @ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // NOTE: this method is deprecated, context parameter will not take effect. | ||||
| func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) { | ||||
| 
 | ||||
| 	assert.Nil(t, o.Driver()) | ||||
| 
 | ||||
| 	assert.Nil(t, o.QueryM2MWithCtx(nil, nil, "")) | ||||
| 	assert.Nil(t, o.QueryM2M(nil, "")) | ||||
| 	assert.Nil(t, o.ReadWithCtx(nil, nil)) | ||||
| 	assert.Nil(t, o.Read(nil)) | ||||
| @ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) { | ||||
| 	assert.Nil(t, err) | ||||
| 	assert.Equal(t, int64(0), i) | ||||
| 
 | ||||
| 	assert.Nil(t, o.QueryTableWithCtx(nil, nil)) | ||||
| 	assert.Nil(t, o.QueryTable(nil)) | ||||
| 
 | ||||
| 	assert.Nil(t, o.Read(nil)) | ||||
|  | ||||
| @ -27,7 +27,7 @@ import ( | ||||
| // this Filter's behavior looks a little bit strange | ||||
| // for example: | ||||
| // if we want to trace QuerySetter | ||||
| // actually we trace invoking "QueryTable" and "QueryTableWithCtx" | ||||
| // actually we trace invoking "QueryTable" | ||||
| // the method Begin*, Commit and Rollback are ignored. | ||||
| // When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. | ||||
| type FilterChainBuilder struct { | ||||
|  | ||||
| @ -31,7 +31,7 @@ import ( | ||||
| // this Filter's behavior looks a little bit strange | ||||
| // for example: | ||||
| // if we want to records the metrics of QuerySetter | ||||
| // actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx" | ||||
| // actually we only records metrics of invoking "QueryTable" | ||||
| type FilterChainBuilder struct { | ||||
| 	summaryVec prometheus.ObserverVec | ||||
| 	AppName    string | ||||
|  | ||||
| @ -20,6 +20,7 @@ import ( | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/beego/beego/v2/core/logs" | ||||
| 	"github.com/beego/beego/v2/core/utils" | ||||
| ) | ||||
| 
 | ||||
| @ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac | ||||
| } | ||||
| 
 | ||||
| func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { | ||||
| 	return f.QueryM2MWithCtx(context.Background(), md, name) | ||||
| } | ||||
| 
 | ||||
| func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { | ||||
| 
 | ||||
| 	mi, _ := modelCache.getByMd(md) | ||||
| 	inv := &Invocation{ | ||||
| 		Method:      "QueryM2MWithCtx", | ||||
| 		Method:      "QueryM2M", | ||||
| 		Args:        []interface{}{md, name}, | ||||
| 		Md:          md, | ||||
| 		mi:          mi, | ||||
| 		InsideTx:    f.insideTx, | ||||
| 		TxStartTime: f.txStartTime, | ||||
| 		f: func(c context.Context) []interface{} { | ||||
| 			res := f.ormer.QueryM2MWithCtx(c, md, name) | ||||
| 			res := f.ormer.QueryM2M(md, name) | ||||
| 			return []interface{}{res} | ||||
| 		}, | ||||
| 	} | ||||
| 	res := f.root(ctx, inv) | ||||
| 	res := f.root(context.Background(), inv) | ||||
| 	if res[0] == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return res[0].(QueryM2Mer) | ||||
| } | ||||
| 
 | ||||
| func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { | ||||
| 	return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName) | ||||
| // NOTE: this method is deprecated, context parameter will not take effect. | ||||
| func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { | ||||
| 	logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.") | ||||
| 	return f.QueryM2M(md, name) | ||||
| } | ||||
| 
 | ||||
| func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { | ||||
| func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { | ||||
| 	var ( | ||||
| 		name string | ||||
| 		md   interface{} | ||||
| @ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT | ||||
| 	} | ||||
| 
 | ||||
| 	inv := &Invocation{ | ||||
| 		Method:      "QueryTableWithCtx", | ||||
| 		Method:      "QueryTable", | ||||
| 		Args:        []interface{}{ptrStructOrTableName}, | ||||
| 		InsideTx:    f.insideTx, | ||||
| 		TxStartTime: f.txStartTime, | ||||
| 		Md:          md, | ||||
| 		mi:          mi, | ||||
| 		f: func(c context.Context) []interface{} { | ||||
| 			res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) | ||||
| 			res := f.ormer.QueryTable(ptrStructOrTableName) | ||||
| 			return []interface{}{res} | ||||
| 		}, | ||||
| 	} | ||||
| 	res := f.root(ctx, inv) | ||||
| 	res := f.root(context.Background(), inv) | ||||
| 
 | ||||
| 	if res[0] == nil { | ||||
| 		return nil | ||||
| @ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT | ||||
| 	return res[0].(QuerySeter) | ||||
| } | ||||
| 
 | ||||
| // NOTE: this method is deprecated, context parameter will not take effect. | ||||
| func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter { | ||||
| 	logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.") | ||||
| 	return f.QueryTable(ptrStructOrTableName) | ||||
| } | ||||
| 
 | ||||
| func (f *filterOrmDecorator) DBStats() *sql.DBStats { | ||||
| 	inv := &Invocation{ | ||||
| 		Method:      "DBStats", | ||||
|  | ||||
| @ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) { | ||||
| 	o := &filterMockOrm{} | ||||
| 	od := NewFilterOrmDecorator(o, func(next Filter) Filter { | ||||
| 		return func(ctx context.Context, inv *Invocation) []interface{} { | ||||
| 			assert.Equal(t, "QueryM2MWithCtx", inv.Method) | ||||
| 			assert.Equal(t, "QueryM2M", inv.Method) | ||||
| 			assert.Equal(t, 2, len(inv.Args)) | ||||
| 			assert.Equal(t, "FILTER_TEST", inv.GetTableName()) | ||||
| 			assert.False(t, inv.InsideTx) | ||||
| @ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) { | ||||
| 	o := &filterMockOrm{} | ||||
| 	od := NewFilterOrmDecorator(o, func(next Filter) Filter { | ||||
| 		return func(ctx context.Context, inv *Invocation) []interface{} { | ||||
| 			assert.Equal(t, "QueryTableWithCtx", inv.Method) | ||||
| 			assert.Equal(t, "QueryTable", inv.Method) | ||||
| 			assert.Equal(t, 1, len(inv.Args)) | ||||
| 			assert.Equal(t, "FILTER_TEST", inv.GetTableName()) | ||||
| 			assert.False(t, inv.InsideTx) | ||||
|  | ||||
| @ -332,10 +332,6 @@ end: | ||||
| 
 | ||||
| // register register models to model cache | ||||
| func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { | ||||
| 	if mc.done { | ||||
| 		err = fmt.Errorf("register must be run before BootStrap") | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	for _, model := range models { | ||||
| 		val := reflect.ValueOf(model) | ||||
| @ -352,7 +348,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m | ||||
| 			err = fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if val.Elem().Kind() == reflect.Slice { | ||||
| 			val = reflect.New(val.Elem().Type().Elem()) | ||||
| 		} | ||||
| 		table := getTableName(val) | ||||
| 
 | ||||
| 		if prefixOrSuffixStr != "" { | ||||
| @ -371,8 +369,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m | ||||
| 		} | ||||
| 
 | ||||
| 		if _, ok := mc.get(table); ok { | ||||
| 			err = fmt.Errorf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table) | ||||
| 			return | ||||
| 			return	nil | ||||
| 		} | ||||
| 
 | ||||
| 		mi := newModelInfo(val) | ||||
| @ -389,12 +386,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if mi.fields.pk == nil { | ||||
| 				err = fmt.Errorf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name) | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 		} | ||||
| 
 | ||||
| 		mi.table = table | ||||
|  | ||||
| @ -255,6 +255,22 @@ func NewTM() *TM { | ||||
| 	return obj | ||||
| } | ||||
| 
 | ||||
| type DeptInfo struct { | ||||
| 	ID           int       `orm:"column(id)"` | ||||
| 	Created      time.Time `orm:"auto_now_add"` | ||||
| 	DeptName     string | ||||
| 	EmployeeName string | ||||
| 	Salary       int | ||||
| } | ||||
| 
 | ||||
| type UnregisterModel struct { | ||||
| 	ID           int       `orm:"column(id)"` | ||||
| 	Created      time.Time `orm:"auto_now_add"` | ||||
| 	DeptName     string | ||||
| 	EmployeeName string | ||||
| 	Salary       int | ||||
| } | ||||
| 
 | ||||
| type User struct { | ||||
| 	ID           int    `orm:"column(id)"` | ||||
| 	UserName     string `orm:"size(30);unique"` | ||||
| @ -476,45 +492,45 @@ var ( | ||||
| 	helpinfo = `need driver and source! | ||||
| 
 | ||||
| 	Default DB Drivers. | ||||
| 	 | ||||
| 
 | ||||
| 	  driver: url | ||||
| 	   mysql: https://github.com/go-sql-driver/mysql | ||||
| 	 sqlite3: https://github.com/mattn/go-sqlite3 | ||||
| 	postgres: https://github.com/lib/pq | ||||
| 	tidb: https://github.com/pingcap/tidb | ||||
| 	 | ||||
| 
 | ||||
| 	usage: | ||||
| 	 | ||||
| 
 | ||||
| 	go get -u github.com/beego/beego/v2/client/orm | ||||
| 	go get -u github.com/go-sql-driver/mysql | ||||
| 	go get -u github.com/mattn/go-sqlite3 | ||||
| 	go get -u github.com/lib/pq | ||||
| 	go get -u github.com/pingcap/tidb | ||||
| 	 | ||||
| 
 | ||||
| 	#### MySQL | ||||
| 	mysql -u root -e 'create database orm_test;' | ||||
| 	export ORM_DRIVER=mysql | ||||
| 	export ORM_SOURCE="root:@/orm_test?charset=utf8" | ||||
| 	go test -v github.com/beego/beego/v2/client/orm | ||||
| 	 | ||||
| 	 | ||||
| 
 | ||||
| 
 | ||||
| 	#### Sqlite3 | ||||
| 	export ORM_DRIVER=sqlite3 | ||||
| 	export ORM_SOURCE='file:memory_test?mode=memory' | ||||
| 	go test -v github.com/beego/beego/v2/client/orm | ||||
| 	 | ||||
| 	 | ||||
| 
 | ||||
| 
 | ||||
| 	#### PostgreSQL | ||||
| 	psql -c 'create database orm_test;' -U postgres | ||||
| 	export ORM_DRIVER=postgres | ||||
| 	export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" | ||||
| 	go test -v github.com/beego/beego/v2/client/orm | ||||
| 	 | ||||
| 
 | ||||
| 	#### TiDB | ||||
| 	export ORM_DRIVER=tidb | ||||
| 	export ORM_SOURCE='memory://test/test' | ||||
| 	go test -v github.com/beego/beego/v2/pgk/orm | ||||
| 	 | ||||
| 
 | ||||
| 	` | ||||
| ) | ||||
| 
 | ||||
|  | ||||
| @ -109,6 +109,9 @@ func getTableUnique(val reflect.Value) [][]string { | ||||
| 
 | ||||
| // get whether the table needs to be created for the database alias | ||||
| func isApplicableTableForDB(val reflect.Value, db string) bool { | ||||
| 	if !val.IsValid() { | ||||
| 		return true | ||||
| 	} | ||||
| 	fun := val.MethodByName("IsApplicableTableForDB") | ||||
| 	if fun.IsValid() { | ||||
| 		vals := fun.Call([]reflect.Value{reflect.ValueOf(db)}) | ||||
|  | ||||
| @ -58,6 +58,7 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/beego/beego/v2/client/orm/clauses/order_clause" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| @ -135,7 +136,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error { | ||||
| } | ||||
| func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) | ||||
| 	return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) | ||||
| } | ||||
| 
 | ||||
| // read data to model, like Read(), but use "SELECT FOR UPDATE" form | ||||
| @ -144,7 +145,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { | ||||
| } | ||||
| func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) | ||||
| 	return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true) | ||||
| } | ||||
| 
 | ||||
| // Try to read a row from the database, or insert one if it doesn't exist | ||||
| @ -154,7 +155,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo | ||||
| func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { | ||||
| 	cols = append([]string{col1}, cols...) | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) | ||||
| 	err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) | ||||
| 	if err == ErrNoRows { | ||||
| 		// Create | ||||
| 		id, err := o.InsertWithCtx(ctx, md) | ||||
| @ -179,7 +180,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) { | ||||
| } | ||||
| func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) | ||||
| 	id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) | ||||
| 	if err != nil { | ||||
| 		return id, err | ||||
| 	} | ||||
| @ -222,7 +223,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac | ||||
| 		for i := 0; i < sind.Len(); i++ { | ||||
| 			ind := reflect.Indirect(sind.Index(i)) | ||||
| 			mi, _ := o.getMiInd(ind.Interface(), false) | ||||
| 			id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) | ||||
| 			id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) | ||||
| 			if err != nil { | ||||
| 				return cnt, err | ||||
| 			} | ||||
| @ -233,7 +234,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac | ||||
| 		} | ||||
| 	} else { | ||||
| 		mi, _ := o.getMiInd(sind.Index(0).Interface(), false) | ||||
| 		return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) | ||||
| 		return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ) | ||||
| 	} | ||||
| 	return cnt, nil | ||||
| } | ||||
| @ -244,7 +245,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) ( | ||||
| } | ||||
| func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) | ||||
| 	id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) | ||||
| 	if err != nil { | ||||
| 		return id, err | ||||
| 	} | ||||
| @ -261,7 +262,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { | ||||
| } | ||||
| func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) | ||||
| 	return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols) | ||||
| } | ||||
| 
 | ||||
| // delete model in database | ||||
| @ -271,7 +272,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { | ||||
| } | ||||
| func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) | ||||
| 	num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols) | ||||
| 	if err != nil { | ||||
| 		return num, err | ||||
| 	} | ||||
| @ -283,9 +284,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str | ||||
| 
 | ||||
| // create a models to models queryer | ||||
| func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { | ||||
| 	return o.QueryM2MWithCtx(context.Background(), md, name) | ||||
| } | ||||
| func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	fi := o.getFieldInfo(mi, name) | ||||
| 
 | ||||
| @ -299,6 +297,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri | ||||
| 	return newQueryM2M(md, o, mi, fi, ind) | ||||
| } | ||||
| 
 | ||||
| // NOTE: this method is deprecated, context parameter will not take effect. | ||||
| func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { | ||||
| 	logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.") | ||||
| 	return o.QueryM2M(md, name) | ||||
| } | ||||
| 
 | ||||
| // load related models to md model. | ||||
| // args are limit, offset int and order string. | ||||
| // | ||||
| @ -351,7 +355,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s | ||||
| 	qs.relDepth = relDepth | ||||
| 
 | ||||
| 	if len(order) > 0 { | ||||
| 		qs.orders = []string{order} | ||||
| 		qs.orders = order_clause.ParseOrder(order) | ||||
| 	} | ||||
| 
 | ||||
| 	find := ind.FieldByIndex(fi.fieldIndex) | ||||
| @ -451,9 +455,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS | ||||
| // table name can be string or struct. | ||||
| // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), | ||||
| func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { | ||||
| 	return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName) | ||||
| } | ||||
| func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { | ||||
| 	var name string | ||||
| 	if table, ok := ptrStructOrTableName.(string); ok { | ||||
| 		name = nameStrategyMap[defaultNameStrategy](table) | ||||
| @ -469,7 +470,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in | ||||
| 	if qs == nil { | ||||
| 		panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name)) | ||||
| 	} | ||||
| 	return | ||||
| 	return qs | ||||
| } | ||||
| 
 | ||||
| // NOTE: this method is deprecated, context parameter will not take effect. | ||||
| func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { | ||||
| 	logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.") | ||||
| 	return o.QueryTable(ptrStructOrTableName) | ||||
| } | ||||
| 
 | ||||
| // return a raw query seter for raw sql string. | ||||
| @ -595,9 +602,8 @@ func NewOrm() Ormer { | ||||
| func NewOrmUsingDB(aliasName string) Ormer { | ||||
| 	if al, ok := dataBaseCache.get(aliasName); ok { | ||||
| 		return newDBWithAlias(al) | ||||
| 	} else { | ||||
| 		panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName)) | ||||
| 	} | ||||
| 	panic(fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", aliasName)) | ||||
| } | ||||
| 
 | ||||
| // NewOrmWithDB create a new ormer object with specify *sql.DB for query | ||||
|  | ||||
| @ -16,12 +16,13 @@ package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/beego/beego/v2/client/orm/clauses" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // ExprSep define the expression separation | ||||
| const ( | ||||
| 	ExprSep = "__" | ||||
| 	ExprSep = clauses.ExprSep | ||||
| ) | ||||
| 
 | ||||
| type condValue struct { | ||||
|  | ||||
| @ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error { | ||||
| } | ||||
| 
 | ||||
| func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { | ||||
| 	return d.ExecContext(context.Background(), args...) | ||||
| } | ||||
| 
 | ||||
| func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { | ||||
| 	a := time.Now() | ||||
| 	res, err := d.stmt.Exec(args...) | ||||
| 	res, err := d.stmt.ExecContext(ctx, args...) | ||||
| 	debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) | ||||
| 	return res, err | ||||
| } | ||||
| 
 | ||||
| func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { | ||||
| 	return d.QueryContext(context.Background(), args...) | ||||
| } | ||||
| 
 | ||||
| func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { | ||||
| 	a := time.Now() | ||||
| 	res, err := d.stmt.Query(args...) | ||||
| 	res, err := d.stmt.QueryContext(ctx, args...) | ||||
| 	debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) | ||||
| 	return res, err | ||||
| } | ||||
| 
 | ||||
| func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { | ||||
| 	return d.QueryRowContext(context.Background(), args...) | ||||
| } | ||||
| 
 | ||||
| func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { | ||||
| 	a := time.Now() | ||||
| 	res := d.stmt.QueryRow(args...) | ||||
| 	debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| ) | ||||
| @ -31,6 +32,10 @@ var _ Inserter = new(insertSet) | ||||
| 
 | ||||
| // insert model ignore it's registered or not. | ||||
| func (o *insertSet) Insert(md interface{}) (int64, error) { | ||||
| 	return o.InsertWithCtx(context.Background(), md) | ||||
| } | ||||
| 
 | ||||
| func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { | ||||
| 	if o.closed { | ||||
| 		return 0, ErrStmtClosed | ||||
| 	} | ||||
| @ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { | ||||
| 	if name != o.mi.fullName { | ||||
| 		panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name)) | ||||
| 	} | ||||
| 	id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) | ||||
| 	id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) | ||||
| 	if err != nil { | ||||
| 		return id, err | ||||
| 	} | ||||
| @ -70,11 +75,11 @@ func (o *insertSet) Close() error { | ||||
| } | ||||
| 
 | ||||
| // create new insert queryer. | ||||
| func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { | ||||
| func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { | ||||
| 	bi := new(insertSet) | ||||
| 	bi.orm = orm | ||||
| 	bi.mi = mi | ||||
| 	st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) | ||||
| 	st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| @ -14,7 +14,10 @@ | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import "reflect" | ||||
| import ( | ||||
| 	"context" | ||||
| 	"reflect" | ||||
| ) | ||||
| 
 | ||||
| // model to model struct | ||||
| type queryM2M struct { | ||||
| @ -33,6 +36,10 @@ type queryM2M struct { | ||||
| // | ||||
| // make sure the relation is defined in post model struct tag. | ||||
| func (o *queryM2M) Add(mds ...interface{}) (int64, error) { | ||||
| 	return o.AddWithCtx(context.Background(), mds...) | ||||
| } | ||||
| 
 | ||||
| func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { | ||||
| 	fi := o.fi | ||||
| 	mi := fi.relThroughModelInfo | ||||
| 	mfi := fi.reverseFieldInfo | ||||
| @ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { | ||||
| 	} | ||||
| 	names = append(names, otherNames...) | ||||
| 	values = append(values, otherValues...) | ||||
| 	return dbase.InsertValue(orm.db, mi, true, names, values) | ||||
| 	return dbase.InsertValue(ctx, orm.db, mi, true, names, values) | ||||
| } | ||||
| 
 | ||||
| // remove models following the origin model relationship | ||||
| func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { | ||||
| 	return o.RemoveWithCtx(context.Background(), mds...) | ||||
| } | ||||
| 
 | ||||
| func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { | ||||
| 	fi := o.fi | ||||
| 	qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) | ||||
| 
 | ||||
| @ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { | ||||
| 
 | ||||
| // check model is existed in relationship of origin model | ||||
| func (o *queryM2M) Exist(md interface{}) bool { | ||||
| 	return o.ExistWithCtx(context.Background(), md) | ||||
| } | ||||
| 
 | ||||
| func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool { | ||||
| 	fi := o.fi | ||||
| 	return o.qs.Filter(fi.reverseFieldInfo.name, o.md). | ||||
| 		Filter(fi.reverseFieldInfoTwo.name, md).Exist() | ||||
| 		Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx) | ||||
| } | ||||
| 
 | ||||
| // clean all models in related of origin model | ||||
| func (o *queryM2M) Clear() (int64, error) { | ||||
| 	return o.ClearWithCtx(context.Background()) | ||||
| } | ||||
| 
 | ||||
| func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) { | ||||
| 	fi := o.fi | ||||
| 	return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() | ||||
| 	return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx) | ||||
| } | ||||
| 
 | ||||
| // count all related models of origin model | ||||
| func (o *queryM2M) Count() (int64, error) { | ||||
| 	return o.CountWithCtx(context.Background()) | ||||
| } | ||||
| 
 | ||||
| func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) { | ||||
| 	fi := o.fi | ||||
| 	return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() | ||||
| 	return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx) | ||||
| } | ||||
| 
 | ||||
| var _ QueryM2Mer = new(queryM2M) | ||||
|  | ||||
| @ -18,6 +18,7 @@ import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"github.com/beego/beego/v2/client/orm/clauses/order_clause" | ||||
| 	"github.com/beego/beego/v2/client/orm/hints" | ||||
| ) | ||||
| 
 | ||||
| @ -64,21 +65,20 @@ func ColValue(opt operator, value interface{}) interface{} { | ||||
| 
 | ||||
| // real query struct | ||||
| type querySet struct { | ||||
| 	mi         *modelInfo | ||||
| 	cond       *Condition | ||||
| 	related    []string | ||||
| 	relDepth   int | ||||
| 	limit      int64 | ||||
| 	offset     int64 | ||||
| 	groups     []string | ||||
| 	orders     []string | ||||
| 	distinct   bool | ||||
| 	forUpdate  bool | ||||
| 	useIndex   int | ||||
| 	indexes    []string | ||||
| 	orm        *ormBase | ||||
| 	ctx        context.Context | ||||
| 	forContext bool | ||||
| 	mi        *modelInfo | ||||
| 	cond      *Condition | ||||
| 	related   []string | ||||
| 	relDepth  int | ||||
| 	limit     int64 | ||||
| 	offset    int64 | ||||
| 	groups    []string | ||||
| 	orders    []*order_clause.Order | ||||
| 	distinct  bool | ||||
| 	forUpdate bool | ||||
| 	useIndex  int | ||||
| 	indexes   []string | ||||
| 	orm       *ormBase | ||||
| 	aggregate string | ||||
| } | ||||
| 
 | ||||
| var _ QuerySeter = new(querySet) | ||||
| @ -139,8 +139,20 @@ func (o querySet) GroupBy(exprs ...string) QuerySeter { | ||||
| 
 | ||||
| // add ORDER expression. | ||||
| // "column" means ASC, "-column" means DESC. | ||||
| func (o querySet) OrderBy(exprs ...string) QuerySeter { | ||||
| 	o.orders = exprs | ||||
| func (o querySet) OrderBy(expressions ...string) QuerySeter { | ||||
| 	if len(expressions) <= 0 { | ||||
| 		return &o | ||||
| 	} | ||||
| 	o.orders = order_clause.ParseOrder(expressions...) | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // add ORDER expression. | ||||
| func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter { | ||||
| 	if len(orders) <= 0 { | ||||
| 		return &o | ||||
| 	} | ||||
| 	o.orders = orders | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| @ -210,23 +222,39 @@ func (o querySet) GetCond() *Condition { | ||||
| 
 | ||||
| // return QuerySeter execution result number | ||||
| func (o *querySet) Count() (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) | ||||
| 	return o.CountWithCtx(context.Background()) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // check result empty or not after QuerySeter executed | ||||
| func (o *querySet) Exist() bool { | ||||
| 	cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) | ||||
| 	return o.ExistWithCtx(context.Background()) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) ExistWithCtx(ctx context.Context) bool { | ||||
| 	cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) | ||||
| 	return cnt > 0 | ||||
| } | ||||
| 
 | ||||
| // execute update with parameters | ||||
| func (o *querySet) Update(values Params) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) | ||||
| 	return o.UpdateWithCtx(context.Background(), values) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // execute delete | ||||
| func (o *querySet) Delete() (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) | ||||
| 	return o.DeleteWithCtx(context.Background()) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // return a insert queryer. | ||||
| @ -235,20 +263,32 @@ func (o *querySet) Delete() (int64, error) { | ||||
| // 	i,err := sq.PrepareInsert() | ||||
| // 	i.Add(&user1{},&user2{}) | ||||
| func (o *querySet) PrepareInsert() (Inserter, error) { | ||||
| 	return newInsertSet(o.orm, o.mi) | ||||
| 	return o.PrepareInsertWithCtx(context.Background()) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) { | ||||
| 	return newInsertSet(ctx, o.orm, o.mi) | ||||
| } | ||||
| 
 | ||||
| // query all data and map to containers. | ||||
| // cols means the columns when querying. | ||||
| func (o *querySet) All(container interface{}, cols ...string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) | ||||
| 	return o.AllWithCtx(context.Background(), container, cols...) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) | ||||
| } | ||||
| 
 | ||||
| // query one row data and map to containers. | ||||
| // cols means the columns when querying. | ||||
| func (o *querySet) One(container interface{}, cols ...string) error { | ||||
| 	return o.OneWithCtx(context.Background(), container, cols...) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { | ||||
| 	o.limit = 1 | ||||
| 	num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) | ||||
| 	num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @ -266,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error { | ||||
| // expres means condition expression. | ||||
| // it converts data to []map[column]value. | ||||
| func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) | ||||
| 	return o.ValuesWithCtx(context.Background(), results, exprs...) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // query all data and map to [][]interface | ||||
| // it converts data to [][column_index]value | ||||
| func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) | ||||
| 	return o.ValuesListWithCtx(context.Background(), results, exprs...) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // query all data and map to []interface. | ||||
| // it's designed for one row record set, auto change to []value, not [][column]value. | ||||
| func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) | ||||
| 	return o.ValuesFlatWithCtx(context.Background(), result, expr) | ||||
| } | ||||
| 
 | ||||
| func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // query all rows into map[string]interface with specify key and value column name. | ||||
| @ -309,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) | ||||
| 	panic(ErrNotImplement) | ||||
| } | ||||
| 
 | ||||
| // set context to QuerySeter. | ||||
| func (o querySet) WithContext(ctx context.Context) QuerySeter { | ||||
| 	o.ctx = ctx | ||||
| 	o.forContext = true | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // create new QuerySeter. | ||||
| func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { | ||||
| 	o := new(querySet) | ||||
| @ -323,3 +368,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { | ||||
| 	o.orm = orm | ||||
| 	return o | ||||
| } | ||||
| 
 | ||||
| // aggregate func | ||||
| func (o querySet) Aggregate(s string) QuerySeter { | ||||
| 	o.aggregate = s | ||||
| 	return &o | ||||
| } | ||||
|  | ||||
| @ -21,6 +21,7 @@ import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"github.com/beego/beego/v2/client/orm/clauses/order_clause" | ||||
| 	"io/ioutil" | ||||
| 	"math" | ||||
| 	"os" | ||||
| @ -205,6 +206,7 @@ func TestSyncDb(t *testing.T) { | ||||
| 	RegisterModel(new(Index)) | ||||
| 	RegisterModel(new(StrPk)) | ||||
| 	RegisterModel(new(TM)) | ||||
| 	RegisterModel(new(DeptInfo)) | ||||
| 
 | ||||
| 	err := RunSyncdb("default", true, Debug) | ||||
| 	throwFail(t, err) | ||||
| @ -232,6 +234,7 @@ func TestRegisterModels(t *testing.T) { | ||||
| 	RegisterModel(new(Index)) | ||||
| 	RegisterModel(new(StrPk)) | ||||
| 	RegisterModel(new(TM)) | ||||
| 	RegisterModel(new(DeptInfo)) | ||||
| 
 | ||||
| 	BootStrap() | ||||
| 
 | ||||
| @ -333,6 +336,73 @@ func TestTM(t *testing.T) { | ||||
| 	throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC")) | ||||
| } | ||||
| 
 | ||||
| func TestUnregisterModel(t *testing.T) { | ||||
| 	data := []*DeptInfo{ | ||||
| 		{ | ||||
| 			DeptName:     "A", | ||||
| 			EmployeeName: "A1", | ||||
| 			Salary:       1000, | ||||
| 		}, | ||||
| 		{ | ||||
| 			DeptName:     "A", | ||||
| 			EmployeeName: "A2", | ||||
| 			Salary:       2000, | ||||
| 		}, | ||||
| 		{ | ||||
| 			DeptName:     "B", | ||||
| 			EmployeeName: "B1", | ||||
| 			Salary:       2000, | ||||
| 		}, | ||||
| 		{ | ||||
| 			DeptName:     "B", | ||||
| 			EmployeeName: "B2", | ||||
| 			Salary:       4000, | ||||
| 		}, | ||||
| 		{ | ||||
| 			DeptName:     "B", | ||||
| 			EmployeeName: "B3", | ||||
| 			Salary:       3000, | ||||
| 		}, | ||||
| 	} | ||||
| 	qs := dORM.QueryTable("dept_info") | ||||
| 	i, _ := qs.PrepareInsert() | ||||
| 	for _, d := range data { | ||||
| 		_, err := i.Insert(d) | ||||
| 		if err != nil { | ||||
| 			throwFail(t, err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	f := func() { | ||||
| 		var res []UnregisterModel | ||||
| 		n, err := dORM.QueryTable("dept_info").All(&res) | ||||
| 		throwFail(t, err) | ||||
| 		throwFail(t, AssertIs(n, 5)) | ||||
| 		throwFail(t, AssertIs(res[0].EmployeeName, "A1")) | ||||
| 
 | ||||
| 		type Sum struct { | ||||
| 			DeptName string | ||||
| 			Total    int | ||||
| 		} | ||||
| 		var sun []Sum | ||||
| 		qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun) | ||||
| 		throwFail(t, AssertIs(sun[0].DeptName, "A")) | ||||
| 		throwFail(t, AssertIs(sun[0].Total, 3000)) | ||||
| 
 | ||||
| 		type Max struct { | ||||
| 			DeptName string | ||||
| 			Max      float64 | ||||
| 		} | ||||
| 		var max []Max | ||||
| 		qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max) | ||||
| 		throwFail(t, AssertIs(max[1].DeptName, "B")) | ||||
| 		throwFail(t, AssertIs(max[1].Max, 4000)) | ||||
| 	} | ||||
| 	for i := 0; i < 5; i++ { | ||||
| 		f() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNullDataTypes(t *testing.T) { | ||||
| 	d := DataNull{} | ||||
| 
 | ||||
| @ -1077,6 +1147,26 @@ func TestOrderBy(t *testing.T) { | ||||
| 	num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() | ||||
| 	throwFail(t, err) | ||||
| 	throwFail(t, AssertIs(num, 1)) | ||||
| 
 | ||||
| 	num, err = qs.OrderClauses( | ||||
| 		order_clause.Clause( | ||||
| 			order_clause.Column(`profile__age`), | ||||
| 			order_clause.SortDescending(), | ||||
| 		), | ||||
| 	).Filter("user_name", "astaxie").Count() | ||||
| 	throwFail(t, err) | ||||
| 	throwFail(t, AssertIs(num, 1)) | ||||
| 
 | ||||
| 	if IsMysql { | ||||
| 		num, err = qs.OrderClauses( | ||||
| 			order_clause.Clause( | ||||
| 				order_clause.Column(`rand()`), | ||||
| 				order_clause.Raw(), | ||||
| 			), | ||||
| 		).Filter("user_name", "astaxie").Count() | ||||
| 		throwFail(t, err) | ||||
| 		throwFail(t, AssertIs(num, 1)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestAll(t *testing.T) { | ||||
| @ -1163,6 +1253,19 @@ func TestValues(t *testing.T) { | ||||
| 		throwFail(t, AssertIs(maps[2]["Profile"], nil)) | ||||
| 	} | ||||
| 
 | ||||
| 	num, err = qs.OrderClauses( | ||||
| 		order_clause.Clause( | ||||
| 			order_clause.Column("Id"), | ||||
| 			order_clause.SortAscending(), | ||||
| 		), | ||||
| 	).Values(&maps) | ||||
| 	throwFail(t, err) | ||||
| 	throwFail(t, AssertIs(num, 3)) | ||||
| 	if num == 3 { | ||||
| 		throwFail(t, AssertIs(maps[0]["UserName"], "slene")) | ||||
| 		throwFail(t, AssertIs(maps[2]["Profile"], nil)) | ||||
| 	} | ||||
| 
 | ||||
| 	num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") | ||||
| 	throwFail(t, err) | ||||
| 	throwFail(t, AssertIs(num, 3)) | ||||
| @ -2717,3 +2820,23 @@ func TestCondition(t *testing.T) { | ||||
| 	throwFail(t, AssertIs(!cycleFlag, true)) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func TestContextCanceled(t *testing.T) { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 
 | ||||
| 	user := User{UserName: "slene"} | ||||
| 
 | ||||
| 	err := dORM.ReadWithCtx(ctx, &user, "UserName") | ||||
| 	throwFail(t, err) | ||||
| 
 | ||||
| 	cancel() | ||||
| 	err = dORM.ReadWithCtx(ctx, &user, "UserName") | ||||
| 	throwFail(t, AssertIs(err, context.Canceled)) | ||||
| 
 | ||||
| 	ctx, cancel = context.WithCancel(context.Background()) | ||||
| 	cancel() | ||||
| 
 | ||||
| 	qs := dORM.QueryTable(user) | ||||
| 	_, err = qs.Filter("UserName", "slene").CountWithCtx(ctx) | ||||
| 	throwFail(t, AssertIs(err, context.Canceled)) | ||||
| } | ||||
|  | ||||
| @ -20,6 +20,7 @@ import ( | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/beego/beego/v2/client/orm/clauses/order_clause" | ||||
| 	"github.com/beego/beego/v2/core/utils" | ||||
| ) | ||||
| 
 | ||||
| @ -196,12 +197,16 @@ type DQL interface { | ||||
| 	// 	post := Post{Id: 4} | ||||
| 	// 	m2m := Ormer.QueryM2M(&post, "Tags") | ||||
| 	QueryM2M(md interface{}, name string) QueryM2Mer | ||||
| 	// NOTE: this method is deprecated, context parameter will not take effect. | ||||
| 	// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx | ||||
| 	QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer | ||||
| 
 | ||||
| 	// return a QuerySeter for table operations. | ||||
| 	// table name can be string or struct. | ||||
| 	// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), | ||||
| 	QueryTable(ptrStructOrTableName interface{}) QuerySeter | ||||
| 	// NOTE: this method is deprecated, context parameter will not take effect. | ||||
| 	// Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx | ||||
| 	QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter | ||||
| 
 | ||||
| 	DBStats() *sql.DBStats | ||||
| @ -230,6 +235,7 @@ type TxOrmer interface { | ||||
| // Inserter insert prepared statement | ||||
| type Inserter interface { | ||||
| 	Insert(interface{}) (int64, error) | ||||
| 	InsertWithCtx(context.Context, interface{}) (int64, error) | ||||
| 	Close() error | ||||
| } | ||||
| 
 | ||||
| @ -289,6 +295,28 @@ type QuerySeter interface { | ||||
| 	// for example: | ||||
| 	//	qs.OrderBy("-status") | ||||
| 	OrderBy(exprs ...string) QuerySeter | ||||
| 	// add ORDER expression by order clauses | ||||
| 	// for example: | ||||
| 	//	OrderClauses( | ||||
| 	//		order_clause.Clause( | ||||
| 	//			order.Column("Id"), | ||||
| 	//			order.SortAscending(), | ||||
| 	//		), | ||||
| 	//		order_clause.Clause( | ||||
| 	//			order.Column("status"), | ||||
| 	//			order.SortDescending(), | ||||
| 	//		), | ||||
| 	//	) | ||||
| 	//	OrderClauses(order_clause.Clause( | ||||
| 	//		order_clause.Column(`user__status`), | ||||
| 	//		order_clause.SortDescending(),//default None | ||||
| 	//	)) | ||||
| 	//	OrderClauses(order_clause.Clause( | ||||
| 	//		order_clause.Column(`random()`), | ||||
| 	//		order_clause.SortNone(),//default None | ||||
| 	//		order_clause.Raw(),//default false.if true, do not check field is valid or not | ||||
| 	//	)) | ||||
| 	OrderClauses(orders ...*order_clause.Order) QuerySeter | ||||
| 	// add FORCE INDEX expression. | ||||
| 	// for example: | ||||
| 	//	qs.ForceIndex(`idx_name1`,`idx_name2`) | ||||
| @ -327,9 +355,11 @@ type QuerySeter interface { | ||||
| 	// for example: | ||||
| 	//	num, err = qs.Filter("profile__age__gt", 28).Count() | ||||
| 	Count() (int64, error) | ||||
| 	CountWithCtx(context.Context) (int64, error) | ||||
| 	// check result empty or not after QuerySeter executed | ||||
| 	// the same as QuerySeter.Count > 0 | ||||
| 	Exist() bool | ||||
| 	ExistWithCtx(context.Context) bool | ||||
| 	// execute update with parameters | ||||
| 	// for example: | ||||
| 	//	num, err = qs.Filter("user_name", "slene").Update(Params{ | ||||
| @ -339,11 +369,13 @@ type QuerySeter interface { | ||||
| 	//		"user_name": "slene2" | ||||
| 	//	}) // user slene's  name will change to slene2 | ||||
| 	Update(values Params) (int64, error) | ||||
| 	UpdateWithCtx(ctx context.Context, values Params) (int64, error) | ||||
| 	// delete from table | ||||
| 	// for example: | ||||
| 	//	num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() | ||||
| 	// 	//delete two user  who's name is testing1 or testing2 | ||||
| 	Delete() (int64, error) | ||||
| 	DeleteWithCtx(context.Context) (int64, error) | ||||
| 	// return a insert queryer. | ||||
| 	// it can be used in times. | ||||
| 	// example: | ||||
| @ -352,18 +384,21 @@ type QuerySeter interface { | ||||
| 	//	num, err = i.Insert(&user2) // user table will add one record user2 at once | ||||
| 	//	err = i.Close() //don't forget call Close | ||||
| 	PrepareInsert() (Inserter, error) | ||||
| 	PrepareInsertWithCtx(context.Context) (Inserter, error) | ||||
| 	// query all data and map to containers. | ||||
| 	// cols means the columns when querying. | ||||
| 	// for example: | ||||
| 	//	var users []*User | ||||
| 	//	qs.All(&users) // users[0],users[1],users[2] ... | ||||
| 	All(container interface{}, cols ...string) (int64, error) | ||||
| 	AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) | ||||
| 	// query one row data and map to containers. | ||||
| 	// cols means the columns when querying. | ||||
| 	// for example: | ||||
| 	//	var user User | ||||
| 	//	qs.One(&user) //user.UserName == "slene" | ||||
| 	One(container interface{}, cols ...string) error | ||||
| 	OneWithCtx(ctx context.Context, container interface{}, cols ...string) error | ||||
| 	// query all data and map to []map[string]interface. | ||||
| 	// expres means condition expression. | ||||
| 	// it converts data to []map[column]value. | ||||
| @ -371,18 +406,21 @@ type QuerySeter interface { | ||||
| 	//	var maps []Params | ||||
| 	//	qs.Values(&maps) //maps[0]["UserName"]=="slene" | ||||
| 	Values(results *[]Params, exprs ...string) (int64, error) | ||||
| 	ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) | ||||
| 	// query all data and map to [][]interface | ||||
| 	// it converts data to [][column_index]value | ||||
| 	// for example: | ||||
| 	//	var list []ParamsList | ||||
| 	//	qs.ValuesList(&list) // list[0][1] == "slene" | ||||
| 	ValuesList(results *[]ParamsList, exprs ...string) (int64, error) | ||||
| 	ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) | ||||
| 	// query all data and map to []interface. | ||||
| 	// it's designed for one column record set, auto change to []value, not [][column]value. | ||||
| 	// for example: | ||||
| 	//	var list ParamsList | ||||
| 	//	qs.ValuesFlat(&list, "UserName") // list[0] == "slene" | ||||
| 	ValuesFlat(result *ParamsList, expr string) (int64, error) | ||||
| 	ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) | ||||
| 	// query all rows into map[string]interface with specify key and value column name. | ||||
| 	// keyCol = "name", valueCol = "value" | ||||
| 	// table data | ||||
| @ -405,6 +443,15 @@ type QuerySeter interface { | ||||
| 	// 	Found int | ||||
| 	// } | ||||
| 	RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) | ||||
| 	// aggregate func. | ||||
| 	// for example: | ||||
| 	// type result struct { | ||||
| 	//  DeptName string | ||||
| 	//	Total    int | ||||
| 	// } | ||||
| 	// var res []result | ||||
| 	//  o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res) | ||||
| 	Aggregate(s string) QuerySeter | ||||
| } | ||||
| 
 | ||||
| // QueryM2Mer model to model query struct | ||||
| @ -422,18 +469,23 @@ type QueryM2Mer interface { | ||||
| 	// insert one or more rows to m2m table | ||||
| 	// make sure the relation is defined in post model struct tag. | ||||
| 	Add(...interface{}) (int64, error) | ||||
| 	AddWithCtx(context.Context, ...interface{}) (int64, error) | ||||
| 	// remove models following the origin model relationship | ||||
| 	// only delete rows from m2m table | ||||
| 	// for example: | ||||
| 	// tag3 := &Tag{Id:5,Name: "TestTag3"} | ||||
| 	// num, err = m2m.Remove(tag3) | ||||
| 	Remove(...interface{}) (int64, error) | ||||
| 	RemoveWithCtx(context.Context, ...interface{}) (int64, error) | ||||
| 	// check model is existed in relationship of origin model | ||||
| 	Exist(interface{}) bool | ||||
| 	ExistWithCtx(context.Context, interface{}) bool | ||||
| 	// clean all models in related of origin model | ||||
| 	Clear() (int64, error) | ||||
| 	ClearWithCtx(context.Context) (int64, error) | ||||
| 	// count all related models of origin model | ||||
| 	Count() (int64, error) | ||||
| 	CountWithCtx(context.Context) (int64, error) | ||||
| } | ||||
| 
 | ||||
| // RawPreparer raw query statement | ||||
| @ -507,11 +559,11 @@ type RawSeter interface { | ||||
| type stmtQuerier interface { | ||||
| 	Close() error | ||||
| 	Exec(args ...interface{}) (sql.Result, error) | ||||
| 	// ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) | ||||
| 	ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) | ||||
| 	Query(args ...interface{}) (*sql.Rows, error) | ||||
| 	// QueryContext(args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryRow(args ...interface{}) *sql.Row | ||||
| 	// QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row | ||||
| 	QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row | ||||
| } | ||||
| 
 | ||||
| // db querier | ||||
| @ -548,28 +600,28 @@ type txEnder interface { | ||||
| 
 | ||||
| // base database struct | ||||
| type dbBaser interface { | ||||
| 	Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error | ||||
| 	ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) | ||||
| 	Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) | ||||
| 	ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) | ||||
| 	Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error | ||||
| 	ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) | ||||
| 	Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) | ||||
| 	ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) | ||||
| 
 | ||||
| 	Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) | ||||
| 	InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) | ||||
| 	InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) | ||||
| 	InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) | ||||
| 	InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) | ||||
| 	Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) | ||||
| 	InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) | ||||
| 	InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) | ||||
| 	InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) | ||||
| 	InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) | ||||
| 
 | ||||
| 	Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) | ||||
| 	UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) | ||||
| 	Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) | ||||
| 	UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) | ||||
| 
 | ||||
| 	Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) | ||||
| 	DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) | ||||
| 	Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) | ||||
| 	DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) | ||||
| 
 | ||||
| 	SupportUpdateJoin() bool | ||||
| 	OperatorSQL(string) string | ||||
| 	GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) | ||||
| 	GenerateOperatorLeftCol(*fieldInfo, string, *string) | ||||
| 	PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) | ||||
| 	PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error) | ||||
| 	MaxLimit() uint64 | ||||
| 	TableQuote() string | ||||
| 	ReplaceMarks(*string) | ||||
| @ -578,12 +630,12 @@ type dbBaser interface { | ||||
| 	TimeToDB(*time.Time, *time.Location) | ||||
| 	DbTypes() map[string]string | ||||
| 	GetTables(dbQuerier) (map[string]bool, error) | ||||
| 	GetColumns(dbQuerier, string) (map[string][3]string, error) | ||||
| 	GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error) | ||||
| 	ShowTablesQuery() string | ||||
| 	ShowColumnsQuery(string) string | ||||
| 	IndexExists(dbQuerier, string, string) bool | ||||
| 	IndexExists(context.Context, dbQuerier, string, string) bool | ||||
| 	collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) | ||||
| 	setval(dbQuerier, *modelInfo, []string) error | ||||
| 	setval(context.Context, dbQuerier, *modelInfo, []string) error | ||||
| 
 | ||||
| 	GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string | ||||
| } | ||||
|  | ||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							| @ -25,7 +25,7 @@ require ( | ||||
| 	github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect | ||||
| 	github.com/gomodule/redigo v2.0.0+incompatible | ||||
| 	github.com/google/go-cmp v0.5.0 // indirect | ||||
| 	github.com/google/uuid v1.1.1 // indirect | ||||
| 	github.com/google/uuid v1.1.1 | ||||
| 	github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 | ||||
| 	github.com/hashicorp/golang-lru v0.5.4 | ||||
| 	github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 | ||||
|  | ||||
| @ -112,13 +112,13 @@ func registerAdmin() error { | ||||
| 			HttpServer: NewHttpServerWithCfg(BConfig), | ||||
| 		} | ||||
| 		// keep in mind that all data should be html escaped to avoid XSS attack | ||||
| 		beeAdminApp.Router("/", c, SetRouterMethods(c, "get:AdminIndex")) | ||||
| 		beeAdminApp.Router("/qps", c, SetRouterMethods(c, "get:QpsIndex")) | ||||
| 		beeAdminApp.Router("/prof", c, SetRouterMethods(c, "get:ProfIndex")) | ||||
| 		beeAdminApp.Router("/healthcheck", c, SetRouterMethods(c, "get:Healthcheck")) | ||||
| 		beeAdminApp.Router("/task", c, SetRouterMethods(c, "get:TaskStatus")) | ||||
| 		beeAdminApp.Router("/listconf", c, SetRouterMethods(c, "get:ListConf")) | ||||
| 		beeAdminApp.Router("/metrics", c, SetRouterMethods(c, "get:PrometheusMetrics")) | ||||
| 		beeAdminApp.Router("/", c, "get:AdminIndex") | ||||
| 		beeAdminApp.Router("/qps", c, "get:QpsIndex") | ||||
| 		beeAdminApp.Router("/prof", c, "get:ProfIndex") | ||||
| 		beeAdminApp.Router("/healthcheck", c, "get:Healthcheck") | ||||
| 		beeAdminApp.Router("/task", c, "get:TaskStatus") | ||||
| 		beeAdminApp.Router("/listconf", c, "get:ListConf") | ||||
| 		beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics") | ||||
| 
 | ||||
| 		go beeAdminApp.Run() | ||||
| 	} | ||||
|  | ||||
| @ -29,6 +29,7 @@ import ( | ||||
| 	"encoding/base64" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/beego/beego/v2/server/web/session" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @ -195,6 +196,22 @@ func (ctx *Context) RenderMethodResult(result interface{}) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Session return session store of this context of request | ||||
| func (ctx *Context) Session() (store session.Store, err error) { | ||||
| 	if ctx.Input != nil { | ||||
| 		if ctx.Input.CruSession != nil { | ||||
| 			store = ctx.Input.CruSession | ||||
| 			return | ||||
| 		} else { | ||||
| 			err = errors.New(`no valid session store(please initialize session)`) | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		err = errors.New(`no valid input`) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Response is a wrapper for the http.ResponseWriter | ||||
| // Started:  if true, response was already written to so the other handler will not be executed | ||||
| type Response struct { | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| package context | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/beego/beego/v2/server/web/session" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| @ -45,3 +46,26 @@ func TestXsrfReset_01(t *testing.T) { | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestContext_Session(t *testing.T) { | ||||
| 	c := NewContext() | ||||
| 	if store, err := c.Session(); store != nil || err == nil { | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestContext_Session1(t *testing.T) { | ||||
| 	c := Context{} | ||||
| 	if store, err := c.Session(); store != nil || err == nil { | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestContext_Session2(t *testing.T) { | ||||
| 	c := NewContext() | ||||
| 	c.Input.CruSession = &session.MemSessionStore{} | ||||
| 
 | ||||
| 	if store, err := c.Session(); store == nil || err != nil { | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
| @ -18,6 +18,7 @@ import ( | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 
 | ||||
| @ -37,4 +38,5 @@ func TestFilterChain(t *testing.T) { | ||||
| 	ctx.Input.SetData("RouterPattern", "my-route") | ||||
| 	filter(ctx) | ||||
| 	assert.True(t, ctx.Input.GetData("invocation").(bool)) | ||||
| 	time.Sleep(1 * time.Second) | ||||
| } | ||||
|  | ||||
							
								
								
									
										35
									
								
								server/web/filter/session/filter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								server/web/filter/session/filter.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,35 @@ | ||||
| package session | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/beego/beego/v2/core/logs" | ||||
| 	"github.com/beego/beego/v2/server/web" | ||||
| 	webContext "github.com/beego/beego/v2/server/web/context" | ||||
| 	"github.com/beego/beego/v2/server/web/session" | ||||
| ) | ||||
| 
 | ||||
| //Session maintain session for web service | ||||
| //Session new a session storage and store it into webContext.Context | ||||
| func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { | ||||
| 	sessionConfig := session.NewManagerConfig(options...) | ||||
| 	sessionManager, _ := session.NewManager(string(providerType), sessionConfig) | ||||
| 	go sessionManager.GC() | ||||
| 
 | ||||
| 	return func(next web.FilterFunc) web.FilterFunc { | ||||
| 		return func(ctx *webContext.Context) { | ||||
| 			if ctx.Input.CruSession != nil { | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 			if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil { | ||||
| 				logs.Error(`init session error:%s`, err.Error()) | ||||
| 			} else { | ||||
| 				//release session at the end of request | ||||
| 				defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) | ||||
| 				ctx.Input.CruSession = sess | ||||
| 			} | ||||
| 
 | ||||
| 			next(ctx) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										86
									
								
								server/web/filter/session/filter_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								server/web/filter/session/filter_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,86 @@ | ||||
| package session | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/beego/beego/v2/server/web" | ||||
| 	webContext "github.com/beego/beego/v2/server/web/context" | ||||
| 	"github.com/beego/beego/v2/server/web/session" | ||||
| 	"github.com/google/uuid" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) { | ||||
| 	r, _ := http.NewRequest(method, path, nil) | ||||
| 	w := httptest.NewRecorder() | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 
 | ||||
| 	if w.Code != code { | ||||
| 		t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSession(t *testing.T) { | ||||
| 	storeKey := uuid.New().String() | ||||
| 	handler := web.NewControllerRegister() | ||||
| 	handler.InsertFilterChain( | ||||
| 		"*", | ||||
| 		Session( | ||||
| 			session.ProviderMemory, | ||||
| 			session.CfgCookieName(`go_session_id`), | ||||
| 			session.CfgSetCookie(true), | ||||
| 			session.CfgGcLifeTime(3600), | ||||
| 			session.CfgMaxLifeTime(3600), | ||||
| 			session.CfgSecure(false), | ||||
| 			session.CfgCookieLifeTime(3600), | ||||
| 		), | ||||
| 	) | ||||
| 	handler.InsertFilterChain( | ||||
| 		"*", | ||||
| 		func(next web.FilterFunc) web.FilterFunc { | ||||
| 			return func(ctx *webContext.Context) { | ||||
| 				if store := ctx.Input.GetData(storeKey); store == nil { | ||||
| 					t.Error(`store should not be nil`) | ||||
| 				} | ||||
| 				next(ctx) | ||||
| 			} | ||||
| 		}, | ||||
| 	) | ||||
| 	handler.Any("*", func(ctx *webContext.Context) { | ||||
| 		ctx.Output.SetStatus(200) | ||||
| 	}) | ||||
| 
 | ||||
| 	testRequest(t, handler, "/dataset1/resource1", "GET", 200) | ||||
| } | ||||
| 
 | ||||
| func TestSession1(t *testing.T) { | ||||
| 	handler := web.NewControllerRegister() | ||||
| 	handler.InsertFilterChain( | ||||
| 		"*", | ||||
| 		Session( | ||||
| 			session.ProviderMemory, | ||||
| 			session.CfgCookieName(`go_session_id`), | ||||
| 			session.CfgSetCookie(true), | ||||
| 			session.CfgGcLifeTime(3600), | ||||
| 			session.CfgMaxLifeTime(3600), | ||||
| 			session.CfgSecure(false), | ||||
| 			session.CfgCookieLifeTime(3600), | ||||
| 		), | ||||
| 	) | ||||
| 	handler.InsertFilterChain( | ||||
| 		"*", | ||||
| 		func(next web.FilterFunc) web.FilterFunc { | ||||
| 			return func(ctx *webContext.Context) { | ||||
| 				if store, err := ctx.Session(); store == nil || err != nil { | ||||
| 					t.Error(`store should not be nil`) | ||||
| 				} | ||||
| 				next(ctx) | ||||
| 			} | ||||
| 		}, | ||||
| 	) | ||||
| 	handler.Any("*", func(ctx *webContext.Context) { | ||||
| 		ctx.Output.SetStatus(200) | ||||
| 	}) | ||||
| 
 | ||||
| 	testRequest(t, handler, "/dataset1/resource1", "GET", 200) | ||||
| } | ||||
| @ -15,9 +15,12 @@ | ||||
| package web | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 
 | ||||
| @ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) { | ||||
| 	ns := NewNamespace("/chain") | ||||
| 
 | ||||
| 	ns.Get("/*", func(ctx *context.Context) { | ||||
| 		ctx.Output.Body([]byte("hello")) | ||||
| 		_ = ctx.Output.Body([]byte("hello")) | ||||
| 	}) | ||||
| 
 | ||||
| 	r, _ := http.NewRequest("GET", "/chain/user", nil) | ||||
| 	w := httptest.NewRecorder() | ||||
| 
 | ||||
| 	BeeApp.Handlers.Init() | ||||
| 	BeeApp.Handlers.ServeHTTP(w, r) | ||||
| 
 | ||||
| 	assert.Equal(t, "filter-chain", w.Header().Get("filter")) | ||||
| } | ||||
| 
 | ||||
| func TestControllerRegister_InsertFilterChain_Order(t *testing.T) { | ||||
| 	InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { | ||||
| 		return func(ctx *context.Context) { | ||||
| 			ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano())) | ||||
| 			time.Sleep(time.Millisecond * 10) | ||||
| 			next(ctx) | ||||
| 		} | ||||
| 	}) | ||||
| 
 | ||||
| 
 | ||||
| 	InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { | ||||
| 		return func(ctx *context.Context) { | ||||
| 			ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano())) | ||||
| 			time.Sleep(time.Millisecond * 10) | ||||
| 			next(ctx) | ||||
| 		} | ||||
| 	}) | ||||
| 
 | ||||
| 	r, _ := http.NewRequest("GET", "/abc", nil) | ||||
| 	w := httptest.NewRecorder() | ||||
| 
 | ||||
| 	BeeApp.Handlers.Init() | ||||
| 	BeeApp.Handlers.ServeHTTP(w, r) | ||||
| 	first := w.Header().Get("first") | ||||
| 	second := w.Header().Get("second") | ||||
| 
 | ||||
| 	ft, _ := strconv.ParseInt(first, 10, 64) | ||||
| 	st, _ := strconv.ParseInt(second, 10, 64) | ||||
| 
 | ||||
| 	assert.True(t, st > ft) | ||||
| } | ||||
|  | ||||
| @ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) { | ||||
| 
 | ||||
| 	// setup the handler | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/", &TestFlashController{}, SetRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) | ||||
| 	handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 
 | ||||
| 	// get the Set-Cookie value | ||||
|  | ||||
| @ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { | ||||
| // Router same as beego.Rourer | ||||
| // refer: https://godoc.org/github.com/beego/beego/v2#Router | ||||
| func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { | ||||
| 	n.handlers.Add(rootpath, c, SetRouterMethods(c, mappingMethods...)) | ||||
| 	n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...)) | ||||
| 	return n | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -121,24 +121,30 @@ type ControllerInfo struct { | ||||
| 	sessionOn      bool | ||||
| } | ||||
| 
 | ||||
| type ControllerOptions func(*ControllerInfo) | ||||
| type ControllerOption func(*ControllerInfo) | ||||
| 
 | ||||
| func (c *ControllerInfo) GetPattern() string { | ||||
| 	return c.pattern | ||||
| } | ||||
| 
 | ||||
| func SetRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOptions { | ||||
| func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { | ||||
| 	return func(c *ControllerInfo) { | ||||
| 		c.methods = parseMappingMethods(ctrlInterface, mappingMethod) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func SetRouterSessionOn(sessionOn bool) ControllerOptions { | ||||
| func WithRouterSessionOn(sessionOn bool) ControllerOption { | ||||
| 	return func(c *ControllerInfo) { | ||||
| 		c.sessionOn = sessionOn | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type filterChainConfig struct { | ||||
| 	pattern string | ||||
| 	chain FilterChain | ||||
| 	opts []FilterOpt | ||||
| } | ||||
| 
 | ||||
| // ControllerRegister containers registered router rules, controller handlers and filters. | ||||
| type ControllerRegister struct { | ||||
| 	routers      map[string]*Tree | ||||
| @ -151,6 +157,9 @@ type ControllerRegister struct { | ||||
| 	// the filter created by FilterChain | ||||
| 	chainRoot *FilterRouter | ||||
| 
 | ||||
| 	// keep registered chain and build it when serve http | ||||
| 	filterChains []filterChainConfig | ||||
| 
 | ||||
| 	cfg *Config | ||||
| } | ||||
| 
 | ||||
| @ -171,11 +180,23 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { | ||||
| 			}, | ||||
| 		}, | ||||
| 		cfg: cfg, | ||||
| 		filterChains: make([]filterChainConfig, 0, 4), | ||||
| 	} | ||||
| 	res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) | ||||
| 	return res | ||||
| } | ||||
| 
 | ||||
| // Init will be executed when HttpServer start running | ||||
| func (p *ControllerRegister) Init() { | ||||
| 	for i := len(p.filterChains) - 1; i >= 0 ; i --  { | ||||
| 		fc := p.filterChains[i] | ||||
| 		root := p.chainRoot | ||||
| 		filterFunc := fc.chain(root.filterFunc) | ||||
| 		p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...) | ||||
| 		p.chainRoot.next = root | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Add controller handler and pattern rules to ControllerRegister. | ||||
| // usage: | ||||
| //	default methods is the same name as method | ||||
| @ -186,7 +207,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { | ||||
| //	Add("/api/delete",&RestController{},"delete:DeleteFood") | ||||
| //	Add("/api",&RestController{},"get,post:ApiFunc" | ||||
| //	Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") | ||||
| func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOptions) { | ||||
| func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) { | ||||
| 	p.addWithMethodParams(pattern, c, nil, opts...) | ||||
| } | ||||
| 
 | ||||
| @ -239,7 +260,7 @@ func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOptions) { | ||||
| func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) { | ||||
| 	reflectVal := reflect.ValueOf(c) | ||||
| 	t := reflect.Indirect(reflectVal).Type() | ||||
| 
 | ||||
| @ -311,7 +332,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { | ||||
| 					p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) | ||||
| 				} | ||||
| 
 | ||||
| 				p.addWithMethodParams(a.Router, c, a.MethodParams, SetRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) | ||||
| 				p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| @ -513,12 +534,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter | ||||
| //     } | ||||
| // } | ||||
| func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { | ||||
| 	root := p.chainRoot | ||||
| 	filterFunc := chain(root.filterFunc) | ||||
| 	opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) | ||||
| 	p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) | ||||
| 	p.chainRoot.next = root | ||||
| 
 | ||||
| 	opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) | ||||
| 	p.filterChains = append(p.filterChains, filterChainConfig{ | ||||
| 		pattern: pattern, | ||||
| 		chain: chain, | ||||
| 		opts: opts, | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // add Filter into | ||||
| @ -583,7 +605,7 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str | ||||
| 	for _, l := range t.leaves { | ||||
| 		if c, ok := l.runObject.(*ControllerInfo); ok { | ||||
| 			if c.routerType == routerTypeBeego && | ||||
| 				strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { | ||||
| 				strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) { | ||||
| 				find := false | ||||
| 				if HTTPMETHOD[strings.ToUpper(methodName)] { | ||||
| 					if len(c.methods) == 0 { | ||||
|  | ||||
| @ -26,6 +26,14 @@ import ( | ||||
| 	"github.com/beego/beego/v2/server/web/context" | ||||
| ) | ||||
| 
 | ||||
| type PrefixTestController struct { | ||||
| 	Controller | ||||
| } | ||||
| 
 | ||||
| func (ptc *PrefixTestController) PrefixList() { | ||||
| 	ptc.Ctx.Output.Body([]byte("i am list in prefix test")) | ||||
| } | ||||
| 
 | ||||
| type TestController struct { | ||||
| 	Controller | ||||
| } | ||||
| @ -87,10 +95,24 @@ func (jc *JSONController) Get() { | ||||
| 	jc.Ctx.Output.Body([]byte("ok")) | ||||
| } | ||||
| 
 | ||||
| func TestPrefixUrlFor(t *testing.T){ | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList")) | ||||
| 
 | ||||
| 	if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` { | ||||
| 		logs.Info(a) | ||||
| 		t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list") | ||||
| 	} | ||||
| 	if a := handler.URLFor(`TestController.PrefixList`); a != `` { | ||||
| 		logs.Info(a) | ||||
| 		t.Errorf("TestController.PrefixList must equal to empty string") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUrlFor(t *testing.T) { | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) | ||||
| 	handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) | ||||
| 	handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) | ||||
| 	handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) | ||||
| 	if a := handler.URLFor("TestController.List"); a != "/api/list" { | ||||
| 		logs.Info(a) | ||||
| 		t.Errorf("TestController.List must equal to /api/list") | ||||
| @ -113,9 +135,9 @@ func TestUrlFor3(t *testing.T) { | ||||
| 
 | ||||
| func TestUrlFor2(t *testing.T) { | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) | ||||
| 	handler.Add("/v1/:username/edit", &TestController{}, SetRouterMethods(&TestController{}, "get:GetURL")) | ||||
| 	handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) | ||||
| 	handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) | ||||
| 	handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL")) | ||||
| 	handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) | ||||
| 	handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) | ||||
| 	if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { | ||||
| 		logs.Info(handler.URLFor("TestController.GetURL")) | ||||
| @ -145,7 +167,7 @@ func TestUserFunc(t *testing.T) { | ||||
| 	w := httptest.NewRecorder() | ||||
| 
 | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) | ||||
| 	handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 	if w.Body.String() != "i am list" { | ||||
| 		t.Errorf("user define func can't run") | ||||
| @ -235,7 +257,7 @@ func TestRouteOk(t *testing.T) { | ||||
| 	w := httptest.NewRecorder() | ||||
| 
 | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "get:GetParams")) | ||||
| 	handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams")) | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 	body := w.Body.String() | ||||
| 	if body != "anderson+thomas+kungfu" { | ||||
| @ -249,7 +271,7 @@ func TestManyRoute(t *testing.T) { | ||||
| 	w := httptest.NewRecorder() | ||||
| 
 | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetManyRouter")) | ||||
| 	handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter")) | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 
 | ||||
| 	body := w.Body.String() | ||||
| @ -266,7 +288,7 @@ func TestEmptyResponse(t *testing.T) { | ||||
| 	w := httptest.NewRecorder() | ||||
| 
 | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/beego-empty.html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetEmptyBody")) | ||||
| 	handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody")) | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 
 | ||||
| 	if body := w.Body.String(); body != "" { | ||||
| @ -761,8 +783,8 @@ func TestRouterSessionSet(t *testing.T) { | ||||
| 	r, _ := http.NewRequest("GET", "/user", nil) | ||||
| 	w := httptest.NewRecorder() | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), | ||||
| 		SetRouterSessionOn(false)) | ||||
| 	handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), | ||||
| 		WithRouterSessionOn(false)) | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 	if w.Header().Get("Set-Cookie") != "" { | ||||
| 		t.Errorf("TestRotuerSessionSet failed") | ||||
| @ -772,8 +794,8 @@ func TestRouterSessionSet(t *testing.T) { | ||||
| 	r, _ = http.NewRequest("GET", "/user", nil) | ||||
| 	w = httptest.NewRecorder() | ||||
| 	handler = NewControllerRegister() | ||||
| 	handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), | ||||
| 		SetRouterSessionOn(true)) | ||||
| 	handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), | ||||
| 		WithRouterSessionOn(true)) | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 	if w.Header().Get("Set-Cookie") != "" { | ||||
| 		t.Errorf("TestRotuerSessionSet failed") | ||||
| @ -787,8 +809,8 @@ func TestRouterSessionSet(t *testing.T) { | ||||
| 	r, _ = http.NewRequest("GET", "/user", nil) | ||||
| 	w = httptest.NewRecorder() | ||||
| 	handler = NewControllerRegister() | ||||
| 	handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), | ||||
| 		SetRouterSessionOn(false)) | ||||
| 	handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), | ||||
| 		WithRouterSessionOn(false)) | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 	if w.Header().Get("Set-Cookie") != "" { | ||||
| 		t.Errorf("TestRotuerSessionSet failed") | ||||
| @ -798,8 +820,8 @@ func TestRouterSessionSet(t *testing.T) { | ||||
| 	r, _ = http.NewRequest("GET", "/user", nil) | ||||
| 	w = httptest.NewRecorder() | ||||
| 	handler = NewControllerRegister() | ||||
| 	handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), | ||||
| 		SetRouterSessionOn(true)) | ||||
| 	handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), | ||||
| 		WithRouterSessionOn(true)) | ||||
| 	handler.ServeHTTP(w, r) | ||||
| 	if w.Header().Get("Set-Cookie") == "" { | ||||
| 		t.Errorf("TestRotuerSessionSet failed") | ||||
|  | ||||
| @ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | ||||
| 
 | ||||
| 	initBeforeHTTPRun() | ||||
| 
 | ||||
| 	// init... | ||||
| 	app.initAddr(addr) | ||||
| 	app.Handlers.Init() | ||||
| 
 | ||||
| 	addr = app.Cfg.Listen.HTTPAddr | ||||
| 
 | ||||
| @ -266,8 +268,12 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { | ||||
| } | ||||
| 
 | ||||
| // Router see HttpServer.Router | ||||
| func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { | ||||
| 	return BeeApp.Router(rootpath, c, opts...) | ||||
| func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { | ||||
| 	return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...)) | ||||
| } | ||||
| 
 | ||||
| func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { | ||||
| 	return BeeApp.RouterWithOpts(rootpath, c, opts...) | ||||
| } | ||||
| 
 | ||||
| // Router adds a patterned controller handler to BeeApp. | ||||
| @ -286,7 +292,11 @@ func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) * | ||||
| //  beego.Router("/api/create",&RestController{},"post:CreateFood") | ||||
| //  beego.Router("/api/update",&RestController{},"put:UpdateFood") | ||||
| //  beego.Router("/api/delete",&RestController{},"delete:DeleteFood") | ||||
| func (app *HttpServer) Router(rootPath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { | ||||
| func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer { | ||||
| 	return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...)) | ||||
| } | ||||
| 
 | ||||
| func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { | ||||
| 	app.Handlers.Add(rootPath, c, opts...) | ||||
| 	return app | ||||
| } | ||||
|  | ||||
| @ -15,21 +15,22 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func TestRedis(t *testing.T) { | ||||
| 	sessionConfig := &session.ManagerConfig{ | ||||
| 		CookieName:      "gosessionid", | ||||
| 		EnableSetCookie: true, | ||||
| 		Gclifetime:      3600, | ||||
| 		Maxlifetime:     3600, | ||||
| 		Secure:          false, | ||||
| 		CookieLifeTime:  3600, | ||||
| 	} | ||||
| 
 | ||||
| 	redisAddr := os.Getenv("REDIS_ADDR") | ||||
| 	if redisAddr == "" { | ||||
| 		redisAddr = "127.0.0.1:6379" | ||||
| 	} | ||||
| 	redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr) | ||||
| 
 | ||||
| 	sessionConfig := session.NewManagerConfig( | ||||
| 		session.CfgCookieName(`gosessionid`), | ||||
| 		session.CfgSetCookie(true), | ||||
| 		session.CfgGcLifeTime(3600), | ||||
| 		session.CfgMaxLifeTime(3600), | ||||
| 		session.CfgSecure(false), | ||||
| 		session.CfgCookieLifeTime(3600), | ||||
| 		session.CfgProviderConfig(redisConfig), | ||||
| 	) | ||||
| 
 | ||||
| 	sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr) | ||||
| 	globalSession, err := session.NewManager("redis", sessionConfig) | ||||
| 	if err != nil { | ||||
| 		t.Fatal("could not create manager:", err) | ||||
|  | ||||
| @ -13,15 +13,15 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func TestRedisSentinel(t *testing.T) { | ||||
| 	sessionConfig := &session.ManagerConfig{ | ||||
| 		CookieName:      "gosessionid", | ||||
| 		EnableSetCookie: true, | ||||
| 		Gclifetime:      3600, | ||||
| 		Maxlifetime:     3600, | ||||
| 		Secure:          false, | ||||
| 		CookieLifeTime:  3600, | ||||
| 		ProviderConfig:  "127.0.0.1:6379,100,,0,master", | ||||
| 	} | ||||
| 	sessionConfig := session.NewManagerConfig( | ||||
| 		session.CfgCookieName(`gosessionid`), | ||||
| 		session.CfgSetCookie(true), | ||||
| 		session.CfgGcLifeTime(3600), | ||||
| 		session.CfgMaxLifeTime(3600), | ||||
| 		session.CfgSecure(false), | ||||
| 		session.CfgCookieLifeTime(3600), | ||||
| 		session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"), | ||||
| 	) | ||||
| 	globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) | ||||
| 	if e != nil { | ||||
| 		t.Log(e) | ||||
|  | ||||
| @ -91,25 +91,6 @@ func GetProvider(name string) (Provider, error) { | ||||
| 	return provider, nil | ||||
| } | ||||
| 
 | ||||
| // ManagerConfig define the session config | ||||
| type ManagerConfig struct { | ||||
| 	CookieName              string `json:"cookieName"` | ||||
| 	EnableSetCookie         bool   `json:"enableSetCookie,omitempty"` | ||||
| 	Gclifetime              int64  `json:"gclifetime"` | ||||
| 	Maxlifetime             int64  `json:"maxLifetime"` | ||||
| 	DisableHTTPOnly         bool   `json:"disableHTTPOnly"` | ||||
| 	Secure                  bool   `json:"secure"` | ||||
| 	CookieLifeTime          int    `json:"cookieLifeTime"` | ||||
| 	ProviderConfig          string `json:"providerConfig"` | ||||
| 	Domain                  string `json:"domain"` | ||||
| 	SessionIDLength         int64  `json:"sessionIDLength"` | ||||
| 	EnableSidInHTTPHeader   bool   `json:"EnableSidInHTTPHeader"` | ||||
| 	SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` | ||||
| 	EnableSidInURLQuery     bool   `json:"EnableSidInURLQuery"` | ||||
| 	SessionIDPrefix         string `json:"sessionIDPrefix"` | ||||
| 	CookieSameSite          http.SameSite `json:"cookieSameSite"` | ||||
| } | ||||
| 
 | ||||
| // Manager contains Provider and its configuration. | ||||
| type Manager struct { | ||||
| 	provider Provider | ||||
|  | ||||
							
								
								
									
										143
									
								
								server/web/session/session_config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								server/web/session/session_config.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,143 @@ | ||||
| package session | ||||
| 
 | ||||
| import "net/http" | ||||
| 
 | ||||
| // ManagerConfig define the session config | ||||
| type ManagerConfig struct { | ||||
| 	CookieName              string        `json:"cookieName"` | ||||
| 	EnableSetCookie         bool          `json:"enableSetCookie,omitempty"` | ||||
| 	Gclifetime              int64         `json:"gclifetime"` | ||||
| 	Maxlifetime             int64         `json:"maxLifetime"` | ||||
| 	DisableHTTPOnly         bool          `json:"disableHTTPOnly"` | ||||
| 	Secure                  bool          `json:"secure"` | ||||
| 	CookieLifeTime          int           `json:"cookieLifeTime"` | ||||
| 	ProviderConfig          string        `json:"providerConfig"` | ||||
| 	Domain                  string        `json:"domain"` | ||||
| 	SessionIDLength         int64         `json:"sessionIDLength"` | ||||
| 	EnableSidInHTTPHeader   bool          `json:"EnableSidInHTTPHeader"` | ||||
| 	SessionNameInHTTPHeader string        `json:"SessionNameInHTTPHeader"` | ||||
| 	EnableSidInURLQuery     bool          `json:"EnableSidInURLQuery"` | ||||
| 	SessionIDPrefix         string        `json:"sessionIDPrefix"` | ||||
| 	CookieSameSite          http.SameSite `json:"cookieSameSite"` | ||||
| } | ||||
| 
 | ||||
| func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) { | ||||
| 	for _, opt := range opts { | ||||
| 		opt(c) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ManagerConfigOpt func(config *ManagerConfig) | ||||
| 
 | ||||
| func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig { | ||||
| 	config := &ManagerConfig{} | ||||
| 	for _, opt := range opts { | ||||
| 		opt(config) | ||||
| 	} | ||||
| 	return config | ||||
| } | ||||
| 
 | ||||
| // CfgCookieName set key of session id | ||||
| func CfgCookieName(cookieName string) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.CookieName = cookieName | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // CfgCookieName set len of session id | ||||
| func CfgSessionIdLength(len int64) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.SessionIDLength = len | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // CfgSessionIdPrefix set prefix of session id | ||||
| func CfgSessionIdPrefix(prefix string) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.SessionIDPrefix = prefix | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgSetCookie whether set `Set-Cookie` header in HTTP response | ||||
| func CfgSetCookie(enable bool) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.EnableSetCookie = enable | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgGcLifeTime set session gc lift time | ||||
| func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.Gclifetime = lifeTime | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgMaxLifeTime set session lift time | ||||
| func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.Maxlifetime = lifeTime | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgGcLifeTime set session lift time | ||||
| func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.CookieLifeTime = lifeTime | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgProviderConfig configure session provider | ||||
| func CfgProviderConfig(providerConfig string) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.ProviderConfig = providerConfig | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgDomain set cookie domain | ||||
| func CfgDomain(domain string) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.Domain = domain | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgSessionIdInHTTPHeader enable session id in http header | ||||
| func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.EnableSidInHTTPHeader = enable | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgSetSessionNameInHTTPHeader set key of session id in http header | ||||
| func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.SessionNameInHTTPHeader = name | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //EnableSidInURLQuery enable session id in query string | ||||
| func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.EnableSidInURLQuery = enable | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //DisableHTTPOnly set HTTPOnly for http.Cookie | ||||
| func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.DisableHTTPOnly = !HTTPOnly | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgSecure set Secure for http.Cookie | ||||
| func CfgSecure(Enable bool) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.Secure = Enable | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| //CfgSameSite set http.SameSite | ||||
| func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt { | ||||
| 	return func(config *ManagerConfig) { | ||||
| 		config.CookieSameSite = sameSite | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										222
									
								
								server/web/session/session_config_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								server/web/session/session_config_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,222 @@ | ||||
| package session | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestCfgCookieLifeTime(t *testing.T) { | ||||
| 	value := 8754 | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgCookieLifeTime(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.CookieLifeTime != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgDomain(t *testing.T) { | ||||
| 	value := `http://domain.com` | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgDomain(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.Domain != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSameSite(t *testing.T) { | ||||
| 	value := http.SameSiteLaxMode | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSameSite(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.CookieSameSite != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSecure(t *testing.T) { | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSecure(true), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.Secure != true { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSecure1(t *testing.T) { | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSecure(false), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.Secure != false { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSessionIdPrefix(t *testing.T) { | ||||
| 	value := `sodiausodkljalsd` | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSessionIdPrefix(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.SessionIDPrefix != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSetSessionNameInHTTPHeader(t *testing.T) { | ||||
| 	value := `sodiausodkljalsd` | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSetSessionNameInHTTPHeader(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.SessionNameInHTTPHeader != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgCookieName(t *testing.T) { | ||||
| 	value := `sodiausodkljalsd` | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgCookieName(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.CookieName != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgEnableSidInURLQuery(t *testing.T) { | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgEnableSidInURLQuery(true), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.EnableSidInURLQuery != true { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgGcLifeTime(t *testing.T) { | ||||
| 	value := int64(5454) | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgGcLifeTime(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.Gclifetime != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgHTTPOnly(t *testing.T) { | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgHTTPOnly(true), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.DisableHTTPOnly != false { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgHTTPOnly2(t *testing.T) { | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgHTTPOnly(false), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.DisableHTTPOnly != true { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgMaxLifeTime(t *testing.T) { | ||||
| 	value := int64(5454) | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgMaxLifeTime(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.Maxlifetime != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgProviderConfig(t *testing.T) { | ||||
| 	value := `asodiuasldkj12i39809as` | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgProviderConfig(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.ProviderConfig != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSessionIdInHTTPHeader(t *testing.T) { | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSessionIdInHTTPHeader(true), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.EnableSidInHTTPHeader != true { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSessionIdInHTTPHeader1(t *testing.T) { | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSessionIdInHTTPHeader(false), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.EnableSidInHTTPHeader != false { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSessionIdLength(t *testing.T) { | ||||
| 	value := int64(100) | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSessionIdLength(value), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.SessionIDLength != value { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSetCookie(t *testing.T) { | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSetCookie(true), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.EnableSetCookie != true { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCfgSetCookie1(t *testing.T) { | ||||
| 	c := NewManagerConfig( | ||||
| 		CfgSetCookie(false), | ||||
| 	) | ||||
| 
 | ||||
| 	if c.EnableSetCookie != false { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestNewManagerConfig(t *testing.T) { | ||||
| 	c := NewManagerConfig() | ||||
| 	if c == nil { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestManagerConfig_Opts(t *testing.T) { | ||||
| 	c := NewManagerConfig() | ||||
| 	c.Opts(CfgSetCookie(true)) | ||||
| 
 | ||||
| 	if c.EnableSetCookie != true { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										18
									
								
								server/web/session/session_provider_type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								server/web/session/session_provider_type.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,18 @@ | ||||
| package session | ||||
| 
 | ||||
| type ProviderType string | ||||
| 
 | ||||
| const ( | ||||
| 	ProviderCookie        ProviderType = `cookie` | ||||
| 	ProviderFile          ProviderType = `file` | ||||
| 	ProviderMemory        ProviderType = `memory` | ||||
| 	ProviderCouchbase     ProviderType = `couchbase` | ||||
| 	ProviderLedis         ProviderType = `ledis` | ||||
| 	ProviderMemcache      ProviderType = `memcache` | ||||
| 	ProviderMysql         ProviderType = `mysql` | ||||
| 	ProviderPostgresql    ProviderType = `postgresql` | ||||
| 	ProviderRedis         ProviderType = `redis` | ||||
| 	ProviderRedisCluster  ProviderType = `redis_cluster` | ||||
| 	ProviderRedisSentinel ProviderType = `redis_sentinel` | ||||
| 	ProviderSsdb          ProviderType = `ssdb` | ||||
| ) | ||||
| @ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { | ||||
| 	var method = "GET" | ||||
| 
 | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) | ||||
| 	handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) | ||||
| 	handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) | ||||
| 	handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) | ||||
| 	handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) | ||||
| 	handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) | ||||
| 
 | ||||
| 	// Test original root | ||||
| 	testHelperFnContentCheck(t, handler, "Test original root", | ||||
| @ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { | ||||
| 
 | ||||
| 	// Replace the root path TestPreUnregController action with the action from | ||||
| 	// TestPostUnregController | ||||
| 	handler.Add("/", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) | ||||
| 	handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) | ||||
| 
 | ||||
| 	// Test replacement root (expect change) | ||||
| 	testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) | ||||
| @ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { | ||||
| 	var method = "GET" | ||||
| 
 | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) | ||||
| 	handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) | ||||
| 	handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) | ||||
| 	handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) | ||||
| 	handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) | ||||
| 	handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) | ||||
| 
 | ||||
| 	// Test original root | ||||
| 	testHelperFnContentCheck(t, handler, | ||||
| @ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { | ||||
| 
 | ||||
| 	// Replace the "level1" path TestPreUnregController action with the action from | ||||
| 	// TestPostUnregController | ||||
| 	handler.Add("/level1", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) | ||||
| 	handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) | ||||
| 
 | ||||
| 	// Test replacement root (expect no change from the original) | ||||
| 	testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) | ||||
| @ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { | ||||
| 	var method = "GET" | ||||
| 
 | ||||
| 	handler := NewControllerRegister() | ||||
| 	handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) | ||||
| 	handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) | ||||
| 	handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) | ||||
| 	handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) | ||||
| 	handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) | ||||
| 	handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) | ||||
| 
 | ||||
| 	// Test original root | ||||
| 	testHelperFnContentCheck(t, handler, | ||||
| @ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { | ||||
| 
 | ||||
| 	// Replace the "/level1/level2" path TestPreUnregController action with the action from | ||||
| 	// TestPostUnregController | ||||
| 	handler.Add("/level1/level2", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) | ||||
| 	handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) | ||||
| 
 | ||||
| 	// Test replacement root (expect no change from the original) | ||||
| 	testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) | ||||
|  | ||||
							
								
								
									
										7
									
								
								sonar-project.properties
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								sonar-project.properties
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | ||||
| sonar.organization=beego | ||||
| sonar.projectKey=beego_beego | ||||
| 
 | ||||
| # relative paths to source directories. More details and properties are described | ||||
| # in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/ | ||||
| sonar.sources=. | ||||
| sonar.exclusions=**/*_test.go | ||||
| @ -109,6 +109,23 @@ func TestTask_Run(t *testing.T) { | ||||
| 	assert.Equal(t, "Hello, world! 101", l[1].errinfo) | ||||
| } | ||||
| 
 | ||||
| func TestCrudTask(t *testing.T) { | ||||
| 	m := newTaskManager() | ||||
| 	m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error { | ||||
| 		return nil | ||||
| 	})) | ||||
| 
 | ||||
| 	m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error { | ||||
| 		return nil | ||||
| 	})) | ||||
| 
 | ||||
| 	m.DeleteTask("my-task1") | ||||
| 	assert.Equal(t, 1, len(m.adminTaskList)) | ||||
| 
 | ||||
| 	m.ClearTask() | ||||
| 	assert.Equal(t, 0, len(m.adminTaskList)) | ||||
| } | ||||
| 
 | ||||
| func wait(wg *sync.WaitGroup) chan bool { | ||||
| 	ch := make(chan bool) | ||||
| 	go func() { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user