From 9bd2934e42b930157ec085108a90e0294264bab0 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 25 Oct 2020 21:07:03 +0800 Subject: [PATCH 01/41] add order clause --- client/orm/db_tables.go | 13 ++++---- client/orm/orm.go | 3 +- client/orm/orm_queryset.go | 11 ++++--- client/orm/structs/order_phrase.go | 48 ++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 11 deletions(-) create mode 100644 client/orm/structs/order_phrase.go diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 5fd472d1..c67af052 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -16,6 +16,7 @@ package orm import ( "fmt" + "github.com/astaxie/beego/client/orm/structs" "strings" "time" ) @@ -421,7 +422,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { } // generate order sql. -func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { +func (t *dbTables) getOrderSQL(orders []*structs.OrderClause) (orderSQL string) { if len(orders) == 0 { return } @@ -430,16 +431,16 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { orderSqls := make([]string, 0, len(orders)) for _, order := range orders { + column := order.Column asc := "ASC" - if order[0] == '-' { + if order.Sort == structs.DESCENDING { asc = "DESC" - order = order[1:] } - exprs := strings.Split(order, ExprSep) + clause := strings.Split(column, ExprSep) - index, _, fi, suc := t.parseExprs(t.mi, exprs) + index, _, fi, suc := t.parseExprs(t.mi, clause) if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) } orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) diff --git a/client/orm/orm.go b/client/orm/orm.go index a83faeb2..52a99572 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -58,6 +58,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/astaxie/beego/client/orm/structs" "os" "reflect" "time" @@ -351,7 +352,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s qs.relDepth = relDepth if len(order) > 0 { - qs.orders = []string{order} + qs.orders = structs.ParseOrderClause(order) } find := ind.FieldByIndex(fi.fieldIndex) diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index ed223e24..480dc561 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -17,8 +17,8 @@ package orm import ( "context" "fmt" - "github.com/astaxie/beego/client/orm/hints" + "github.com/astaxie/beego/client/orm/structs" ) type colValue struct { @@ -71,7 +71,7 @@ type querySet struct { limit int64 offset int64 groups []string - orders []string + orders []*structs.OrderClause distinct bool forUpdate bool useIndex int @@ -139,8 +139,11 @@ func (o querySet) GroupBy(exprs ...string) QuerySeter { // add ORDER expression. // "column" means ASC, "-column" means DESC. -func (o querySet) OrderBy(exprs ...string) QuerySeter { - o.orders = exprs +func (o querySet) OrderBy(expressions ...string) QuerySeter { + if len(expressions) <= 0 { + return &o + } + o.orders = structs.ParseOrderClause(expressions...) return &o } diff --git a/client/orm/structs/order_phrase.go b/client/orm/structs/order_phrase.go new file mode 100644 index 00000000..30faf429 --- /dev/null +++ b/client/orm/structs/order_phrase.go @@ -0,0 +1,48 @@ +package structs + +import "fmt" + +type Sort int8 + +const ( + ASCENDING Sort = 1 + DESCENDING Sort = 2 +) + +type OrderClause struct { + Column string + Sort Sort +} + +var _ fmt.Stringer = new(OrderClause) + +func (o *OrderClause) String() string { + sort := `` + if o.Sort == ASCENDING { + sort = `ASC` + } else if o.Sort == DESCENDING { + sort = `DESC` + } else { + return fmt.Sprintf("%s", o.Column) + } + return fmt.Sprintf("%s %s", o.Column, sort) +} + +func ParseOrderClause(expressions ...string) []*OrderClause { + var orders []*OrderClause + for _, expression := range expressions { + sort := ASCENDING + column := expression + if expression[0] == '-' { + sort = DESCENDING + column = expression[1:] + } + + orders = append(orders, &OrderClause{ + Column: column, + Sort: sort, + }) + } + + return orders +} From 544c621017dd82aa184bced157cb82fe201a0f4a Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 25 Oct 2020 23:11:05 +0800 Subject: [PATCH 02/41] move dir & delete some useless function --- client/orm/db_tables.go | 17 +++++----- client/orm/orm.go | 4 +-- client/orm/orm_conds.go | 3 +- client/orm/orm_queryset.go | 6 ++-- client/orm/structs/clauses/const.go | 5 +++ client/orm/structs/clauses/order.go | 40 ++++++++++++++++++++++++ client/orm/structs/order_phrase.go | 48 ----------------------------- 7 files changed, 62 insertions(+), 61 deletions(-) create mode 100644 client/orm/structs/clauses/const.go create mode 100644 client/orm/structs/clauses/order.go delete mode 100644 client/orm/structs/order_phrase.go diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index c67af052..7c6e6fff 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -16,7 +16,7 @@ package orm import ( "fmt" - "github.com/astaxie/beego/client/orm/structs" + "github.com/astaxie/beego/client/orm/structs/clauses" "strings" "time" ) @@ -422,7 +422,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { } // generate order sql. -func (t *dbTables) getOrderSQL(orders []*structs.OrderClause) (orderSQL string) { +func (t *dbTables) getOrderSQL(orders []*clauses.Order) (orderSQL string) { if len(orders) == 0 { return } @@ -431,10 +431,13 @@ func (t *dbTables) getOrderSQL(orders []*structs.OrderClause) (orderSQL string) orderSqls := make([]string, 0, len(orders)) for _, order := range orders { - column := order.Column - asc := "ASC" - if order.Sort == structs.DESCENDING { - asc = "DESC" + column := order.GetColumn() + var sort string + switch order.GetSort() { + case clauses.ASCENDING: + sort = "ASC" + case clauses.DESCENDING: + sort = "DESC" } clause := strings.Split(column, ExprSep) @@ -443,7 +446,7 @@ func (t *dbTables) getOrderSQL(orders []*structs.OrderClause) (orderSQL string) panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) } - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, sort)) } orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) diff --git a/client/orm/orm.go b/client/orm/orm.go index 52a99572..d1a66b88 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -58,7 +58,7 @@ import ( "database/sql" "errors" "fmt" - "github.com/astaxie/beego/client/orm/structs" + "github.com/astaxie/beego/client/orm/structs/clauses" "os" "reflect" "time" @@ -352,7 +352,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s qs.relDepth = relDepth if len(order) > 0 { - qs.orders = structs.ParseOrderClause(order) + qs.orders = clauses.ParseOrder(order) } find := ind.FieldByIndex(fi.fieldIndex) diff --git a/client/orm/orm_conds.go b/client/orm/orm_conds.go index f3fd66f0..7a798c6f 100644 --- a/client/orm/orm_conds.go +++ b/client/orm/orm_conds.go @@ -16,12 +16,13 @@ package orm import ( "fmt" + "github.com/astaxie/beego/client/orm/structs/clauses" "strings" ) // ExprSep define the expression separation const ( - ExprSep = "__" + ExprSep = clauses.ExprSep ) type condValue struct { diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 480dc561..5ca10421 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -18,7 +18,7 @@ import ( "context" "fmt" "github.com/astaxie/beego/client/orm/hints" - "github.com/astaxie/beego/client/orm/structs" + "github.com/astaxie/beego/client/orm/structs/clauses" ) type colValue struct { @@ -71,7 +71,7 @@ type querySet struct { limit int64 offset int64 groups []string - orders []*structs.OrderClause + orders []*clauses.Order distinct bool forUpdate bool useIndex int @@ -143,7 +143,7 @@ func (o querySet) OrderBy(expressions ...string) QuerySeter { if len(expressions) <= 0 { return &o } - o.orders = structs.ParseOrderClause(expressions...) + o.orders = clauses.ParseOrder(expressions...) return &o } diff --git a/client/orm/structs/clauses/const.go b/client/orm/structs/clauses/const.go new file mode 100644 index 00000000..a0574a64 --- /dev/null +++ b/client/orm/structs/clauses/const.go @@ -0,0 +1,5 @@ +package clauses + +const ( + ExprSep = "__" +) diff --git a/client/orm/structs/clauses/order.go b/client/orm/structs/clauses/order.go new file mode 100644 index 00000000..050bcf1b --- /dev/null +++ b/client/orm/structs/clauses/order.go @@ -0,0 +1,40 @@ +package clauses + +type Sort int8 + +const ( + ASCENDING Sort = 1 + DESCENDING Sort = 2 +) + +type Order struct { + column string + sort Sort +} + +func (o *Order) GetColumn() string { + return o.column +} + +func (o *Order) GetSort() Sort { + return o.sort +} + +func ParseOrder(expressions ...string) []*Order { + var orders []*Order + for _, expression := range expressions { + sort := ASCENDING + column := expression + if expression[0] == '-' { + sort = DESCENDING + column = expression[1:] + } + + orders = append(orders, &Order{ + column: column, + sort: sort, + }) + } + + return orders +} diff --git a/client/orm/structs/order_phrase.go b/client/orm/structs/order_phrase.go deleted file mode 100644 index 30faf429..00000000 --- a/client/orm/structs/order_phrase.go +++ /dev/null @@ -1,48 +0,0 @@ -package structs - -import "fmt" - -type Sort int8 - -const ( - ASCENDING Sort = 1 - DESCENDING Sort = 2 -) - -type OrderClause struct { - Column string - Sort Sort -} - -var _ fmt.Stringer = new(OrderClause) - -func (o *OrderClause) String() string { - sort := `` - if o.Sort == ASCENDING { - sort = `ASC` - } else if o.Sort == DESCENDING { - sort = `DESC` - } else { - return fmt.Sprintf("%s", o.Column) - } - return fmt.Sprintf("%s %s", o.Column, sort) -} - -func ParseOrderClause(expressions ...string) []*OrderClause { - var orders []*OrderClause - for _, expression := range expressions { - sort := ASCENDING - column := expression - if expression[0] == '-' { - sort = DESCENDING - column = expression[1:] - } - - orders = append(orders, &OrderClause{ - Column: column, - Sort: sort, - }) - } - - return orders -} From b1d5ba8ece68900ce448cf8cd79fc42f76ccdda0 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 25 Oct 2020 23:16:41 +0800 Subject: [PATCH 03/41] support sort none --- client/orm/db_tables.go | 4 ++-- client/orm/structs/clauses/order.go | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 7c6e6fff..5c330f4e 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -434,9 +434,9 @@ func (t *dbTables) getOrderSQL(orders []*clauses.Order) (orderSQL string) { column := order.GetColumn() var sort string switch order.GetSort() { - case clauses.ASCENDING: + case clauses.SortAscending: sort = "ASC" - case clauses.DESCENDING: + case clauses.SortDescending: sort = "DESC" } clause := strings.Split(column, ExprSep) diff --git a/client/orm/structs/clauses/order.go b/client/orm/structs/clauses/order.go index 050bcf1b..66e8fe73 100644 --- a/client/orm/structs/clauses/order.go +++ b/client/orm/structs/clauses/order.go @@ -3,8 +3,9 @@ package clauses type Sort int8 const ( - ASCENDING Sort = 1 - DESCENDING Sort = 2 + SortNone Sort = 0 + SortAscending Sort = 1 + SortDescending Sort = 2 ) type Order struct { @@ -23,10 +24,10 @@ func (o *Order) GetSort() Sort { func ParseOrder(expressions ...string) []*Order { var orders []*Order for _, expression := range expressions { - sort := ASCENDING + sort := SortAscending column := expression if expression[0] == '-' { - sort = DESCENDING + sort = SortDescending column = expression[1:] } From 56fa213a6edf0c9ae0421ebf59683cbe6c41c8ca Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 25 Oct 2020 23:49:21 +0800 Subject: [PATCH 04/41] support OrderClauses for QuerySeter --- client/orm/orm_queryset.go | 9 +++++++ client/orm/structs/clauses/order.go | 34 ++++++++++++++++++++++++ client/orm/structs/clauses/order_test.go | 29 ++++++++++++++++++++ client/orm/types.go | 19 +++++++++++++ 4 files changed, 91 insertions(+) create mode 100644 client/orm/structs/clauses/order_test.go diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 5ca10421..e6493d8f 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -147,6 +147,15 @@ func (o querySet) OrderBy(expressions ...string) QuerySeter { return &o } +// add ORDER expression. +func (o querySet) OrderClauses(orders ...*clauses.Order) QuerySeter { + if len(orders) <= 0 { + return &o + } + o.orders = orders + return &o +} + // add DISTINCT to SELECT func (o querySet) Distinct() QuerySeter { o.distinct = true diff --git a/client/orm/structs/clauses/order.go b/client/orm/structs/clauses/order.go index 66e8fe73..64e4ebbd 100644 --- a/client/orm/structs/clauses/order.go +++ b/client/orm/structs/clauses/order.go @@ -8,9 +8,21 @@ const ( SortDescending Sort = 2 ) +type OrderOption func(order *Order) + type Order struct { column string sort Sort + isRaw bool +} + +func OrderClause(options ...OrderOption) *Order { + o := &Order{} + for _, option := range options { + option(o) + } + + return o } func (o *Order) GetColumn() string { @@ -21,6 +33,10 @@ func (o *Order) GetSort() Sort { return o.sort } +func (o *Order) IsRaw() bool { + return o.isRaw +} + func ParseOrder(expressions ...string) []*Order { var orders []*Order for _, expression := range expressions { @@ -39,3 +55,21 @@ func ParseOrder(expressions ...string) []*Order { return orders } + +func OrderColumn(column string) OrderOption { + return func(order *Order) { + order.column = column + } +} + +func OrderSort(sort Sort) OrderOption { + return func(order *Order) { + order.sort = sort + } +} + +func OrderRaw(isRaw bool) OrderOption { + return func(order *Order) { + order.isRaw = isRaw + } +} diff --git a/client/orm/structs/clauses/order_test.go b/client/orm/structs/clauses/order_test.go new file mode 100644 index 00000000..2f44975d --- /dev/null +++ b/client/orm/structs/clauses/order_test.go @@ -0,0 +1,29 @@ +package clauses + +import "testing" + +func TestOrderClause(t *testing.T) { + var ( + column = `a` + sort = SortDescending + raw = true + ) + + o := OrderClause( + OrderColumn(column), + OrderSort(sort), + OrderRaw(raw), + ) + + if o.GetColumn() != column { + t.Error() + } + + if o.GetSort() != sort { + t.Error() + } + + if o.IsRaw() != raw { + t.Error() + } +} diff --git a/client/orm/types.go b/client/orm/types.go index 34c61d51..f1d6def7 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -17,6 +17,7 @@ package orm import ( "context" "database/sql" + "github.com/astaxie/beego/client/orm/structs/clauses" "reflect" "time" @@ -289,6 +290,24 @@ type QuerySeter interface { // for example: // qs.OrderBy("-status") OrderBy(exprs ...string) QuerySeter + // add ORDER expression by order clauses + // for example: + // OrderClauses(clauses.OrderClause( + // clauses.OrderColumn(`status`), + // clauses.OrderSort(clauses.SortAscending), + // clauses.OrderRaw(false), + // )) + // OrderClauses(clauses.OrderClause( + // clauses.OrderColumn(`user__status`), + // clauses.OrderSort(clauses.SortAscending), + // clauses.OrderRaw(false), + // )) + // OrderClauses(clauses.OrderClause( + // clauses.OrderColumn(`random()`), + // clauses.OrderSort(clauses.SortNone), + // clauses.OrderRaw(true), + // )) + OrderClauses(orders ...*clauses.Order) QuerySeter // add FORCE INDEX expression. // for example: // qs.ForceIndex(`idx_name1`,`idx_name2`) From 7d4e88c1b951db2ac778feaf02ac530ff171c212 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Mon, 26 Oct 2020 00:01:15 +0800 Subject: [PATCH 05/41] support raw order --- client/orm/db_tables.go | 29 ++++++++++++++++------------- client/orm/structs/clauses/const.go | 1 + client/orm/structs/clauses/order.go | 20 +++++++++++++++++--- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 5c330f4e..34fa0f07 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -432,21 +432,24 @@ func (t *dbTables) getOrderSQL(orders []*clauses.Order) (orderSQL string) { orderSqls := make([]string, 0, len(orders)) for _, order := range orders { column := order.GetColumn() - var sort string - switch order.GetSort() { - case clauses.SortAscending: - sort = "ASC" - case clauses.SortDescending: - sort = "DESC" - } - clause := strings.Split(column, ExprSep) + clause := strings.Split(column, clauses.ExprDot) - index, _, fi, suc := t.parseExprs(t.mi, clause) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) - } + if order.IsRaw() { + if len(clause) == 2 { + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString())) + } else if len(clause) == 1 { + orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString())) + } else { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } + } else { + index, _, fi, suc := t.parseExprs(t.mi, clause) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, sort)) + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString())) + } } orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) diff --git a/client/orm/structs/clauses/const.go b/client/orm/structs/clauses/const.go index a0574a64..747d3fd7 100644 --- a/client/orm/structs/clauses/const.go +++ b/client/orm/structs/clauses/const.go @@ -2,4 +2,5 @@ package clauses const ( ExprSep = "__" + ExprDot = "." ) diff --git a/client/orm/structs/clauses/order.go b/client/orm/structs/clauses/order.go index 64e4ebbd..2f7ff6ba 100644 --- a/client/orm/structs/clauses/order.go +++ b/client/orm/structs/clauses/order.go @@ -1,5 +1,7 @@ package clauses +import "strings" + type Sort int8 const ( @@ -33,6 +35,18 @@ func (o *Order) GetSort() Sort { return o.sort } +func (o *Order) SortString() string { + switch o.GetSort() { + case SortAscending: + return "ASC" + case SortDescending: + return "DESC" + } + + return `` +} + + func (o *Order) IsRaw() bool { return o.isRaw } @@ -41,10 +55,10 @@ func ParseOrder(expressions ...string) []*Order { var orders []*Order for _, expression := range expressions { sort := SortAscending - column := expression - if expression[0] == '-' { + column := strings.ReplaceAll(expression, ExprSep, ExprDot) + if column[0] == '-' { sort = SortDescending - column = expression[1:] + column = column[1:] } orders = append(orders, &Order{ From d24388ad819916bf1bea2920f0bd8577c4cb3ea9 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 26 Oct 2020 18:57:01 +0800 Subject: [PATCH 06/41] opt Methods --- client/orm/structs/clauses/order.go | 19 +++++-- client/orm/structs/clauses/order_test.go | 67 +++++++++++++++++++++--- client/orm/types.go | 10 ++-- 3 files changed, 78 insertions(+), 18 deletions(-) diff --git a/client/orm/structs/clauses/order.go b/client/orm/structs/clauses/order.go index 2f7ff6ba..e1110eec 100644 --- a/client/orm/structs/clauses/order.go +++ b/client/orm/structs/clauses/order.go @@ -46,7 +46,6 @@ func (o *Order) SortString() string { return `` } - func (o *Order) IsRaw() bool { return o.isRaw } @@ -76,14 +75,26 @@ func OrderColumn(column string) OrderOption { } } -func OrderSort(sort Sort) OrderOption { +func sort(sort Sort) OrderOption { return func(order *Order) { order.sort = sort } } -func OrderRaw(isRaw bool) OrderOption { +func OrderSortAscending() OrderOption { + return sort(SortAscending) +} + +func OrderSortDescending() OrderOption { + return sort(SortDescending) +} + +func OrderSortNone() OrderOption { + return sort(SortNone) +} + +func OrderRaw() OrderOption { return func(order *Order) { - order.isRaw = isRaw + order.isRaw = true } } diff --git a/client/orm/structs/clauses/order_test.go b/client/orm/structs/clauses/order_test.go index 2f44975d..dbb37a0a 100644 --- a/client/orm/structs/clauses/order_test.go +++ b/client/orm/structs/clauses/order_test.go @@ -5,25 +5,76 @@ import "testing" func TestOrderClause(t *testing.T) { var ( column = `a` - sort = SortDescending - raw = true ) o := OrderClause( OrderColumn(column), - OrderSort(sort), - OrderRaw(raw), ) if o.GetColumn() != column { t.Error() } +} - if o.GetSort() != sort { - t.Error() - } +func TestOrderSortAscending(t *testing.T) { + o := OrderClause( + OrderSortAscending(), + ) - if o.IsRaw() != raw { + if o.GetSort() != SortAscending { t.Error() } } + +func TestOrderSortDescending(t *testing.T) { + o := OrderClause( + OrderSortDescending(), + ) + + if o.GetSort() != SortDescending { + t.Error() + } +} + +func TestOrderSortNone(t *testing.T) { + o1 := OrderClause( + OrderSortNone(), + ) + + if o1.GetSort() != SortNone { + t.Error() + } + + o2 := OrderClause() + + if o2.GetSort() != SortNone { + t.Error() + } +} + +func TestOrderRaw(t *testing.T) { + o1 := OrderClause() + + if o1.IsRaw() { + t.Error() + } + + o2 := OrderClause( + OrderRaw(), + ) + + if !o2.IsRaw() { + t.Error() + } +} + +func TestOrderColumn(t *testing.T) { + o1 := OrderClause( + OrderColumn(`aaa`), + ) + + if o1.GetColumn() != `aaa` { + t.Error() + } +} + diff --git a/client/orm/types.go b/client/orm/types.go index f1d6def7..14a34025 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -294,18 +294,16 @@ type QuerySeter interface { // for example: // OrderClauses(clauses.OrderClause( // clauses.OrderColumn(`status`), - // clauses.OrderSort(clauses.SortAscending), - // clauses.OrderRaw(false), + // clauses.OrderSortAscending(),//default None // )) // OrderClauses(clauses.OrderClause( // clauses.OrderColumn(`user__status`), - // clauses.OrderSort(clauses.SortAscending), - // clauses.OrderRaw(false), + // clauses.OrderSortDescending(),//default None // )) // OrderClauses(clauses.OrderClause( // clauses.OrderColumn(`random()`), - // clauses.OrderSort(clauses.SortNone), - // clauses.OrderRaw(true), + // clauses.OrderSortNone(),//default None + // clauses.OrderRaw(),//default false.if true, do not check field is valid or not // )) OrderClauses(orders ...*clauses.Order) QuerySeter // add FORCE INDEX expression. From beedfa1b53d1847f6cb1f31b0cafac184fbee156 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Mon, 26 Oct 2020 19:52:21 +0800 Subject: [PATCH 07/41] add UT --- client/orm/orm_test.go | 34 +++++++++++++++++++++++++++++ client/orm/structs/clauses/order.go | 2 +- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index 565f6c60..8f057df1 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -21,6 +21,7 @@ import ( "context" "database/sql" "fmt" + "github.com/astaxie/beego/client/orm/structs/clauses" "io/ioutil" "math" "os" @@ -1077,6 +1078,26 @@ func TestOrderBy(t *testing.T) { num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderClauses( + clauses.OrderClause( + clauses.OrderColumn(`profile__age`), + clauses.OrderSortDescending(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsMysql { + num, err = qs.OrderClauses( + clauses.OrderClause( + clauses.OrderColumn(`rand()`), + clauses.OrderRaw(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + } } func TestAll(t *testing.T) { @@ -1163,6 +1184,19 @@ func TestValues(t *testing.T) { throwFail(t, AssertIs(maps[2]["Profile"], nil)) } + num, err = qs.OrderClauses( + clauses.OrderClause( + clauses.OrderColumn("Id"), + clauses.OrderSortAscending(), + ), + ).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") throwFail(t, err) throwFail(t, AssertIs(num, 3)) diff --git a/client/orm/structs/clauses/order.go b/client/orm/structs/clauses/order.go index e1110eec..432bb8f4 100644 --- a/client/orm/structs/clauses/order.go +++ b/client/orm/structs/clauses/order.go @@ -71,7 +71,7 @@ func ParseOrder(expressions ...string) []*Order { func OrderColumn(column string) OrderOption { return func(order *Order) { - order.column = column + order.column = strings.ReplaceAll(column, ExprSep, ExprDot) } } From 6d828e793968d982bd623ee9cae6e8dbc433e363 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 27 Oct 2020 19:32:56 +0800 Subject: [PATCH 08/41] move dir and add ut for ParseOrder --- client/orm/db_tables.go | 7 +- client/orm/orm.go | 4 +- client/orm/orm_conds.go | 4 +- client/orm/orm_queryset.go | 8 +- client/orm/orm_test.go | 20 ++-- .../orm/structs/clauses/{ => order}/order.go | 47 ++++---- .../orm/structs/clauses/order/order_test.go | 112 ++++++++++++++++++ client/orm/structs/clauses/order_test.go | 80 ------------- client/orm/structs/{clauses => }/const.go | 2 +- client/orm/types.go | 32 +++-- 10 files changed, 179 insertions(+), 137 deletions(-) rename client/orm/structs/clauses/{ => order}/order.go (54%) create mode 100644 client/orm/structs/clauses/order/order_test.go delete mode 100644 client/orm/structs/clauses/order_test.go rename client/orm/structs/{clauses => }/const.go (81%) diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 34fa0f07..94450f94 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -16,7 +16,8 @@ package orm import ( "fmt" - "github.com/astaxie/beego/client/orm/structs/clauses" + "github.com/astaxie/beego/client/orm/structs" + "github.com/astaxie/beego/client/orm/structs/clauses/order" "strings" "time" ) @@ -422,7 +423,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { } // generate order sql. -func (t *dbTables) getOrderSQL(orders []*clauses.Order) (orderSQL string) { +func (t *dbTables) getOrderSQL(orders []*order.Order) (orderSQL string) { if len(orders) == 0 { return } @@ -432,7 +433,7 @@ func (t *dbTables) getOrderSQL(orders []*clauses.Order) (orderSQL string) { orderSqls := make([]string, 0, len(orders)) for _, order := range orders { column := order.GetColumn() - clause := strings.Split(column, clauses.ExprDot) + clause := strings.Split(column, structs.ExprDot) if order.IsRaw() { if len(clause) == 2 { diff --git a/client/orm/orm.go b/client/orm/orm.go index d1a66b88..27ae1fc2 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -58,7 +58,7 @@ import ( "database/sql" "errors" "fmt" - "github.com/astaxie/beego/client/orm/structs/clauses" + order2 "github.com/astaxie/beego/client/orm/structs/clauses/order" "os" "reflect" "time" @@ -352,7 +352,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s qs.relDepth = relDepth if len(order) > 0 { - qs.orders = clauses.ParseOrder(order) + qs.orders = order2.ParseOrder(order) } find := ind.FieldByIndex(fi.fieldIndex) diff --git a/client/orm/orm_conds.go b/client/orm/orm_conds.go index 7a798c6f..a2f6743c 100644 --- a/client/orm/orm_conds.go +++ b/client/orm/orm_conds.go @@ -16,13 +16,13 @@ package orm import ( "fmt" - "github.com/astaxie/beego/client/orm/structs/clauses" + "github.com/astaxie/beego/client/orm/structs" "strings" ) // ExprSep define the expression separation const ( - ExprSep = clauses.ExprSep + ExprSep = structs.ExprSep ) type condValue struct { diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index e6493d8f..29b65b1a 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -18,7 +18,7 @@ import ( "context" "fmt" "github.com/astaxie/beego/client/orm/hints" - "github.com/astaxie/beego/client/orm/structs/clauses" + "github.com/astaxie/beego/client/orm/structs/clauses/order" ) type colValue struct { @@ -71,7 +71,7 @@ type querySet struct { limit int64 offset int64 groups []string - orders []*clauses.Order + orders []*order.Order distinct bool forUpdate bool useIndex int @@ -143,12 +143,12 @@ func (o querySet) OrderBy(expressions ...string) QuerySeter { if len(expressions) <= 0 { return &o } - o.orders = clauses.ParseOrder(expressions...) + o.orders = order.ParseOrder(expressions...) return &o } // add ORDER expression. -func (o querySet) OrderClauses(orders ...*clauses.Order) QuerySeter { +func (o querySet) OrderClauses(orders ...*order.Order) QuerySeter { if len(orders) <= 0 { return &o } diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index 8f057df1..e6041396 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -21,7 +21,7 @@ import ( "context" "database/sql" "fmt" - "github.com/astaxie/beego/client/orm/structs/clauses" + "github.com/astaxie/beego/client/orm/structs/clauses/order" "io/ioutil" "math" "os" @@ -1080,9 +1080,9 @@ func TestOrderBy(t *testing.T) { throwFail(t, AssertIs(num, 1)) num, err = qs.OrderClauses( - clauses.OrderClause( - clauses.OrderColumn(`profile__age`), - clauses.OrderSortDescending(), + order.Clause( + order.Column(`profile__age`), + order.SortDescending(), ), ).Filter("user_name", "astaxie").Count() throwFail(t, err) @@ -1090,9 +1090,9 @@ func TestOrderBy(t *testing.T) { if IsMysql { num, err = qs.OrderClauses( - clauses.OrderClause( - clauses.OrderColumn(`rand()`), - clauses.OrderRaw(), + order.Clause( + order.Column(`rand()`), + order.Raw(), ), ).Filter("user_name", "astaxie").Count() throwFail(t, err) @@ -1185,9 +1185,9 @@ func TestValues(t *testing.T) { } num, err = qs.OrderClauses( - clauses.OrderClause( - clauses.OrderColumn("Id"), - clauses.OrderSortAscending(), + order.Clause( + order.Column("Id"), + order.SortAscending(), ), ).Values(&maps) throwFail(t, err) diff --git a/client/orm/structs/clauses/order.go b/client/orm/structs/clauses/order/order.go similarity index 54% rename from client/orm/structs/clauses/order.go rename to client/orm/structs/clauses/order/order.go index 432bb8f4..c07202d5 100644 --- a/client/orm/structs/clauses/order.go +++ b/client/orm/structs/clauses/order/order.go @@ -1,16 +1,19 @@ -package clauses +package order -import "strings" +import ( + "github.com/astaxie/beego/client/orm/structs" + "strings" +) type Sort int8 const ( - SortNone Sort = 0 - SortAscending Sort = 1 - SortDescending Sort = 2 + None Sort = 0 + Ascending Sort = 1 + Descending Sort = 2 ) -type OrderOption func(order *Order) +type Option func(order *Order) type Order struct { column string @@ -18,7 +21,7 @@ type Order struct { isRaw bool } -func OrderClause(options ...OrderOption) *Order { +func Clause(options ...Option) *Order { o := &Order{} for _, option := range options { option(o) @@ -37,9 +40,9 @@ func (o *Order) GetSort() Sort { func (o *Order) SortString() string { switch o.GetSort() { - case SortAscending: + case Ascending: return "ASC" - case SortDescending: + case Descending: return "DESC" } @@ -53,10 +56,10 @@ func (o *Order) IsRaw() bool { func ParseOrder(expressions ...string) []*Order { var orders []*Order for _, expression := range expressions { - sort := SortAscending - column := strings.ReplaceAll(expression, ExprSep, ExprDot) + sort := Ascending + column := strings.ReplaceAll(expression, structs.ExprSep, structs.ExprDot) if column[0] == '-' { - sort = SortDescending + sort = Descending column = column[1:] } @@ -69,31 +72,31 @@ func ParseOrder(expressions ...string) []*Order { return orders } -func OrderColumn(column string) OrderOption { +func Column(column string) Option { return func(order *Order) { - order.column = strings.ReplaceAll(column, ExprSep, ExprDot) + order.column = strings.ReplaceAll(column, structs.ExprSep, structs.ExprDot) } } -func sort(sort Sort) OrderOption { +func sort(sort Sort) Option { return func(order *Order) { order.sort = sort } } -func OrderSortAscending() OrderOption { - return sort(SortAscending) +func SortAscending() Option { + return sort(Ascending) } -func OrderSortDescending() OrderOption { - return sort(SortDescending) +func SortDescending() Option { + return sort(Descending) } -func OrderSortNone() OrderOption { - return sort(SortNone) +func SortNone() Option { + return sort(None) } -func OrderRaw() OrderOption { +func Raw() Option { return func(order *Order) { order.isRaw = true } diff --git a/client/orm/structs/clauses/order/order_test.go b/client/orm/structs/clauses/order/order_test.go new file mode 100644 index 00000000..93072960 --- /dev/null +++ b/client/orm/structs/clauses/order/order_test.go @@ -0,0 +1,112 @@ +package order + +import ( + "testing" +) + +func TestOrderClause(t *testing.T) { + var ( + column = `a` + ) + + o := Clause( + Column(column), + ) + + if o.GetColumn() != column { + t.Error() + } +} + +func TestOrderSortAscending(t *testing.T) { + o := Clause( + SortAscending(), + ) + + if o.GetSort() != Ascending { + t.Error() + } +} + +func TestOrderSortDescending(t *testing.T) { + o := Clause( + SortDescending(), + ) + + if o.GetSort() != Descending { + t.Error() + } +} + +func TestOrderSortNone(t *testing.T) { + o1 := Clause( + SortNone(), + ) + + if o1.GetSort() != None { + t.Error() + } + + o2 := Clause() + + if o2.GetSort() != None { + t.Error() + } +} + +func TestOrderRaw(t *testing.T) { + o1 := Clause() + + if o1.IsRaw() { + t.Error() + } + + o2 := Clause( + Raw(), + ) + + if !o2.IsRaw() { + t.Error() + } +} + +func TestOrderColumn(t *testing.T) { + o1 := Clause( + Column(`aaa`), + ) + + if o1.GetColumn() != `aaa` { + t.Error() + } +} + +func TestParseOrder(t *testing.T) { + orders := ParseOrder( + `-user__status`, + `status`, + `user__status`, + ) + + t.Log(orders) + + if orders[0].GetSort() != Descending { + t.Error() + } + + if orders[0].GetColumn() != `user.status` { + t.Error() + } + + if orders[1].GetColumn() != `status` { + t.Error() + } + + if orders[1].GetSort() != Ascending { + t.Error() + } + + if orders[2].GetColumn() != `user.status` { + t.Error() + } + +} diff --git a/client/orm/structs/clauses/order_test.go b/client/orm/structs/clauses/order_test.go deleted file mode 100644 index dbb37a0a..00000000 --- a/client/orm/structs/clauses/order_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package clauses - -import "testing" - -func TestOrderClause(t *testing.T) { - var ( - column = `a` - ) - - o := OrderClause( - OrderColumn(column), - ) - - if o.GetColumn() != column { - t.Error() - } -} - -func TestOrderSortAscending(t *testing.T) { - o := OrderClause( - OrderSortAscending(), - ) - - if o.GetSort() != SortAscending { - t.Error() - } -} - -func TestOrderSortDescending(t *testing.T) { - o := OrderClause( - OrderSortDescending(), - ) - - if o.GetSort() != SortDescending { - t.Error() - } -} - -func TestOrderSortNone(t *testing.T) { - o1 := OrderClause( - OrderSortNone(), - ) - - if o1.GetSort() != SortNone { - t.Error() - } - - o2 := OrderClause() - - if o2.GetSort() != SortNone { - t.Error() - } -} - -func TestOrderRaw(t *testing.T) { - o1 := OrderClause() - - if o1.IsRaw() { - t.Error() - } - - o2 := OrderClause( - OrderRaw(), - ) - - if !o2.IsRaw() { - t.Error() - } -} - -func TestOrderColumn(t *testing.T) { - o1 := OrderClause( - OrderColumn(`aaa`), - ) - - if o1.GetColumn() != `aaa` { - t.Error() - } -} - diff --git a/client/orm/structs/clauses/const.go b/client/orm/structs/const.go similarity index 81% rename from client/orm/structs/clauses/const.go rename to client/orm/structs/const.go index 747d3fd7..42a1845a 100644 --- a/client/orm/structs/clauses/const.go +++ b/client/orm/structs/const.go @@ -1,4 +1,4 @@ -package clauses +package structs const ( ExprSep = "__" diff --git a/client/orm/types.go b/client/orm/types.go index 14a34025..d39aa75c 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -17,7 +17,7 @@ package orm import ( "context" "database/sql" - "github.com/astaxie/beego/client/orm/structs/clauses" + "github.com/astaxie/beego/client/orm/structs/clauses/order" "reflect" "time" @@ -292,20 +292,26 @@ type QuerySeter interface { OrderBy(exprs ...string) QuerySeter // add ORDER expression by order clauses // for example: - // OrderClauses(clauses.OrderClause( - // clauses.OrderColumn(`status`), - // clauses.OrderSortAscending(),//default None + // OrderClauses( + // order.Clause( + // order.Column("Id"), + // order.SortAscending(), + // ), + // order.Clause( + // order.Column("status"), + // order.SortDescending(), + // ), + // ) + // OrderClauses(order.Clause( + // order.Column(`user__status`), + // order.SortDescending(),//default None // )) - // OrderClauses(clauses.OrderClause( - // clauses.OrderColumn(`user__status`), - // clauses.OrderSortDescending(),//default None + // OrderClauses(order.Clause( + // order.Column(`random()`), + // order.SortNone(),//default None + // order.Raw(),//default false.if true, do not check field is valid or not // )) - // OrderClauses(clauses.OrderClause( - // clauses.OrderColumn(`random()`), - // clauses.OrderSortNone(),//default None - // clauses.OrderRaw(),//default false.if true, do not check field is valid or not - // )) - OrderClauses(orders ...*clauses.Order) QuerySeter + OrderClauses(orders ...*order.Order) QuerySeter // add FORCE INDEX expression. // for example: // qs.ForceIndex(`idx_name1`,`idx_name2`) From d147f4a0184abaf759b7c1f30eb7242b03b439df Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 27 Oct 2020 19:35:35 +0800 Subject: [PATCH 09/41] format comment --- client/orm/types.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/orm/types.go b/client/orm/types.go index d39aa75c..6630e14f 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -302,11 +302,11 @@ type QuerySeter interface { // order.SortDescending(), // ), // ) - // OrderClauses(order.Clause( + // OrderClauses(order.Clause( // order.Column(`user__status`), // order.SortDescending(),//default None // )) - // OrderClauses(order.Clause( + // OrderClauses(order.Clause( // order.Column(`random()`), // order.SortNone(),//default None // order.Raw(),//default false.if true, do not check field is valid or not From 759982b3b856f322a0d600487a627dac81bd09c4 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Tue, 27 Oct 2020 20:01:53 +0800 Subject: [PATCH 10/41] add more UT --- .../orm/structs/clauses/order/order_test.go | 44 ++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/client/orm/structs/clauses/order/order_test.go b/client/orm/structs/clauses/order/order_test.go index 93072960..0e9356f0 100644 --- a/client/orm/structs/clauses/order/order_test.go +++ b/client/orm/structs/clauses/order/order_test.go @@ -4,7 +4,7 @@ import ( "testing" ) -func TestOrderClause(t *testing.T) { +func TestClause(t *testing.T) { var ( column = `a` ) @@ -18,7 +18,7 @@ func TestOrderClause(t *testing.T) { } } -func TestOrderSortAscending(t *testing.T) { +func TestSortAscending(t *testing.T) { o := Clause( SortAscending(), ) @@ -28,7 +28,7 @@ func TestOrderSortAscending(t *testing.T) { } } -func TestOrderSortDescending(t *testing.T) { +func TestSortDescending(t *testing.T) { o := Clause( SortDescending(), ) @@ -38,7 +38,7 @@ func TestOrderSortDescending(t *testing.T) { } } -func TestOrderSortNone(t *testing.T) { +func TestSortNone(t *testing.T) { o1 := Clause( SortNone(), ) @@ -54,7 +54,7 @@ func TestOrderSortNone(t *testing.T) { } } -func TestOrderRaw(t *testing.T) { +func TestRaw(t *testing.T) { o1 := Clause() if o1.IsRaw() { @@ -70,7 +70,7 @@ func TestOrderRaw(t *testing.T) { } } -func TestOrderColumn(t *testing.T) { +func TestColumn(t *testing.T) { o1 := Clause( Column(`aaa`), ) @@ -110,3 +110,35 @@ func TestParseOrder(t *testing.T) { } } + +func TestOrder_GetColumn(t *testing.T) { + o := Clause( + Column(`user__id`), + ) + if o.GetColumn() != `user.id` { + t.Error() + } +} + +func TestOrder_GetSort(t *testing.T) { + o := Clause( + SortDescending(), + ) + if o.GetSort() != Descending { + t.Error() + } +} + +func TestOrder_IsRaw(t *testing.T) { + o1 := Clause() + if o1.IsRaw() { + t.Error() + } + + o2 := Clause( + Raw(), + ) + if !o2.IsRaw() { + t.Error() + } +} From 508105d32ac0b4e34d1094535e577a8d9f90ee34 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Thu, 29 Oct 2020 18:57:16 +0800 Subject: [PATCH 11/41] fix UT:concurrent map iteration and map write --- task/task.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/task/task.go b/task/task.go index 8f25a0f3..00cbbfa7 100644 --- a/task/task.go +++ b/task/task.go @@ -452,9 +452,11 @@ func (m *taskManager) StartTask() { func (m *taskManager) run() { now := time.Now().Local() + m.taskLock.Lock() for _, t := range m.adminTaskList { t.SetNext(nil, now) } + m.taskLock.Unlock() for { // we only use RLock here because NewMapSorter copy the reference, do not change any thing From f8fb50999bc7305256d6c1844313169aab53a9c8 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Thu, 29 Oct 2020 19:05:32 +0800 Subject: [PATCH 12/41] move dir --- client/orm/{structs => clauses}/const.go | 2 +- .../order => clauses/order_clause}/order.go | 8 +++---- .../order_clause}/order_test.go | 2 +- client/orm/db_tables.go | 7 +++--- client/orm/orm.go | 4 ++-- client/orm/orm_conds.go | 4 ++-- client/orm/orm_queryset.go | 8 +++---- client/orm/orm_test.go | 20 ++++++++--------- client/orm/types.go | 22 +++++++++---------- 9 files changed, 38 insertions(+), 39 deletions(-) rename client/orm/{structs => clauses}/const.go (81%) rename client/orm/{structs/clauses/order => clauses/order_clause}/order.go (85%) rename client/orm/{structs/clauses/order => clauses/order_clause}/order_test.go (98%) diff --git a/client/orm/structs/const.go b/client/orm/clauses/const.go similarity index 81% rename from client/orm/structs/const.go rename to client/orm/clauses/const.go index 42a1845a..747d3fd7 100644 --- a/client/orm/structs/const.go +++ b/client/orm/clauses/const.go @@ -1,4 +1,4 @@ -package structs +package clauses const ( ExprSep = "__" diff --git a/client/orm/structs/clauses/order/order.go b/client/orm/clauses/order_clause/order.go similarity index 85% rename from client/orm/structs/clauses/order/order.go rename to client/orm/clauses/order_clause/order.go index c07202d5..510c5505 100644 --- a/client/orm/structs/clauses/order/order.go +++ b/client/orm/clauses/order_clause/order.go @@ -1,7 +1,7 @@ -package order +package order_clause import ( - "github.com/astaxie/beego/client/orm/structs" + "github.com/astaxie/beego/client/orm/clauses" "strings" ) @@ -57,7 +57,7 @@ func ParseOrder(expressions ...string) []*Order { var orders []*Order for _, expression := range expressions { sort := Ascending - column := strings.ReplaceAll(expression, structs.ExprSep, structs.ExprDot) + column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot) if column[0] == '-' { sort = Descending column = column[1:] @@ -74,7 +74,7 @@ func ParseOrder(expressions ...string) []*Order { func Column(column string) Option { return func(order *Order) { - order.column = strings.ReplaceAll(column, structs.ExprSep, structs.ExprDot) + order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot) } } diff --git a/client/orm/structs/clauses/order/order_test.go b/client/orm/clauses/order_clause/order_test.go similarity index 98% rename from client/orm/structs/clauses/order/order_test.go rename to client/orm/clauses/order_clause/order_test.go index 0e9356f0..172e7492 100644 --- a/client/orm/structs/clauses/order/order_test.go +++ b/client/orm/clauses/order_clause/order_test.go @@ -1,4 +1,4 @@ -package order +package order_clause import ( "testing" diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 94450f94..76a14fd0 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -16,8 +16,7 @@ package orm import ( "fmt" - "github.com/astaxie/beego/client/orm/structs" - "github.com/astaxie/beego/client/orm/structs/clauses/order" + "github.com/astaxie/beego/client/orm/clauses" "strings" "time" ) @@ -423,7 +422,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { } // generate order sql. -func (t *dbTables) getOrderSQL(orders []*order.Order) (orderSQL string) { +func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) { if len(orders) == 0 { return } @@ -433,7 +432,7 @@ func (t *dbTables) getOrderSQL(orders []*order.Order) (orderSQL string) { orderSqls := make([]string, 0, len(orders)) for _, order := range orders { column := order.GetColumn() - clause := strings.Split(column, structs.ExprDot) + clause := strings.Split(column, clauses.ExprDot) if order.IsRaw() { if len(clause) == 2 { diff --git a/client/orm/orm.go b/client/orm/orm.go index 27ae1fc2..a228c626 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -58,7 +58,7 @@ import ( "database/sql" "errors" "fmt" - order2 "github.com/astaxie/beego/client/orm/structs/clauses/order" + "github.com/astaxie/beego/client/orm/clauses/order_clause" "os" "reflect" "time" @@ -352,7 +352,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s qs.relDepth = relDepth if len(order) > 0 { - qs.orders = order2.ParseOrder(order) + qs.orders = order_clause.ParseOrder(order) } find := ind.FieldByIndex(fi.fieldIndex) diff --git a/client/orm/orm_conds.go b/client/orm/orm_conds.go index a2f6743c..fbf62c5a 100644 --- a/client/orm/orm_conds.go +++ b/client/orm/orm_conds.go @@ -16,13 +16,13 @@ package orm import ( "fmt" - "github.com/astaxie/beego/client/orm/structs" + "github.com/astaxie/beego/client/orm/clauses" "strings" ) // ExprSep define the expression separation const ( - ExprSep = structs.ExprSep + ExprSep = clauses.ExprSep ) type condValue struct { diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 29b65b1a..45900487 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -18,7 +18,7 @@ import ( "context" "fmt" "github.com/astaxie/beego/client/orm/hints" - "github.com/astaxie/beego/client/orm/structs/clauses/order" + "github.com/astaxie/beego/client/orm/clauses/order_clause" ) type colValue struct { @@ -71,7 +71,7 @@ type querySet struct { limit int64 offset int64 groups []string - orders []*order.Order + orders []*order_clause.Order distinct bool forUpdate bool useIndex int @@ -143,12 +143,12 @@ func (o querySet) OrderBy(expressions ...string) QuerySeter { if len(expressions) <= 0 { return &o } - o.orders = order.ParseOrder(expressions...) + o.orders = order_clause.ParseOrder(expressions...) return &o } // add ORDER expression. -func (o querySet) OrderClauses(orders ...*order.Order) QuerySeter { +func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter { if len(orders) <= 0 { return &o } diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index e6041396..2f3f9d55 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -21,7 +21,7 @@ import ( "context" "database/sql" "fmt" - "github.com/astaxie/beego/client/orm/structs/clauses/order" + "github.com/astaxie/beego/client/orm/clauses/order_clause" "io/ioutil" "math" "os" @@ -1080,9 +1080,9 @@ func TestOrderBy(t *testing.T) { throwFail(t, AssertIs(num, 1)) num, err = qs.OrderClauses( - order.Clause( - order.Column(`profile__age`), - order.SortDescending(), + order_clause.Clause( + order_clause.Column(`profile__age`), + order_clause.SortDescending(), ), ).Filter("user_name", "astaxie").Count() throwFail(t, err) @@ -1090,9 +1090,9 @@ func TestOrderBy(t *testing.T) { if IsMysql { num, err = qs.OrderClauses( - order.Clause( - order.Column(`rand()`), - order.Raw(), + order_clause.Clause( + order_clause.Column(`rand()`), + order_clause.Raw(), ), ).Filter("user_name", "astaxie").Count() throwFail(t, err) @@ -1185,9 +1185,9 @@ func TestValues(t *testing.T) { } num, err = qs.OrderClauses( - order.Clause( - order.Column("Id"), - order.SortAscending(), + order_clause.Clause( + order_clause.Column("Id"), + order_clause.SortAscending(), ), ).Values(&maps) throwFail(t, err) diff --git a/client/orm/types.go b/client/orm/types.go index 6630e14f..5da40830 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -17,7 +17,7 @@ package orm import ( "context" "database/sql" - "github.com/astaxie/beego/client/orm/structs/clauses/order" + "github.com/astaxie/beego/client/orm/clauses/order_clause" "reflect" "time" @@ -293,25 +293,25 @@ type QuerySeter interface { // add ORDER expression by order clauses // for example: // OrderClauses( - // order.Clause( + // order_clause.Clause( // order.Column("Id"), // order.SortAscending(), // ), - // order.Clause( + // order_clause.Clause( // order.Column("status"), // order.SortDescending(), // ), // ) - // OrderClauses(order.Clause( - // order.Column(`user__status`), - // order.SortDescending(),//default None + // OrderClauses(order_clause.Clause( + // order_clause.Column(`user__status`), + // order_clause.SortDescending(),//default None // )) - // OrderClauses(order.Clause( - // order.Column(`random()`), - // order.SortNone(),//default None - // order.Raw(),//default false.if true, do not check field is valid or not + // OrderClauses(order_clause.Clause( + // order_clause.Column(`random()`), + // order_clause.SortNone(),//default None + // order_clause.Raw(),//default false.if true, do not check field is valid or not // )) - OrderClauses(orders ...*order.Order) QuerySeter + OrderClauses(orders ...*order_clause.Order) QuerySeter // add FORCE INDEX expression. // for example: // qs.ForceIndex(`idx_name1`,`idx_name2`) From f864e585e1b82815816e9acfbdfddad8dc9986b9 Mon Sep 17 00:00:00 2001 From: jianzhiyao Date: Thu, 29 Oct 2020 19:07:42 +0800 Subject: [PATCH 13/41] fix import --- client/orm/db_tables.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 76a14fd0..82b8ab7e 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -17,6 +17,7 @@ package orm import ( "fmt" "github.com/astaxie/beego/client/orm/clauses" + "github.com/astaxie/beego/client/orm/clauses/order_clause" "strings" "time" ) From e597b05c938f0b6348b116426589c636e0966c97 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 26 Dec 2020 18:38:52 +0800 Subject: [PATCH 14/41] fix-3928 --- server/web/router.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/web/router.go b/server/web/router.go index 5a663386..866ea745 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -548,7 +548,7 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str for _, l := range t.leaves { if c, ok := l.runObject.(*ControllerInfo); ok { if c.routerType == routerTypeBeego && - strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) { find := false if HTTPMETHOD[strings.ToUpper(methodName)] { if len(c.methods) == 0 { From 6e4398f7ec0736cdc4aea489ae35e85c5898b6b0 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 26 Dec 2020 22:45:56 +0800 Subject: [PATCH 15/41] add UT for issue 3928 --- server/web/router_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/server/web/router_test.go b/server/web/router_test.go index 87997322..521605e5 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -26,6 +26,14 @@ import ( "github.com/beego/beego/v2/server/web/context" ) +type PrefixTestController struct { + Controller +} + +func (ptc *PrefixTestController) PrefixList() { + ptc.Ctx.Output.Body([]byte("i am list in prefix test")) +} + type TestController struct { Controller } @@ -87,6 +95,20 @@ func (jc *JSONController) Get() { jc.Ctx.Output.Body([]byte("ok")) } +func TestPrefixUrlFor(t *testing.T){ + handler := NewControllerRegister() + handler.Add("/my/prefix/list", &PrefixTestController{}, "get:PrefixList") + + if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` { + logs.Info(a) + t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list") + } + if a := handler.URLFor(`TestController.PrefixList`); a != `` { + logs.Info(a) + t.Errorf("TestController.PrefixList must equal to empty string") + } +} + func TestUrlFor(t *testing.T) { handler := NewControllerRegister() handler.Add("/api/list", &TestController{}, "*:List") From 41b1833898cd9843355f057599392a56c7e54ada Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 2 Jan 2021 21:55:12 +0800 Subject: [PATCH 16/41] make session easier to be configured --- server/web/session/redis/sess_redis_test.go | 21 +-- .../sess_redis_sentinel_test.go | 18 +-- server/web/session/session.go | 19 --- server/web/session/session_config.go | 137 ++++++++++++++++++ server/web/session/session_provider_type.go | 16 ++ 5 files changed, 173 insertions(+), 38 deletions(-) create mode 100644 server/web/session/session_config.go create mode 100644 server/web/session/session_provider_type.go diff --git a/server/web/session/redis/sess_redis_test.go b/server/web/session/redis/sess_redis_test.go index fe5c363b..2b15eef1 100644 --- a/server/web/session/redis/sess_redis_test.go +++ b/server/web/session/redis/sess_redis_test.go @@ -15,21 +15,22 @@ import ( ) func TestRedis(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - } - redisAddr := os.Getenv("REDIS_ADDR") if redisAddr == "" { redisAddr = "127.0.0.1:6379" } + redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr) + + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig(redisConfig), + ) - sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr) globalSession, err := session.NewManager("redis", sessionConfig) if err != nil { t.Fatal("could not create manager:", err) diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go index 0a8030ce..489e8998 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go @@ -13,15 +13,15 @@ import ( ) func TestRedisSentinel(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - ProviderConfig: "127.0.0.1:6379,100,,0,master", - } + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"), + ) globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) if e != nil { t.Log(e) diff --git a/server/web/session/session.go b/server/web/session/session.go index 6b53ec29..ca0407e8 100644 --- a/server/web/session/session.go +++ b/server/web/session/session.go @@ -91,25 +91,6 @@ func GetProvider(name string) (Provider, error) { return provider, nil } -// ManagerConfig define the session config -type ManagerConfig struct { - CookieName string `json:"cookieName"` - EnableSetCookie bool `json:"enableSetCookie,omitempty"` - Gclifetime int64 `json:"gclifetime"` - Maxlifetime int64 `json:"maxLifetime"` - DisableHTTPOnly bool `json:"disableHTTPOnly"` - Secure bool `json:"secure"` - CookieLifeTime int `json:"cookieLifeTime"` - ProviderConfig string `json:"providerConfig"` - Domain string `json:"domain"` - SessionIDLength int64 `json:"sessionIDLength"` - EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` - SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` - EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` - SessionIDPrefix string `json:"sessionIDPrefix"` - CookieSameSite http.SameSite `json:"cookieSameSite"` -} - // Manager contains Provider and its configuration. type Manager struct { provider Provider diff --git a/server/web/session/session_config.go b/server/web/session/session_config.go new file mode 100644 index 00000000..a1e24ae3 --- /dev/null +++ b/server/web/session/session_config.go @@ -0,0 +1,137 @@ +package session + +import "net/http" + +// ManagerConfig define the session config +type ManagerConfig struct { + CookieName string `json:"cookieName"` + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` + Secure bool `json:"secure"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIDLength int64 `json:"sessionIDLength"` + EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` + SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` + EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` + SessionIDPrefix string `json:"sessionIDPrefix"` + CookieSameSite http.SameSite `json:"cookieSameSite"` +} + +type ManagerConfigOpt func(config *ManagerConfig) + +func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig { + config := &ManagerConfig{} + for _, opt := range opts { + opt(config) + } + return config +} + +// CfgCookieName set key of session id +func CfgCookieName(cookieName string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieName = cookieName + } +} + +// CfgCookieName set len of session id +func CfgSessionIdLength(len int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDLength = len + } +} + +// CfgSessionIdPrefix set prefix of session id +func CfgSessionIdPrefix(prefix string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDPrefix = prefix + } +} + +//CfgSetCookie whether set `Set-Cookie` header in HTTP response +func CfgSetCookie(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSetCookie = enable + } +} + +//CfgGcLifeTime set session gc lift time +func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Gclifetime = lifeTime + } +} + +//CfgMaxLifeTime set session lift time +func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Maxlifetime = lifeTime + } +} + +//CfgGcLifeTime set session lift time +func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieLifeTime = lifeTime + } +} + +//CfgProviderConfig configure session provider +func CfgProviderConfig(providerConfig string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.ProviderConfig = providerConfig + } +} + +//CfgDomain set cookie domain +func CfgDomain(domain string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Domain = domain + } +} + +//CfgSessionIdInHTTPHeader enable session id in http header +func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInHTTPHeader = enable + } +} + +//CfgSetSessionNameInHTTPHeader set key of session id in http header +func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionNameInHTTPHeader = name + } +} + +//EnableSidInURLQuery enable session id in query string +func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInURLQuery = enable + } +} + +//DisableHTTPOnly set HTTPOnly for http.Cookie +func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.DisableHTTPOnly = !HTTPOnly + } +} + +//CfgSecure set Secure for http.Cookie +func CfgSecure(Enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Secure = Enable + } +} + +//CfgSameSite set http.SameSite +func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieSameSite = sameSite + } +} diff --git a/server/web/session/session_provider_type.go b/server/web/session/session_provider_type.go new file mode 100644 index 00000000..c14a3ecc --- /dev/null +++ b/server/web/session/session_provider_type.go @@ -0,0 +1,16 @@ +package session + +const ( + ProviderCookie = `cookie` + ProviderFile = `file` + ProviderMemory = `memory` + ProviderCouchbase = `couchbase` + ProviderLedis = `ledis` + ProviderMemcache = `memcache` + ProviderMysql = `mysql` + ProviderPostgresql = `postgresql` + ProviderRedis = `redis` + ProviderRedisCluster = `redis_cluster` + ProviderRedisSentinel = `redis_sentinel` + ProviderSsdb = `ssdb` +) From a4f8fbd5a14e97a3e8ad5f2ec9e87a80c6f7dea9 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 3 Jan 2021 00:33:44 +0800 Subject: [PATCH 17/41] add session filter & UT --- go.mod | 2 +- server/web/filter/session/filter.go | 65 ++++++++++++ server/web/filter/session/filter_test.go | 112 ++++++++++++++++++++ server/web/session/session_provider_type.go | 26 ++--- 4 files changed, 192 insertions(+), 13 deletions(-) create mode 100644 server/web/filter/session/filter.go create mode 100644 server/web/filter/session/filter_test.go diff --git a/go.mod b/go.mod index 89baa406..ce67b7d2 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-cmp v0.5.0 // indirect - github.com/google/uuid v1.1.1 // indirect + github.com/google/uuid v1.1.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/hashicorp/golang-lru v0.5.4 github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go new file mode 100644 index 00000000..b37a9f51 --- /dev/null +++ b/server/web/filter/session/filter.go @@ -0,0 +1,65 @@ +package session + +import ( + "context" + "errors" + "fmt" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" + "github.com/google/uuid" + "sync" +) + +var ( + sessionKey string + sessionKeyOnce sync.Once +) + +func getSessionKey() string { + + sessionKeyOnce.Do(func() { + //generate an unique session store key + sessionKey = fmt.Sprintf(`sess_store:%d`, uuid.New().ID()) + }) + + return sessionKey +} + +func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { + sessionConfig := session.NewManagerConfig(options...) + sessionManager, _ := session.NewManager(string(providerType), sessionConfig) + go sessionManager.GC() + + return func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + + if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil { + logs.Warning(`init session error:%s`, err.Error()) + } else { + //release session at the end of request + defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) + ctx.Input.SetData(getSessionKey(), sess) + } + + next(ctx) + + } + } +} + +func GetStore(ctx *webContext.Context) (store session.Store, err error) { + if ctx == nil { + err = errors.New(`ctx is nil`) + return + } + + if s, ok := ctx.Input.GetData(getSessionKey()).(session.Store); ok { + store = s + return + } else { + err = errors.New(`can not get a valid session store`) + return + } +} diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go new file mode 100644 index 00000000..7b24f6ad --- /dev/null +++ b/server/web/filter/session/filter_test.go @@ -0,0 +1,112 @@ +package session + +import ( + "context" + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" + "net/http" + "net/http/httptest" + "testing" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code) + } +} + +func TestSession(t *testing.T) { + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store := ctx.Input.GetData(getSessionKey()); store == nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} + +func TestGetStore(t *testing.T) { + handler := web.NewControllerRegister() + + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + var ( + checkKey = `asodiuasdk1j)AS(87` + checkValue = `ASsd-09812-3` + + store session.Store + err error + + c = context.Background() + ) + + if store, err = GetStore(ctx); err == nil { + if store == nil { + t.Error(`store should not be nil`) + } else { + _ = store.Set(c, checkKey, checkValue) + } + } else { + t.Error(err) + } + + next(ctx) + + if store != nil { + if v := store.Get(c, checkKey); v != checkValue { + t.Error(v, `is not equals to`, checkValue) + } + }else{ + t.Error(`store should not be nil`) + } + + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} diff --git a/server/web/session/session_provider_type.go b/server/web/session/session_provider_type.go index c14a3ecc..78dc116d 100644 --- a/server/web/session/session_provider_type.go +++ b/server/web/session/session_provider_type.go @@ -1,16 +1,18 @@ package session +type ProviderType string + const ( - ProviderCookie = `cookie` - ProviderFile = `file` - ProviderMemory = `memory` - ProviderCouchbase = `couchbase` - ProviderLedis = `ledis` - ProviderMemcache = `memcache` - ProviderMysql = `mysql` - ProviderPostgresql = `postgresql` - ProviderRedis = `redis` - ProviderRedisCluster = `redis_cluster` - ProviderRedisSentinel = `redis_sentinel` - ProviderSsdb = `ssdb` + ProviderCookie ProviderType = `cookie` + ProviderFile ProviderType = `file` + ProviderMemory ProviderType = `memory` + ProviderCouchbase ProviderType = `couchbase` + ProviderLedis ProviderType = `ledis` + ProviderMemcache ProviderType = `memcache` + ProviderMysql ProviderType = `mysql` + ProviderPostgresql ProviderType = `postgresql` + ProviderRedis ProviderType = `redis` + ProviderRedisCluster ProviderType = `redis_cluster` + ProviderRedisSentinel ProviderType = `redis_sentinel` + ProviderSsdb ProviderType = `ssdb` ) From babf0dfc14872fafd146fd0e0583d5e804dbc0b5 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 3 Jan 2021 21:45:20 +0800 Subject: [PATCH 18/41] Make compatible --- adapter/app.go | 2 +- adapter/router.go | 2 +- server/web/admin.go | 14 +++++++------- server/web/flash_test.go | 2 +- server/web/namespace.go | 2 +- server/web/router.go | 12 ++++++------ server/web/router_test.go | 36 +++++++++++++++++------------------ server/web/server.go | 14 +++++++++++--- server/web/unregroute_test.go | 24 +++++++++++------------ 9 files changed, 58 insertions(+), 50 deletions(-) diff --git a/adapter/app.go b/adapter/app.go index 565a9795..8502256b 100644 --- a/adapter/app.go +++ b/adapter/app.go @@ -74,7 +74,7 @@ func oldMiddlewareToNew(mws []MiddleWare) []web.MiddleWare { // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { - return (*App)(web.Router(rootpath, c, web.SetRouterMethods(c, mappingMethods...))) + return (*App)(web.Router(rootpath, c, mappingMethods...)) } // UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful diff --git a/adapter/router.go b/adapter/router.go index 17e270ca..9a615efe 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister { // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - (*web.ControllerRegister)(p).Add(pattern, c, web.SetRouterMethods(c, mappingMethods...)) + (*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...)) } // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller diff --git a/server/web/admin.go b/server/web/admin.go index d640c1be..89c9ddb9 100644 --- a/server/web/admin.go +++ b/server/web/admin.go @@ -112,13 +112,13 @@ func registerAdmin() error { HttpServer: NewHttpServerWithCfg(BConfig), } // keep in mind that all data should be html escaped to avoid XSS attack - beeAdminApp.Router("/", c, SetRouterMethods(c, "get:AdminIndex")) - beeAdminApp.Router("/qps", c, SetRouterMethods(c, "get:QpsIndex")) - beeAdminApp.Router("/prof", c, SetRouterMethods(c, "get:ProfIndex")) - beeAdminApp.Router("/healthcheck", c, SetRouterMethods(c, "get:Healthcheck")) - beeAdminApp.Router("/task", c, SetRouterMethods(c, "get:TaskStatus")) - beeAdminApp.Router("/listconf", c, SetRouterMethods(c, "get:ListConf")) - beeAdminApp.Router("/metrics", c, SetRouterMethods(c, "get:PrometheusMetrics")) + beeAdminApp.Router("/", c, "get:AdminIndex") + beeAdminApp.Router("/qps", c, "get:QpsIndex") + beeAdminApp.Router("/prof", c, "get:ProfIndex") + beeAdminApp.Router("/healthcheck", c, "get:Healthcheck") + beeAdminApp.Router("/task", c, "get:TaskStatus") + beeAdminApp.Router("/listconf", c, "get:ListConf") + beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics") go beeAdminApp.Run() } diff --git a/server/web/flash_test.go b/server/web/flash_test.go index c1ca9554..3e20c8fb 100644 --- a/server/web/flash_test.go +++ b/server/web/flash_test.go @@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) { // setup the handler handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, SetRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) + handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) handler.ServeHTTP(w, r) // get the Set-Cookie value diff --git a/server/web/namespace.go b/server/web/namespace.go index 3598a222..892f648a 100644 --- a/server/web/namespace.go +++ b/server/web/namespace.go @@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { // Router same as beego.Rourer // refer: https://godoc.org/github.com/beego/beego/v2#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, SetRouterMethods(c, mappingMethods...)) + n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...)) return n } diff --git a/server/web/router.go b/server/web/router.go index ba85ad6e..613c4172 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -121,19 +121,19 @@ type ControllerInfo struct { sessionOn bool } -type ControllerOptions func(*ControllerInfo) +type ControllerOption func(*ControllerInfo) func (c *ControllerInfo) GetPattern() string { return c.pattern } -func SetRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOptions { +func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { return func(c *ControllerInfo) { c.methods = parseMappingMethods(ctrlInterface, mappingMethod) } } -func SetRouterSessionOn(sessionOn bool) ControllerOptions { +func WithRouterSessionOn(sessionOn bool) ControllerOption { return func(c *ControllerInfo) { c.sessionOn = sessionOn } @@ -186,7 +186,7 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { // Add("/api/delete",&RestController{},"delete:DeleteFood") // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOptions) { +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) { p.addWithMethodParams(pattern, c, nil, opts...) } @@ -239,7 +239,7 @@ func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) { } } -func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOptions) { +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() @@ -311,7 +311,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } - p.addWithMethodParams(a.Router, c, a.MethodParams, SetRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) + p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) } } } diff --git a/server/web/router_test.go b/server/web/router_test.go index bd3953ba..474405e8 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -97,7 +97,7 @@ func (jc *JSONController) Get() { func TestPrefixUrlFor(t *testing.T){ handler := NewControllerRegister() - handler.Add("/my/prefix/list", &PrefixTestController{}, "get:PrefixList") + handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList")) if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` { logs.Info(a) @@ -111,8 +111,8 @@ func TestPrefixUrlFor(t *testing.T){ func TestUrlFor(t *testing.T) { handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) - handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) if a := handler.URLFor("TestController.List"); a != "/api/list" { logs.Info(a) t.Errorf("TestController.List must equal to /api/list") @@ -135,9 +135,9 @@ func TestUrlFor3(t *testing.T) { func TestUrlFor2(t *testing.T) { handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) - handler.Add("/v1/:username/edit", &TestController{}, SetRouterMethods(&TestController{}, "get:GetURL")) - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, SetRouterMethods(&TestController{}, "*:Param")) + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL")) + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { logs.Info(handler.URLFor("TestController.GetURL")) @@ -167,7 +167,7 @@ func TestUserFunc(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, SetRouterMethods(&TestController{}, "*:List")) + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) handler.ServeHTTP(w, r) if w.Body.String() != "i am list" { t.Errorf("user define func can't run") @@ -257,7 +257,7 @@ func TestRouteOk(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, SetRouterMethods(&TestController{}, "get:GetParams")) + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams")) handler.ServeHTTP(w, r) body := w.Body.String() if body != "anderson+thomas+kungfu" { @@ -271,7 +271,7 @@ func TestManyRoute(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetManyRouter")) + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter")) handler.ServeHTTP(w, r) body := w.Body.String() @@ -288,7 +288,7 @@ func TestEmptyResponse(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, SetRouterMethods(&TestController{}, "get:GetEmptyBody")) + handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody")) handler.ServeHTTP(w, r) if body := w.Body.String(); body != "" { @@ -783,8 +783,8 @@ func TestRouterSessionSet(t *testing.T) { r, _ := http.NewRequest("GET", "/user", nil) w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), - SetRouterSessionOn(false)) + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) handler.ServeHTTP(w, r) if w.Header().Get("Set-Cookie") != "" { t.Errorf("TestRotuerSessionSet failed") @@ -794,8 +794,8 @@ func TestRouterSessionSet(t *testing.T) { r, _ = http.NewRequest("GET", "/user", nil) w = httptest.NewRecorder() handler = NewControllerRegister() - handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), - SetRouterSessionOn(true)) + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) handler.ServeHTTP(w, r) if w.Header().Get("Set-Cookie") != "" { t.Errorf("TestRotuerSessionSet failed") @@ -809,8 +809,8 @@ func TestRouterSessionSet(t *testing.T) { r, _ = http.NewRequest("GET", "/user", nil) w = httptest.NewRecorder() handler = NewControllerRegister() - handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), - SetRouterSessionOn(false)) + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) handler.ServeHTTP(w, r) if w.Header().Get("Set-Cookie") != "" { t.Errorf("TestRotuerSessionSet failed") @@ -820,8 +820,8 @@ func TestRouterSessionSet(t *testing.T) { r, _ = http.NewRequest("GET", "/user", nil) w = httptest.NewRecorder() handler = NewControllerRegister() - handler.Add("/user", &TestController{}, SetRouterMethods(&TestController{}, "get:Get"), - SetRouterSessionOn(true)) + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) handler.ServeHTTP(w, r) if w.Header().Get("Set-Cookie") == "" { t.Errorf("TestRotuerSessionSet failed") diff --git a/server/web/server.go b/server/web/server.go index 280828ff..6c4eaf9e 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -266,8 +266,12 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { } // Router see HttpServer.Router -func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { - return BeeApp.Router(rootpath, c, opts...) +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { + return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...)) +} + +func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { + return BeeApp.RouterWithOpts(rootpath, c, opts...) } // Router adds a patterned controller handler to BeeApp. @@ -286,7 +290,11 @@ func Router(rootpath string, c ControllerInterface, opts ...ControllerOptions) * // beego.Router("/api/create",&RestController{},"post:CreateFood") // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") -func (app *HttpServer) Router(rootPath string, c ControllerInterface, opts ...ControllerOptions) *HttpServer { +func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer { + return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...)) +} + +func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { app.Handlers.Add(rootPath, c, opts...) return app } diff --git a/server/web/unregroute_test.go b/server/web/unregroute_test.go index 9745dbac..226cffb8 100644 --- a/server/web/unregroute_test.go +++ b/server/web/unregroute_test.go @@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) - handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) - handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, "Test original root", @@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Replace the root path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) + handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) // Test replacement root (expect change) testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) @@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) - handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) - handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Replace the "level1" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) - handler.Add("/level1", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) - handler.Add("/level1/level2", &TestPreUnregController{}, SetRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Replace the "/level1/level2" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, SetRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) + handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) From c51a81222b1ccf8c75ddb10ec12161fefa1e148a Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 3 Jan 2021 19:28:06 +0800 Subject: [PATCH 19/41] Add golint action --- .github/workflows/golangci-lint.yml | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/golangci-lint.yml diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml new file mode 100644 index 00000000..f6224ff1 --- /dev/null +++ b/.github/workflows/golangci-lint.yml @@ -0,0 +1,32 @@ +name: golangci-lint +on: + push: + tags: + - v* + branches: + - master + - main + pull_request: +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: golangci-lint + uses: golangci/golangci-lint-action@v2 + with: + # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. + version: v1.29 + + # Optional: working directory, useful for monorepos +# working-directory: ./ + + # Optional: golangci-lint command line arguments. + args: --timeout=5m + + # Optional: show only new issues if it's a pull request. The default value is `false`. + only-new-issues: true + + # Optional: if set to true then the action will use pre-installed Go + # skip-go-installation: true \ No newline at end of file From 30dbf8fc3a0ad0c2bfcbe7e2de63746201b654bd Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Mon, 4 Jan 2021 16:29:03 +0800 Subject: [PATCH 20/41] 1.support dynamic registration model 2.support aggregete func --- client/orm/db.go | 32 +++++++++++------- client/orm/models.go | 17 +++------- client/orm/models_test.go | 16 +++++++++ client/orm/orm_queryset.go | 7 ++++ client/orm/orm_test.go | 69 ++++++++++++++++++++++++++++++++++++++ client/orm/types.go | 9 +++++ 6 files changed, 124 insertions(+), 26 deletions(-) diff --git a/client/orm/db.go b/client/orm/db.go index 4080f292..c994469f 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -948,9 +948,10 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi val := reflect.ValueOf(container) ind := reflect.Indirect(val) - errTyp := true + unregister := true one := true isPtr := true + name := "" if val.Kind() == reflect.Ptr { fn := "" @@ -963,19 +964,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi case reflect.Struct: isPtr = false fn = getFullName(typ) + name = getTableName(reflect.New(typ)) } } else { fn = getFullName(ind.Type()) + name = getTableName(ind) } - errTyp = fn != mi.fullName + unregister = fn != mi.fullName } - if errTyp { - if one { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) - } else { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) - } + if unregister { + RegisterModel(container) } rlimit := qs.limit @@ -1040,6 +1039,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi if qs.distinct { sqlSelect += " DISTINCT" } + if qs.aggregate != "" { + sels = qs.aggregate + } query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, specifyIndexes, join, where, groupBy, orderBy, limit) @@ -1064,16 +1066,20 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } } + defer rs.Close() + + slice := ind + if unregister { + mi, _ = modelCache.get(name) + tCols = mi.fields.dbcols + colsNum = len(tCols) + } + refs := make([]interface{}, colsNum) for i := range refs { var ref interface{} refs[i] = &ref } - - defer rs.Close() - - slice := ind - var cnt int64 for rs.Next() { if one && cnt == 0 || !one { diff --git a/client/orm/models.go b/client/orm/models.go index 64dfab09..0f07e24d 100644 --- a/client/orm/models.go +++ b/client/orm/models.go @@ -332,10 +332,6 @@ end: // register register models to model cache func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { - if mc.done { - err = fmt.Errorf("register must be run before BootStrap") - return - } for _, model := range models { val := reflect.ValueOf(model) @@ -352,7 +348,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m err = fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) return } - + if val.Elem().Kind() == reflect.Slice { + val = reflect.New(val.Elem().Type().Elem()) + } table := getTableName(val) if prefixOrSuffixStr != "" { @@ -371,8 +369,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m } if _, ok := mc.get(table); ok { - err = fmt.Errorf(" table name `%s` repeat register, must be unique\n", table) - return + return nil } mi := newModelInfo(val) @@ -389,12 +386,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m } } } - - if mi.fields.pk == nil { - err = fmt.Errorf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - return - } - } mi.table = table diff --git a/client/orm/models_test.go b/client/orm/models_test.go index e3f74c0b..5add6e45 100644 --- a/client/orm/models_test.go +++ b/client/orm/models_test.go @@ -255,6 +255,22 @@ func NewTM() *TM { return obj } +type DeptInfo struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + +type UnregisterModel struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + type User struct { ID int `orm:"column(id)"` UserName string `orm:"size(30);unique"` diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 177cfc3a..293e4d29 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -79,6 +79,7 @@ type querySet struct { orm *ormBase ctx context.Context forContext bool + aggregate string } var _ QuerySeter = new(querySet) @@ -323,3 +324,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o.orm = orm return o } + +// aggregate func +func (o querySet) Aggregate(s string) QuerySeter { + o.aggregate = s + return &o +} \ No newline at end of file diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index f1074973..91f2f929 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -205,6 +205,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(Index)) RegisterModel(new(StrPk)) RegisterModel(new(TM)) + RegisterModel(new(DeptInfo)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -232,6 +233,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(Index)) RegisterModel(new(StrPk)) RegisterModel(new(TM)) + RegisterModel(new(DeptInfo)) BootStrap() @@ -333,6 +335,73 @@ func TestTM(t *testing.T) { throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC")) } +func TestUnregisterModel(t *testing.T) { + data := []*DeptInfo{ + { + DeptName: "A", + EmployeeName: "A1", + Salary: 1000, + }, + { + DeptName: "A", + EmployeeName: "A2", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B1", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B2", + Salary: 4000, + }, + { + DeptName: "B", + EmployeeName: "B3", + Salary: 3000, + }, + } + qs := dORM.QueryTable("dept_info") + i, _ := qs.PrepareInsert() + for _, d := range data { + _, err := i.Insert(d) + if err != nil { + throwFail(t, err) + } + } + + f := func() { + var res []UnregisterModel + n, err := dORM.QueryTable("dept_info").All(&res) + throwFail(t, err) + throwFail(t, AssertIs(n, 5)) + throwFail(t, AssertIs(res[0].EmployeeName, "A1")) + + type Sum struct { + DeptName string + Total int + } + var sun []Sum + qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun) + throwFail(t, AssertIs(sun[0].DeptName, "A")) + throwFail(t, AssertIs(sun[0].Total, 3000)) + + type Max struct { + DeptName string + Max float64 + } + var max []Max + qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max) + throwFail(t, AssertIs(max[1].DeptName, "B")) + throwFail(t, AssertIs(max[1].Max, 4000)) + } + for i := 0; i < 5; i++ { + f() + } +} + func TestNullDataTypes(t *testing.T) { d := DataNull{} diff --git a/client/orm/types.go b/client/orm/types.go index cb735ac8..60889ede 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -405,6 +405,15 @@ type QuerySeter interface { // Found int // } RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + // aggregate func. + // for example: + // type result struct { + // DeptName string + // Total int + // } + // var res []result + // o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res) + Aggregate(s string) QuerySeter } // QueryM2Mer model to model query struct From 5575cc1a5ce44f29b05b72db531d24abdc2907e1 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Mon, 4 Jan 2021 20:43:34 +0800 Subject: [PATCH 21/41] Add more golint action parameter --- .github/workflows/golangci-lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index f6224ff1..85b159db 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -23,7 +23,7 @@ jobs: # working-directory: ./ # Optional: golangci-lint command line arguments. - args: --timeout=5m + args: --timeout=5m --print-issued-lines=true --print-linter-name=true --uniq-by-line=true # Optional: show only new issues if it's a pull request. The default value is `false`. only-new-issues: true From d29b0a589ccb90542fe5901d30e72e04f36b3697 Mon Sep 17 00:00:00 2001 From: AllenX2018 Date: Thu, 7 Jan 2021 11:40:32 +0800 Subject: [PATCH 22/41] Fix orm many2many generated table error --- client/orm/models_utils.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/orm/models_utils.go b/client/orm/models_utils.go index 950ca243..b2e5760e 100644 --- a/client/orm/models_utils.go +++ b/client/orm/models_utils.go @@ -109,6 +109,9 @@ func getTableUnique(val reflect.Value) [][]string { // get whether the table needs to be created for the database alias func isApplicableTableForDB(val reflect.Value, db string) bool { + if !val.IsValid() { + return true + } fun := val.MethodByName("IsApplicableTableForDB") if fun.IsValid() { vals := fun.Call([]reflect.Value{reflect.ValueOf(db)}) From 9105402f8ce1ad21523bc913d9dbd48bdfd7230c Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 8 Jan 2021 20:52:36 +0800 Subject: [PATCH 23/41] reverse filter chain sort and add test to ensure the FIFO --- client/httplib/httplib_test.go | 22 +++++++++++++++++++ server/web/filter_chain_test.go | 38 ++++++++++++++++++++++++++++++++- server/web/router.go | 32 ++++++++++++++++++++++----- server/web/server.go | 2 ++ 4 files changed, 88 insertions(+), 6 deletions(-) diff --git a/client/httplib/httplib_test.go b/client/httplib/httplib_test.go index 1763b1b5..b8cd1112 100644 --- a/client/httplib/httplib_test.go +++ b/client/httplib/httplib_test.go @@ -301,6 +301,28 @@ func TestAddFilter(t *testing.T) { assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains)) } +func TestFilterChainOrder(t *testing.T) { + req := Get("http://beego.me") + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("first"), nil + } + }) + + + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("second"), nil + } + }) + + resp, err := req.DoRequestWithCtx(context.Background()) + assert.Nil(t, err) + data := make([]byte, 5) + _, _ = resp.Body.Read(data) + assert.Equal(t, "first", string(data)) +} + func TestHead(t *testing.T) { req := Head("http://beego.me") assert.NotNil(t, req) diff --git a/server/web/filter_chain_test.go b/server/web/filter_chain_test.go index 2a428b78..5dd38fc5 100644 --- a/server/web/filter_chain_test.go +++ b/server/web/filter_chain_test.go @@ -15,9 +15,12 @@ package web import ( + "fmt" "net/http" "net/http/httptest" + "strconv" "testing" + "time" "github.com/stretchr/testify/assert" @@ -36,13 +39,46 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) { ns := NewNamespace("/chain") ns.Get("/*", func(ctx *context.Context) { - ctx.Output.Body([]byte("hello")) + _ = ctx.Output.Body([]byte("hello")) }) r, _ := http.NewRequest("GET", "/chain/user", nil) w := httptest.NewRecorder() + BeeApp.Handlers.Init() BeeApp.Handlers.ServeHTTP(w, r) assert.Equal(t, "filter-chain", w.Header().Get("filter")) } + +func TestControllerRegister_InsertFilterChain_Order(t *testing.T) { + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + r, _ := http.NewRequest("GET", "/abc", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.Init() + BeeApp.Handlers.ServeHTTP(w, r) + first := w.Header().Get("first") + second := w.Header().Get("second") + + ft, _ := strconv.ParseInt(first, 10, 64) + st, _ := strconv.ParseInt(second, 10, 64) + + assert.True(t, st > ft) +} diff --git a/server/web/router.go b/server/web/router.go index 613c4172..9ba078bf 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -139,6 +139,12 @@ func WithRouterSessionOn(sessionOn bool) ControllerOption { } } +type filterChainConfig struct { + pattern string + chain FilterChain + opts []FilterOpt +} + // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree @@ -151,6 +157,9 @@ type ControllerRegister struct { // the filter created by FilterChain chainRoot *FilterRouter + // keep registered chain and build it when serve http + filterChains []filterChainConfig + cfg *Config } @@ -171,11 +180,23 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { }, }, cfg: cfg, + filterChains: make([]filterChainConfig, 0, 4), } res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res } +// Init will be executed when HttpServer start running +func (p *ControllerRegister) Init() { + for i := len(p.filterChains) - 1; i >= 0 ; i -- { + fc := p.filterChains[i] + root := p.chainRoot + filterFunc := fc.chain(root.filterFunc) + p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...) + p.chainRoot.next = root + } +} + // Add controller handler and pattern rules to ControllerRegister. // usage: // default methods is the same name as method @@ -513,12 +534,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // } // } func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { - root := p.chainRoot - filterFunc := chain(root.filterFunc) - opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) - p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) - p.chainRoot.next = root + opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) + p.filterChains = append(p.filterChains, filterChainConfig{ + pattern: pattern, + chain: chain, + opts: opts, + }) } // add Filter into diff --git a/server/web/server.go b/server/web/server.go index 6c4eaf9e..548b906b 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { initBeforeHTTPRun() + // init... app.initAddr(addr) + app.Handlers.Init() addr = app.Cfg.Listen.HTTPAddr From 833d7349216f33a1ca7503913562a65b20372f34 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 9 Jan 2021 10:13:44 +0800 Subject: [PATCH 24/41] add UT and ManagerConfig.Opts(opts ...ManagerConfigOpt) --- server/web/session/session_config.go | 6 + server/web/session/session_config_test.go | 222 ++++++++++++++++++++++ 2 files changed, 228 insertions(+) create mode 100644 server/web/session/session_config_test.go diff --git a/server/web/session/session_config.go b/server/web/session/session_config.go index a1e24ae3..e42247db 100644 --- a/server/web/session/session_config.go +++ b/server/web/session/session_config.go @@ -21,6 +21,12 @@ type ManagerConfig struct { CookieSameSite http.SameSite `json:"cookieSameSite"` } +func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) { + for _, opt := range opts { + opt(c) + } +} + type ManagerConfigOpt func(config *ManagerConfig) func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig { diff --git a/server/web/session/session_config_test.go b/server/web/session/session_config_test.go new file mode 100644 index 00000000..a596c5c6 --- /dev/null +++ b/server/web/session/session_config_test.go @@ -0,0 +1,222 @@ +package session + +import ( + "net/http" + "testing" +) + +func TestCfgCookieLifeTime(t *testing.T) { + value := 8754 + c := NewManagerConfig( + CfgCookieLifeTime(value), + ) + + if c.CookieLifeTime != value { + t.Error() + } +} + +func TestCfgDomain(t *testing.T) { + value := `http://domain.com` + c := NewManagerConfig( + CfgDomain(value), + ) + + if c.Domain != value { + t.Error() + } +} + +func TestCfgSameSite(t *testing.T) { + value := http.SameSiteLaxMode + c := NewManagerConfig( + CfgSameSite(value), + ) + + if c.CookieSameSite != value { + t.Error() + } +} + +func TestCfgSecure(t *testing.T) { + c := NewManagerConfig( + CfgSecure(true), + ) + + if c.Secure != true { + t.Error() + } +} + +func TestCfgSecure1(t *testing.T) { + c := NewManagerConfig( + CfgSecure(false), + ) + + if c.Secure != false { + t.Error() + } +} + +func TestCfgSessionIdPrefix(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSessionIdPrefix(value), + ) + + if c.SessionIDPrefix != value { + t.Error() + } +} + +func TestCfgSetSessionNameInHTTPHeader(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSetSessionNameInHTTPHeader(value), + ) + + if c.SessionNameInHTTPHeader != value { + t.Error() + } +} + +func TestCfgCookieName(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgCookieName(value), + ) + + if c.CookieName != value { + t.Error() + } +} + +func TestCfgEnableSidInURLQuery(t *testing.T) { + c := NewManagerConfig( + CfgEnableSidInURLQuery(true), + ) + + if c.EnableSidInURLQuery != true { + t.Error() + } +} + +func TestCfgGcLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgGcLifeTime(value), + ) + + if c.Gclifetime != value { + t.Error() + } +} + +func TestCfgHTTPOnly(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(true), + ) + + if c.DisableHTTPOnly != false { + t.Error() + } +} + +func TestCfgHTTPOnly2(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(false), + ) + + if c.DisableHTTPOnly != true { + t.Error() + } +} + +func TestCfgMaxLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgMaxLifeTime(value), + ) + + if c.Maxlifetime != value { + t.Error() + } +} + +func TestCfgProviderConfig(t *testing.T) { + value := `asodiuasldkj12i39809as` + c := NewManagerConfig( + CfgProviderConfig(value), + ) + + if c.ProviderConfig != value { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(true), + ) + + if c.EnableSidInHTTPHeader != true { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader1(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(false), + ) + + if c.EnableSidInHTTPHeader != false { + t.Error() + } +} + +func TestCfgSessionIdLength(t *testing.T) { + value := int64(100) + c := NewManagerConfig( + CfgSessionIdLength(value), + ) + + if c.SessionIDLength != value { + t.Error() + } +} + +func TestCfgSetCookie(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(true), + ) + + if c.EnableSetCookie != true { + t.Error() + } +} + +func TestCfgSetCookie1(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(false), + ) + + if c.EnableSetCookie != false { + t.Error() + } +} + +func TestNewManagerConfig(t *testing.T) { + c := NewManagerConfig() + if c == nil { + t.Error() + } +} + +func TestManagerConfig_Opts(t *testing.T) { + c := NewManagerConfig() + c.Opts(CfgSetCookie(true)) + + if c.EnableSetCookie != true { + t.Error() + } +} \ No newline at end of file From 962eac05f62540b416a54325879b207eeb2ec73b Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 9 Jan 2021 11:21:29 +0800 Subject: [PATCH 25/41] add storeKey supports --- server/web/filter/session/filter.go | 46 ++++++---- server/web/filter/session/filter_test.go | 102 ++++++++++++++++++++++- 2 files changed, 130 insertions(+), 18 deletions(-) diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go index b37a9f51..69bfaa9d 100644 --- a/server/web/filter/session/filter.go +++ b/server/web/filter/session/filter.go @@ -9,25 +9,30 @@ import ( webContext "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/session" "github.com/google/uuid" - "sync" ) var ( - sessionKey string - sessionKeyOnce sync.Once + sessionFormatSign = uuid.New().ID() + defaultStorageKey = uuid.New().String() ) -func getSessionKey() string { - - sessionKeyOnce.Do(func() { - //generate an unique session store key - sessionKey = fmt.Sprintf(`sess_store:%d`, uuid.New().ID()) - }) - - return sessionKey +func sessionStoreKey(key string) string { + return fmt.Sprintf( + `sess_%d:%s`, + sessionFormatSign, + key, + ) } -func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { +//Session maintain session for web service +//Session new a session storage and store it into webContext.Context +// +//params: +//ctx: pointer of beego web context +//storeName: set the storage key in ctx.Input +// +//if you want to get session storage, just see GetStore +func Session(providerType session.ProviderType, storeName string, options ...session.ManagerConfigOpt) web.FilterChain { sessionConfig := session.NewManagerConfig(options...) sessionManager, _ := session.NewManager(string(providerType), sessionConfig) go sessionManager.GC() @@ -40,7 +45,7 @@ func Session(providerType session.ProviderType, options ...session.ManagerConfig } else { //release session at the end of request defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) - ctx.Input.SetData(getSessionKey(), sess) + ctx.Input.SetData(sessionStoreKey(storeName), sess) } next(ctx) @@ -49,13 +54,14 @@ func Session(providerType session.ProviderType, options ...session.ManagerConfig } } -func GetStore(ctx *webContext.Context) (store session.Store, err error) { +//GetStore get session storage in beego web context +func GetStore(ctx *webContext.Context, storeName string) (store session.Store, err error) { if ctx == nil { err = errors.New(`ctx is nil`) return } - if s, ok := ctx.Input.GetData(getSessionKey()).(session.Store); ok { + if s, ok := ctx.Input.GetData(sessionStoreKey(storeName)).(session.Store); ok { store = s return } else { @@ -63,3 +69,13 @@ func GetStore(ctx *webContext.Context) (store session.Store, err error) { return } } + +//DefaultSession call Session with default storage key +func DefaultSession(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { + return Session(providerType, defaultStorageKey, options...) +} + +//GetDefaultStore call GetStore with default storage key +func GetDefaultStore(ctx *webContext.Context) (store session.Store, err error) { + return GetStore(ctx, defaultStorageKey) +} diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go index 7b24f6ad..ea4e7a6e 100644 --- a/server/web/filter/session/filter_test.go +++ b/server/web/filter/session/filter_test.go @@ -5,6 +5,7 @@ import ( "github.com/beego/beego/v2/server/web" webContext "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/session" + "github.com/google/uuid" "net/http" "net/http/httptest" "testing" @@ -21,11 +22,13 @@ func testRequest(t *testing.T, handler *web.ControllerRegister, path string, met } func TestSession(t *testing.T) { + storeKey := uuid.New().String() handler := web.NewControllerRegister() handler.InsertFilterChain( "*", Session( session.ProviderMemory, + storeKey, session.CfgCookieName(`go_session_id`), session.CfgSetCookie(true), session.CfgGcLifeTime(3600), @@ -38,7 +41,7 @@ func TestSession(t *testing.T) { "*", func(next web.FilterFunc) web.FilterFunc { return func(ctx *webContext.Context) { - if store := ctx.Input.GetData(getSessionKey()); store == nil { + if store := ctx.Input.GetData(storeKey); store == nil { t.Error(`store should not be nil`) } next(ctx) @@ -53,11 +56,104 @@ func TestSession(t *testing.T) { } func TestGetStore(t *testing.T) { + storeKey := uuid.New().String() handler := web.NewControllerRegister() handler.InsertFilterChain( "*", Session( + session.ProviderMemory, + storeKey, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + var ( + checkKey = `asodiuasdk1j)AS(87` + checkValue = `ASsd-09812-3` + + store session.Store + err error + + c = context.Background() + ) + + if store, err = GetStore(ctx, storeKey); err == nil { + if store == nil { + t.Error(`store should not be nil`) + } else { + _ = store.Set(c, checkKey, checkValue) + } + } else { + t.Error(err) + } + + next(ctx) + + if store != nil { + if v := store.Get(c, checkKey); v != checkValue { + t.Error(v, `is not equals to`, checkValue) + } + } else { + t.Error(`store should not be nil`) + } + + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} + +func TestDefaultSession(t *testing.T) { + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + DefaultSession( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store := ctx.Input.GetData(defaultStorageKey); store == nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} + +func TestGetDefaultStore(t *testing.T) { + handler := web.NewControllerRegister() + + handler.InsertFilterChain( + "*", + DefaultSession( session.ProviderMemory, session.CfgCookieName(`go_session_id`), session.CfgSetCookie(true), @@ -81,7 +177,7 @@ func TestGetStore(t *testing.T) { c = context.Background() ) - if store, err = GetStore(ctx); err == nil { + if store, err = GetDefaultStore(ctx); err == nil { if store == nil { t.Error(`store should not be nil`) } else { @@ -97,7 +193,7 @@ func TestGetStore(t *testing.T) { if v := store.Get(c, checkKey); v != checkValue { t.Error(v, `is not equals to`, checkValue) } - }else{ + } else { t.Error(`store should not be nil`) } From c4b585741a23262a36d6df2f483131e6aecbb893 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 9 Jan 2021 11:27:36 +0800 Subject: [PATCH 26/41] rename params --- server/web/filter/session/filter.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go index 69bfaa9d..2db1db88 100644 --- a/server/web/filter/session/filter.go +++ b/server/web/filter/session/filter.go @@ -29,10 +29,10 @@ func sessionStoreKey(key string) string { // //params: //ctx: pointer of beego web context -//storeName: set the storage key in ctx.Input +//storeKey: set the storage key in ctx.Input // //if you want to get session storage, just see GetStore -func Session(providerType session.ProviderType, storeName string, options ...session.ManagerConfigOpt) web.FilterChain { +func Session(providerType session.ProviderType, storeKey string, options ...session.ManagerConfigOpt) web.FilterChain { sessionConfig := session.NewManagerConfig(options...) sessionManager, _ := session.NewManager(string(providerType), sessionConfig) go sessionManager.GC() @@ -45,7 +45,7 @@ func Session(providerType session.ProviderType, storeName string, options ...ses } else { //release session at the end of request defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) - ctx.Input.SetData(sessionStoreKey(storeName), sess) + ctx.Input.SetData(sessionStoreKey(storeKey), sess) } next(ctx) @@ -55,13 +55,13 @@ func Session(providerType session.ProviderType, storeName string, options ...ses } //GetStore get session storage in beego web context -func GetStore(ctx *webContext.Context, storeName string) (store session.Store, err error) { +func GetStore(ctx *webContext.Context, storeKey string) (store session.Store, err error) { if ctx == nil { err = errors.New(`ctx is nil`) return } - if s, ok := ctx.Input.GetData(sessionStoreKey(storeName)).(session.Store); ok { + if s, ok := ctx.Input.GetData(sessionStoreKey(storeKey)).(session.Store); ok { store = s return } else { From d7b79f23a6afcfc5cdecebea1098eae2f48840c4 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 9 Jan 2021 21:15:48 +0800 Subject: [PATCH 27/41] Add sonar check action --- .github/workflows/sonar.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/sonar.yml diff --git a/.github/workflows/sonar.yml b/.github/workflows/sonar.yml new file mode 100644 index 00000000..9c63eaac --- /dev/null +++ b/.github/workflows/sonar.yml @@ -0,0 +1,19 @@ +on: + # Trigger analysis when pushing in master or pull requests, and when creating + # a pull request. + pull_request: + types: [opened, synchronize, reopened] +name: Main Workflow +jobs: + sonarcloud: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + # Disabling shallow clone is recommended for improving relevancy of reporting + fetch-depth: 0 + - name: SonarCloud Scan + uses: sonarsource/sonarcloud-github-action@master + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} \ No newline at end of file From 0cf3b035ebf0ab74023324aa23e1f5ca5974d5ac Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Fri, 8 Jan 2021 22:54:39 +0800 Subject: [PATCH 28/41] update changelog --- .github/workflows/changelog.yml | 2 +- CHANGELOG.md | 1 + server/web/filter/prometheus/filter_test.go | 2 ++ task/task_test.go | 17 +++++++++++++++++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index 7e4b1032..50e91510 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -8,7 +8,7 @@ on: pull_request: types: [opened, synchronize, reopened, labeled, unlabeled] branches: - - master + - develop jobs: changelog: diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a259efc..14c40055 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ # developing +- Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427) - Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398) - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) - Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385) diff --git a/server/web/filter/prometheus/filter_test.go b/server/web/filter/prometheus/filter_test.go index f00f20e7..be008d41 100644 --- a/server/web/filter/prometheus/filter_test.go +++ b/server/web/filter/prometheus/filter_test.go @@ -18,6 +18,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" @@ -37,4 +38,5 @@ func TestFilterChain(t *testing.T) { ctx.Input.SetData("RouterPattern", "my-route") filter(ctx) assert.True(t, ctx.Input.GetData("invocation").(bool)) + time.Sleep(1 * time.Second) } diff --git a/task/task_test.go b/task/task_test.go index 5e117cbd..c87757ef 100644 --- a/task/task_test.go +++ b/task/task_test.go @@ -109,6 +109,23 @@ func TestTask_Run(t *testing.T) { assert.Equal(t, "Hello, world! 101", l[1].errinfo) } +func TestCrudTask(t *testing.T) { + m := newTaskManager() + m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.DeleteTask("my-task1") + assert.Equal(t, 1, len(m.adminTaskList)) + + m.ClearTask() + assert.Equal(t, 0, len(m.adminTaskList)) +} + func wait(wg *sync.WaitGroup) chan bool { ch := make(chan bool) go func() { From 3150285542bd2339030a92e8cfedf6736038f6ff Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 9 Jan 2021 21:27:30 +0800 Subject: [PATCH 29/41] fix imports --- client/orm/db_tables.go | 4 ++-- client/orm/orm_queryset.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 82b8ab7e..d62d8106 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -16,8 +16,8 @@ package orm import ( "fmt" - "github.com/astaxie/beego/client/orm/clauses" - "github.com/astaxie/beego/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/clauses" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "strings" "time" ) diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 9f7a3c1e..66e1c442 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -17,8 +17,8 @@ package orm import ( "context" "fmt" - "github.com/astaxie/beego/v2/client/orm/hints" - "github.com/astaxie/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/hints" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" ) type colValue struct { From 791c28b1d4a5ec03207d912d550f9cf8e8fec238 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 9 Jan 2021 22:16:21 +0800 Subject: [PATCH 30/41] fix imports --- client/orm/clauses/order_clause/order.go | 2 +- client/orm/orm.go | 2 +- client/orm/orm_conds.go | 2 +- client/orm/orm_test.go | 2 +- client/orm/types.go | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/client/orm/clauses/order_clause/order.go b/client/orm/clauses/order_clause/order.go index 510c5505..e45c2f85 100644 --- a/client/orm/clauses/order_clause/order.go +++ b/client/orm/clauses/order_clause/order.go @@ -1,7 +1,7 @@ package order_clause import ( - "github.com/astaxie/beego/client/orm/clauses" + "github.com/beego/beego/v2/client/orm/clauses" "strings" ) diff --git a/client/orm/orm.go b/client/orm/orm.go index f4422648..f33c4b51 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -58,7 +58,7 @@ import ( "database/sql" "errors" "fmt" - "github.com/astaxie/beego/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "os" "reflect" "time" diff --git a/client/orm/orm_conds.go b/client/orm/orm_conds.go index 99828f3c..0080d53c 100644 --- a/client/orm/orm_conds.go +++ b/client/orm/orm_conds.go @@ -16,7 +16,7 @@ package orm import ( "fmt" - "github.com/astaxie/beego/client/orm/clauses" + "github.com/beego/beego/v2/client/orm/clauses" "strings" ) diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index 1dd3bc8d..f2943739 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -21,7 +21,7 @@ import ( "context" "database/sql" "fmt" - "github.com/astaxie/beego/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "io/ioutil" "math" "os" diff --git a/client/orm/types.go b/client/orm/types.go index 37aa9f3f..b22848fe 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -17,7 +17,7 @@ package orm import ( "context" "database/sql" - "github.com/astaxie/beego/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "reflect" "time" From a98edc03cd321611308b3101a38cdecca30773e1 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 9 Jan 2021 23:39:16 +0800 Subject: [PATCH 31/41] adapt CruSession --- server/web/filter/session/filter.go | 36 ++------- server/web/filter/session/filter_test.go | 96 +----------------------- server/web/session/session.go | 5 +- 3 files changed, 12 insertions(+), 125 deletions(-) diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go index 2db1db88..685f86c1 100644 --- a/server/web/filter/session/filter.go +++ b/server/web/filter/session/filter.go @@ -3,27 +3,12 @@ package session import ( "context" "errors" - "fmt" "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/server/web" webContext "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/session" - "github.com/google/uuid" ) -var ( - sessionFormatSign = uuid.New().ID() - defaultStorageKey = uuid.New().String() -) - -func sessionStoreKey(key string) string { - return fmt.Sprintf( - `sess_%d:%s`, - sessionFormatSign, - key, - ) -} - //Session maintain session for web service //Session new a session storage and store it into webContext.Context // @@ -32,20 +17,23 @@ func sessionStoreKey(key string) string { //storeKey: set the storage key in ctx.Input // //if you want to get session storage, just see GetStore -func Session(providerType session.ProviderType, storeKey string, options ...session.ManagerConfigOpt) web.FilterChain { +func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { sessionConfig := session.NewManagerConfig(options...) sessionManager, _ := session.NewManager(string(providerType), sessionConfig) go sessionManager.GC() return func(next web.FilterFunc) web.FilterFunc { return func(ctx *webContext.Context) { + if ctx.Input.CruSession != nil { + return + } if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil { logs.Warning(`init session error:%s`, err.Error()) } else { //release session at the end of request defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) - ctx.Input.SetData(sessionStoreKey(storeKey), sess) + ctx.Input.CruSession = sess } next(ctx) @@ -55,13 +43,13 @@ func Session(providerType session.ProviderType, storeKey string, options ...sess } //GetStore get session storage in beego web context -func GetStore(ctx *webContext.Context, storeKey string) (store session.Store, err error) { +func GetStore(ctx *webContext.Context) (store session.Store, err error) { if ctx == nil { err = errors.New(`ctx is nil`) return } - if s, ok := ctx.Input.GetData(sessionStoreKey(storeKey)).(session.Store); ok { + if s := ctx.Input.CruSession; s != nil { store = s return } else { @@ -69,13 +57,3 @@ func GetStore(ctx *webContext.Context, storeKey string) (store session.Store, er return } } - -//DefaultSession call Session with default storage key -func DefaultSession(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { - return Session(providerType, defaultStorageKey, options...) -} - -//GetDefaultStore call GetStore with default storage key -func GetDefaultStore(ctx *webContext.Context) (store session.Store, err error) { - return GetStore(ctx, defaultStorageKey) -} diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go index ea4e7a6e..7e38cdc2 100644 --- a/server/web/filter/session/filter_test.go +++ b/server/web/filter/session/filter_test.go @@ -28,7 +28,6 @@ func TestSession(t *testing.T) { "*", Session( session.ProviderMemory, - storeKey, session.CfgCookieName(`go_session_id`), session.CfgSetCookie(true), session.CfgGcLifeTime(3600), @@ -56,14 +55,12 @@ func TestSession(t *testing.T) { } func TestGetStore(t *testing.T) { - storeKey := uuid.New().String() handler := web.NewControllerRegister() handler.InsertFilterChain( "*", Session( session.ProviderMemory, - storeKey, session.CfgCookieName(`go_session_id`), session.CfgSetCookie(true), session.CfgGcLifeTime(3600), @@ -86,98 +83,7 @@ func TestGetStore(t *testing.T) { c = context.Background() ) - if store, err = GetStore(ctx, storeKey); err == nil { - if store == nil { - t.Error(`store should not be nil`) - } else { - _ = store.Set(c, checkKey, checkValue) - } - } else { - t.Error(err) - } - - next(ctx) - - if store != nil { - if v := store.Get(c, checkKey); v != checkValue { - t.Error(v, `is not equals to`, checkValue) - } - } else { - t.Error(`store should not be nil`) - } - - } - }, - ) - handler.Any("*", func(ctx *webContext.Context) { - ctx.Output.SetStatus(200) - }) - - testRequest(t, handler, "/dataset1/resource1", "GET", 200) -} - -func TestDefaultSession(t *testing.T) { - handler := web.NewControllerRegister() - handler.InsertFilterChain( - "*", - DefaultSession( - session.ProviderMemory, - session.CfgCookieName(`go_session_id`), - session.CfgSetCookie(true), - session.CfgGcLifeTime(3600), - session.CfgMaxLifeTime(3600), - session.CfgSecure(false), - session.CfgCookieLifeTime(3600), - ), - ) - handler.InsertFilterChain( - "*", - func(next web.FilterFunc) web.FilterFunc { - return func(ctx *webContext.Context) { - if store := ctx.Input.GetData(defaultStorageKey); store == nil { - t.Error(`store should not be nil`) - } - next(ctx) - } - }, - ) - handler.Any("*", func(ctx *webContext.Context) { - ctx.Output.SetStatus(200) - }) - - testRequest(t, handler, "/dataset1/resource1", "GET", 200) -} - -func TestGetDefaultStore(t *testing.T) { - handler := web.NewControllerRegister() - - handler.InsertFilterChain( - "*", - DefaultSession( - session.ProviderMemory, - session.CfgCookieName(`go_session_id`), - session.CfgSetCookie(true), - session.CfgGcLifeTime(3600), - session.CfgMaxLifeTime(3600), - session.CfgSecure(false), - session.CfgCookieLifeTime(3600), - ), - ) - handler.InsertFilterChain( - "*", - func(next web.FilterFunc) web.FilterFunc { - return func(ctx *webContext.Context) { - var ( - checkKey = `asodiuasdk1j)AS(87` - checkValue = `ASsd-09812-3` - - store session.Store - err error - - c = context.Background() - ) - - if store, err = GetDefaultStore(ctx); err == nil { + if store, err = GetStore(ctx); err == nil { if store == nil { t.Error(`store should not be nil`) } else { diff --git a/server/web/session/session.go b/server/web/session/session.go index ca0407e8..911f45fe 100644 --- a/server/web/session/session.go +++ b/server/web/session/session.go @@ -279,7 +279,10 @@ func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) // it can do gc in times after gc lifetime. func (manager *Manager) GC() { manager.provider.SessionGC(nil) - time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) + ticker := time.NewTicker(time.Duration(manager.config.Gclifetime) * time.Second) + for range ticker.C { + manager.provider.SessionGC(nil) + } } // SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. From 25bf1259c61bea15c04c92228cc15f64dc18dd66 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sat, 9 Jan 2021 23:55:18 +0800 Subject: [PATCH 32/41] add change log --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a259efc..28cdd2f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,4 +3,5 @@ - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) - Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385) - Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385) -- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) \ No newline at end of file +- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) +- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294) \ No newline at end of file From 1f475585e55e5c710f95d25d974d7e44b58b62ec Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 10 Jan 2021 00:33:53 +0800 Subject: [PATCH 33/41] change log level and recover code --- CHANGELOG.md | 3 ++- server/web/filter/session/filter.go | 2 +- server/web/session/session.go | 5 +---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a259efc..6e9c6c1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,4 +3,5 @@ - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) - Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385) - Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385) -- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) \ No newline at end of file +- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) +- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) \ No newline at end of file diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go index 685f86c1..1776d604 100644 --- a/server/web/filter/session/filter.go +++ b/server/web/filter/session/filter.go @@ -29,7 +29,7 @@ func Session(providerType session.ProviderType, options ...session.ManagerConfig } if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil { - logs.Warning(`init session error:%s`, err.Error()) + logs.Error(`init session error:%s`, err.Error()) } else { //release session at the end of request defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) diff --git a/server/web/session/session.go b/server/web/session/session.go index 911f45fe..ca0407e8 100644 --- a/server/web/session/session.go +++ b/server/web/session/session.go @@ -279,10 +279,7 @@ func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) // it can do gc in times after gc lifetime. func (manager *Manager) GC() { manager.provider.SessionGC(nil) - ticker := time.NewTicker(time.Duration(manager.config.Gclifetime) * time.Second) - for range ticker.C { - manager.provider.SessionGC(nil) - } + time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) } // SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. From 2590fbd5cb1276cbd220db08542c5a56478a6b9b Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sat, 9 Jan 2021 21:36:30 +0800 Subject: [PATCH 34/41] Fix sonar check --- .github/workflows/sonar.yml | 13 ++++++++----- CHANGELOG.md | 1 + sonar-project.properties | 6 ++++++ 3 files changed, 15 insertions(+), 5 deletions(-) create mode 100644 sonar-project.properties diff --git a/.github/workflows/sonar.yml b/.github/workflows/sonar.yml index 9c63eaac..1a3b927a 100644 --- a/.github/workflows/sonar.yml +++ b/.github/workflows/sonar.yml @@ -1,9 +1,15 @@ on: # Trigger analysis when pushing in master or pull requests, and when creating # a pull request. + push: + branches: + - develop pull_request: types: [opened, synchronize, reopened] -name: Main Workflow +name: Sonar Check +env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} jobs: sonarcloud: runs-on: ubuntu-latest @@ -13,7 +19,4 @@ jobs: # Disabling shallow clone is recommended for improving relevancy of reporting fetch-depth: 0 - name: SonarCloud Scan - uses: sonarsource/sonarcloud-github-action@master - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} \ No newline at end of file + uses: sonarsource/sonarcloud-github-action@master \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 14c40055..ec59f070 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ # developing +- Add sonar check. [4432](https://github.com/beego/beego/pull/4432) - Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427) - Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398) - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 00000000..2fc78d8d --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,6 @@ +sonar.organization=beego +sonar.projectKey=beego_beego + +# relative paths to source directories. More details and properties are described +# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/ +sonar.sources=. \ No newline at end of file From eed869a29671aa2732a52a36ed650754795e0ac5 Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 10 Jan 2021 13:20:55 +0800 Subject: [PATCH 35/41] Remove sonar --- .github/workflows/sonar.yml | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 .github/workflows/sonar.yml diff --git a/.github/workflows/sonar.yml b/.github/workflows/sonar.yml deleted file mode 100644 index 1a3b927a..00000000 --- a/.github/workflows/sonar.yml +++ /dev/null @@ -1,22 +0,0 @@ -on: - # Trigger analysis when pushing in master or pull requests, and when creating - # a pull request. - push: - branches: - - develop - pull_request: - types: [opened, synchronize, reopened] -name: Sonar Check -env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} -jobs: - sonarcloud: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - with: - # Disabling shallow clone is recommended for improving relevancy of reporting - fetch-depth: 0 - - name: SonarCloud Scan - uses: sonarsource/sonarcloud-github-action@master \ No newline at end of file From 31f79c2ee2baecda21fe17b469739fd8915b82b9 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 10 Jan 2021 13:50:43 +0800 Subject: [PATCH 36/41] move getting session store to web context --- server/web/context/context.go | 17 +++++++ server/web/context/context_test.go | 46 ++++++++++++++++++ server/web/filter/session/filter.go | 26 +--------- server/web/filter/session/filter_test.go | 60 ------------------------ 4 files changed, 64 insertions(+), 85 deletions(-) diff --git a/server/web/context/context.go b/server/web/context/context.go index 6070c996..edbf14f5 100644 --- a/server/web/context/context.go +++ b/server/web/context/context.go @@ -29,6 +29,7 @@ import ( "encoding/base64" "errors" "fmt" + "github.com/beego/beego/v2/server/web/session" "net" "net/http" "strconv" @@ -195,6 +196,22 @@ func (ctx *Context) RenderMethodResult(result interface{}) { } } +// Session return session store of this context of request +func (ctx *Context) Session() (store session.Store,err error){ + if ctx.Input != nil { + if ctx.Input.CruSession != nil { + store = ctx.Input.CruSession + return + } else { + err = errors.New(`no valid session store(please initialize session)`) + return + } + } else { + err = errors.New(`no valid input`) + return + } +} + // Response is a wrapper for the http.ResponseWriter // Started: if true, response was already written to so the other handler will not be executed type Response struct { diff --git a/server/web/context/context_test.go b/server/web/context/context_test.go index 7c0535e0..7fdb310f 100644 --- a/server/web/context/context_test.go +++ b/server/web/context/context_test.go @@ -15,6 +15,9 @@ package context import ( + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/filter/session" + webSession "github.com/beego/beego/v2/server/web/session" "net/http" "net/http/httptest" "testing" @@ -45,3 +48,46 @@ func TestXsrfReset_01(t *testing.T) { t.FailNow() } } + +func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code) + } +} + +func TestContext_Session(t *testing.T) { + handler := web.NewControllerRegister() + + handler.InsertFilterChain( + "*", + session.Session( + webSession.ProviderMemory, + webSession.CfgCookieName(`go_session_id`), + webSession.CfgSetCookie(true), + webSession.CfgGcLifeTime(3600), + webSession.CfgMaxLifeTime(3600), + webSession.CfgSecure(false), + webSession.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *Context) { + if _, err := ctx.Session(); err == nil { + t.Error() + } + + } + }, + ) + handler.Any("*", func(ctx *Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} \ No newline at end of file diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go index 1776d604..bcf9edf4 100644 --- a/server/web/filter/session/filter.go +++ b/server/web/filter/session/filter.go @@ -2,7 +2,6 @@ package session import ( "context" - "errors" "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/server/web" webContext "github.com/beego/beego/v2/server/web/context" @@ -11,12 +10,6 @@ import ( //Session maintain session for web service //Session new a session storage and store it into webContext.Context -// -//params: -//ctx: pointer of beego web context -//storeKey: set the storage key in ctx.Input -// -//if you want to get session storage, just see GetStore func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { sessionConfig := session.NewManagerConfig(options...) sessionManager, _ := session.NewManager(string(providerType), sessionConfig) @@ -37,23 +30,6 @@ func Session(providerType session.ProviderType, options ...session.ManagerConfig } next(ctx) - } } -} - -//GetStore get session storage in beego web context -func GetStore(ctx *webContext.Context) (store session.Store, err error) { - if ctx == nil { - err = errors.New(`ctx is nil`) - return - } - - if s := ctx.Input.CruSession; s != nil { - store = s - return - } else { - err = errors.New(`can not get a valid session store`) - return - } -} +} \ No newline at end of file diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go index 7e38cdc2..03a88afd 100644 --- a/server/web/filter/session/filter_test.go +++ b/server/web/filter/session/filter_test.go @@ -1,7 +1,6 @@ package session import ( - "context" "github.com/beego/beego/v2/server/web" webContext "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/session" @@ -53,62 +52,3 @@ func TestSession(t *testing.T) { testRequest(t, handler, "/dataset1/resource1", "GET", 200) } - -func TestGetStore(t *testing.T) { - handler := web.NewControllerRegister() - - handler.InsertFilterChain( - "*", - Session( - session.ProviderMemory, - session.CfgCookieName(`go_session_id`), - session.CfgSetCookie(true), - session.CfgGcLifeTime(3600), - session.CfgMaxLifeTime(3600), - session.CfgSecure(false), - session.CfgCookieLifeTime(3600), - ), - ) - handler.InsertFilterChain( - "*", - func(next web.FilterFunc) web.FilterFunc { - return func(ctx *webContext.Context) { - var ( - checkKey = `asodiuasdk1j)AS(87` - checkValue = `ASsd-09812-3` - - store session.Store - err error - - c = context.Background() - ) - - if store, err = GetStore(ctx); err == nil { - if store == nil { - t.Error(`store should not be nil`) - } else { - _ = store.Set(c, checkKey, checkValue) - } - } else { - t.Error(err) - } - - next(ctx) - - if store != nil { - if v := store.Get(c, checkKey); v != checkValue { - t.Error(v, `is not equals to`, checkValue) - } - } else { - t.Error(`store should not be nil`) - } - - } - }, - ) - handler.Any("*", func(ctx *webContext.Context) { - ctx.Output.SetStatus(200) - }) - - testRequest(t, handler, "/dataset1/resource1", "GET", 200) -} From 87ffa0b730d4f6fe96a880adb26393bd1dd4b0f3 Mon Sep 17 00:00:00 2001 From: Anker Jam Date: Sun, 10 Jan 2021 14:22:11 +0800 Subject: [PATCH 37/41] fix UT --- server/web/context/context.go | 2 +- server/web/context/context_test.go | 56 +++++++----------------- server/web/filter/session/filter_test.go | 32 ++++++++++++++ 3 files changed, 50 insertions(+), 40 deletions(-) diff --git a/server/web/context/context.go b/server/web/context/context.go index edbf14f5..099729d0 100644 --- a/server/web/context/context.go +++ b/server/web/context/context.go @@ -197,7 +197,7 @@ func (ctx *Context) RenderMethodResult(result interface{}) { } // Session return session store of this context of request -func (ctx *Context) Session() (store session.Store,err error){ +func (ctx *Context) Session() (store session.Store, err error) { if ctx.Input != nil { if ctx.Input.CruSession != nil { store = ctx.Input.CruSession diff --git a/server/web/context/context_test.go b/server/web/context/context_test.go index 7fdb310f..977c3cbf 100644 --- a/server/web/context/context_test.go +++ b/server/web/context/context_test.go @@ -15,9 +15,7 @@ package context import ( - "github.com/beego/beego/v2/server/web" - "github.com/beego/beego/v2/server/web/filter/session" - webSession "github.com/beego/beego/v2/server/web/session" + "github.com/beego/beego/v2/server/web/session" "net/http" "net/http/httptest" "testing" @@ -49,45 +47,25 @@ func TestXsrfReset_01(t *testing.T) { } } -func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) { - r, _ := http.NewRequest(method, path, nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - - if w.Code != code { - t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code) +func TestContext_Session(t *testing.T) { + c := NewContext() + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() } } -func TestContext_Session(t *testing.T) { - handler := web.NewControllerRegister() +func TestContext_Session1(t *testing.T) { + c := Context{} + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() + } +} - handler.InsertFilterChain( - "*", - session.Session( - webSession.ProviderMemory, - webSession.CfgCookieName(`go_session_id`), - webSession.CfgSetCookie(true), - webSession.CfgGcLifeTime(3600), - webSession.CfgMaxLifeTime(3600), - webSession.CfgSecure(false), - webSession.CfgCookieLifeTime(3600), - ), - ) - handler.InsertFilterChain( - "*", - func(next web.FilterFunc) web.FilterFunc { - return func(ctx *Context) { - if _, err := ctx.Session(); err == nil { - t.Error() - } +func TestContext_Session2(t *testing.T) { + c := NewContext() + c.Input.CruSession = &session.MemSessionStore{} - } - }, - ) - handler.Any("*", func(ctx *Context) { - ctx.Output.SetStatus(200) - }) - - testRequest(t, handler, "/dataset1/resource1", "GET", 200) + if store, err := c.Session(); store == nil || err != nil { + t.FailNow() + } } \ No newline at end of file diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go index 03a88afd..687789a5 100644 --- a/server/web/filter/session/filter_test.go +++ b/server/web/filter/session/filter_test.go @@ -52,3 +52,35 @@ func TestSession(t *testing.T) { testRequest(t, handler, "/dataset1/resource1", "GET", 200) } + +func TestSession1(t *testing.T) { + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store, err := ctx.Session(); store == nil || err != nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} From 95e998e36e26a80003575928757f35afde7ebf9a Mon Sep 17 00:00:00 2001 From: Ming Deng Date: Sun, 10 Jan 2021 18:13:00 +0800 Subject: [PATCH 38/41] sonar ignore test --- CHANGELOG.md | 2 +- sonar-project.properties | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bde738f..2d98d10d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ # developing -- Add sonar check. [4432](https://github.com/beego/beego/pull/4432) +- Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433) - Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427) - Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398) - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) diff --git a/sonar-project.properties b/sonar-project.properties index 2fc78d8d..1a12fb33 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -3,4 +3,5 @@ sonar.projectKey=beego_beego # relative paths to source directories. More details and properties are described # in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/ -sonar.sources=. \ No newline at end of file +sonar.sources=. +sonar.exclusions=**/*_test.go \ No newline at end of file From 21777d3143de175360b6f35e55da63ec6a674240 Mon Sep 17 00:00:00 2001 From: Penghui Liao Date: Fri, 8 Jan 2021 19:04:00 +0800 Subject: [PATCH 39/41] Add context support for orm Signed-off-by: Penghui Liao --- client/orm/cmd.go | 7 ++- client/orm/db.go | 114 +++++++++++++++---------------------- client/orm/db_mysql.go | 11 ++-- client/orm/db_oracle.go | 11 ++-- client/orm/db_postgres.go | 9 +-- client/orm/db_sqlite.go | 13 +++-- client/orm/db_tidb.go | 5 +- client/orm/models_test.go | 20 +++---- client/orm/orm.go | 25 ++++---- client/orm/orm_log.go | 17 +++++- client/orm/orm_object.go | 9 ++- client/orm/orm_querym2m.go | 13 +++-- client/orm/orm_queryset.go | 74 ++++++++++++++---------- client/orm/orm_test.go | 20 +++++++ client/orm/types.go | 42 +++++++------- 15 files changed, 215 insertions(+), 175 deletions(-) diff --git a/client/orm/cmd.go b/client/orm/cmd.go index b0661971..d3836828 100644 --- a/client/orm/cmd.go +++ b/client/orm/cmd.go @@ -15,6 +15,7 @@ package orm import ( + "context" "flag" "fmt" "os" @@ -76,6 +77,7 @@ func RunCommand() { // sync database struct command interface. type commandSyncDb struct { + ctx context.Context al *alias force bool verbose bool @@ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error { } var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(db, mi.table) + columns, err := d.al.DbBaser.GetColumns(d.ctx, db, mi.table) if err != nil { if d.rtOnError { return err @@ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error { } for _, idx := range indexes[mi.table] { - if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { + if !d.al.DbBaser.IndexExists(d.ctx, db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } @@ -290,6 +292,7 @@ func RunSyncdb(name string, force bool, verbose bool) error { al := getDbAlias(name) cmd := new(commandSyncDb) + cmd.ctx = context.TODO() cmd.al = al cmd.force = force cmd.noInfo = !verbose diff --git a/client/orm/db.go b/client/orm/db.go index c994469f..a49d6df7 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "errors" "fmt" @@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } // create insert sql preparation statement object. -func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { +func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() dbcols := make([]string, 0, len(mi.fields.dbcols)) @@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, d.ins.HasReturningID(mi, &query) - stmt, err := q.Prepare(query) + stmt, err := q.PrepareContext(ctx, query) return stmt, query, err } // insert struct with prepared statement and given struct reflect value. -func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err @@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, err := row.Scan(&id) return id, err } - res, err := stmt.Exec(values...) + res, err := stmt.ExecContext(ctx, values...) if err == nil { return res.LastInsertId() } @@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } // query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { var whereCols []string var args []interface{} @@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo d.ins.ReplaceMarks(&query) - row := q.QueryRow(query, args...) + row := q.QueryRowContext(ctx, query, args...) if err := row.Scan(refs...); err != nil { if err == sql.ErrNoRows { return ErrNoRows @@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } // execute insert sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { names := make([]string, 0, len(mi.fields.dbcols)) values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return 0, err } - id, err := d.InsertValue(q, mi, false, names, values) + id, err := d.InsertValue(ctx, q, mi, false, names, values) if err != nil { return 0, err } if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return id, err } // multi-insert sql with given slice struct reflect.Value. -func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { +func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 nums int @@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul } if i > 1 && i%bulk == 0 || length == i { - num, err := d.InsertValue(q, mi, true, names, values[:nums]) + num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums]) if err != nil { return cnt, err } @@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul var err error if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return cnt, err @@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err @@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s // InsertOrUpdate a row // If your primary key or unique column conflict will update // If no will insert -func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { args0 := "" iouStr := "" argsMap := map[string]string{} @@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) if err != nil && err.Error() == `pq: syntax error at or near "ON"` { @@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a } // execute update sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if !ok { return 0, ErrMissPK @@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, setValues...) + res, err := q.ExecContext(ctx, query, setValues...) if err == nil { return res.RowsAffected() } @@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // execute delete sql dbQuerier with given struct reflect.Value. // delete index is pk. -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { var whereCols []string var args []interface{} // if specify cols length > 0, then use it for where condition. @@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, args...) + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { @@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) } } - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // update table-related record by querySet. // need querySet not struct reflect.Value to update related records. -func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { +func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { @@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } d.ins.ReplaceMarks(&query) - var err error - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, values...) - } else { - res, err = q.Exec(query, values...) - } + res, err := q.ExecContext(ctx, query, values...) if err == nil { return res.RowsAffected() } @@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con // delete related records. // do UpdateBanch or DeleteBanch by condition of tables' relationship. -func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { +func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo switch fi.onDelete { case odCascade: cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) + _, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz) if err != nil { return err } @@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * if fi.onDelete == odSetDefault { params[fi.column] = fi.initial.String() } - _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) + _, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz) if err != nil { return err } @@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * } // delete table-related records. -func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { +func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true @@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con d.ins.ReplaceMarks(&query) var rs *sql.Rows - r, err := q.Query(query, args...) + r, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) d.ins.ReplaceMarks(&query) - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, args...) - } else { - res, err = q.Exec(query, args...) - } + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if num > 0 { - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -943,7 +933,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } // read related records. -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) @@ -1052,18 +1042,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi d.ins.ReplaceMarks(&query) - var rs *sql.Rows - var err error - if qs != nil && qs.forContext { - rs, err = q.QueryContext(qs.ctx, query, args...) - if err != nil { - return 0, err - } - } else { - rs, err = q.Query(query, args...) - if err != nil { - return 0, err - } + rs, err := q.QueryContext(ctx, query, args...) + if err != nil { + return 0, err } defer rs.Close() @@ -1178,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } // excute count sql and return count result int64. -func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { +func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -1200,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition d.ins.ReplaceMarks(&query) - var row *sql.Row - if qs != nil && qs.forContext { - row = q.QueryRowContext(qs.ctx, query, args...) - } else { - row = q.QueryRow(query, args...) - } + row := q.QueryRowContext(ctx, query, args...) err = row.Scan(&cnt) return } @@ -1655,7 +1631,7 @@ setValue: } // query sql, read values , save to *[]ParamList. -func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { +func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( maps []Params @@ -1738,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond d.ins.ReplaceMarks(&query) - rs, err := q.Query(query, args...) + rs, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -1853,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool { } // sync auto key -func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { return nil } @@ -1898,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { } // get all cloumns in table. -func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { columns := make(map[string][3]string) query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return columns, err } @@ -1940,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string { } // not implement. -func (d *dbBase) IndexExists(dbQuerier, string, string) bool { +func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool { panic(ErrNotImplement) } diff --git a/client/orm/db_mysql.go b/client/orm/db_mysql.go index ee68baf7..c89b1e52 100644 --- a/client/orm/db_mysql.go +++ b/client/orm/db_mysql.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "reflect" "strings" @@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) @@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool // If your primary key or unique column conflict will update // If no will insert // Add "`" for mysql sql building -func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { var iouStr string argsMap := map[string]string{} @@ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) return id, err diff --git a/client/orm/db_oracle.go b/client/orm/db_oracle.go index 1de440b6..a3b93ff3 100644 --- a/client/orm/db_oracle.go +++ b/client/orm/db_oracle.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strings" @@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string { } // check index is exist -func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ +func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) @@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err diff --git a/client/orm/db_postgres.go b/client/orm/db_postgres.go index 12431d6e..b2f321db 100644 --- a/client/orm/db_postgres.go +++ b/client/orm/db_postgres.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strconv" ) @@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { } // sync auto key -func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { if len(autoFields) == 0 { return nil } @@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string mi.table, name, Q, name, Q, Q, mi.table, Q) - if _, err := db.Exec(query); err != nil { + if _, err := db.ExecContext(ctx, query); err != nil { return err } } @@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string { } // check index exist in postgresql. -func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) - row := db.QueryRow(query) + row := db.QueryRowContext(ctx, query) var cnt int row.Scan(&cnt) return cnt > 0 diff --git a/client/orm/db_sqlite.go b/client/orm/db_sqlite.go index aff713a5..6a4b3131 100644 --- a/client/orm/db_sqlite.go +++ b/client/orm/db_sqlite.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "fmt" "reflect" @@ -73,11 +74,11 @@ type dbBaseSqlite struct { var _ dbBaser = new(dbBaseSqlite) // override base db read for update behavior as SQlite does not support syntax -func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { if isForUpdate { DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") } - return d.dbBase.Read(q, mi, ind, tz, cols, false) + return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false) } // get sqlite operator. @@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string { } // get columns in sqlite. -func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return nil, err } @@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { } // check index exist in sqlite. -func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("PRAGMA index_list('%s')", table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { panic(err) } diff --git a/client/orm/db_tidb.go b/client/orm/db_tidb.go index 6020a488..48c5b4e7 100644 --- a/client/orm/db_tidb.go +++ b/client/orm/db_tidb.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" ) @@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) diff --git a/client/orm/models_test.go b/client/orm/models_test.go index 5add6e45..3fd35765 100644 --- a/client/orm/models_test.go +++ b/client/orm/models_test.go @@ -492,45 +492,45 @@ var ( helpinfo = `need driver and source! Default DB Drivers. - + driver: url mysql: https://github.com/go-sql-driver/mysql sqlite3: https://github.com/mattn/go-sqlite3 postgres: https://github.com/lib/pq tidb: https://github.com/pingcap/tidb - + usage: - + go get -u github.com/beego/beego/v2/client/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq go get -u github.com/pingcap/tidb - + #### MySQL mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" go test -v github.com/beego/beego/v2/client/orm - - + + #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' go test -v github.com/beego/beego/v2/client/orm - - + + #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" go test -v github.com/beego/beego/v2/client/orm - + #### TiDB export ORM_DRIVER=tidb export ORM_SOURCE='memory://test/test' go test -v github.com/beego/beego/v2/pgk/orm - + ` ) diff --git a/client/orm/orm.go b/client/orm/orm.go index f33c4b51..37546202 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -136,7 +136,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error { } func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form @@ -145,7 +145,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { } func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist @@ -155,7 +155,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create id, err := o.InsertWithCtx(ctx, md) @@ -180,7 +180,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) { } func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return id, err } @@ -223,7 +223,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac for i := 0; i < sind.Len(); i++ { ind := reflect.Indirect(sind.Index(i)) mi, _ := o.getMiInd(ind.Interface(), false) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return cnt, err } @@ -234,7 +234,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac } } else { mi, _ := o.getMiInd(sind.Index(0).Interface(), false) - return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ) } return cnt, nil } @@ -245,7 +245,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) ( } func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) + id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { return id, err } @@ -262,7 +262,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) + return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols) } // delete model in database @@ -272,7 +272,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) + num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols) if err != nil { return num, err } @@ -297,7 +297,7 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) } - return newQueryM2M(md, o, mi, fi, ind) + return newQueryM2M(ctx, md, o, mi, fi, ind) } // load related models to md model. @@ -470,7 +470,7 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in if qs == nil { panic(fmt.Errorf(" table name: `%s` not exists", name)) } - return + return qs.WithContext(ctx) } // return a raw query seter for raw sql string. @@ -596,9 +596,8 @@ func NewOrm() Ormer { func NewOrmUsingDB(aliasName string) Ormer { if al, ok := dataBaseCache.get(aliasName); ok { return newDBWithAlias(al) - } else { - panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } + panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } // NewOrmWithDB create a new ormer object with specify *sql.DB for query diff --git a/client/orm/orm_log.go b/client/orm/orm_log.go index 61addeb5..8ac373b5 100644 --- a/client/orm/orm_log.go +++ b/client/orm/orm_log.go @@ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error { } func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { + return d.ExecContext(context.Background(), args...) +} + +func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { a := time.Now() - res, err := d.stmt.Exec(args...) + res, err := d.stmt.ExecContext(ctx, args...) debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) return res, err } - func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { + return d.QueryContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { a := time.Now() - res, err := d.stmt.Query(args...) + res, err := d.stmt.QueryContext(ctx, args...) debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) return res, err } func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { + return d.QueryRowContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { a := time.Now() res := d.stmt.QueryRow(args...) debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) diff --git a/client/orm/orm_object.go b/client/orm/orm_object.go index 6f9798d3..7306438a 100644 --- a/client/orm/orm_object.go +++ b/client/orm/orm_object.go @@ -15,12 +15,14 @@ package orm import ( + "context" "fmt" "reflect" ) // an insert queryer struct type insertSet struct { + ctx context.Context mi *modelInfo orm *ormBase stmt stmtQuerier @@ -44,7 +46,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { if name != o.mi.fullName { panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) } - id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) + id, err := o.orm.alias.DbBaser.InsertStmt(o.ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) if err != nil { return id, err } @@ -70,11 +72,12 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { +func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { bi := new(insertSet) + bi.ctx = ctx bi.orm = orm bi.mi = mi - st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi) if err != nil { return nil, err } diff --git a/client/orm/orm_querym2m.go b/client/orm/orm_querym2m.go index 17e1b5d1..1174c598 100644 --- a/client/orm/orm_querym2m.go +++ b/client/orm/orm_querym2m.go @@ -14,10 +14,14 @@ package orm -import "reflect" +import ( + "context" + "reflect" +) // model to model struct type queryM2M struct { + ctx context.Context md interface{} mi *modelInfo fi *fieldInfo @@ -96,7 +100,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } names = append(names, otherNames...) values = append(values, otherValues...) - return dbase.InsertValue(orm.db, mi, true, names, values) + return dbase.InsertValue(o.ctx, orm.db, mi, true, names, values) } // remove models following the origin model relationship @@ -129,12 +133,13 @@ func (o *queryM2M) Count() (int64, error) { var _ QueryM2Mer = new(queryM2M) // create new M2M queryer. -func newQueryM2M(md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { +func newQueryM2M(ctx context.Context, md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { qm2m := new(queryM2M) + qm2m.ctx = ctx qm2m.md = md qm2m.mi = mi qm2m.fi = fi qm2m.ind = ind - qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) + qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).WithContext(ctx).(*querySet) return qm2m } diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 692c24cf..f7a9f5f6 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -17,8 +17,9 @@ package orm import ( "context" "fmt" - "github.com/beego/beego/v2/client/orm/hints" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/hints" ) type colValue struct { @@ -64,22 +65,21 @@ func ColValue(opt operator, value interface{}) interface{} { // real query struct type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []*order_clause.Order - distinct bool - forUpdate bool - useIndex int - indexes []string - orm *ormBase - ctx context.Context - forContext bool - aggregate string + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []*order_clause.Order + distinct bool + forUpdate bool + useIndex int + indexes []string + orm *ormBase + ctx context.Context + aggregate string } var _ QuerySeter = new(querySet) @@ -221,25 +221,36 @@ func (o querySet) GetCond() *Condition { return o.cond } +func (o querySet) getContext() context.Context { + if o.ctx != nil { + return o.ctx + } + return context.Background() +} + // return QuerySeter execution result number func (o *querySet) Count() (int64, error) { - return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + ctx := o.getContext() + return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // check result empty or not after QuerySeter executed func (o *querySet) Exist() bool { - cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + ctx := o.getContext() + cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return cnt > 0 } // execute update with parameters func (o *querySet) Update(values Params) (int64, error) { - return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) + ctx := o.getContext() + return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) } // execute delete func (o *querySet) Delete() (int64, error) { - return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + ctx := o.getContext() + return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // return a insert queryer. @@ -248,20 +259,23 @@ func (o *querySet) Delete() (int64, error) { // i,err := sq.PrepareInsert() // i.Add(&user1{},&user2{}) func (o *querySet) PrepareInsert() (Inserter, error) { - return newInsertSet(o.orm, o.mi) + ctx := o.getContext() + return newInsertSet(ctx, o.orm, o.mi) } // query all data and map to containers. // cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + ctx := o.getContext() + return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } // query one row data and map to containers. // cols means the columns when querying. func (o *querySet) One(container interface{}, cols ...string) error { + ctx := o.getContext() o.limit = 1 - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) if err != nil { return err } @@ -279,19 +293,22 @@ func (o *querySet) One(container interface{}, cols ...string) error { // expres means condition expression. // it converts data to []map[column]value. func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + ctx := o.getContext() + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to [][]interface // it converts data to [][column_index]value func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + ctx := o.getContext() + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to []interface. // it's designed for one row record set, auto change to []value, not [][column]value. func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) + ctx := o.getContext() + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } // query all rows into map[string]interface with specify key and value column name. @@ -325,7 +342,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) // set context to QuerySeter. func (o querySet) WithContext(ctx context.Context) QuerySeter { o.ctx = ctx - o.forContext = true return &o } @@ -341,4 +357,4 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { func (o querySet) Aggregate(s string) QuerySeter { o.aggregate = s return &o -} \ No newline at end of file +} diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index d13a0b65..8b0004a0 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -2820,3 +2820,23 @@ func TestCondition(t *testing.T) { throwFail(t, AssertIs(!cycleFlag, true)) return } + +func TestContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + user := User{UserName: "slene"} + + err := dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, err) + + cancel() + err = dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, AssertIs(err, context.Canceled)) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + + qs := dORM.QueryTable(user).WithContext(ctx) + _, err = qs.Filter("UserName", "slene").Count() + throwFail(t, AssertIs(err, context.Canceled)) +} diff --git a/client/orm/types.go b/client/orm/types.go index dd6c0b95..da1062d8 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -236,6 +236,8 @@ type Inserter interface { // QuerySeter query seter type QuerySeter interface { + // add query context for querySeter + WithContext(context.Context) QuerySeter // add condition expression to QuerySeter. // for example: // filter by UserName == 'slene' @@ -539,11 +541,11 @@ type RawSeter interface { type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) - // ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) - // QueryContext(args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) QueryRow(args ...interface{}) *sql.Row - // QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row } // db querier @@ -580,28 +582,28 @@ type txEnder interface { // base database struct type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) - InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) - InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) + InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) SupportUpdateJoin() bool OperatorSQL(string) string GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorLeftCol(*fieldInfo, string, *string) - PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) + PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) @@ -610,12 +612,12 @@ type dbBaser interface { TimeToDB(*time.Time, *time.Location) DbTypes() map[string]string GetTables(dbQuerier) (map[string]bool, error) - GetColumns(dbQuerier, string) (map[string][3]string, error) + GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error) ShowTablesQuery() string ShowColumnsQuery(string) string - IndexExists(dbQuerier, string, string) bool + IndexExists(context.Context, dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) - setval(dbQuerier, *modelInfo, []string) error + setval(context.Context, dbQuerier, *modelInfo, []string) error GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string } From ef227bf46734173bed3623fc8d5703e89e7fa972 Mon Sep 17 00:00:00 2001 From: Penghui Liao Date: Sun, 10 Jan 2021 13:59:01 +0800 Subject: [PATCH 40/41] Deprecate QueryM2MWithCtx and QueryTableWithCtx - Add methods with `WithCtx` suffix and remove ctx fileld from QueryStter and QueryM2M. - Deprecate QueryTableWithCtx and QueryM2MWithCtx. Signed-off-by: Penghui Liao --- CHANGELOG.md | 3 +- client/orm/cmd.go | 7 ++- client/orm/do_nothing_orm.go | 2 + client/orm/do_nothing_orm_test.go | 2 - client/orm/filter/opentracing/filter.go | 2 +- client/orm/filter/prometheus/filter.go | 2 +- client/orm/filter_orm_decorator.go | 31 +++++++----- client/orm/filter_orm_decorator_test.go | 4 +- client/orm/orm.go | 22 +++++---- client/orm/orm_querym2m.go | 34 +++++++++---- client/orm/orm_queryset.go | 64 +++++++++++++++---------- client/orm/orm_test.go | 4 +- client/orm/types.go | 21 +++++++- 13 files changed, 130 insertions(+), 68 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 222046bc..717bee62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,4 +7,5 @@ - Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385) - Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) - Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294) -- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) \ No newline at end of file +- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) +- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424) diff --git a/client/orm/cmd.go b/client/orm/cmd.go index d3836828..b377a5f2 100644 --- a/client/orm/cmd.go +++ b/client/orm/cmd.go @@ -77,7 +77,6 @@ func RunCommand() { // sync database struct command interface. type commandSyncDb struct { - ctx context.Context al *alias force bool verbose bool @@ -143,6 +142,7 @@ func (d *commandSyncDb) Run() error { fmt.Printf(" %s\n", err.Error()) } + ctx := context.Background() for i, mi := range modelCache.allOrdered() { if !isApplicableTableForDB(mi.addrField, d.al.Name) { @@ -156,7 +156,7 @@ func (d *commandSyncDb) Run() error { } var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(d.ctx, db, mi.table) + columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table) if err != nil { if d.rtOnError { return err @@ -190,7 +190,7 @@ func (d *commandSyncDb) Run() error { } for _, idx := range indexes[mi.table] { - if !d.al.DbBaser.IndexExists(d.ctx, db, idx.Table, idx.Name) { + if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } @@ -292,7 +292,6 @@ func RunSyncdb(name string, force bool, verbose bool) error { al := getDbAlias(name) cmd := new(commandSyncDb) - cmd.ctx = context.TODO() cmd.al = al cmd.force = force cmd.noInfo = !verbose diff --git a/client/orm/do_nothing_orm.go b/client/orm/do_nothing_orm.go index c6da420d..59ffe877 100644 --- a/client/orm/do_nothing_orm.go +++ b/client/orm/do_nothing_orm.go @@ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { return nil } @@ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { return nil } diff --git a/client/orm/do_nothing_orm_test.go b/client/orm/do_nothing_orm_test.go index 4d477353..e10f70af 100644 --- a/client/orm/do_nothing_orm_test.go +++ b/client/orm/do_nothing_orm_test.go @@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, o.Driver()) - assert.Nil(t, o.QueryM2MWithCtx(nil, nil, "")) assert.Nil(t, o.QueryM2M(nil, "")) assert.Nil(t, o.ReadWithCtx(nil, nil)) assert.Nil(t, o.Read(nil)) @@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int64(0), i) - assert.Nil(t, o.QueryTableWithCtx(nil, nil)) assert.Nil(t, o.QueryTable(nil)) assert.Nil(t, o.Read(nil)) diff --git a/client/orm/filter/opentracing/filter.go b/client/orm/filter/opentracing/filter.go index 75852c63..7afa07f1 100644 --- a/client/orm/filter/opentracing/filter.go +++ b/client/orm/filter/opentracing/filter.go @@ -27,7 +27,7 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to trace QuerySetter -// actually we trace invoking "QueryTable" and "QueryTableWithCtx" +// actually we trace invoking "QueryTable" // the method Begin*, Commit and Rollback are ignored. // When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. type FilterChainBuilder struct { diff --git a/client/orm/filter/prometheus/filter.go b/client/orm/filter/prometheus/filter.go index db60876e..e68e9670 100644 --- a/client/orm/filter/prometheus/filter.go +++ b/client/orm/filter/prometheus/filter.go @@ -31,7 +31,7 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to records the metrics of QuerySetter -// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx" +// actually we only records metrics of invoking "QueryTable" type FilterChainBuilder struct { summaryVec prometheus.ObserverVec AppName string diff --git a/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go index a60390a1..caf2b3f9 100644 --- a/client/orm/filter_orm_decorator.go +++ b/client/orm/filter_orm_decorator.go @@ -20,6 +20,7 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/core/utils" ) @@ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac } func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { - return f.QueryM2MWithCtx(context.Background(), md, name) -} - -func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, _ := modelCache.getByMd(md) inv := &Invocation{ - Method: "QueryM2MWithCtx", + Method: "QueryM2M", Args: []interface{}{md, name}, Md: md, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, f: func(c context.Context) []interface{} { - res := f.ormer.QueryM2MWithCtx(c, md, name) + res := f.ormer.QueryM2M(md, name) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil } return res[0].(QueryM2Mer) } -func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { - return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName) +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.") + return f.QueryM2M(md, name) } -func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { +func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { var ( name string md interface{} @@ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT } inv := &Invocation{ - Method: "QueryTableWithCtx", + Method: "QueryTable", Args: []interface{}{ptrStructOrTableName}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, Md: md, mi: mi, f: func(c context.Context) []interface{} { - res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) + res := f.ormer.QueryTable(ptrStructOrTableName) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil @@ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT return res[0].(QuerySeter) } +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.") + return f.QueryTable(ptrStructOrTableName) +} + func (f *filterOrmDecorator) DBStats() *sql.DBStats { inv := &Invocation{ Method: "DBStats", diff --git a/client/orm/filter_orm_decorator_test.go b/client/orm/filter_orm_decorator_test.go index 9e223358..566499dd 100644 --- a/client/orm/filter_orm_decorator_test.go +++ b/client/orm/filter_orm_decorator_test.go @@ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryM2MWithCtx", inv.Method) + assert.Equal(t, "QueryM2M", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) @@ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryTableWithCtx", inv.Method) + assert.Equal(t, "QueryTable", inv.Method) assert.Equal(t, 1, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) diff --git a/client/orm/orm.go b/client/orm/orm.go index 37546202..660f2939 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -284,9 +284,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str // create a models to models queryer func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { - return o.QueryM2MWithCtx(context.Background(), md, name) -} -func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -297,7 +294,13 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) } - return newQueryM2M(ctx, md, o, mi, fi, ind) + return newQueryM2M(md, o, mi, fi, ind) +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.") + return o.QueryM2M(md, name) } // load related models to md model. @@ -452,9 +455,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { - return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName) -} -func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { name = nameStrategyMap[defaultNameStrategy](table) @@ -470,7 +470,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in if qs == nil { panic(fmt.Errorf(" table name: `%s` not exists", name)) } - return qs.WithContext(ctx) + return qs +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.") + return o.QueryTable(ptrStructOrTableName) } // return a raw query seter for raw sql string. diff --git a/client/orm/orm_querym2m.go b/client/orm/orm_querym2m.go index 1174c598..9da49bba 100644 --- a/client/orm/orm_querym2m.go +++ b/client/orm/orm_querym2m.go @@ -21,7 +21,6 @@ import ( // model to model struct type queryM2M struct { - ctx context.Context md interface{} mi *modelInfo fi *fieldInfo @@ -37,6 +36,10 @@ type queryM2M struct { // // make sure the relation is defined in post model struct tag. func (o *queryM2M) Add(mds ...interface{}) (int64, error) { + return o.AddWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi mi := fi.relThroughModelInfo mfi := fi.reverseFieldInfo @@ -100,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } names = append(names, otherNames...) values = append(values, otherValues...) - return dbase.InsertValue(o.ctx, orm.db, mi, true, names, values) + return dbase.InsertValue(ctx, orm.db, mi, true, names, values) } // remove models following the origin model relationship func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { + return o.RemoveWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) @@ -113,33 +120,44 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { // check model is existed in relationship of origin model func (o *queryM2M) Exist(md interface{}) bool { + return o.ExistWithCtx(context.Background(), md) +} + +func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md). - Filter(fi.reverseFieldInfoTwo.name, md).Exist() + Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx) } // clean all models in related of origin model func (o *queryM2M) Clear() (int64, error) { + return o.ClearWithCtx(context.Background()) +} + +func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx) } // count all related models of origin model func (o *queryM2M) Count() (int64, error) { + return o.CountWithCtx(context.Background()) +} + +func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx) } var _ QueryM2Mer = new(queryM2M) // create new M2M queryer. -func newQueryM2M(ctx context.Context, md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { +func newQueryM2M(md interface{}, o *ormBase, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { qm2m := new(queryM2M) - qm2m.ctx = ctx qm2m.md = md qm2m.mi = mi qm2m.fi = fi qm2m.ind = ind - qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).WithContext(ctx).(*querySet) + qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) return qm2m } diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index f7a9f5f6..9f7b8441 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -78,7 +78,6 @@ type querySet struct { useIndex int indexes []string orm *ormBase - ctx context.Context aggregate string } @@ -221,35 +220,40 @@ func (o querySet) GetCond() *Condition { return o.cond } -func (o querySet) getContext() context.Context { - if o.ctx != nil { - return o.ctx - } - return context.Background() -} - // return QuerySeter execution result number func (o *querySet) Count() (int64, error) { - ctx := o.getContext() + return o.CountWithCtx(context.Background()) +} + +func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) { return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // check result empty or not after QuerySeter executed func (o *querySet) Exist() bool { - ctx := o.getContext() + return o.ExistWithCtx(context.Background()) +} + +func (o *querySet) ExistWithCtx(ctx context.Context) bool { cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return cnt > 0 } // execute update with parameters func (o *querySet) Update(values Params) (int64, error) { - ctx := o.getContext() + return o.UpdateWithCtx(context.Background(), values) +} + +func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) { return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) } // execute delete func (o *querySet) Delete() (int64, error) { - ctx := o.getContext() + return o.DeleteWithCtx(context.Background()) +} + +func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) { return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } @@ -259,21 +263,30 @@ func (o *querySet) Delete() (int64, error) { // i,err := sq.PrepareInsert() // i.Add(&user1{},&user2{}) func (o *querySet) PrepareInsert() (Inserter, error) { - ctx := o.getContext() + return o.PrepareInsertWithCtx(context.Background()) +} + +func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) { return newInsertSet(ctx, o.orm, o.mi) } // query all data and map to containers. // cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - ctx := o.getContext() + return o.AllWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } // query one row data and map to containers. // cols means the columns when querying. func (o *querySet) One(container interface{}, cols ...string) error { - ctx := o.getContext() + return o.OneWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { o.limit = 1 num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) if err != nil { @@ -293,21 +306,30 @@ func (o *querySet) One(container interface{}, cols ...string) error { // expres means condition expression. // it converts data to []map[column]value. func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - ctx := o.getContext() + return o.ValuesWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to [][]interface // it converts data to [][column_index]value func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - ctx := o.getContext() + return o.ValuesListWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to []interface. // it's designed for one row record set, auto change to []value, not [][column]value. func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - ctx := o.getContext() + return o.ValuesFlatWithCtx(context.Background(), result, expr) +} + +func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) { return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } @@ -339,12 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) panic(ErrNotImplement) } -// set context to QuerySeter. -func (o querySet) WithContext(ctx context.Context) QuerySeter { - o.ctx = ctx - return &o -} - // create new QuerySeter. func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o := new(querySet) diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index 8b0004a0..e2e25ac4 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -2836,7 +2836,7 @@ func TestContextCanceled(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) cancel() - qs := dORM.QueryTable(user).WithContext(ctx) - _, err = qs.Filter("UserName", "slene").Count() + qs := dORM.QueryTable(user) + _, err = qs.Filter("UserName", "slene").CountWithCtx(ctx) throwFail(t, AssertIs(err, context.Canceled)) } diff --git a/client/orm/types.go b/client/orm/types.go index da1062d8..203f057a 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -197,12 +197,16 @@ type DQL interface { // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") QueryM2M(md interface{}, name string) QueryM2Mer + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), QueryTable(ptrStructOrTableName interface{}) QuerySeter + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter DBStats() *sql.DBStats @@ -236,8 +240,6 @@ type Inserter interface { // QuerySeter query seter type QuerySeter interface { - // add query context for querySeter - WithContext(context.Context) QuerySeter // add condition expression to QuerySeter. // for example: // filter by UserName == 'slene' @@ -352,9 +354,11 @@ type QuerySeter interface { // for example: // num, err = qs.Filter("profile__age__gt", 28).Count() Count() (int64, error) + CountWithCtx(context.Context) (int64, error) // check result empty or not after QuerySeter executed // the same as QuerySeter.Count > 0 Exist() bool + ExistWithCtx(context.Context) bool // execute update with parameters // for example: // num, err = qs.Filter("user_name", "slene").Update(Params{ @@ -364,11 +368,13 @@ type QuerySeter interface { // "user_name": "slene2" // }) // user slene's name will change to slene2 Update(values Params) (int64, error) + UpdateWithCtx(ctx context.Context, values Params) (int64, error) // delete from table // for example: // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // //delete two user who's name is testing1 or testing2 Delete() (int64, error) + DeleteWithCtx(context.Context) (int64, error) // return a insert queryer. // it can be used in times. // example: @@ -377,18 +383,21 @@ type QuerySeter interface { // num, err = i.Insert(&user2) // user table will add one record user2 at once // err = i.Close() //don't forget call Close PrepareInsert() (Inserter, error) + PrepareInsertWithCtx(context.Context) (Inserter, error) // query all data and map to containers. // cols means the columns when querying. // for example: // var users []*User // qs.All(&users) // users[0],users[1],users[2] ... All(container interface{}, cols ...string) (int64, error) + AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) // query one row data and map to containers. // cols means the columns when querying. // for example: // var user User // qs.One(&user) //user.UserName == "slene" One(container interface{}, cols ...string) error + OneWithCtx(ctx context.Context, container interface{}, cols ...string) error // query all data and map to []map[string]interface. // expres means condition expression. // it converts data to []map[column]value. @@ -396,18 +405,21 @@ type QuerySeter interface { // var maps []Params // qs.Values(&maps) //maps[0]["UserName"]=="slene" Values(results *[]Params, exprs ...string) (int64, error) + ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) // query all data and map to [][]interface // it converts data to [][column_index]value // for example: // var list []ParamsList // qs.ValuesList(&list) // list[0][1] == "slene" ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) // query all data and map to []interface. // it's designed for one column record set, auto change to []value, not [][column]value. // for example: // var list ParamsList // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" ValuesFlat(result *ParamsList, expr string) (int64, error) + ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) // query all rows into map[string]interface with specify key and value column name. // keyCol = "name", valueCol = "value" // table data @@ -456,18 +468,23 @@ type QueryM2Mer interface { // insert one or more rows to m2m table // make sure the relation is defined in post model struct tag. Add(...interface{}) (int64, error) + AddWithCtx(context.Context, ...interface{}) (int64, error) // remove models following the origin model relationship // only delete rows from m2m table // for example: // tag3 := &Tag{Id:5,Name: "TestTag3"} // num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) + RemoveWithCtx(context.Context, ...interface{}) (int64, error) // check model is existed in relationship of origin model Exist(interface{}) bool + ExistWithCtx(context.Context, interface{}) bool // clean all models in related of origin model Clear() (int64, error) + ClearWithCtx(context.Context) (int64, error) // count all related models of origin model Count() (int64, error) + CountWithCtx(context.Context) (int64, error) } // RawPreparer raw query statement From c3b6c01c1505391121cea69c7404f797c7ff3493 Mon Sep 17 00:00:00 2001 From: Penghui Liao Date: Sun, 10 Jan 2021 23:46:24 +0800 Subject: [PATCH 41/41] Add InsertWithCtx method on Inserter interface. Signed-off-by: Penghui Liao --- client/orm/orm_object.go | 8 +++++--- client/orm/types.go | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/client/orm/orm_object.go b/client/orm/orm_object.go index 7306438a..50c1ca41 100644 --- a/client/orm/orm_object.go +++ b/client/orm/orm_object.go @@ -22,7 +22,6 @@ import ( // an insert queryer struct type insertSet struct { - ctx context.Context mi *modelInfo orm *ormBase stmt stmtQuerier @@ -33,6 +32,10 @@ var _ Inserter = new(insertSet) // insert model ignore it's registered or not. func (o *insertSet) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} + +func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { if o.closed { return 0, ErrStmtClosed } @@ -46,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { if name != o.mi.fullName { panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) } - id, err := o.orm.alias.DbBaser.InsertStmt(o.ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) + id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) if err != nil { return id, err } @@ -74,7 +77,6 @@ func (o *insertSet) Close() error { // create new insert queryer. func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { bi := new(insertSet) - bi.ctx = ctx bi.orm = orm bi.mi = mi st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi) diff --git a/client/orm/types.go b/client/orm/types.go index 203f057a..59eb9055 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -17,10 +17,10 @@ package orm import ( "context" "database/sql" - "github.com/beego/beego/v2/client/orm/clauses/order_clause" "reflect" "time" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/core/utils" ) @@ -235,6 +235,7 @@ type TxOrmer interface { // Inserter insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) + InsertWithCtx(context.Context, interface{}) (int64, error) Close() error }