diff --git a/.deepsource.toml b/.deepsource.toml new file mode 100644 index 00000000..7901d2d2 --- /dev/null +++ b/.deepsource.toml @@ -0,0 +1,12 @@ +version = 1 + +test_patterns = ["**/*_test.go"] + +exclude_patterns = ["scripts/**"] + +[[analyzers]] +name = "go" +enabled = true + + [analyzers.meta] + import_paths = ["github.com/beego/beego"] diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 00000000..50e91510 --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,34 @@ +# This action requires that any PR targeting the master branch should touch at +# least one CHANGELOG file. If a CHANGELOG entry is not required, add the "Skip +# Changelog" label to disable this action. + +name: changelog + +on: + pull_request: + types: [opened, synchronize, reopened, labeled, unlabeled] + branches: + - develop + +jobs: + changelog: + runs-on: ubuntu-latest + if: "!contains(github.event.pull_request.labels.*.name, 'Skip Changelog')" + + steps: + - uses: actions/checkout@v2 + + - name: Check for CHANGELOG changes + run: | + # Only the latest commit of the feature branch is available + # automatically. To diff with the base branch, we need to + # fetch that too (and we only need its latest commit). + git fetch origin ${{ github.base_ref }} --depth=1 + if [[ $(git diff --name-only FETCH_HEAD | grep CHANGELOG) ]] + then + echo "A CHANGELOG was modified. Looks good!" + else + echo "No CHANGELOG was modified." + echo "Please add a CHANGELOG entry, or add the \"Skip Changelog\" label if not required." + false + fi \ No newline at end of file diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml new file mode 100644 index 00000000..85b159db --- /dev/null +++ b/.github/workflows/golangci-lint.yml @@ -0,0 +1,32 @@ +name: golangci-lint +on: + push: + tags: + - v* + branches: + - master + - main + pull_request: +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: golangci-lint + uses: golangci/golangci-lint-action@v2 + with: + # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. + version: v1.29 + + # Optional: working directory, useful for monorepos +# working-directory: ./ + + # Optional: golangci-lint command line arguments. + args: --timeout=5m --print-issued-lines=true --print-linter-name=true --uniq-by-line=true + + # Optional: show only new issues if it's a pull request. The default value is `false`. + only-new-issues: true + + # Optional: if set to true then the action will use pre-installed Go + # skip-go-installation: true \ No newline at end of file diff --git a/.gitignore b/.gitignore index 304c4b73..0306c438 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ _beeTmp2/ pkg/_beeTmp/ pkg/_beeTmp2/ test/tmp/ + +profile.out diff --git a/.travis.yml b/.travis.yml index 8e495e66..f625016e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,6 +16,7 @@ env: - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" before_install: + - export CODECOV_TOKEN="4f4bc484-32a8-43b7-9f48-20966bd48ceb" # link the local repo with ${GOPATH}/src// - GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*} # relies on GOPATH to contain only one directory... @@ -55,30 +56,9 @@ before_install: - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.serialize.name test" - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put sub.sub.key1 sub.sub.key" install: - - go get github.com/lib/pq - - go get github.com/go-sql-driver/mysql - - go get github.com/mattn/go-sqlite3 - - go get github.com/bradfitz/gomemcache/memcache - - go get github.com/gomodule/redigo/redis - - go get github.com/beego/x2j - - go get github.com/couchbase/go-couchbase - - go get github.com/beego/goyaml2 - - go get gopkg.in/yaml.v2 - - go get github.com/belogik/goes - - go get github.com/ledisdb/ledisdb - - go get github.com/ssdb/gossdb/ssdb - - go get github.com/cloudflare/golz4 - - go get github.com/gogo/protobuf/proto - - go get github.com/Knetic/govaluate - - go get github.com/casbin/casbin - - go get github.com/elazarl/go-bindata-assetfs - - go get github.com/OwnLocal/goes - - go get github.com/shiena/ansicolor - go get -u honnef.co/go/tools/cmd/staticcheck - go get -u github.com/mdempsky/unconvert - go get -u github.com/gordonklaus/ineffassign - - go get -u golang.org/x/lint/golint - - go get -u github.com/go-redis/redis before_script: # - @@ -87,19 +67,19 @@ before_script: - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" - - sh -c "go get github.com/golang/lint/golint; golint ./...;" - sh -c "go list ./... | grep -v vendor | xargs go vet -v" - mkdir -p res/var - ./ssdb/ssdb-server ./ssdb/ssdb.conf -d after_script: - killall -w ssdb-server - rm -rf ./res/var/* +after_success: + - bash <(curl -s https://codecov.io/bash) script: - - go test ./... + - GO111MODULE=on go test -coverprofile=coverage.txt -covermode=atomic ./... - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./ - unconvert $(go list ./... | grep -v /vendor/) - ineffassign . - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s - - golint ./... addons: postgresql: "9.6" diff --git a/CHANGELOG.md b/CHANGELOG.md index fc786fb1..fd275a2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,40 @@ # developing +- Error codes definition of cache module. [4493](https://github.com/beego/beego/pull/4493) +- Remove generateCommentRoute http hook. Using `bee generate routers` commands instead.[4486](https://github.com/beego/beego/pull/4486) [bee PR 762](https://github.com/beego/bee/pull/762) +- Fix: /abc.html/aaa match /abc/aaa. [4459](https://github.com/beego/beego/pull/4459) +- ORM mock. [4407](https://github.com/beego/beego/pull/4407) +- Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433) +- Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427) +- Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398) +- Support `RollbackUnlessCommit` API. [4542](https://github.com/beego/beego/pull/4542) - Fix 4503 and 4504: Add `when` to `Write([]byte)` method and add `prefix` to `writeMsg`. [4507](https://github.com/beego/beego/pull/4507) - Fix 4480: log format incorrect. [4482](https://github.com/beego/beego/pull/4482) - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) - Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385) - Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385) - Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) +- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294) +- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) +- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) +- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424) +- Finish timeout option for tasks #4441 [4441](https://github.com/beego/beego/pull/4441) +- Error Module brief design & using httplib module to validate this design. [4453](https://github.com/beego/beego/pull/4453) - Fix 4444: panic when 404 not found. [4446](https://github.com/beego/beego/pull/4446) -- Fix 4435: fix panic when controller dir not found. [4452](https://github.com/beego/beego/pull/4452) \ No newline at end of file +- Fix 4435: fix panic when controller dir not found. [4452](https://github.com/beego/beego/pull/4452) +- Fix 4456: Fix router method expression [4456](https://github.com/beego/beego/pull/4456) +- Remove some `go get` lines in `.travis.yml` file [4469](https://github.com/beego/beego/pull/4469) +- Fix 4451: support QueryExecutor interface. [4461](https://github.com/beego/beego/pull/4461) +- Add some testing scripts [4461](https://github.com/beego/beego/pull/4461) +- Refactor httplib: Move debug code to a filter [4440](https://github.com/beego/beego/issues/4440) +- fix: code quality issues [4513](https://github.com/beego/beego/pull/4513) +- Optimize maligned structs to reduce memory foot-print [4525](https://github.com/beego/beego/pull/4525) +- Feat: add token bucket ratelimit filter [4508](https://github.com/beego/beego/pull/4508) +- Improve: Avoid ignoring mistakes that need attention [4548](https://github.com/beego/beego/pull/4548) +- Integration: DeepSource [4560](https://github.com/beego/beego/pull/4560) + + + +## Fix Sonar +- [4473](https://github.com/beego/beego/pull/4473) +- [4474](https://github.com/beego/beego/pull/4474) +- [4479](https://github.com/beego/beego/pull/4479) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2f9189e2..83a7eaea 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,7 +36,7 @@ We provide docker compose file to start all middlewares. You can run: ```shell script -docker-compose -f scripts/test_docker_compose.yml up -d +docker-compose -f scripts/test_docker_compose.yaml up -d ``` Unit tests read addresses from environment, here is an example: @@ -53,7 +53,7 @@ export SSDB_ADDR="192.168.0.105:8888" ### Pull requests -First of all. beego follow the gitflow. So please send you pull request to **develop-2** branch. We will close the pull +First of all. beego follow the gitflow. So please send you pull request to **develop** branch. We will close the pull request to master branch. We are always happy to receive pull requests, and do our best to review them as fast as possible. Not sure if that typo diff --git a/ERROR_SPECIFICATION.md b/ERROR_SPECIFICATION.md new file mode 100644 index 00000000..68a04bd1 --- /dev/null +++ b/ERROR_SPECIFICATION.md @@ -0,0 +1,5 @@ +# Error Module + +## Module code +- httplib 1 +- cache 2 \ No newline at end of file diff --git a/README.md b/README.md index 85b3df2a..aa023fc7 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/beego/beego/v2?status.svg)](http://godoc.org/github.com/beego/beego/v2) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) [![Go Report Card](https://goreportcard.com/badge/github.com/beego/beego/v2)](https://goreportcard.com/report/github.com/beego/beego/v2) +# Beego [![Build Status](https://travis-ci.org/beego/beego.svg?branch=master)](https://travis-ci.org/beego/beego) [![GoDoc](http://godoc.org/github.com/beego/beego?status.svg)](http://godoc.org/github.com/beego/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) [![Go Report Card](https://goreportcard.com/badge/github.com/beego/beego)](https://goreportcard.com/report/github.com/beego/beego) Beego is used for rapid development of enterprise application in Go, including RESTful APIs, web apps and backend services. @@ -19,7 +19,7 @@ Beego is compos of four parts: ## Quick Start -[Officail website](http://beego.me) +[Official website](http://beego.me) [Example](https://github.com/beego/beego-example) @@ -40,7 +40,7 @@ Beego is compos of four parts: #### Download and install - go get github.com/beego/beego/v2@v2.0.0 + go get github.com/beego/beego/v2@latest #### Create file `hello.go` @@ -90,8 +90,7 @@ Congratulations! You've just built your first **beego** app. ## Community * [http://beego.me/community](http://beego.me/community) -* Welcome to join us in Slack: [https://beego.slack.com](https://beego.slack.com), you can get invited - from [here](https://github.com/beego/beedoc/issues/232) +* Welcome to join us in Slack: [https://beego.slack.com invite](https://join.slack.com/t/beego/shared_invite/zt-fqlfjaxs-_CRmiITCSbEqQG9NeBqXKA), * QQ Group Group ID:523992905 * [Contribution Guide](https://github.com/beego/beedoc/blob/master/en-US/intro/contributing.md). diff --git a/adapter/cache/cache.go b/adapter/cache/cache.go index f615b26f..1291568c 100644 --- a/adapter/cache/cache.go +++ b/adapter/cache/cache.go @@ -16,7 +16,7 @@ // Usage: // // import( -// "github.com/beego/beego/v2/cache" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("memory", `{"interval":60}`) diff --git a/adapter/cache/cache_test.go b/adapter/cache/cache_test.go index f6217e1a..261e1e5e 100644 --- a/adapter/cache/cache_test.go +++ b/adapter/cache/cache_test.go @@ -19,12 +19,22 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" +) + +const ( + initError = "init err" + setError = "set Error" + checkError = "check err" + getError = "get err" + getMultiError = "GetMulti Error" ) func TestCacheIncr(t *testing.T) { bm, err := NewCache("memory", `{"interval":20}`) if err != nil { - t.Error("init err") + t.Error(initError) } // timeoutDuration := 10 * time.Second @@ -45,147 +55,95 @@ func TestCacheIncr(t *testing.T) { func TestCache(t *testing.T) { bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } + assert.Nil(t, err) - time.Sleep(30 * time.Second) + timeoutDuration := 5 * time.Second + err = bm.Put("astaxie", 1, timeoutDuration) + assert.Nil(t, err) - if bm.IsExist("astaxie") { - t.Error("check err") - } + assert.True(t, bm.IsExist("astaxie")) - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Equal(t, 1, bm.Get("astaxie")) - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } + time.Sleep(10 * time.Second) - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } + assert.False(t, bm.IsExist("astaxie")) - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } + err = bm.Put("astaxie", 1, timeoutDuration) + assert.Nil(t, err) - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } + err = bm.Incr("astaxie") + assert.Nil(t, err) - // test GetMulti - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } + assert.Equal(t, 2, bm.Get("astaxie")) - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } + assert.Nil(t, bm.Decr("astaxie")) + + assert.Equal(t, 1, bm.Get("astaxie")) + + assert.Nil(t, bm.Delete("astaxie")) + + assert.False(t, bm.IsExist("astaxie")) + + assert.Nil(t, bm.Put("astaxie", "author", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie")) + + assert.Equal(t, "author", bm.Get("astaxie")) + + assert.Nil(t, bm.Put("astaxie1", "author1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie1")) vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } + + assert.Equal(t, 2, len(vv)) + + + assert.Equal(t, "author", vv[0]) + + assert.Equal(t, "author1", vv[1]) } func TestFileCache(t *testing.T) { bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } + assert.Nil(t, err) + timeoutDuration := 5 * time.Second - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, bm.Put("astaxie", 1, timeoutDuration)) - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } + assert.True(t, bm.IsExist("astaxie")) - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } + assert.Equal(t, 1, bm.Get("astaxie")) - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } + assert.Nil(t, bm.Incr("astaxie")) - // test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } + assert.Equal(t, 2, bm.Get("astaxie")) - // test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } + assert.Nil(t, bm.Decr("astaxie")) + + assert.Equal(t, 1, bm.Get("astaxie")) + assert.Nil(t, bm.Delete("astaxie")) + + assert.False(t, bm.IsExist("astaxie")) + + assert.Nil(t, bm.Put("astaxie", "author", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie")) + + assert.Equal(t, "author", bm.Get("astaxie")) + + assert.Nil(t, bm.Put("astaxie1", "author1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie1")) vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - os.RemoveAll("cache") + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "author", vv[0]) + assert.Equal(t, "author1", vv[1]) + assert.Nil(t, os.RemoveAll("cache")) } diff --git a/adapter/cache/memcache/memcache.go b/adapter/cache/memcache/memcache.go index 16948f65..37b4b282 100644 --- a/adapter/cache/memcache/memcache.go +++ b/adapter/cache/memcache/memcache.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/cache/memcache" -// "github.com/beego/beego/v2/cache" +// _ "github.com/beego/beego/v2/client/cache/memcache" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) diff --git a/adapter/cache/memcache/memcache_test.go b/adapter/cache/memcache/memcache_test.go index 6382543e..cbef74f6 100644 --- a/adapter/cache/memcache/memcache_test.go +++ b/adapter/cache/memcache/memcache_test.go @@ -21,9 +21,19 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/beego/beego/v2/adapter/cache" ) +const ( + initError = "init err" + setError = "set Error" + checkError = "check err" + getError = "get err" + getMultiError = "GetMulti Error" +) + func TestMemcacheCache(t *testing.T) { addr := os.Getenv("MEMCACHE_ADDR") @@ -32,83 +42,52 @@ func TestMemcacheCache(t *testing.T) { } bm, err := cache.NewCache("memcache", fmt.Sprintf(`{"conn": "%s"}`, addr)) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } + assert.Nil(t, err) + timeoutDuration := 5 * time.Second + + assert.Nil(t, bm.Put("astaxie", "1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie")) time.Sleep(11 * time.Second) - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.False(t, bm.IsExist("astaxie")) - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } + assert.Nil(t, bm.Put("astaxie", "1", timeoutDuration)) + v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, bm.Incr("astaxie")) - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { - t.Error("get err") - } + v, err = strconv.Atoi(string(bm.Get("astaxie").([]byte))) + assert.Nil(t, err) + assert.Equal(t, 2, v) - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } + assert.Nil(t, bm.Decr("astaxie")) - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } + v, err = strconv.Atoi(string(bm.Get("astaxie").([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) - // test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } + assert.Nil(t, bm.Delete("astaxie")) - if v := bm.Get("astaxie").([]byte); string(v) != "author" { - t.Error("get err") - } + assert.False(t, bm.IsExist("astaxie")) - // test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } + assert.Nil(t, bm.Put("astaxie", "author", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie")) + + assert.Equal(t, []byte("author"), bm.Get("astaxie")) + + assert.Nil(t, bm.Put("astaxie1", "author1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie1")) vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { - t.Error("GetMulti ERROR") - } - if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Equal(t, []byte("author"), vv[0]) + assert.Equal(t, []byte("author1"), vv[1]) - // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } + assert.Nil(t, bm.ClearAll()) } diff --git a/adapter/cache/redis/redis.go b/adapter/cache/redis/redis.go index bfbeeb9c..003bc6b1 100644 --- a/adapter/cache/redis/redis.go +++ b/adapter/cache/redis/redis.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/cache/redis" -// "github.com/beego/beego/v2/cache" +// _ "github.com/beego/beego/v2/client/cache/redis" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) diff --git a/adapter/cache/redis/redis_test.go b/adapter/cache/redis/redis_test.go index 39a30e8e..3f0ddf6e 100644 --- a/adapter/cache/redis/redis_test.go +++ b/adapter/cache/redis/redis_test.go @@ -21,10 +21,19 @@ import ( "time" "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" "github.com/beego/beego/v2/adapter/cache" ) +const ( + initError = "init err" + setError = "set Error" + checkError = "check err" + getError = "get err" + getMultiError = "GetMulti Error" +) + func TestRedisCache(t *testing.T) { redisAddr := os.Getenv("REDIS_ADDR") if redisAddr == "" { @@ -32,98 +41,79 @@ func TestRedisCache(t *testing.T) { } bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, redisAddr)) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } + assert.Nil(t, err) + timeoutDuration := 5 * time.Second - time.Sleep(11 * time.Second) + assert.Nil(t, bm.Put("astaxie", 1, timeoutDuration)) - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.True(t, bm.IsExist("astaxie")) - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } + time.Sleep(7 * time.Second) - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.False(t, bm.IsExist("astaxie")) - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { - t.Error("get err") - } + assert.Nil(t, bm.Put("astaxie", 1, timeoutDuration)) - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } + v, err := redis.Int(bm.Get("astaxie"), err) + assert.Nil(t, err) + assert.Equal(t, 1, v) - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } + assert.Nil(t, bm.Incr("astaxie")) - // test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } + v, err = redis.Int(bm.Get("astaxie"), err) + assert.Nil(t, err) + assert.Equal(t, 2, v) - if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { - t.Error("get err") - } + assert.Nil(t, bm.Decr("astaxie")) - // test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } + v, err = redis.Int(bm.Get("astaxie"), err) + assert.Nil(t, err) + assert.Equal(t, 1, v) + + assert.Nil(t, bm.Delete("astaxie")) + + assert.False(t, bm.IsExist("astaxie")) + + assert.Nil(t, bm.Put("astaxie", "author", timeoutDuration)) + assert.True(t, bm.IsExist("astaxie")) + + vs, err := redis.String(bm.Get("astaxie"), err) + assert.Nil(t, err) + assert.Equal(t, "author", vs) + + assert.Nil(t, bm.Put("astaxie1", "author1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie1")) vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[0], nil); v != "author" { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[1], nil); v != "author1" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + + vs, err = redis.String(vv[0], nil) + + assert.Nil(t, err) + assert.Equal(t, "author", vs) + + vs, err = redis.String(vv[1], nil) + + assert.Nil(t, err) + assert.Equal(t, "author1", vs) + + assert.Nil(t, bm.ClearAll()) // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } } -func TestCache_Scan(t *testing.T) { +func TestCacheScan(t *testing.T) { timeoutDuration := 10 * time.Second // init bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) if err != nil { - t.Error("init err") + t.Error(initError) } // insert all for i := 0; i < 10000; i++ { if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { - t.Error("set Error", err) + t.Error(setError, err) } } diff --git a/adapter/cache/ssdb/ssdb_test.go b/adapter/cache/ssdb/ssdb_test.go index 98e805d1..d61d0a05 100644 --- a/adapter/cache/ssdb/ssdb_test.go +++ b/adapter/cache/ssdb/ssdb_test.go @@ -7,9 +7,19 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/beego/beego/v2/adapter/cache" ) +const ( + initError = "init err" + setError = "set Error" + checkError = "check err" + getError = "get err" + getMultiError = "GetMulti Error" +) + func TestSsdbcacheCache(t *testing.T) { ssdbAddr := os.Getenv("SSDB_ADDR") if ssdbAddr == "" { @@ -17,95 +27,59 @@ func TestSsdbcacheCache(t *testing.T) { } ssdb, err := cache.NewCache("ssdb", fmt.Sprintf(`{"conn": "%s"}`, ssdbAddr)) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) + + assert.False(t, ssdb.IsExist("ssdb")) // test put and exist - if ssdb.IsExist("ssdb") { - t.Error("check err") - } - timeoutDuration := 10 * time.Second + timeoutDuration := 3 * time.Second // timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } + assert.Nil(t, ssdb.Put("ssdb", "ssdb", timeoutDuration)) + assert.True(t, ssdb.IsExist("ssdb")) - // Get test done - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Nil(t, ssdb.Put("ssdb", "ssdb", timeoutDuration)) - if v := ssdb.Get("ssdb"); v != "ssdb" { - t.Error("get Error") - } + assert.Equal(t, "ssdb", ssdb.Get("ssdb")) // inc/dec test done - if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if err = ssdb.Incr("ssdb"); err != nil { - t.Error("incr Error", err) - } + assert.Nil(t, ssdb.Put("ssdb", "2", timeoutDuration)) - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } + assert.Nil(t, ssdb.Incr("ssdb")) - if err = ssdb.Decr("ssdb"); err != nil { - t.Error("decr error") - } + v, err := strconv.Atoi(ssdb.Get("ssdb").(string)) + assert.Nil(t, err) + assert.Equal(t, 3, v) + + assert.Nil(t, ssdb.Decr("ssdb")) + + assert.Nil(t, ssdb.Put("ssdb", "3", timeoutDuration)) // test del - if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } - if err := ssdb.Delete("ssdb"); err == nil { - if ssdb.IsExist("ssdb") { - t.Error("delete err") - } - } + v, err = strconv.Atoi(ssdb.Get("ssdb").(string)) + assert.Nil(t, err) + assert.Equal(t, 3, v) + + assert.Nil(t, ssdb.Delete("ssdb")) + assert.False(t, ssdb.IsExist("ssdb")) // test string - if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } - if v := ssdb.Get("ssdb").(string); v != "ssdb" { - t.Error("get err") - } + assert.Nil(t, ssdb.Put("ssdb", "ssdb", -10*time.Second)) + + assert.True(t, ssdb.IsExist("ssdb")) + assert.Equal(t, "ssdb", ssdb.Get("ssdb")) // test GetMulti done - if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb1") { - t.Error("check err") - } - vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) - if len(vv) != 2 { - t.Error("getmulti error") - } - if vv[0].(string) != "ssdb" { - t.Error("getmulti error") - } - if vv[1].(string) != "ssdb1" { - t.Error("getmulti error") - } + assert.Nil(t, ssdb.Put("ssdb1", "ssdb1", -10*time.Second)) + assert.True(t, ssdb.IsExist("ssdb1") ) + vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "ssdb", vv[0]) + assert.Equal(t, "ssdb1", vv[1]) + + assert.Nil(t, ssdb.ClearAll()) + assert.False(t, ssdb.IsExist("ssdb")) + assert.False(t, ssdb.IsExist("ssdb1")) // test clear all done - if err = ssdb.ClearAll(); err != nil { - t.Error("clear all err") - } - if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { - t.Error("check err") - } } diff --git a/adapter/config/config.go b/adapter/config/config.go index a935e281..bf2496fc 100644 --- a/adapter/config/config.go +++ b/adapter/config/config.go @@ -14,7 +14,7 @@ // Package config is used to parse config. // Usage: -// import "github.com/beego/beego/v2/config" +// import "github.com/beego/beego/v2/core/config" // Examples. // // cnf, err := config.NewConfig("ini", "config.conf") diff --git a/adapter/config/ini_test.go b/adapter/config/ini_test.go index 60f1febd..07992ba7 100644 --- a/adapter/config/ini_test.go +++ b/adapter/config/ini_test.go @@ -81,7 +81,8 @@ password = ${GOPATH} } ) - f, err := os.Create("testini.conf") + cfgFile := "testini.conf" + f, err := os.Create(cfgFile) if err != nil { t.Fatal(err) } @@ -91,8 +92,8 @@ password = ${GOPATH} t.Fatal(err) } f.Close() - defer os.Remove("testini.conf") - iniconf, err := NewConfig("ini", "testini.conf") + defer os.Remove(cfgFile) + iniconf, err := NewConfig("ini", cfgFile) if err != nil { t.Fatal(err) } diff --git a/adapter/config/json_test.go b/adapter/config/json_test.go index 16f42409..f0076f2a 100644 --- a/adapter/config/json_test.go +++ b/adapter/config/json_test.go @@ -18,6 +18,8 @@ import ( "fmt" "os" "testing" + + "github.com/stretchr/testify/assert" ) func TestJsonStartsWithArray(t *testing.T) { @@ -32,7 +34,8 @@ func TestJsonStartsWithArray(t *testing.T) { "serviceAPI": "http://www.test.com/employee" } ]` - f, err := os.Create("testjsonWithArray.conf") + cfgFileName := "testjsonWithArray.conf" + f, err := os.Create(cfgFileName) if err != nil { t.Fatal(err) } @@ -42,8 +45,8 @@ func TestJsonStartsWithArray(t *testing.T) { t.Fatal(err) } f.Close() - defer os.Remove("testjsonWithArray.conf") - jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + defer os.Remove(cfgFileName) + jsonconf, err := NewConfig("json", cfgFileName) if err != nil { t.Fatal(err) } @@ -132,7 +135,8 @@ func TestJson(t *testing.T) { } ) - f, err := os.Create("testjson.conf") + cfgFileName := "testjson.conf" + f, err := os.Create(cfgFileName) if err != nil { t.Fatal(err) } @@ -142,8 +146,8 @@ func TestJson(t *testing.T) { t.Fatal(err) } f.Close() - defer os.Remove("testjson.conf") - jsonconf, err := NewConfig("json", "testjson.conf") + defer os.Remove(cfgFileName) + jsonconf, err := NewConfig("json", cfgFileName) if err != nil { t.Fatal(err) } @@ -167,56 +171,39 @@ func TestJson(t *testing.T) { default: value, err = jsonconf.DIY(k) } - if err != nil { - t.Fatalf("get key %q value fatal,%v err %s", k, v, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Fatalf("get key %q value, want %v got %v .", k, v, value) - } - } - if err = jsonconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if jsonconf.String("name") != "astaxie" { - t.Fatal("get name error") + assert.Nil(t, err) + assert.Equal(t, fmt.Sprintf("%v", v), fmt.Sprintf("%v", value)) } - if db, err := jsonconf.DIY("database"); err != nil { - t.Fatal(err) - } else if m, ok := db.(map[string]interface{}); !ok { - t.Log(db) - t.Fatal("db not map[string]interface{}") - } else { - if m["host"].(string) != "host" { - t.Fatal("get host err") - } - } + assert.Nil(t, jsonconf.Set("name", "astaxie")) - if _, err := jsonconf.Int("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an Int") - } + assert.Equal(t, "astaxie", jsonconf.String("name")) - if _, err := jsonconf.Int64("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an Int64") - } + db, err := jsonconf.DIY("database") + assert.Nil(t, err) - if _, err := jsonconf.Float("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting a Float") - } + m, ok := db.(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t,"host" , m["host"]) - if _, err := jsonconf.DIY("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an interface{}") - } + _, err = jsonconf.Int("unknown") + assert.NotNil(t, err) - if val := jsonconf.String("unknown"); val != "" { - t.Error("unknown keys should return an empty string when expecting a String") - } + _, err = jsonconf.Int64("unknown") + assert.NotNil(t, err) - if _, err := jsonconf.Bool("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting a Bool") - } + _, err = jsonconf.Float("unknown") + assert.NotNil(t, err) - if !jsonconf.DefaultBool("unknown", true) { - t.Error("unknown keys with default value wrong") - } + _, err = jsonconf.DIY("unknown") + assert.NotNil(t, err) + + val := jsonconf.String("unknown") + assert.Equal(t, "", val) + + _, err = jsonconf.Bool("unknown") + assert.NotNil(t, err) + + assert.True(t, jsonconf.DefaultBool("unknown", true)) } diff --git a/adapter/config/xml/xml.go b/adapter/config/xml/xml.go index 190cee97..8c623033 100644 --- a/adapter/config/xml/xml.go +++ b/adapter/config/xml/xml.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/config/xml" -// "github.com/beego/beego/v2/config" +// _ "github.com/beego/beego/v2/core/config/xml" +// "github.com/beego/beego/v2/core/config" // ) // // cnf, err := config.NewConfig("xml", "config.xml") diff --git a/adapter/config/xml/xml_test.go b/adapter/config/xml/xml_test.go index 5e43ca0f..95b21fd9 100644 --- a/adapter/config/xml/xml_test.go +++ b/adapter/config/xml/xml_test.go @@ -58,7 +58,8 @@ func TestXML(t *testing.T) { } ) - f, err := os.Create("testxml.conf") + cfgFileName := "testxml.conf" + f, err := os.Create(cfgFileName) if err != nil { t.Fatal(err) } @@ -68,9 +69,9 @@ func TestXML(t *testing.T) { t.Fatal(err) } f.Close() - defer os.Remove("testxml.conf") + defer os.Remove(cfgFileName) - xmlconf, err := config.NewConfig("xml", "testxml.conf") + xmlconf, err := config.NewConfig("xml", cfgFileName) if err != nil { t.Fatal(err) } diff --git a/adapter/config/yaml/yaml.go b/adapter/config/yaml/yaml.go index 8d0bb697..538f1178 100644 --- a/adapter/config/yaml/yaml.go +++ b/adapter/config/yaml/yaml.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/config/yaml" -// "github.com/beego/beego/v2/config" +// _ "github.com/beego/beego/v2/core/config/yaml" +// "github.com/beego/beego/v2/core/config" // ) // // cnf, err := config.NewConfig("yaml", "config.yaml") diff --git a/adapter/config/yaml/yaml_test.go b/adapter/config/yaml/yaml_test.go index d567b554..323b5e87 100644 --- a/adapter/config/yaml/yaml_test.go +++ b/adapter/config/yaml/yaml_test.go @@ -54,7 +54,8 @@ func TestYaml(t *testing.T) { "emptystrings": []string{}, } ) - f, err := os.Create("testyaml.conf") + cfgFileName := "testyaml.conf" + f, err := os.Create(cfgFileName) if err != nil { t.Fatal(err) } @@ -64,8 +65,8 @@ func TestYaml(t *testing.T) { t.Fatal(err) } f.Close() - defer os.Remove("testyaml.conf") - yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + defer os.Remove(cfgFileName) + yamlconf, err := config.NewConfig("yaml", cfgFileName) if err != nil { t.Fatal(err) } diff --git a/adapter/context/context.go b/adapter/context/context.go index 16e631fc..bb8f7cd9 100644 --- a/adapter/context/context.go +++ b/adapter/context/context.go @@ -15,7 +15,7 @@ // Package context provide the context utils // Usage: // -// import "github.com/beego/beego/v2/context" +// import "github.com/beego/beego/v2/server/web/context" // // ctx := context.Context{Request:req,ResponseWriter:rw} // diff --git a/adapter/context/param/conv.go b/adapter/context/param/conv.go new file mode 100644 index 00000000..ec4c6b7e --- /dev/null +++ b/adapter/context/param/conv.go @@ -0,0 +1,18 @@ +package param + +import ( + "reflect" + + beecontext "github.com/beego/beego/v2/adapter/context" + "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/context/param" +) + +// ConvertParams converts http method params to values that will be passed to the method controller as arguments +func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { + nps := make([]*param.MethodParam, 0, len(methodParams)) + for _, mp := range methodParams { + nps = append(nps, (*param.MethodParam)(mp)) + } + return param.ConvertParams(nps, methodType, (*context.Context)(ctx)) +} diff --git a/adapter/context/param/conv_test.go b/adapter/context/param/conv_test.go new file mode 100644 index 00000000..6f18f240 --- /dev/null +++ b/adapter/context/param/conv_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package param + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/adapter/context" +) + + +// Demo is used to test, it's empty +func Demo(i int) { + +} + +func TestConvertParams(t *testing.T) { + res := ConvertParams(nil, reflect.TypeOf(Demo), context.NewContext()) + assert.Equal(t, 0, len(res)) + ctx := context.NewContext() + ctx.Input.RequestBody = []byte("11") + res = ConvertParams([]*MethodParam{ + New("A", InBody), + }, reflect.TypeOf(Demo), ctx) + assert.Equal(t, int64(11), res[0].Int()) +} diff --git a/adapter/context/param/methodparams.go b/adapter/context/param/methodparams.go new file mode 100644 index 00000000..000539db --- /dev/null +++ b/adapter/context/param/methodparams.go @@ -0,0 +1,29 @@ +package param + +import ( + "github.com/beego/beego/v2/server/web/context/param" +) + +// MethodParam keeps param information to be auto passed to controller methods +type MethodParam param.MethodParam + +// New creates a new MethodParam with name and specific options +func New(name string, opts ...MethodParamOption) *MethodParam { + newOps := make([]param.MethodParamOption, 0, len(opts)) + for _, o := range opts { + newOps = append(newOps, oldMpoToNew(o)) + } + return (*MethodParam)(param.New(name, newOps...)) +} + +// Make creates an array of MethodParmas or an empty array +func Make(list ...*MethodParam) []*MethodParam { + if len(list) > 0 { + return list + } + return nil +} + +func (mp *MethodParam) String() string { + return (*param.MethodParam)(mp).String() +} diff --git a/adapter/context/param/methodparams_test.go b/adapter/context/param/methodparams_test.go new file mode 100644 index 00000000..9d5155bf --- /dev/null +++ b/adapter/context/param/methodparams_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package param + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMethodParamString(t *testing.T) { + method := New("myName", IsRequired, InHeader, Default("abc")) + s := method.String() + assert.Equal(t, `param.New("myName", param.IsRequired, param.InHeader, param.Default("abc"))`, s) +} + +func TestMake(t *testing.T) { + res := Make() + assert.Equal(t, 0, len(res)) + res = Make(New("myName", InBody)) + assert.Equal(t, 1, len(res)) +} diff --git a/adapter/context/param/options.go b/adapter/context/param/options.go new file mode 100644 index 00000000..1d9364c2 --- /dev/null +++ b/adapter/context/param/options.go @@ -0,0 +1,45 @@ +package param + +import ( + "github.com/beego/beego/v2/server/web/context/param" +) + +// MethodParamOption defines a func which apply options on a MethodParam +type MethodParamOption func(*MethodParam) + +// IsRequired indicates that this param is required and can not be omitted from the http request +var IsRequired MethodParamOption = func(p *MethodParam) { + param.IsRequired((*param.MethodParam)(p)) +} + +// InHeader indicates that this param is passed via an http header +var InHeader MethodParamOption = func(p *MethodParam) { + param.InHeader((*param.MethodParam)(p)) +} + +// InPath indicates that this param is part of the URL path +var InPath MethodParamOption = func(p *MethodParam) { + param.InPath((*param.MethodParam)(p)) +} + +// InBody indicates that this param is passed as an http request body +var InBody MethodParamOption = func(p *MethodParam) { + param.InBody((*param.MethodParam)(p)) +} + +// Default provides a default value for the http param +func Default(defaultValue interface{}) MethodParamOption { + return newMpoToOld(param.Default(defaultValue)) +} + +func newMpoToOld(n param.MethodParamOption) MethodParamOption { + return func(methodParam *MethodParam) { + n((*param.MethodParam)(methodParam)) + } +} + +func oldMpoToNew(old MethodParamOption) param.MethodParamOption { + return func(methodParam *param.MethodParam) { + old((*MethodParam)(methodParam)) + } +} diff --git a/adapter/grace/grace.go b/adapter/grace/grace.go index 6e582bac..de047eb1 100644 --- a/adapter/grace/grace.go +++ b/adapter/grace/grace.go @@ -22,7 +22,7 @@ // "net/http" // "os" // -// "github.com/beego/beego/v2/grace" +// "github.com/beego/beego/v2/server/web/grace" // ) // // func handler(w http.ResponseWriter, r *http.Request) { diff --git a/adapter/httplib/httplib.go b/adapter/httplib/httplib.go index 0a182cae..005eee0f 100644 --- a/adapter/httplib/httplib.go +++ b/adapter/httplib/httplib.go @@ -15,7 +15,7 @@ // Package httplib is used as http.Client // Usage: // -// import "github.com/beego/beego/v2/httplib" +// import "github.com/beego/beego/v2/client/httplib" // // b := httplib.Post("http://beego.me/") // b.Param("username","astaxie") @@ -115,12 +115,6 @@ func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { return b } -// Debug sets show debug or not when executing request. -func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { - b.delegate.Debug(isdebug) - return b -} - // Retries sets Retries times. // default is 0 means no retried. // -1 means retried forever. @@ -135,17 +129,6 @@ func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { return b } -// DumpBody setting whether need to Dump the Body. -func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { - b.delegate.DumpBody(isdump) - return b -} - -// DumpRequest return the DumpRequest -func (b *BeegoHTTPRequest) DumpRequest() []byte { - return b.delegate.DumpRequest() -} - // SetTimeout sets connect time out and read-write time out for BeegoRequest. func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { b.delegate.SetTimeout(connectTimeout, readWriteTimeout) diff --git a/adapter/httplib/httplib_test.go b/adapter/httplib/httplib_test.go index e7605c87..298d84f9 100644 --- a/adapter/httplib/httplib_test.go +++ b/adapter/httplib/httplib_test.go @@ -15,6 +15,7 @@ package httplib import ( + "bytes" "errors" "io/ioutil" "net" @@ -25,8 +26,11 @@ import ( "time" ) +const getUrl = "http://httpbin.org/get" +const ipUrl = "http://httpbin.org/ip" + func TestResponse(t *testing.T) { - req := Get("http://httpbin.org/get") + req := Get(getUrl) resp, err := req.Response() if err != nil { t.Fatal(err) @@ -63,7 +67,8 @@ func TestDoRequest(t *testing.T) { } func TestGet(t *testing.T) { - req := Get("http://httpbin.org/get") + + req := Get(getUrl) b, err := req.Bytes() if err != nil { t.Fatal(err) @@ -205,7 +210,7 @@ func TestWithSetting(t *testing.T) { setting.ReadWriteTimeout = 5 * time.Second SetDefaultSetting(setting) - str, err := Get("http://httpbin.org/get").String() + str, err := Get(getUrl).String() if err != nil { t.Fatal(err) } @@ -218,7 +223,8 @@ func TestWithSetting(t *testing.T) { } func TestToJson(t *testing.T) { - req := Get("http://httpbin.org/ip") + + req := Get(ipUrl) resp, err := req.Response() if err != nil { t.Fatal(err) @@ -249,28 +255,28 @@ func TestToJson(t *testing.T) { func TestToFile(t *testing.T) { f := "beego_testfile" - req := Get("http://httpbin.org/ip") + req := Get(ipUrl) err := req.ToFile(f) if err != nil { t.Fatal(err) } defer os.Remove(f) b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { + if n := bytes.Index(b, []byte("origin")); n == -1 { t.Fatal(err) } } func TestToFileDir(t *testing.T) { f := "./files/beego_testfile" - req := Get("http://httpbin.org/ip") + req := Get(ipUrl) err := req.ToFile(f) if err != nil { t.Fatal(err) } defer os.RemoveAll("./files") b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { + if n := bytes.Index(b, []byte("origin")); n == -1 { t.Fatal(err) } } diff --git a/adapter/log.go b/adapter/log.go index 25e82d26..edf101ad 100644 --- a/adapter/log.go +++ b/adapter/log.go @@ -23,7 +23,7 @@ import ( ) // Log levels to control the logging output. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. const ( LevelEmergency = webLog.LevelEmergency LevelAlert = webLog.LevelAlert @@ -36,90 +36,90 @@ const ( ) // BeeLogger references the used application logger. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. var BeeLogger = logs.GetBeeLogger() // SetLevel sets the global log level used by the simple logger. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func SetLevel(l int) { logs.SetLevel(l) } // SetLogFuncCall set the CallDepth, default is 3 -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func SetLogFuncCall(b bool) { logs.SetLogFuncCall(b) } // SetLogger sets a new logger. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func SetLogger(adaptername string, config string) error { return logs.SetLogger(adaptername, config) } // Emergency logs a message at emergency level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Emergency(v ...interface{}) { logs.Emergency(generateFmtStr(len(v)), v...) } // Alert logs a message at alert level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Alert(v ...interface{}) { logs.Alert(generateFmtStr(len(v)), v...) } // Critical logs a message at critical level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Critical(v ...interface{}) { logs.Critical(generateFmtStr(len(v)), v...) } // Error logs a message at error level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Error(v ...interface{}) { logs.Error(generateFmtStr(len(v)), v...) } // Warning logs a message at warning level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Warning(v ...interface{}) { logs.Warning(generateFmtStr(len(v)), v...) } // Warn compatibility alias for Warning() -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Warn(v ...interface{}) { logs.Warn(generateFmtStr(len(v)), v...) } // Notice logs a message at notice level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Notice(v ...interface{}) { logs.Notice(generateFmtStr(len(v)), v...) } // Informational logs a message at info level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Informational(v ...interface{}) { logs.Informational(generateFmtStr(len(v)), v...) } // Info compatibility alias for Warning() -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Info(v ...interface{}) { logs.Info(generateFmtStr(len(v)), v...) } // Debug logs a message at debug level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Debug(v ...interface{}) { logs.Debug(generateFmtStr(len(v)), v...) } // Trace logs a message at trace level. // compatibility alias for Warning() -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Trace(v ...interface{}) { logs.Trace(generateFmtStr(len(v)), v...) } diff --git a/adapter/logs/log.go b/adapter/logs/log.go index 9d098d8f..d53cc2ce 100644 --- a/adapter/logs/log.go +++ b/adapter/logs/log.go @@ -15,7 +15,7 @@ // Package logs provide a general log interface // Usage: // -// import "github.com/beego/beego/v2/logs" +// import "github.com/beego/beego/v2/core/logs" // // log := NewLogger(10000) // log.SetLogger("console", "") diff --git a/adapter/logs/logger_test.go b/adapter/logs/logger_test.go index 9f2cc5a5..42708fa5 100644 --- a/adapter/logs/logger_test.go +++ b/adapter/logs/logger_test.go @@ -18,7 +18,7 @@ import ( "testing" ) -func TestBeeLogger_Info(t *testing.T) { +func TestBeeLoggerInfo(t *testing.T) { log := NewLogger(1000) log.SetLogger("file", `{"net":"tcp","addr":":7020"}`) } diff --git a/adapter/metric/prometheus_test.go b/adapter/metric/prometheus_test.go index 53984845..72212dd4 100644 --- a/adapter/metric/prometheus_test.go +++ b/adapter/metric/prometheus_test.go @@ -15,6 +15,7 @@ package metric import ( + "fmt" "net/http" "net/url" "testing" @@ -26,7 +27,9 @@ import ( ) func TestPrometheusMiddleWare(t *testing.T) { - middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + fmt.Print("you are coming") + })) writer := &context.Response{} request := &http.Request{ URL: &url.URL{ diff --git a/adapter/namespace.go b/adapter/namespace.go index 709f6aa5..af7c77f8 100644 --- a/adapter/namespace.go +++ b/adapter/namespace.go @@ -238,141 +238,158 @@ func AddNamespace(nl ...*Namespace) { // NSCond is Namespace Condition func NSCond(cond namespaceCond) LinkNamespace { + wc := web.NSCond(func(b *context.Context) bool { + return cond((*adtContext.Context)(b)) + }) return func(namespace *Namespace) { - web.NSCond(func(b *context.Context) bool { - return cond((*adtContext.Context)(b)) - }) + wc((*web.Namespace)(namespace)) } } // NSBefore Namespace BeforeRouter filter func NSBefore(filterList ...FilterFunc) LinkNamespace { + nfs := oldToNewFilter(filterList) + wf := web.NSBefore(nfs...) return func(namespace *Namespace) { - nfs := oldToNewFilter(filterList) - web.NSBefore(nfs...) + wf((*web.Namespace)(namespace)) } } // NSAfter add Namespace FinishRouter filter func NSAfter(filterList ...FilterFunc) LinkNamespace { + nfs := oldToNewFilter(filterList) + wf := web.NSAfter(nfs...) return func(namespace *Namespace) { - nfs := oldToNewFilter(filterList) - web.NSAfter(nfs...) + wf((*web.Namespace)(namespace)) } } // NSInclude Namespace Include ControllerInterface func NSInclude(cList ...ControllerInterface) LinkNamespace { + nfs := oldToNewCtrlIntfs(cList) + wi := web.NSInclude(nfs...) return func(namespace *Namespace) { - nfs := oldToNewCtrlIntfs(cList) - web.NSInclude(nfs...) + wi((*web.Namespace)(namespace)) } } // NSRouter call Namespace Router func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { + wn := web.NSRouter(rootpath, c, mappingMethods...) return func(namespace *Namespace) { - web.Router(rootpath, c, mappingMethods...) + wn((*web.Namespace)(namespace)) } } // NSGet call Namespace Get func NSGet(rootpath string, f FilterFunc) LinkNamespace { + ln := web.NSGet(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSGet(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + ln((*web.Namespace)(ns)) } } // NSPost call Namespace Post func NSPost(rootpath string, f FilterFunc) LinkNamespace { + wp := web.NSPost(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.Post(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wp((*web.Namespace)(ns)) } } // NSHead call Namespace Head func NSHead(rootpath string, f FilterFunc) LinkNamespace { + wb := web.NSHead(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSHead(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wb((*web.Namespace)(ns)) } } // NSPut call Namespace Put func NSPut(rootpath string, f FilterFunc) LinkNamespace { + wn := web.NSPut(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSPut(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wn((*web.Namespace)(ns)) } } // NSDelete call Namespace Delete func NSDelete(rootpath string, f FilterFunc) LinkNamespace { + wn := web.NSDelete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSDelete(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wn((*web.Namespace)(ns)) } } // NSAny call Namespace Any func NSAny(rootpath string, f FilterFunc) LinkNamespace { + wn := web.NSAny(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSAny(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wn((*web.Namespace)(ns)) } } // NSOptions call Namespace Options func NSOptions(rootpath string, f FilterFunc) LinkNamespace { + wo := web.NSOptions(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSOptions(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wo((*web.Namespace)(ns)) } } // NSPatch call Namespace Patch func NSPatch(rootpath string, f FilterFunc) LinkNamespace { + wn := web.NSPatch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSPatch(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wn((*web.Namespace)(ns)) } } // NSAutoRouter call Namespace AutoRouter func NSAutoRouter(c ControllerInterface) LinkNamespace { + wn := web.NSAutoRouter(c) return func(ns *Namespace) { - web.NSAutoRouter(c) + wn((*web.Namespace)(ns)) } } // NSAutoPrefix call Namespace AutoPrefix func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { + wn := web.NSAutoPrefix(prefix, c) return func(ns *Namespace) { - web.NSAutoPrefix(prefix, c) + wn((*web.Namespace)(ns)) } } // NSNamespace add sub Namespace func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { + nps := oldToNewLinkNs(params) + wn := web.NSNamespace(prefix, nps...) return func(ns *Namespace) { - nps := oldToNewLinkNs(params) - web.NSNamespace(prefix, nps...) + wn((*web.Namespace)(ns)) } } // NSHandler add handler func NSHandler(rootpath string, h http.Handler) LinkNamespace { + wn := web.NSHandler(rootpath, h) return func(ns *Namespace) { - web.NSHandler(rootpath, h) + wn((*web.Namespace)(ns)) } } diff --git a/adapter/orm/models_boot.go b/adapter/orm/models_boot.go index e004f35a..678b86e6 100644 --- a/adapter/orm/models_boot.go +++ b/adapter/orm/models_boot.go @@ -25,7 +25,7 @@ func RegisterModel(models ...interface{}) { // RegisterModelWithPrefix register models with a prefix func RegisterModelWithPrefix(prefix string, models ...interface{}) { - orm.RegisterModelWithPrefix(prefix, models) + orm.RegisterModelWithPrefix(prefix, models...) } // RegisterModelWithSuffix register models with a suffix diff --git a/adapter/orm/models_boot_test.go b/adapter/orm/models_boot_test.go new file mode 100644 index 00000000..5471885b --- /dev/null +++ b/adapter/orm/models_boot_test.go @@ -0,0 +1,31 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" +) + +type User struct { + Id int +} + +type Seller struct { + Id int +} + +func TestRegisterModelWithPrefix(t *testing.T) { + RegisterModelWithPrefix("test", &User{}, &Seller{}) +} diff --git a/adapter/orm/orm.go b/adapter/orm/orm.go index c603de2f..f3283fd4 100644 --- a/adapter/orm/orm.go +++ b/adapter/orm/orm.go @@ -21,7 +21,7 @@ // // import ( // "fmt" -// "github.com/beego/beego/v2/orm" +// "github.com/beego/beego/v2/client/orm" // _ "github.com/go-sql-driver/mysql" // import your used driver // ) // diff --git a/adapter/orm/query_setter_adapter.go b/adapter/orm/query_setter_adapter.go index 7f506759..edea0a15 100644 --- a/adapter/orm/query_setter_adapter.go +++ b/adapter/orm/query_setter_adapter.go @@ -21,14 +21,16 @@ import ( type baseQuerySetter struct { } +const shouldNotInvoke = "you should not invoke this method." + func (b *baseQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { - panic("you should not invoke this method.") + panic(shouldNotInvoke) } func (b *baseQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { - panic("you should not invoke this method.") + panic(shouldNotInvoke) } func (b *baseQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { - panic("you should not invoke this method.") + panic(shouldNotInvoke) } diff --git a/adapter/orm/utils.go b/adapter/orm/utils.go index 22bf8d63..cd54f867 100644 --- a/adapter/orm/utils.go +++ b/adapter/orm/utils.go @@ -195,7 +195,7 @@ func snakeStringWithAcronym(s string) string { } data = append(data, d) } - return strings.ToLower(string(data[:])) + return strings.ToLower(string(data)) } // snake string, XxYy to xx_yy , XxYY to xx_y_y @@ -213,7 +213,7 @@ func snakeString(s string) string { } data = append(data, d) } - return strings.ToLower(string(data[:])) + return strings.ToLower(string(data)) } // SetNameStrategy set different name strategy @@ -241,7 +241,7 @@ func camelString(s string) string { } data = append(data, d) } - return string(data[:]) + return string(data) } type argString []string diff --git a/adapter/orm/utils_test.go b/adapter/orm/utils_test.go index 7d94cada..fbf8663e 100644 --- a/adapter/orm/utils_test.go +++ b/adapter/orm/utils_test.go @@ -16,6 +16,8 @@ package orm import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestCamelString(t *testing.T) { @@ -29,9 +31,7 @@ func TestCamelString(t *testing.T) { for _, v := range snake { res := camelString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } + assert.Equal(t, answer[v], res) } } @@ -46,9 +46,7 @@ func TestSnakeString(t *testing.T) { for _, v := range camel { res := snakeString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } + assert.Equal(t, answer[v], res) } } @@ -63,8 +61,6 @@ func TestSnakeStringWithAcronym(t *testing.T) { for _, v := range camel { res := snakeStringWithAcronym(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } + assert.Equal(t, answer[v], res) } } diff --git a/adapter/plugins/apiauth/apiauth.go b/adapter/plugins/apiauth/apiauth.go index fd0c7ff4..d5511427 100644 --- a/adapter/plugins/apiauth/apiauth.go +++ b/adapter/plugins/apiauth/apiauth.go @@ -17,7 +17,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/apiauth" +// "github.com/beego/beego/v2/server/web/filter/apiauth" // ) // // func main(){ diff --git a/adapter/plugins/auth/basic.go b/adapter/plugins/auth/basic.go index 4ef3343f..173252ca 100644 --- a/adapter/plugins/auth/basic.go +++ b/adapter/plugins/auth/basic.go @@ -16,7 +16,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/auth" +// "github.com/beego/beego/v2/server/web/filter/auth" // ) // // func main(){ diff --git a/adapter/plugins/authz/authz.go b/adapter/plugins/authz/authz.go index 114c8c9a..096d7efb 100644 --- a/adapter/plugins/authz/authz.go +++ b/adapter/plugins/authz/authz.go @@ -16,7 +16,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/authz" +// "github.com/beego/beego/v2/server/web/filter/authz" // "github.com/casbin/casbin" // ) // diff --git a/adapter/plugins/authz/authz_test.go b/adapter/plugins/authz/authz_test.go index fa5410ca..4963ceab 100644 --- a/adapter/plugins/authz/authz_test.go +++ b/adapter/plugins/authz/authz_test.go @@ -26,6 +26,11 @@ import ( "github.com/beego/beego/v2/adapter/plugins/auth" ) +const ( + authCfg = "authz_model.conf" + authCsv = "authz_policy.csv" +) + func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { r, _ := http.NewRequest(method, path, nil) r.SetBasicAuth(user, "123") @@ -40,70 +45,79 @@ func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, p func TestBasic(t *testing.T) { handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + _ = handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + + _ = handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer(authCfg, authCsv))) handler.Any("*", func(ctx *context.Context) { ctx.Output.SetStatus(200) }) - testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) - testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) - testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) - testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) + const d1r1 = "/dataset1/resource1" + testRequest(t, handler, "alice", d1r1, "GET", 200) + testRequest(t, handler, "alice", d1r1, "POST", 200) + const d1r2 = "/dataset1/resource2" + testRequest(t, handler, "alice", d1r2, "GET", 200) + testRequest(t, handler, "alice", d1r2, "POST", 403) } func TestPathWildcard(t *testing.T) { handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + _ = handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + _ = handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer(authCfg, authCsv))) handler.Any("*", func(ctx *context.Context) { ctx.Output.SetStatus(200) }) - testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) - testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) - testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) - testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) - testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + const d2r1 = "/dataset2/resource1" + testRequest(t, handler, "bob", d2r1, "GET", 200) + testRequest(t, handler, "bob", d2r1, "POST", 200) + testRequest(t, handler, "bob", d2r1, "DELETE", 200) + const d2r2 = "/dataset2/resource2" + testRequest(t, handler, "bob", d2r2, "GET", 200) + testRequest(t, handler, "bob", d2r2, "POST", 403) + testRequest(t, handler, "bob", d2r2, "DELETE", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) + const item1 = "/dataset2/folder1/item1" + testRequest(t, handler, "bob", item1, "GET", 403) + testRequest(t, handler, "bob", item1, "POST", 200) + testRequest(t, handler, "bob", item1, "DELETE", 403) + const item2 = "/dataset2/folder1/item2" + testRequest(t, handler, "bob", item2, "GET", 403) + testRequest(t, handler, "bob", item2, "POST", 200) + testRequest(t, handler, "bob", item2, "DELETE", 403) } func TestRBAC(t *testing.T) { handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) - e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + _ = handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer(authCfg, authCsv) + _ = handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) handler.Any("*", func(ctx *context.Context) { ctx.Output.SetStatus(200) }) // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. - testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) - testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) - testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) - testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + const dataSet1 = "/dataset1/item" + testRequest(t, handler, "cathy", dataSet1, "GET", 200) + testRequest(t, handler, "cathy", dataSet1, "POST", 200) + testRequest(t, handler, "cathy", dataSet1, "DELETE", 200) + const dataSet2 = "/dataset2/item" + testRequest(t, handler, "cathy", dataSet2, "GET", 403) + testRequest(t, handler, "cathy", dataSet2, "POST", 403) + testRequest(t, handler, "cathy", dataSet2, "DELETE", 403) // delete all roles on user cathy, so cathy cannot access any resources now. e.DeleteRolesForUser("cathy") - testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + testRequest(t, handler, "cathy", dataSet1, "GET", 403) + testRequest(t, handler, "cathy", dataSet1, "POST", 403) + testRequest(t, handler, "cathy", dataSet1, "DELETE", 403) + testRequest(t, handler, "cathy", dataSet2, "GET", 403) + testRequest(t, handler, "cathy", dataSet2, "POST", 403) + testRequest(t, handler, "cathy", dataSet2, "DELETE", 403) } diff --git a/adapter/plugins/cors/cors.go b/adapter/plugins/cors/cors.go index 6a836585..89ac9c68 100644 --- a/adapter/plugins/cors/cors.go +++ b/adapter/plugins/cors/cors.go @@ -16,7 +16,7 @@ // Usage // import ( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/cors" +// "github.com/beego/beego/v2/server/web/filter/cors" // ) // // func main() { diff --git a/adapter/router.go b/adapter/router.go index 900e3eb7..9a615efe 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -87,7 +87,7 @@ func NewControllerRegister() *ControllerRegister { // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - (*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...) + (*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...)) } // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller diff --git a/adapter/session/couchbase/sess_couchbase.go b/adapter/session/couchbase/sess_couchbase.go index 4ce2d69d..2dc4ce18 100644 --- a/adapter/session/couchbase/sess_couchbase.go +++ b/adapter/session/couchbase/sess_couchbase.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/couchbase" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/couchbase" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/memcache/sess_memcache.go b/adapter/session/memcache/sess_memcache.go index e81d06c6..dfbc2d6e 100644 --- a/adapter/session/memcache/sess_memcache.go +++ b/adapter/session/memcache/sess_memcache.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/memcache" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/memcache" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/mysql/sess_mysql.go b/adapter/session/mysql/sess_mysql.go index d47e7496..5272d3fa 100644 --- a/adapter/session/mysql/sess_mysql.go +++ b/adapter/session/mysql/sess_mysql.go @@ -28,8 +28,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/mysql" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/mysql" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/postgres/sess_postgresql.go b/adapter/session/postgres/sess_postgresql.go index a24794d6..a6278f17 100644 --- a/adapter/session/postgres/sess_postgresql.go +++ b/adapter/session/postgres/sess_postgresql.go @@ -38,8 +38,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/postgresql" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/postgresql" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/redis/sess_redis.go b/adapter/session/redis/sess_redis.go index a5fcedf6..d38a675d 100644 --- a/adapter/session/redis/sess_redis.go +++ b/adapter/session/redis/sess_redis.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/redis_cluster/redis_cluster.go b/adapter/session/redis_cluster/redis_cluster.go index f4c8e4d1..623d72cc 100644 --- a/adapter/session/redis_cluster/redis_cluster.go +++ b/adapter/session/redis_cluster/redis_cluster.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis_cluster" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis_cluster" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/redis_sentinel/sess_redis_sentinel.go b/adapter/session/redis_sentinel/sess_redis_sentinel.go index 4498e55d..4b7dab77 100644 --- a/adapter/session/redis_sentinel/sess_redis_sentinel.go +++ b/adapter/session/redis_sentinel/sess_redis_sentinel.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis_sentinel" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis_sentinel" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/redis_sentinel/sess_redis_sentinel_test.go b/adapter/session/redis_sentinel/sess_redis_sentinel_test.go index 0a6249ee..b08d0256 100644 --- a/adapter/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/adapter/session/redis_sentinel/sess_redis_sentinel_test.go @@ -5,6 +5,8 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" + "github.com/beego/beego/v2/adapter/session" ) @@ -19,71 +21,55 @@ func TestRedisSentinel(t *testing.T) { ProviderConfig: "127.0.0.1:6379,100,,0,master", } globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { t.Log(e) return } - // todo test if e==nil + go globalSessions.GC() r, _ := http.NewRequest("GET", "/", nil) w := httptest.NewRecorder() sess, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("session start failed:", err) - } + assert.Nil(t, err) defer sess.SessionRelease(w) // SET AND GET err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set username failed:", err) - } + assert.Nil(t, err) username := sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } + assert.Equal(t, "astaxie", username) // DELETE err = sess.Delete("username") - if err != nil { - t.Fatal("delete username failed:", err) - } + assert.Nil(t, err) + username = sess.Get("username") - if username != nil { - t.Fatal("delete username failed") - } + assert.Nil(t, username) // FLUSH err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set failed:", err) - } + assert.Nil(t, err) + err = sess.Set("password", "1qaz2wsx") - if err != nil { - t.Fatal("set failed:", err) - } + assert.Nil(t, err) + username = sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } + assert.Equal(t, "astaxie", username) + password := sess.Get("password") - if password != "1qaz2wsx" { - t.Fatal("get password failed") - } + assert.Equal(t, "1qaz2wsx", password) + err = sess.Flush() - if err != nil { - t.Fatal("flush failed:", err) - } + assert.Nil(t, err) + username = sess.Get("username") - if username != nil { - t.Fatal("flush failed") - } + assert.Nil(t, username) + password = sess.Get("password") - if password != nil { - t.Fatal("flush failed") - } + assert.Nil(t, password) sess.SessionRelease(w) diff --git a/adapter/session/sess_cookie_test.go b/adapter/session/sess_cookie_test.go index b6726005..5d6b44e3 100644 --- a/adapter/session/sess_cookie_test.go +++ b/adapter/session/sess_cookie_test.go @@ -22,6 +22,8 @@ import ( "testing" ) +const setCookieKey = "Set-Cookie" + func TestCookie(t *testing.T) { config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` conf := new(ManagerConfig) @@ -46,7 +48,8 @@ func TestCookie(t *testing.T) { t.Fatal("get username error") } sess.SessionRelease(w) - if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + + if cookiestr := w.Header().Get(setCookieKey); cookiestr == "" { t.Fatal("setcookie error") } else { parts := strings.Split(strings.TrimSpace(cookiestr), ";") @@ -79,7 +82,7 @@ func TestDestorySessionCookie(t *testing.T) { // request again ,will get same sesssion id . r1, _ := http.NewRequest("GET", "/", nil) - r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + r1.Header.Set("Cookie", w.Header().Get(setCookieKey)) w = httptest.NewRecorder() newSession, err := globalSessions.SessionStart(w, r1) if err != nil { @@ -92,7 +95,7 @@ func TestDestorySessionCookie(t *testing.T) { // After destroy session , will get a new session id . globalSessions.SessionDestroy(w, r1) r2, _ := http.NewRequest("GET", "/", nil) - r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + r2.Header.Set("Cookie", w.Header().Get(setCookieKey)) w = httptest.NewRecorder() newSession, err = globalSessions.SessionStart(w, r2) diff --git a/adapter/session/session.go b/adapter/session/session.go index 40e947fd..703adbde 100644 --- a/adapter/session/session.go +++ b/adapter/session/session.go @@ -16,7 +16,7 @@ // // Usage: // import( -// "github.com/beego/beego/v2/session" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/templatefunc_test.go b/adapter/templatefunc_test.go index f5113606..2fd18e3d 100644 --- a/adapter/templatefunc_test.go +++ b/adapter/templatefunc_test.go @@ -19,19 +19,15 @@ import ( "net/url" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestSubstr(t *testing.T) { s := `012345` - if Substr(s, 0, 2) != "01" { - t.Error("should be equal") - } - if Substr(s, 0, 100) != "012345" { - t.Error("should be equal") - } - if Substr(s, 12, 100) != "012345" { - t.Error("should be equal") - } + assert.Equal(t, "01", Substr(s, 0, 2)) + assert.Equal(t, "012345", Substr(s, 0, 100)) + assert.Equal(t, "012345", Substr(s, 12, 100)) } func TestHtml2str(t *testing.T) { @@ -39,73 +35,51 @@ func TestHtml2str(t *testing.T) { \n` - if HTML2str(h) != "123\\n\n\\n" { - t.Error("should be equal") - } + assert.Equal(t, "123\\n\n\\n", HTML2str(h)) } func TestDateFormat(t *testing.T) { ts := "Mon, 01 Jul 2013 13:27:42 CST" tt, _ := time.Parse(time.RFC1123, ts) - if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { - t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) - } + assert.Equal(t, "2013-07-01 13:27:42", DateFormat(tt, "2006-01-02 15:04:05")) } func TestDate(t *testing.T) { ts := "Mon, 01 Jul 2013 13:27:42 CST" tt, _ := time.Parse(time.RFC1123, ts) - if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { - t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) - } - if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { - t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) - } - if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { - t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) - } - if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { - t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) - } + assert.Equal(t, "2013-07-01 13:27:42", Date(tt, "Y-m-d H:i:s")) + + assert.Equal(t, "13-7-1 01:27:42 PM", Date(tt, "y-n-j h:i:s A")) + assert.Equal(t, "Mon, 01 Jul 2013 1:27:42 pm", Date(tt, "D, d M Y g:i:s a")) + assert.Equal(t, "Monday, 01 July 2013 13:27:42", Date(tt, "l, d F Y G:i:s")) } func TestCompareRelated(t *testing.T) { - if !Compare("abc", "abc") { - t.Error("should be equal") - } - if Compare("abc", "aBc") { - t.Error("should be not equal") - } - if !Compare("1", 1) { - t.Error("should be equal") - } - if CompareNot("abc", "abc") { - t.Error("should be equal") - } - if !CompareNot("abc", "aBc") { - t.Error("should be not equal") - } - if !NotNil("a string") { - t.Error("should not be nil") - } + assert.True(t, Compare("abc", "abc")) + + assert.False(t, Compare("abc", "aBc")) + + assert.True(t, Compare("1", 1)) + + assert.False(t, CompareNot("abc", "abc")) + + assert.True(t, CompareNot("abc", "aBc")) + assert.True(t, NotNil("a string")) } func TestHtmlquote(t *testing.T) { h := `<' ”“&">` s := `<' ”“&">` - if Htmlquote(s) != h { - t.Error("should be equal") - } + assert.Equal(t, h, Htmlquote(s)) } func TestHtmlunquote(t *testing.T) { h := `<' ”“&">` s := `<' ”“&">` - if Htmlunquote(h) != s { - t.Error("should be equal") - } + assert.Equal(t, s, Htmlunquote(h)) + } func TestParseForm(t *testing.T) { @@ -148,55 +122,42 @@ func TestParseForm(t *testing.T) { "hobby": []string{"", "Basketball", "Football"}, "memo": []string{"nothing"}, } - if err := ParseForm(form, u); err == nil { - t.Fatal("nothing will be changed") - } - if err := ParseForm(form, &u); err != nil { - t.Fatal(err) - } - if u.ID != 0 { - t.Errorf("ID should equal 0 but got %v", u.ID) - } - if len(u.tag) != 0 { - t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) - } - if u.Name.(string) != "test" { - t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) - } - if u.Age != 40 { - t.Errorf("Age should equal 40 but got %v", u.Age) - } - if u.Email != "test@gmail.com" { - t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) - } - if u.Intro != "I am an engineer!" { - t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) - } - if !u.StrBool { - t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) - } + + assert.NotNil(t, ParseForm(form, u)) + + assert.Nil(t, ParseForm(form, &u)) + + assert.Equal(t, 0, u.ID) + + assert.Equal(t, 0, len(u.tag)) + + assert.Equal(t, "test", u.Name) + + assert.Equal(t, 40, u.Age) + + assert.Equal(t, "test@gmail.com", u.Email) + + assert.Equal(t, "I am an engineer!", u.Intro) + + assert.True(t, u.StrBool) + y, m, d := u.Date.Date() - if y != 2014 || m.String() != "November" || d != 12 { - t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) - } - if u.Organization != "beego" { - t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) - } - if u.Title != "CXO" { - t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) - } - if u.Hobby[0] != "" { - t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) - } - if u.Hobby[1] != "Basketball" { - t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) - } - if u.Hobby[2] != "Football" { - t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) - } - if len(u.Memo) != 0 { - t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) - } + + assert.Equal(t, 2014, y) + assert.Equal(t, "November", m.String()) + assert.Equal(t, 12, d) + + assert.Equal(t, "beego", u.Organization) + + assert.Equal(t, "CXO", u.Title) + + assert.Equal(t, "", u.Hobby[0]) + + assert.Equal(t, "Basketball", u.Hobby[1]) + + assert.Equal(t, "Football", u.Hobby[2]) + + assert.Equal(t, 0, len(u.Memo)) } func TestRenderForm(t *testing.T) { @@ -212,18 +173,14 @@ func TestRenderForm(t *testing.T) { u := user{Name: "test", Intro: "Some Text"} output := RenderForm(u) - if output != template.HTML("") { - t.Errorf("output should be empty but got %v", output) - } + assert.Equal(t, template.HTML(""), output) output = RenderForm(&u) result := template.HTML( `Name:
` + `年龄:
` + `Sex:
` + `Intro: `) - if output != result { - t.Errorf("output should equal `%v` but got `%v`", result, output) - } + assert.Equal(t, result, output) } func TestMapGet(t *testing.T) { @@ -233,29 +190,18 @@ func TestMapGet(t *testing.T) { "1": 2, } - if res, err := MapGet(m1, "a"); err == nil { - if res.(int64) != 1 { - t.Errorf("Should return 1, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err := MapGet(m1, "a") + assert.Nil(t, err) + assert.Equal(t, int64(1), res) - if res, err := MapGet(m1, "1"); err == nil { - if res.(int64) != 2 { - t.Errorf("Should return 2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err = MapGet(m1, "1") + assert.Nil(t, err) + assert.Equal(t, int64(2), res) - if res, err := MapGet(m1, 1); err == nil { - if res.(int64) != 2 { - t.Errorf("Should return 2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + + res, err = MapGet(m1, 1) + assert.Nil(t, err) + assert.Equal(t, int64(2), res) // test 2 level map m2 := M{ @@ -264,13 +210,9 @@ func TestMapGet(t *testing.T) { }, } - if res, err := MapGet(m2, 1, 2); err == nil { - if res.(float64) != 3.5 { - t.Errorf("Should return 3.5, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err = MapGet(m2, 1, 2) + assert.Nil(t, err) + assert.Equal(t, 3.5, res) // test 5 level map m5 := M{ @@ -285,20 +227,13 @@ func TestMapGet(t *testing.T) { }, } - if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { - if res.(float64) != 1.2 { - t.Errorf("Should return 1.2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err = MapGet(m5, 1, 2, 3, 4, 5) + assert.Nil(t, err) + assert.Equal(t, 1.2, res) // check whether element not exists in map - if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { - if res != nil { - t.Errorf("Should return nil, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err = MapGet(m5, 5, 4, 3, 2, 1) + assert.Nil(t, err) + assert.Nil(t, res) + } diff --git a/adapter/toolbox/statistics_test.go b/adapter/toolbox/statistics_test.go index ac29476c..f4371c3f 100644 --- a/adapter/toolbox/statistics_test.go +++ b/adapter/toolbox/statistics_test.go @@ -21,13 +21,16 @@ import ( ) func TestStatics(t *testing.T) { - StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) - StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) - StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) - StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) - StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) - StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) - StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) + userApi := "/api/user" + post := "POST" + adminUser := "&admin.user" + StatisticsMap.AddStatistics(post, userApi, adminUser, time.Duration(2000)) + StatisticsMap.AddStatistics(post, userApi, adminUser, time.Duration(120000)) + StatisticsMap.AddStatistics("GET", userApi, adminUser, time.Duration(13000)) + StatisticsMap.AddStatistics(post, "/api/admin", adminUser, time.Duration(14000)) + StatisticsMap.AddStatistics(post, "/api/user/astaxie", adminUser, time.Duration(12000)) + StatisticsMap.AddStatistics(post, "/api/user/xiemengjun", adminUser, time.Duration(13000)) + StatisticsMap.AddStatistics("DELETE", userApi, adminUser, time.Duration(1400)) t.Log(StatisticsMap.GetMap()) data := StatisticsMap.GetMapData() diff --git a/adapter/toolbox/task.go b/adapter/toolbox/task.go index bdd6679f..7b7cd68a 100644 --- a/adapter/toolbox/task.go +++ b/adapter/toolbox/task.go @@ -289,3 +289,7 @@ func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) { func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time { return o.delegate.GetPrev() } + +func (o *oldToNewAdapter) GetTimeout(ctx context.Context) time.Duration { + return 0 +} diff --git a/adapter/utils/captcha/README.md b/adapter/utils/captcha/README.md index 74e1cf82..07a4dc4d 100644 --- a/adapter/utils/captcha/README.md +++ b/adapter/utils/captcha/README.md @@ -7,8 +7,8 @@ package controllers import ( "github.com/beego/beego/v2" - "github.com/beego/beego/v2/cache" - "github.com/beego/beego/v2/utils/captcha" + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/server/web/captcha" ) var cpt *captcha.Captcha diff --git a/adapter/utils/captcha/captcha.go b/adapter/utils/captcha/captcha.go index 4f5dd867..edca528d 100644 --- a/adapter/utils/captcha/captcha.go +++ b/adapter/utils/captcha/captcha.go @@ -20,8 +20,8 @@ // // import ( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/cache" -// "github.com/beego/beego/v2/utils/captcha" +// "github.com/beego/beego/v2/client/cache" +// "github.com/beego/beego/v2/server/web/captcha" // ) // // var cpt *captcha.Captcha diff --git a/adapter/utils/pagination/doc.go b/adapter/utils/pagination/doc.go index d180b093..43ee78b6 100644 --- a/adapter/utils/pagination/doc.go +++ b/adapter/utils/pagination/doc.go @@ -8,7 +8,7 @@ In your beego.Controller: package controllers - import "github.com/beego/beego/v2/utils/pagination" + import "github.com/beego/beego/v2/server/web/pagination" type PostsController struct { beego.Controller diff --git a/adapter/utils/rand_test.go b/adapter/utils/rand_test.go index 6c238b5e..1cb26029 100644 --- a/adapter/utils/rand_test.go +++ b/adapter/utils/rand_test.go @@ -16,7 +16,7 @@ package utils import "testing" -func TestRand_01(t *testing.T) { +func TestRand01(t *testing.T) { bs0 := RandomCreateBytes(16) bs1 := RandomCreateBytes(16) diff --git a/adapter/validation/validation.go b/adapter/validation/validation.go index 8226fa20..eadd4361 100644 --- a/adapter/validation/validation.go +++ b/adapter/validation/validation.go @@ -15,7 +15,7 @@ // Package validation for validations // // import ( -// "github.com/beego/beego/v2/validation" +// "github.com/beego/beego/v2/core/validation" // "log" // ) // diff --git a/adapter/validation/validation_test.go b/adapter/validation/validation_test.go index b4b5b1b6..2e29b641 100644 --- a/adapter/validation/validation_test.go +++ b/adapter/validation/validation_test.go @@ -18,131 +18,83 @@ import ( "regexp" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestRequired(t *testing.T) { valid := Validation{} - if valid.Required(nil, "nil").Ok { - t.Error("nil object should be false") - } - if !valid.Required(true, "bool").Ok { - t.Error("Bool value should always return true") - } - if !valid.Required(false, "bool").Ok { - t.Error("Bool value should always return true") - } - if valid.Required("", "string").Ok { - t.Error("\"'\" string should be false") - } - if valid.Required(" ", "string").Ok { - t.Error("\" \" string should be false") // For #2361 - } - if valid.Required("\n", "string").Ok { - t.Error("new line string should be false") // For #2361 - } - if !valid.Required("astaxie", "string").Ok { - t.Error("string should be true") - } - if valid.Required(0, "zero").Ok { - t.Error("Integer should not be equal 0") - } - if !valid.Required(1, "int").Ok { - t.Error("Integer except 0 should be true") - } - if !valid.Required(time.Now(), "time").Ok { - t.Error("time should be true") - } - if valid.Required([]string{}, "emptySlice").Ok { - t.Error("empty slice should be false") - } - if !valid.Required([]interface{}{"ok"}, "slice").Ok { - t.Error("slice should be true") - } + assert.False(t, valid.Required(nil, "nil").Ok) + assert.True(t, valid.Required(true, "bool").Ok) + + assert.True(t, valid.Required(false, "bool").Ok) + assert.False(t, valid.Required("", "string").Ok) + assert.False(t, valid.Required(" ", "string").Ok) + assert.False(t, valid.Required("\n", "string").Ok) + + assert.True(t, valid.Required("astaxie", "string").Ok) + assert.False(t, valid.Required(0, "zero").Ok) + + assert.True(t, valid.Required(1, "int").Ok) + + assert.True(t, valid.Required(time.Now(), "time").Ok) + + assert.False(t, valid.Required([]string{}, "emptySlice").Ok) + + assert.True(t, valid.Required([]interface{}{"ok"}, "slice").Ok) } func TestMin(t *testing.T) { valid := Validation{} - if valid.Min(-1, 0, "min0").Ok { - t.Error("-1 is less than the minimum value of 0 should be false") - } - if !valid.Min(1, 0, "min0").Ok { - t.Error("1 is greater or equal than the minimum value of 0 should be true") - } + assert.False(t, valid.Min(-1, 0, "min0").Ok) + assert.True(t, valid.Min(1, 0, "min0").Ok) + } func TestMax(t *testing.T) { valid := Validation{} - if valid.Max(1, 0, "max0").Ok { - t.Error("1 is greater than the minimum value of 0 should be false") - } - if !valid.Max(-1, 0, "max0").Ok { - t.Error("-1 is less or equal than the maximum value of 0 should be true") - } + assert.False(t, valid.Max(1, 0, "max0").Ok) + assert.True(t, valid.Max(-1, 0, "max0").Ok) } func TestRange(t *testing.T) { valid := Validation{} - if valid.Range(-1, 0, 1, "range0_1").Ok { - t.Error("-1 is between 0 and 1 should be false") - } - if !valid.Range(1, 0, 1, "range0_1").Ok { - t.Error("1 is between 0 and 1 should be true") - } + assert.False(t, valid.Range(-1, 0, 1, "range0_1").Ok) + + assert.True(t, valid.Range(1, 0, 1, "range0_1").Ok) } func TestMinSize(t *testing.T) { valid := Validation{} - if valid.MinSize("", 1, "minSize1").Ok { - t.Error("the length of \"\" is less than the minimum value of 1 should be false") - } - if !valid.MinSize("ok", 1, "minSize1").Ok { - t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") - } - if valid.MinSize([]string{}, 1, "minSize1").Ok { - t.Error("the length of empty slice is less than the minimum value of 1 should be false") - } - if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { - t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") - } + assert.False(t, valid.MinSize("", 1, "minSize1").Ok) + + assert.True(t, valid.MinSize("ok", 1, "minSize1").Ok) + assert.False(t, valid.MinSize([]string{}, 1, "minSize1").Ok) + assert.True(t, valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok) } func TestMaxSize(t *testing.T) { valid := Validation{} - if valid.MaxSize("ok", 1, "maxSize1").Ok { - t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") - } - if !valid.MaxSize("", 1, "maxSize1").Ok { - t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") - } - if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { - t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") - } - if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { - t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") - } + assert.False(t, valid.MaxSize("ok", 1, "maxSize1").Ok) + assert.True(t, valid.MaxSize("", 1, "maxSize1").Ok) + assert.False(t, valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok) + assert.True(t, valid.MaxSize([]string{}, 1, "maxSize1").Ok) } func TestLength(t *testing.T) { valid := Validation{} - if valid.Length("", 1, "length1").Ok { - t.Error("the length of \"\" must equal 1 should be false") - } - if !valid.Length("1", 1, "length1").Ok { - t.Error("the length of \"1\" must equal 1 should be true") - } - if valid.Length([]string{}, 1, "length1").Ok { - t.Error("the length of empty slice must equal 1 should be false") - } - if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { - t.Error("the length of [\"ok\"] must equal 1 should be true") - } + assert.False(t, valid.Length("", 1, "length1").Ok) + assert.True(t, valid.Length("1", 1, "length1").Ok) + + assert.False(t, valid.Length([]string{}, 1, "length1").Ok) + assert.True(t, valid.Length([]interface{}{"ok"}, 1, "length1").Ok) } func TestAlpha(t *testing.T) { @@ -178,13 +130,16 @@ func TestAlphaNumeric(t *testing.T) { } } +const email = "suchuangji@gmail.com" + func TestMatch(t *testing.T) { valid := Validation{} if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") } - if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + + if !valid.Match(email, regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") } } @@ -217,7 +172,7 @@ func TestEmail(t *testing.T) { if valid.Email("not@a email", "email").Ok { t.Error("\"not@a email\" is a valid email address should be false") } - if !valid.Email("suchuangji@gmail.com", "email").Ok { + if !valid.Email(email, "email").Ok { t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") } if valid.Email("@suchuangji@gmail.com", "email").Ok { @@ -242,7 +197,7 @@ func TestIP(t *testing.T) { func TestBase64(t *testing.T) { valid := Validation{} - if valid.Base64("suchuangji@gmail.com", "base64").Ok { + if valid.Base64(email, "base64").Ok { t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") } if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { @@ -370,44 +325,25 @@ func TestValid(t *testing.T) { u := user{Name: "test@/test/;com", Age: 40} b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Error("validation should be passed") - } + assert.Nil(t, err) + assert.True(t, b) uptr := &user{Name: "test", Age: 40} valid.Clear() b, err = valid.Valid(uptr) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } - if len(valid.Errors) != 1 { - t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) - } - if valid.Errors[0].Key != "Name.Match" { - t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) - } + + assert.Nil(t, err) + assert.False(t, b) + assert.Equal(t, 1, len(valid.Errors)) + assert.Equal(t, "Name.Match", valid.Errors[0].Key) u = user{Name: "test@/test/;com", Age: 180} valid.Clear() b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } - if len(valid.Errors) != 1 { - t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) - } - if valid.Errors[0].Key != "Age.Range." { - t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) - } + assert.Nil(t, err) + assert.False(t, b) + assert.Equal(t, 1, len(valid.Errors)) + assert.Equal(t, "Age.Range.", valid.Errors[0].Key) } func TestRecursiveValid(t *testing.T) { @@ -432,12 +368,8 @@ func TestRecursiveValid(t *testing.T) { u := Account{Password: "abc123_", U: User{}} b, err := valid.RecursiveValid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) } func TestSkipValid(t *testing.T) { @@ -474,21 +406,13 @@ func TestSkipValid(t *testing.T) { valid := Validation{} b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) valid = Validation{RequiredFirst: true} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } + assert.Nil(t, err) + assert.True(t, b) } func TestPointer(t *testing.T) { @@ -506,12 +430,8 @@ func TestPointer(t *testing.T) { valid := Validation{} b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) validEmail := "a@a.com" u = User{ @@ -521,12 +441,8 @@ func TestPointer(t *testing.T) { valid = Validation{RequiredFirst: true} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } + assert.Nil(t, err) + assert.True(t, b) u = User{ ReqEmail: &validEmail, @@ -535,12 +451,8 @@ func TestPointer(t *testing.T) { valid = Validation{} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) invalidEmail := "a@a" u = User{ @@ -550,12 +462,8 @@ func TestPointer(t *testing.T) { valid = Validation{RequiredFirst: true} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) u = User{ ReqEmail: &validEmail, @@ -564,12 +472,8 @@ func TestPointer(t *testing.T) { valid = Validation{} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) } func TestCanSkipAlso(t *testing.T) { @@ -589,21 +493,14 @@ func TestCanSkipAlso(t *testing.T) { valid := Validation{RequiredFirst: true} b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) valid = Validation{RequiredFirst: true} valid.CanSkipAlso("Range") b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } + + assert.Nil(t, err) + assert.True(t, b) } diff --git a/client/cache/README.md b/client/cache/README.md index 7e65cbbf..df1ea095 100644 --- a/client/cache/README.md +++ b/client/cache/README.md @@ -4,7 +4,7 @@ cache is a Go cache manager. It can use many cache adapters. The repo is inspire ## How to install? - go get github.com/beego/beego/v2/cache + go get github.com/beego/beego/v2/client/cache ## What adapters are supported? @@ -15,7 +15,7 @@ As of now this cache support memory, Memcache and Redis. First you must import it import ( - "github.com/beego/beego/v2/cache" + "github.com/beego/beego/v2/client/cache" ) Then init a Cache (example with memory adapter) diff --git a/client/cache/cache.go b/client/cache/cache.go index e73a1c1a..87f7ba62 100644 --- a/client/cache/cache.go +++ b/client/cache/cache.go @@ -16,7 +16,7 @@ // Usage: // // import( -// "github.com/beego/beego/v2/cache" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("memory", `{"interval":60}`) @@ -33,8 +33,9 @@ package cache import ( "context" - "fmt" "time" + + "github.com/beego/beego/v2/core/berror" ) // Cache interface contains all behaviors for cache adapter. @@ -55,12 +56,14 @@ type Cache interface { // Set a cached value with key and expire time. Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error // Delete cached value by key. + // Should not return error if key not found Delete(ctx context.Context, key string) error // Increment a cached int value by key, as a counter. Incr(ctx context.Context, key string) error // Decrement a cached int value by key, as a counter. Decr(ctx context.Context, key string) error // Check if a cached value exists or not. + // if key is expired, return (false, nil) IsExist(ctx context.Context, key string) (bool, error) // Clear all cache. ClearAll(ctx context.Context) error @@ -78,7 +81,7 @@ var adapters = make(map[string]Instance) // it panics. func Register(name string, adapter Instance) { if adapter == nil { - panic("cache: Register adapter is nil") + panic(berror.Error(NilCacheAdapter, "cache: Register adapter is nil").Error()) } if _, ok := adapters[name]; ok { panic("cache: Register called twice for adapter " + name) @@ -92,7 +95,7 @@ func Register(name string, adapter Instance) { func NewCache(adapterName, config string) (adapter Cache, err error) { instanceFunc, ok := adapters[adapterName] if !ok { - err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + err = berror.Errorf(UnknownAdapter, "cache: unknown adapter name %s (forgot to import?)", adapterName) return } adapter = instanceFunc() diff --git a/client/cache/cache_test.go b/client/cache/cache_test.go index 85f83fc4..db651e94 100644 --- a/client/cache/cache_test.go +++ b/client/cache/cache_test.go @@ -16,17 +16,19 @@ package cache import ( "context" + "math" "os" + "strings" "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestCacheIncr(t *testing.T) { bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) // timeoutDuration := 10 * time.Second bm.Put(context.Background(), "edwardhey", 0, time.Second*20) @@ -46,11 +48,9 @@ func TestCacheIncr(t *testing.T) { } func TestCache(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second + bm, err := NewCache("memory", `{"interval":1}`) + assert.Nil(t, err) + timeoutDuration := 5 * time.Second if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } @@ -62,7 +62,7 @@ func TestCache(t *testing.T) { t.Error("get err") } - time.Sleep(30 * time.Second) + time.Sleep(7 * time.Second) if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("check err") @@ -73,130 +73,97 @@ func TestCache(t *testing.T) { } // test different integer type for incr & decr - testMultiIncrDecr(t, bm, timeoutDuration) + testMultiTypeIncrDecr(t, bm, timeoutDuration) + + // test overflow of incr&decr + testIncrOverFlow(t, bm, timeoutDuration) + testDecrOverFlow(t, bm, timeoutDuration) bm.Delete(context.Background(), "astaxie") - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("delete err") - } + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) - // test GetMulti - if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } - if v, _ := bm.Get(context.Background(), "astaxie"); v.(string) != "author" { - t.Error("get err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) - if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { - t.Error("check err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + v, _ := bm.Get(context.Background(), "astaxie") + assert.Equal(t, "author", v) + + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Equal(t, "author", vv[0]) + assert.Equal(t,"author1", vv[1]) + + vv, err = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0] != nil { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - if err != nil && err.Error() != "key [astaxie0] error: the key isn't exist" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Nil(t, vv[0]) + assert.Equal(t, "author1", vv[1]) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key isn't exist")) } func TestFileCache(t *testing.T) { bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) timeoutDuration := 10 * time.Second - if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie", 1, timeoutDuration)) - if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 1 { - t.Error("get err") - } + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + v, _ := bm.Get(context.Background(), "astaxie") + assert.Equal(t, 1, v) // test different integer type for incr & decr - testMultiIncrDecr(t, bm, timeoutDuration) + testMultiTypeIncrDecr(t, bm, timeoutDuration) + + // test overflow of incr&decr + testIncrOverFlow(t, bm, timeoutDuration) + testDecrOverFlow(t, bm, timeoutDuration) bm.Delete(context.Background(), "astaxie") - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("delete err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) // test string - if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } - if v, _ := bm.Get(context.Background(), "astaxie"); v.(string) != "author" { - t.Error("get err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + v, _ = bm.Get(context.Background(), "astaxie") + assert.Equal(t, "author", v) // test GetMulti - if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { - t.Error("check err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Equal(t, "author", vv[0]) + assert.Equal(t, "author1", vv[1]) vv, err = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0] != nil { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - if err == nil { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Nil(t, vv[0]) + + assert.Equal(t, "author1", vv[1]) + assert.NotNil(t, err) os.RemoveAll("cache") } -func testMultiIncrDecr(t *testing.T, c Cache, timeout time.Duration) { +func testMultiTypeIncrDecr(t *testing.T, c Cache, timeout time.Duration) { testIncrDecr(t, c, 1, 2, timeout) testIncrDecr(t, c, int32(1), int32(2), timeout) testIncrDecr(t, c, int64(1), int64(2), timeout) @@ -206,30 +173,52 @@ func testMultiIncrDecr(t *testing.T, c Cache, timeout time.Duration) { } func testIncrDecr(t *testing.T, c Cache, beforeIncr interface{}, afterIncr interface{}, timeout time.Duration) { - var err error ctx := context.Background() key := "incDecKey" - if err = c.Put(ctx, key, beforeIncr, timeout); err != nil { - t.Error("Get Error", err) - } - if err = c.Incr(ctx, key); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, c.Put(ctx, key, beforeIncr, timeout)) + assert.Nil(t, c.Incr(ctx, key)) - if v, _ := c.Get(ctx, key); v != afterIncr { - t.Error("Get Error") - } - if err = c.Decr(ctx, key); err != nil { - t.Error("Decr Error", err) - } + v, _ := c.Get(ctx, key) + assert.Equal(t, afterIncr, v) - if v, _ := c.Get(ctx, key); v != beforeIncr { - t.Error("Get Error") - } + assert.Nil(t, c.Decr(ctx, key)) - if err := c.Delete(ctx, key); err != nil { - t.Error("Delete Error") + v, _ = c.Get(ctx, key) + assert.Equal(t, v, beforeIncr) + assert.Nil(t, c.Delete(ctx, key)) +} + +func testIncrOverFlow(t *testing.T, c Cache, timeout time.Duration) { + ctx := context.Background() + key := "incKey" + + assert.Nil(t, c.Put(ctx, key, int64(math.MaxInt64), timeout)) + // int64 + defer func() { + assert.Nil(t, c.Delete(ctx, key)) + }() + assert.NotNil(t, c.Incr(ctx, key)) +} + +func testDecrOverFlow(t *testing.T, c Cache, timeout time.Duration) { + var err error + ctx := context.Background() + key := "decKey" + + // int64 + if err = c.Put(ctx, key, int64(math.MinInt64), timeout); err != nil { + t.Error("Put Error: ", err.Error()) + return + } + defer func() { + if err = c.Delete(ctx, key); err != nil { + t.Errorf("Delete error: %s", err.Error()) + } + }() + if err = c.Decr(ctx, key); err == nil { + t.Error("Decr error") + return } } diff --git a/client/cache/calc_utils.go b/client/cache/calc_utils.go new file mode 100644 index 00000000..e18946c1 --- /dev/null +++ b/client/cache/calc_utils.go @@ -0,0 +1,92 @@ +package cache + +import ( + "math" + + "github.com/beego/beego/v2/core/berror" +) + +var ( + ErrIncrementOverflow = berror.Error(IncrementOverflow, "this incr invocation will overflow.") + ErrDecrementOverflow = berror.Error(DecrementOverflow, "this decr invocation will overflow.") + ErrNotIntegerType = berror.Error(NotIntegerType, "item val is not (u)int (u)int32 (u)int64") +) + + + +func incr(originVal interface{}) (interface{}, error) { + switch val := originVal.(type) { + case int: + tmp := val + 1 + if val > 0 && tmp < 0 { + return nil, ErrIncrementOverflow + } + return tmp, nil + case int32: + if val == math.MaxInt32 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + case int64: + if val == math.MaxInt64 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + case uint: + tmp := val + 1 + if tmp < val { + return nil, ErrIncrementOverflow + } + return tmp, nil + case uint32: + if val == math.MaxUint32 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + case uint64: + if val == math.MaxUint64 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + default: + return nil, ErrNotIntegerType + } +} + +func decr(originVal interface{}) (interface{}, error) { + switch val := originVal.(type) { + case int: + tmp := val - 1 + if val < 0 && tmp > 0 { + return nil, ErrDecrementOverflow + } + return tmp, nil + case int32: + if val == math.MinInt32 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case int64: + if val == math.MinInt64 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case uint: + if val == 0 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case uint32: + if val == 0 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case uint64: + if val == 0 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + default: + return nil, ErrNotIntegerType + } +} diff --git a/client/cache/calc_utils_test.go b/client/cache/calc_utils_test.go new file mode 100644 index 00000000..1f8d3377 --- /dev/null +++ b/client/cache/calc_utils_test.go @@ -0,0 +1,140 @@ +package cache + +import ( + "math" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIncr(t *testing.T) { + // int + var originVal interface{} = int(1) + var updateVal interface{} = int(2) + val, err := incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(int(1<<(strconv.IntSize-1) - 1)) + assert.Equal(t, ErrIncrementOverflow, err) + + // int32 + originVal = int32(1) + updateVal = int32(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(int32(math.MaxInt32)) + assert.Equal(t, ErrIncrementOverflow, err) + + // int64 + originVal = int64(1) + updateVal = int64(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(int64(math.MaxInt64)) + assert.Equal(t, ErrIncrementOverflow, err) + + // uint + originVal = uint(1) + updateVal = uint(2) + val, err = incr(originVal) + + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(uint(1<<(strconv.IntSize) - 1)) + assert.Equal(t, ErrIncrementOverflow, err) + + // uint32 + originVal = uint32(1) + updateVal = uint32(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(uint32(math.MaxUint32)) + assert.Equal(t, ErrIncrementOverflow, err) + + // uint64 + originVal = uint64(1) + updateVal = uint64(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(uint64(math.MaxUint64)) + assert.Equal(t, ErrIncrementOverflow, err) + // other type + _, err = incr("string") + assert.Equal(t, ErrNotIntegerType, err) +} + +func TestDecr(t *testing.T) { + // int + var originVal interface{} = int(2) + var updateVal interface{} = int(1) + val, err := decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(int(-1 << (strconv.IntSize - 1))) + assert.Equal(t, ErrDecrementOverflow, err) + // int32 + originVal = int32(2) + updateVal = int32(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(int32(math.MinInt32)) + assert.Equal(t, ErrDecrementOverflow, err) + + // int64 + originVal = int64(2) + updateVal = int64(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(int64(math.MinInt64)) + assert.Equal(t, ErrDecrementOverflow, err) + + // uint + originVal = uint(2) + updateVal = uint(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(uint(0)) + assert.Equal(t, ErrDecrementOverflow, err) + + // uint32 + originVal = uint32(2) + updateVal = uint32(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(uint32(0)) + assert.Equal(t, ErrDecrementOverflow, err) + + // uint64 + originVal = uint64(2) + updateVal = uint64(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(uint64(0)) + assert.Equal(t, ErrDecrementOverflow, err) + + // other type + _, err = decr("string") + assert.Equal(t, ErrNotIntegerType, err) +} diff --git a/client/cache/conv_test.go b/client/cache/conv_test.go index b90e224a..523150d1 100644 --- a/client/cache/conv_test.go +++ b/client/cache/conv_test.go @@ -16,128 +16,74 @@ package cache import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestGetString(t *testing.T) { var t1 = "test1" - if "test1" != GetString(t1) { - t.Error("get string from string error") - } - var t2 = []byte("test2") - if "test2" != GetString(t2) { - t.Error("get string from byte array error") - } - var t3 = 1 - if "1" != GetString(t3) { - t.Error("get string from int error") - } - var t4 int64 = 1 - if "1" != GetString(t4) { - t.Error("get string from int64 error") - } - var t5 = 1.1 - if "1.1" != GetString(t5) { - t.Error("get string from float64 error") - } - if "" != GetString(nil) { - t.Error("get string from nil error") - } + assert.Equal(t, "test1", GetString(t1)) + var t2 = []byte("test2") + assert.Equal(t, "test2", GetString(t2)) + var t3 = 1 + assert.Equal(t, "1", GetString(t3)) + var t4 int64 = 1 + assert.Equal(t, "1", GetString(t4)) + var t5 = 1.1 + assert.Equal(t, "1.1", GetString(t5)) + assert.Equal(t, "", GetString(nil)) } func TestGetInt(t *testing.T) { var t1 = 1 - if 1 != GetInt(t1) { - t.Error("get int from int error") - } + assert.Equal(t, 1, GetInt(t1)) var t2 int32 = 32 - if 32 != GetInt(t2) { - t.Error("get int from int32 error") - } + assert.Equal(t, 32, GetInt(t2)) + var t3 int64 = 64 - if 64 != GetInt(t3) { - t.Error("get int from int64 error") - } + assert.Equal(t, 64, GetInt(t3)) var t4 = "128" - if 128 != GetInt(t4) { - t.Error("get int from num string error") - } - if 0 != GetInt(nil) { - t.Error("get int from nil error") - } + + assert.Equal(t, 128, GetInt(t4)) + assert.Equal(t, 0, GetInt(nil)) } func TestGetInt64(t *testing.T) { var i int64 = 1 var t1 = 1 - if i != GetInt64(t1) { - t.Error("get int64 from int error") - } + assert.Equal(t, i, GetInt64(t1)) var t2 int32 = 1 - if i != GetInt64(t2) { - t.Error("get int64 from int32 error") - } + + assert.Equal(t, i, GetInt64(t2)) var t3 int64 = 1 - if i != GetInt64(t3) { - t.Error("get int64 from int64 error") - } + assert.Equal(t, i, GetInt64(t3)) var t4 = "1" - if i != GetInt64(t4) { - t.Error("get int64 from num string error") - } - if 0 != GetInt64(nil) { - t.Error("get int64 from nil") - } + assert.Equal(t, i, GetInt64(t4)) + assert.Equal(t, int64(0), GetInt64(nil)) } func TestGetFloat64(t *testing.T) { var f = 1.11 var t1 float32 = 1.11 - if f != GetFloat64(t1) { - t.Error("get float64 from float32 error") - } + assert.Equal(t, f, GetFloat64(t1)) var t2 = 1.11 - if f != GetFloat64(t2) { - t.Error("get float64 from float64 error") - } + assert.Equal(t, f, GetFloat64(t2)) var t3 = "1.11" - if f != GetFloat64(t3) { - t.Error("get float64 from string error") - } + assert.Equal(t, f, GetFloat64(t3)) var f2 float64 = 1 var t4 = 1 - if f2 != GetFloat64(t4) { - t.Error("get float64 from int error") - } + assert.Equal(t, f2, GetFloat64(t4)) - if 0 != GetFloat64(nil) { - t.Error("get float64 from nil error") - } + assert.Equal(t, float64(0), GetFloat64(nil)) } func TestGetBool(t *testing.T) { var t1 = true - if !GetBool(t1) { - t.Error("get bool from bool error") - } + assert.True(t, GetBool(t1)) var t2 = "true" - if !GetBool(t2) { - t.Error("get bool from string error") - } - if GetBool(nil) { - t.Error("get bool from nil error") - } -} + assert.True(t, GetBool(t2)) -func byteArrayEquals(a []byte, b []byte) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true + assert.False(t, GetBool(nil)) } diff --git a/client/cache/error_code.go b/client/cache/error_code.go new file mode 100644 index 00000000..4305d7e4 --- /dev/null +++ b/client/cache/error_code.go @@ -0,0 +1,182 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "github.com/beego/beego/v2/core/berror" +) + +var NilCacheAdapter = berror.DefineCode(4002001, moduleName, "NilCacheAdapter", ` +It means that you register cache adapter by pass nil. +A cache adapter is an instance of Cache interface. +`) + +var DuplicateAdapter = berror.DefineCode(4002002, moduleName, "DuplicateAdapter", ` +You register two adapter with same name. In beego cache module, one name one adapter. +Once you got this error, please check the error stack, search adapter +`) + +var UnknownAdapter = berror.DefineCode(4002003, moduleName, "UnknownAdapter", ` +Unknown adapter, do you forget to register the adapter? +You must register adapter before use it. For example, if you want to use redis implementation, +you must import the cache/redis package. +`) + +var IncrementOverflow = berror.DefineCode(4002004, moduleName, "IncrementOverflow", ` +The increment operation will overflow. +`) + +var DecrementOverflow = berror.DefineCode(4002005, moduleName, "DecrementOverflow", ` +The decrement operation will overflow. +`) + +var NotIntegerType = berror.DefineCode(4002006, moduleName, "NotIntegerType", ` +The type of value is not (u)int (u)int32 (u)int64. +When you want to call Incr or Decr function of Cache API, you must confirm that the value's type is one of (u)int (u)int32 (u)int64. +`) + +var InvalidFileCacheDirectoryLevelCfg = berror.DefineCode(4002007, moduleName, "InvalidFileCacheDirectoryLevelCfg", ` +You pass invalid DirectoryLevel parameter when you try to StartAndGC file cache instance. +This parameter must be a integer, and please check your input. +`) + +var InvalidFileCacheEmbedExpiryCfg = berror.DefineCode(4002008, moduleName, "InvalidFileCacheEmbedExpiryCfg", ` +You pass invalid EmbedExpiry parameter when you try to StartAndGC file cache instance. +This parameter must be a integer, and please check your input. +`) + +var CreateFileCacheDirFailed = berror.DefineCode(4002009, moduleName, "CreateFileCacheDirFailed", ` +Beego failed to create file cache directory. There are two cases: +1. You pass invalid CachePath parameter. Please check your input. +2. Beego doesn't have the permission to create this directory. Please check your file mode. +`) + +var InvalidFileCachePath = berror.DefineCode(4002010, moduleName, "InvalidFilePath", ` +The file path of FileCache is invalid. Please correct the config. +`) + +var ReadFileCacheContentFailed = berror.DefineCode(4002011, moduleName, "ReadFileCacheContentFailed", ` +Usually you won't got this error. It means that Beego cannot read the data from the file. +You need to check whether the file exist. Sometimes it may be deleted by other processes. +If the file exists, please check the permission that Beego is able to read data from the file. +`) + +var InvalidGobEncodedData = berror.DefineCode(4002012, moduleName, "InvalidEncodedData", ` +The data is invalid. When you try to decode the invalid data, you got this error. +Please confirm that the data is encoded by GOB correctly. +`) + +var GobEncodeDataFailed = berror.DefineCode(4002013, moduleName, "GobEncodeDataFailed", ` +Beego could not encode the data to GOB byte array. In general, the data type is invalid. +For example, GOB doesn't support function type. +Basic types, string, structure, structure pointer are supported. +`) + +var KeyExpired = berror.DefineCode(4002014, moduleName, "KeyExpired", ` +Cache key is expired. +You should notice that, a key is expired and then it may be deleted by GC goroutine. +So when you query a key which may be expired, you may got this code, or KeyNotExist. +`) + +var KeyNotExist = berror.DefineCode(4002015, moduleName, "KeyNotExist", ` +Key not found. +`) + +var MultiGetFailed = berror.DefineCode(4002016, moduleName, "MultiGetFailed", ` +Get multiple keys failed. Please check the detail msg to find out the root cause. +`) + +var InvalidMemoryCacheCfg = berror.DefineCode(4002017, moduleName, "InvalidMemoryCacheCfg", ` +The config is invalid. Please check your input. It must be a json string. +`) + +var InvalidMemCacheCfg = berror.DefineCode(4002018, moduleName, "InvalidMemCacheCfg", ` +The config is invalid. Please check your input, it must be json string and contains "conn" field. +`) + +var InvalidMemCacheValue = berror.DefineCode(4002019, moduleName, "InvalidMemCacheValue", ` +The value must be string or byte[], please check your input. +`) + +var InvalidRedisCacheCfg = berror.DefineCode(4002020, moduleName, "InvalidRedisCacheCfg", ` +The config must be json string, and has "conn" field. +`) + +var InvalidSsdbCacheCfg = berror.DefineCode(4002021, moduleName, "InvalidSsdbCacheCfg", ` +The config must be json string, and has "conn" field. The value of "conn" field should be "host:port". +"port" must be a valid integer. +`) + +var InvalidSsdbCacheValue = berror.DefineCode(4002022, moduleName, "InvalidSsdbCacheValue", ` +SSDB cache only accept string value. Please check your input. +`) + + + + + + + + + +var DeleteFileCacheItemFailed = berror.DefineCode(5002001, moduleName, "DeleteFileCacheItemFailed", ` +Beego try to delete file cache item failed. +Please check whether Beego generated file correctly. +And then confirm whether this file is already deleted by other processes or other people. +`) + +var MemCacheCurdFailed = berror.DefineCode(5002002, moduleName, "MemCacheError", ` +When you want to get, put, delete key-value from remote memcache servers, you may get error: +1. You pass invalid servers address, so Beego could not connect to remote server; +2. The servers address is correct, but there is some net issue. Typically there is some firewalls between application and memcache server; +3. Key is invalid. The key's length should be less than 250 and must not contains special characters; +4. The response from memcache server is invalid; +`) + +var RedisCacheCurdFailed = berror.DefineCode(5002003, moduleName, "RedisCacheCurdFailed", ` +When Beego uses client to send request to redis server, it failed. +1. The server addresses is invalid; +2. Network issue, firewall issue or network is unstable; +3. Client failed to manage connection. In extreme cases, Beego's redis client didn't maintain connections correctly, for example, Beego try to send request via closed connection; +4. The request are huge and redis server spent too much time to process it, and client is timeout; + +In general, if you always got this error whatever you do, in most cases, it was caused by network issue. +You could check your network state, and confirm that firewall rules are correct. +`) + +var InvalidConnection = berror.DefineCode(5002004, moduleName, "InvalidConnection", ` +The connection is invalid. Please check your connection info, network, firewall. +You could simply uses ping, telnet or write some simple tests to test network. +`) + +var DialFailed = berror.DefineCode(5002005, moduleName, "DialFailed", ` +When Beego try to dial to remote servers, it failed. Please check your connection info and network state, server state. +`) + +var SsdbCacheCurdFailed = berror.DefineCode(5002006, moduleName, "SsdbCacheCurdFailed", ` +When you try to use SSDB cache, it failed. There are many cases: +1. servers unavailable; +2. network issue, including network unstable, firewall; +3. connection issue; +4. request are huge and servers spent too much time to process it, got timeout; +`) + +var SsdbBadResponse = berror.DefineCode(5002007, moduleName, "SsdbBadResponse", ` +The reponse from SSDB server is invalid. +Usually it indicates something wrong on server side. +`) + +var ErrKeyExpired = berror.Error(KeyExpired, "the key is expired") +var ErrKeyNotExist = berror.Error(KeyNotExist, "the key isn't exist") \ No newline at end of file diff --git a/client/cache/file.go b/client/cache/file.go index 043c4650..ea00c72e 100644 --- a/client/cache/file.go +++ b/client/cache/file.go @@ -30,7 +30,7 @@ import ( "strings" "time" - "github.com/pkg/errors" + "github.com/beego/beego/v2/core/berror" ) // FileCacheItem is basic unit of file cache adapter which @@ -73,38 +73,60 @@ func (fc *FileCache) StartAndGC(config string) error { if err != nil { return err } - if _, ok := cfg["CachePath"]; !ok { - cfg["CachePath"] = FileCachePath - } - if _, ok := cfg["FileSuffix"]; !ok { - cfg["FileSuffix"] = FileCacheFileSuffix - } - if _, ok := cfg["DirectoryLevel"]; !ok { - cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) - } - if _, ok := cfg["EmbedExpiry"]; !ok { - cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) - } - fc.CachePath = cfg["CachePath"] - fc.FileSuffix = cfg["FileSuffix"] - fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) - fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) - fc.Init() - return nil + const cpKey = "CachePath" + const fsKey = "FileSuffix" + const dlKey = "DirectoryLevel" + const eeKey = "EmbedExpiry" + + if _, ok := cfg[cpKey]; !ok { + cfg[cpKey] = FileCachePath + } + + if _, ok := cfg[fsKey]; !ok { + cfg[fsKey] = FileCacheFileSuffix + } + + if _, ok := cfg[dlKey]; !ok { + cfg[dlKey] = strconv.Itoa(FileCacheDirectoryLevel) + } + + if _, ok := cfg[eeKey]; !ok { + cfg[eeKey] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) + } + fc.CachePath = cfg[cpKey] + fc.FileSuffix = cfg[fsKey] + fc.DirectoryLevel, err = strconv.Atoi(cfg[dlKey]) + if err != nil { + return berror.Wrapf(err, InvalidFileCacheDirectoryLevelCfg, + "invalid directory level config, please check your input, it must be integer: %s", cfg[dlKey]) + } + fc.EmbedExpiry, err = strconv.Atoi(cfg[eeKey]) + if err != nil { + return berror.Wrapf(err, InvalidFileCacheEmbedExpiryCfg, + "invalid embed expiry config, please check your input, it must be integer: %s", cfg[eeKey]) + } + return fc.Init() } // Init makes new a dir for file cache if it does not already exist -func (fc *FileCache) Init() { - if ok, _ := exists(fc.CachePath); !ok { // todo : error handle - _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle +func (fc *FileCache) Init() error { + ok, err := exists(fc.CachePath) + if err != nil || ok { + return err } + err = os.MkdirAll(fc.CachePath, os.ModePerm) + if err != nil { + return berror.Wrapf(err, CreateFileCacheDirFailed, + "could not create directory, please check the config [%s] and file mode.", fc.CachePath) + } + return nil } // getCachedFilename returns an md5 encoded file name. -func (fc *FileCache) getCacheFileName(key string) string { +func (fc *FileCache) getCacheFileName(key string) (string, error) { m := md5.New() - io.WriteString(m, key) + _, _ = io.WriteString(m, key) keyMd5 := hex.EncodeToString(m.Sum(nil)) cachePath := fc.CachePath switch fc.DirectoryLevel { @@ -113,18 +135,29 @@ func (fc *FileCache) getCacheFileName(key string) string { case 1: cachePath = filepath.Join(cachePath, keyMd5[0:2]) } - - if ok, _ := exists(cachePath); !ok { // todo : error handle - _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle + ok, err := exists(cachePath) + if err != nil { + return "", err + } + if !ok { + err = os.MkdirAll(cachePath, os.ModePerm) + if err != nil { + return "", berror.Wrapf(err, CreateFileCacheDirFailed, + "could not create the directory: %s", cachePath) + } } - return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) + return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)), nil } // Get value from file cache. // if nonexistent or expired return an empty string. func (fc *FileCache) Get(ctx context.Context, key string) (interface{}, error) { - fileData, err := FileGetContents(fc.getCacheFileName(key)) + fn, err := fc.getCacheFileName(key) + if err != nil { + return nil, err + } + fileData, err := FileGetContents(fn) if err != nil { return nil, err } @@ -136,7 +169,7 @@ func (fc *FileCache) Get(ctx context.Context, key string) (interface{}, error) { } if to.Expired.Before(time.Now()) { - return nil, errors.New("The key is expired") + return nil, ErrKeyExpired } return to.Data, nil } @@ -159,7 +192,7 @@ func (fc *FileCache) GetMulti(ctx context.Context, keys []string) ([]interface{} if len(keysErr) == 0 { return rc, nil } - return rc, errors.New(strings.Join(keysErr, "; ")) + return rc, berror.Error(MultiGetFailed, strings.Join(keysErr, "; ")) } // Put value into file cache. @@ -179,14 +212,26 @@ func (fc *FileCache) Put(ctx context.Context, key string, val interface{}, timeo if err != nil { return err } - return FilePutContents(fc.getCacheFileName(key), data) + + fn, err := fc.getCacheFileName(key) + if err != nil { + return err + } + return FilePutContents(fn, data) } // Delete file cache value. func (fc *FileCache) Delete(ctx context.Context, key string) error { - filename := fc.getCacheFileName(key) + filename, err := fc.getCacheFileName(key) + if err != nil { + return err + } if ok, _ := exists(filename); ok { - return os.Remove(filename) + err = os.Remove(filename) + if err != nil { + return berror.Wrapf(err, DeleteFileCacheItemFailed, + "can not delete this file cache key-value, key is %s and file name is %s", key, filename) + } } return nil } @@ -199,25 +244,12 @@ func (fc *FileCache) Incr(ctx context.Context, key string) error { return err } - var res interface{} - switch val := data.(type) { - case int: - res = val + 1 - case int32: - res = val + 1 - case int64: - res = val + 1 - case uint: - res = val + 1 - case uint32: - res = val + 1 - case uint64: - res = val + 1 - default: - return errors.Errorf("data is not (u)int (u)int32 (u)int64") + val, err := incr(data) + if err != nil { + return err } - return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) + return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry)) } // Decr decreases cached int value. @@ -227,43 +259,21 @@ func (fc *FileCache) Decr(ctx context.Context, key string) error { return err } - var res interface{} - switch val := data.(type) { - case int: - res = val - 1 - case int32: - res = val - 1 - case int64: - res = val - 1 - case uint: - if val > 0 { - res = val - 1 - } else { - return errors.New("data val is less than 0") - } - case uint32: - if val > 0 { - res = val - 1 - } else { - return errors.New("data val is less than 0") - } - case uint64: - if val > 0 { - res = val - 1 - } else { - return errors.New("data val is less than 0") - } - default: - return errors.Errorf("data is not (u)int (u)int32 (u)int64") + val, err := decr(data) + if err != nil { + return err } - return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) + return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry)) } // IsExist checks if value exists. func (fc *FileCache) IsExist(ctx context.Context, key string) (bool, error) { - ret, _ := exists(fc.getCacheFileName(key)) - return ret, nil + fn, err := fc.getCacheFileName(key) + if err != nil { + return false, err + } + return exists(fn) } // ClearAll cleans cached files (not implemented) @@ -280,13 +290,19 @@ func exists(path string) (bool, error) { if os.IsNotExist(err) { return false, nil } - return false, err + return false, berror.Wrapf(err, InvalidFileCachePath, "file cache path is invalid: %s", path) } // FileGetContents Reads bytes from a file. // if non-existent, create this file. -func FileGetContents(filename string) (data []byte, e error) { - return ioutil.ReadFile(filename) +func FileGetContents(filename string) ([]byte, error) { + data, err := ioutil.ReadFile(filename) + if err != nil { + return nil, berror.Wrapf(err, ReadFileCacheContentFailed, + "could not read the data from the file: %s, " + + "please confirm that file exist and Beego has the permission to read the content.", filename) + } + return data, nil } // FilePutContents puts bytes into a file. @@ -301,16 +317,21 @@ func GobEncode(data interface{}) ([]byte, error) { enc := gob.NewEncoder(buf) err := enc.Encode(data) if err != nil { - return nil, err + return nil, berror.Wrap(err, GobEncodeDataFailed, "could not encode this data") } - return buf.Bytes(), err + return buf.Bytes(), nil } // GobDecode Gob decodes a file cache item. func GobDecode(data []byte, to *FileCacheItem) error { buf := bytes.NewBuffer(data) dec := gob.NewDecoder(buf) - return dec.Decode(&to) + err := dec.Decode(&to) + if err != nil { + return berror.Wrap(err, InvalidGobEncodedData, + "could not decode this data to FileCacheItem. Make sure that the data is encoded by GOB.") + } + return nil } func init() { diff --git a/client/cache/file_test.go b/client/cache/file_test.go new file mode 100644 index 00000000..3ffc27f3 --- /dev/null +++ b/client/cache/file_test.go @@ -0,0 +1,108 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFileCacheStartAndGC(t *testing.T) { + fc := NewFileCache().(*FileCache) + err := fc.StartAndGC(`{`) + assert.NotNil(t, err) + err = fc.StartAndGC(`{}`) + assert.Nil(t, err) + + assert.Equal(t, fc.CachePath, FileCachePath) + assert.Equal(t, fc.DirectoryLevel, FileCacheDirectoryLevel) + assert.Equal(t, fc.EmbedExpiry, int(FileCacheEmbedExpiry)) + assert.Equal(t, fc.FileSuffix, FileCacheFileSuffix) + + err = fc.StartAndGC(`{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + // could not create dir + assert.NotNil(t, err) + + str := getTestCacheFilePath() + err = fc.StartAndGC(fmt.Sprintf(`{"CachePath":"%s","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`, str)) + assert.Nil(t, err) + assert.Equal(t, fc.CachePath, str) + assert.Equal(t, fc.DirectoryLevel, 2) + assert.Equal(t, fc.EmbedExpiry, 0) + assert.Equal(t, fc.FileSuffix, ".bin") + + err = fc.StartAndGC(fmt.Sprintf(`{"CachePath":"%s","FileSuffix":".bin","DirectoryLevel":"aaa","EmbedExpiry":"0"}`, str)) + assert.NotNil(t, err) + + err = fc.StartAndGC(fmt.Sprintf(`{"CachePath":"%s","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"aaa"}`, str)) + assert.NotNil(t, err) +} + +func TestFileCacheInit(t *testing.T) { + fc := NewFileCache().(*FileCache) + fc.CachePath = "////aaa" + err := fc.Init() + assert.NotNil(t, err) + fc.CachePath = getTestCacheFilePath() + err = fc.Init() + assert.Nil(t, err) +} + +func TestFileGetContents(t *testing.T) { + _, err := FileGetContents("/bin/aaa") + assert.NotNil(t, err) + fn := filepath.Join(os.TempDir(), "fileCache.txt") + f, err := os.Create(fn) + assert.Nil(t, err) + _, err = f.WriteString("text") + assert.Nil(t, err) + data, err := FileGetContents(fn) + assert.Nil(t, err) + assert.Equal(t, "text", string(data)) +} + +func TestGobEncodeDecode(t *testing.T) { + _, err := GobEncode(func() { + fmt.Print("test func") + }) + assert.NotNil(t, err) + data, err := GobEncode(&FileCacheItem{ + Data: "hello", + }) + assert.Nil(t, err) + err = GobDecode([]byte("wrong data"), &FileCacheItem{}) + assert.NotNil(t, err) + dci := &FileCacheItem{} + err = GobDecode(data, dci) + assert.Nil(t, err) + assert.Equal(t, "hello", dci.Data) +} + +func TestFileCacheDelete(t *testing.T) { + fc := NewFileCache() + err := fc.StartAndGC(`{}`) + assert.Nil(t, err) + err = fc.Delete(context.Background(), "my-key") + assert.Nil(t, err) +} + +func getTestCacheFilePath() string { + return filepath.Join(os.TempDir(), "test", "file.txt") +} \ No newline at end of file diff --git a/client/cache/memcache/memcache.go b/client/cache/memcache/memcache.go index 527d08ca..3816444f 100644 --- a/client/cache/memcache/memcache.go +++ b/client/cache/memcache/memcache.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/cache/memcache" -// "github.com/beego/beego/v2/cache" +// _ "github.com/beego/beego/v2/client/cache/memcache" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) @@ -32,7 +32,6 @@ package memcache import ( "context" "encoding/json" - "errors" "fmt" "strings" "time" @@ -40,6 +39,7 @@ import ( "github.com/bradfitz/gomemcache/memcache" "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" ) // Cache Memcache adapter. @@ -55,36 +55,31 @@ func NewMemCache() cache.Cache { // Get get value from memcache. func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil, err - } - } if item, err := rc.conn.Get(key); err == nil { return item.Value, nil } else { - return nil, err + return nil, berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not read data from memcache, please check your key, network and connection. Root cause: %s", + err.Error()) } } // GetMulti gets a value from a key in memcache. func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { rv := make([]interface{}, len(keys)) - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return rv, err - } - } mv, err := rc.conn.GetMulti(keys) if err != nil { - return rv, err + return rv, berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not read multiple key-values from memcache, " + + "please check your keys, network and connection. Root cause: %s", + err.Error()) } keysErr := make([]string, 0) for i, ki := range keys { if _, ok := mv[ki]; !ok { - keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "the key isn't exist")) + keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "key not exist")) continue } rv[i] = mv[ki].Value @@ -93,78 +88,54 @@ func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, er if len(keysErr) == 0 { return rv, nil } - return rv, fmt.Errorf(strings.Join(keysErr, "; ")) + return rv, berror.Error(cache.MultiGetFailed, strings.Join(keysErr, "; ")) } // Put puts a value into memcache. func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} if v, ok := val.([]byte); ok { item.Value = v } else if str, ok := val.(string); ok { item.Value = []byte(str) } else { - return errors.New("val only support string and []byte") + return berror.Errorf(cache.InvalidMemCacheValue, + "the value must be string or byte[]. key: %s, value:%v", key, val) } - return rc.conn.Set(&item) + return berror.Wrapf(rc.conn.Set(&item), cache.MemCacheCurdFailed, + "could not put key-value to memcache, key: %s", key) } // Delete deletes a value in memcache. func (rc *Cache) Delete(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.Delete(key) + return berror.Wrapf(rc.conn.Delete(key), cache.MemCacheCurdFailed, + "could not delete key-value from memcache, key: %s", key) } // Incr increases counter. func (rc *Cache) Incr(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Increment(key, 1) - return err + return berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not increase value for key: %s", key) } // Decr decreases counter. func (rc *Cache) Decr(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Decrement(key, 1) - return err + return berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not decrease value for key: %s", key) } // IsExist checks if a value exists in memcache. func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false, err - } - } - _, err := rc.conn.Get(key) + _, err := rc.Get(ctx, key) return err == nil, err } // ClearAll clears all cache in memcache. func (rc *Cache) ClearAll(context.Context) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.FlushAll() + return berror.Wrap(rc.conn.FlushAll(), cache.MemCacheCurdFailed, + "try to clear all key-value pairs failed") } // StartAndGC starts the memcache adapter. @@ -172,21 +143,15 @@ func (rc *Cache) ClearAll(context.Context) error { // If an error occurs during connecting, an error is returned func (rc *Cache) StartAndGC(config string) error { var cf map[string]string - json.Unmarshal([]byte(config), &cf) + if err := json.Unmarshal([]byte(config), &cf); err != nil { + return berror.Wrapf(err, cache.InvalidMemCacheCfg, + "could not unmarshal this config, it must be valid json stringP: %s", config) + } + if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") + return berror.Errorf(cache.InvalidMemCacheCfg, `config must contains "conn" field: %s`, config) } rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil -} - -// connect to memcache and keep the connection. -func (rc *Cache) connectInit() error { rc.conn = memcache.New(rc.conninfo...) return nil } diff --git a/client/cache/memcache/memcache_test.go b/client/cache/memcache/memcache_test.go index 083e661c..a6c1f19c 100644 --- a/client/cache/memcache/memcache_test.go +++ b/client/cache/memcache/memcache_test.go @@ -19,10 +19,12 @@ import ( "fmt" "os" "strconv" + "strings" "testing" "time" _ "github.com/bradfitz/gomemcache/memcache" + "github.com/stretchr/testify/assert" "github.com/beego/beego/v2/client/cache" ) @@ -34,78 +36,63 @@ func TestMemcacheCache(t *testing.T) { } bm, err := cache.NewCache("memcache", fmt.Sprintf(`{"conn": "%s"}`, addr)) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) + timeoutDuration := 10 * time.Second - if err = bm.Put(context.Background(), "astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "1", timeoutDuration)) + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) time.Sleep(11 * time.Second) - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("check err") - } - if err = bm.Put(context.Background(), "astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "1", timeoutDuration)) val, _ := bm.Get(context.Background(), "astaxie") - if v, err := strconv.Atoi(string(val.([]byte))); err != nil || v != 1 { - t.Error("get err") - } + v, err := strconv.Atoi(string(val.([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) - if err = bm.Incr(context.Background(), "astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, bm.Incr(context.Background(), "astaxie")) val, _ = bm.Get(context.Background(), "astaxie") - if v, err := strconv.Atoi(string(val.([]byte))); err != nil || v != 2 { - t.Error("get err") - } + v, err = strconv.Atoi(string(val.([]byte))) + assert.Nil(t, err) + assert.Equal(t, 2, v) - if err = bm.Decr(context.Background(), "astaxie"); err != nil { - t.Error("Decr Error", err) - } + assert.Nil(t, bm.Decr(context.Background(), "astaxie")) val, _ = bm.Get(context.Background(), "astaxie") - if v, err := strconv.Atoi(string(val.([]byte))); err != nil || v != 1 { - t.Error("get err") - } + v, err = strconv.Atoi(string(val.([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) bm.Delete(context.Background(), "astaxie") - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("delete err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t,bm.Put(context.Background(), "astaxie", "author", timeoutDuration) ) // test string - if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) val, _ = bm.Get(context.Background(), "astaxie") - if v := val.([]byte); string(v) != "author" { - t.Error("get err") - } + vs := val.([]byte) + assert.Equal(t, "author", string(vs)) // test GetMulti - if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { - t.Error("check err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { t.Error("GetMulti ERROR") } @@ -114,21 +101,14 @@ func TestMemcacheCache(t *testing.T) { } vv, err = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0] != nil { - t.Error("GetMulti ERROR") - } - if string(vv[1].([]byte)) != "author1" { - t.Error("GetMulti ERROR") - } - if err != nil && err.Error() == "key [astaxie0] error: key isn't exist" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Nil(t, vv[0]) + assert.Equal(t, "author1", string(vv[1].([]byte))) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key not exist")) + + assert.Nil(t, bm.ClearAll(context.Background())) // test clear all - if err = bm.ClearAll(context.Background()); err != nil { - t.Error("clear all err") - } } diff --git a/client/cache/memory.go b/client/cache/memory.go index 28c7d980..f294595d 100644 --- a/client/cache/memory.go +++ b/client/cache/memory.go @@ -17,11 +17,12 @@ package cache import ( "context" "encoding/json" - "errors" "fmt" "strings" "sync" "time" + + "github.com/beego/beego/v2/core/berror" ) var ( @@ -41,7 +42,7 @@ func (mi *MemoryItem) isExpire() bool { if mi.lifespan == 0 { return false } - return time.Now().Sub(mi.createdTime) > mi.lifespan + return time.Since(mi.createdTime) > mi.lifespan } // MemoryCache is a memory cache adapter. @@ -64,13 +65,14 @@ func NewMemoryCache() Cache { func (bc *MemoryCache) Get(ctx context.Context, key string) (interface{}, error) { bc.RLock() defer bc.RUnlock() - if itm, ok := bc.items[key]; ok { + if itm, ok := + bc.items[key]; ok { if itm.isExpire() { - return nil, errors.New("the key is expired") + return nil, ErrKeyExpired } return itm.val, nil } - return nil, errors.New("the key isn't exist") + return nil, ErrKeyNotExist } // GetMulti gets caches from memory. @@ -91,7 +93,7 @@ func (bc *MemoryCache) GetMulti(ctx context.Context, keys []string) ([]interface if len(keysErr) == 0 { return rc, nil } - return rc, errors.New(strings.Join(keysErr, "; ")) + return rc, berror.Error(MultiGetFailed, strings.Join(keysErr, "; ")) } // Put puts cache into memory. @@ -108,16 +110,11 @@ func (bc *MemoryCache) Put(ctx context.Context, key string, val interface{}, tim } // Delete cache in memory. +// If the key is not found, it will not return error func (bc *MemoryCache) Delete(ctx context.Context, key string) error { bc.Lock() defer bc.Unlock() - if _, ok := bc.items[key]; !ok { - return errors.New("key not exist") - } delete(bc.items, key) - if _, ok := bc.items[key]; ok { - return errors.New("delete key error") - } return nil } @@ -128,24 +125,14 @@ func (bc *MemoryCache) Incr(ctx context.Context, key string) error { defer bc.Unlock() itm, ok := bc.items[key] if !ok { - return errors.New("key not exist") + return ErrKeyNotExist } - switch val := itm.val.(type) { - case int: - itm.val = val + 1 - case int32: - itm.val = val + 1 - case int64: - itm.val = val + 1 - case uint: - itm.val = val + 1 - case uint32: - itm.val = val + 1 - case uint64: - itm.val = val + 1 - default: - return errors.New("item val is not (u)int (u)int32 (u)int64") + + val, err := incr(itm.val) + if err != nil { + return err } + itm.val = val return nil } @@ -155,36 +142,14 @@ func (bc *MemoryCache) Decr(ctx context.Context, key string) error { defer bc.Unlock() itm, ok := bc.items[key] if !ok { - return errors.New("key not exist") + return ErrKeyNotExist } - switch val := itm.val.(type) { - case int: - itm.val = val - 1 - case int64: - itm.val = val - 1 - case int32: - itm.val = val - 1 - case uint: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint32: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint64: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - default: - return errors.New("item val is not int int64 int32") + + val, err := decr(itm.val) + if err != nil { + return err } + itm.val = val return nil } @@ -209,7 +174,9 @@ func (bc *MemoryCache) ClearAll(context.Context) error { // StartAndGC starts memory cache. Checks expiration in every clock time. func (bc *MemoryCache) StartAndGC(config string) error { var cf map[string]int - json.Unmarshal([]byte(config), &cf) + if err := json.Unmarshal([]byte(config), &cf); err != nil { + return berror.Wrapf(err, InvalidMemoryCacheCfg, "invalid config, please check your input: %s", config) + } if _, ok := cf["interval"]; !ok { cf = make(map[string]int) cf["interval"] = DefaultEvery diff --git a/client/cache/module.go b/client/cache/module.go new file mode 100644 index 00000000..5a4e499e --- /dev/null +++ b/client/cache/module.go @@ -0,0 +1,17 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +const moduleName = "cache" diff --git a/client/cache/redis/redis.go b/client/cache/redis/redis.go index dcf0cd5a..7e70af2e 100644 --- a/client/cache/redis/redis.go +++ b/client/cache/redis/redis.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/cache/redis" -// "github.com/beego/beego/v2/cache" +// _ "github.com/beego/beego/v2/client/cache/redis" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) @@ -32,7 +32,6 @@ package redis import ( "context" "encoding/json" - "errors" "fmt" "strconv" "strings" @@ -41,6 +40,7 @@ import ( "github.com/gomodule/redigo/redis" "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" ) var ( @@ -67,15 +67,20 @@ func NewRedisCache() cache.Cache { } // Execute the redis commands. args[0] must be the key name -func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { - if len(args) < 1 { - return nil, errors.New("missing required arguments") - } +func (rc *Cache) do(commandName string, args ...interface{}) (interface{}, error) { args[0] = rc.associate(args[0]) c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() - return c.Do(commandName, args...) + reply, err := c.Do(commandName, args...) + if err != nil { + return nil, berror.Wrapf(err, cache.RedisCacheCurdFailed, + "could not execute this command: %s", commandName) + } + + return reply, nil } // associate with config key. @@ -95,7 +100,9 @@ func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { // GetMulti gets cache from redis. func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() var args []interface{} for _, key := range keys { args = append(args, rc.associate(key)) @@ -137,13 +144,16 @@ func (rc *Cache) Decr(ctx context.Context, key string) error { } // ClearAll deletes all cache in the redis collection +// Be careful about this method, because it scans all keys and the delete them one by one func (rc *Cache) ClearAll(context.Context) error { cachedKeys, err := rc.Scan(rc.key + ":*") if err != nil { return err } c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() for _, str := range cachedKeys { if _, err = c.Do("DEL", str); err != nil { return err @@ -155,7 +165,9 @@ func (rc *Cache) ClearAll(context.Context) error { // Scan scans all keys matching a given pattern. func (rc *Cache) Scan(pattern string) (keys []string, err error) { c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() var ( cursor uint64 = 0 // start result []interface{} @@ -186,13 +198,16 @@ func (rc *Cache) Scan(pattern string) (keys []string, err error) { // Cached items in redis are stored forever, no garbage collection happens func (rc *Cache) StartAndGC(config string) error { var cf map[string]string - json.Unmarshal([]byte(config), &cf) + err := json.Unmarshal([]byte(config), &cf) + if err != nil { + return berror.Wrapf(err, cache.InvalidRedisCacheCfg, "could not unmarshal the config: %s", config) + } if _, ok := cf["key"]; !ok { cf["key"] = DefaultKey } if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") + return berror.Wrapf(err, cache.InvalidRedisCacheCfg, "config missing conn field: %s", config) } // Format redis://@: @@ -229,9 +244,16 @@ func (rc *Cache) StartAndGC(config string) error { rc.connectInit() c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() - return c.Err() + // test connection + if err = c.Err(); err != nil { + return berror.Wrapf(err, cache.InvalidConnection, + "can not connect to remote redis server, please check the connection info and network state: %s", config) + } + return nil } // connect to redis. @@ -239,19 +261,20 @@ func (rc *Cache) connectInit() { dialFunc := func() (c redis.Conn, err error) { c, err = redis.Dial("tcp", rc.conninfo) if err != nil { - return nil, err + return nil, berror.Wrapf(err, cache.DialFailed, + "could not dial to remote server: %s ", rc.conninfo) } if rc.password != "" { - if _, err := c.Do("AUTH", rc.password); err != nil { - c.Close() + if _, err = c.Do("AUTH", rc.password); err != nil { + _ = c.Close() return nil, err } } _, selecterr := c.Do("SELECT", rc.dbNum) if selecterr != nil { - c.Close() + _ = c.Close() return nil, selecterr } return diff --git a/client/cache/redis/redis_test.go b/client/cache/redis/redis_test.go index 3344bc34..3e794514 100644 --- a/client/cache/redis/redis_test.go +++ b/client/cache/redis/redis_test.go @@ -35,96 +35,74 @@ func TestRedisCache(t *testing.T) { } bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, redisAddr)) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + assert.Nil(t, err) + timeoutDuration := 3 * time.Second - time.Sleep(11 * time.Second) + assert.Nil(t, bm.Put(context.Background(), "astaxie", 1, timeoutDuration)) + + + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + time.Sleep(5 * time.Second) + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", 1, timeoutDuration)) - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("check err") - } - if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } val, _ := bm.Get(context.Background(), "astaxie") - if v, _ := redis.Int(val, err); v != 1 { - t.Error("get err") - } + v, _ := redis.Int(val, err) + assert.Equal(t, 1, v) - if err = bm.Incr(context.Background(), "astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, bm.Incr(context.Background(), "astaxie")) val, _ = bm.Get(context.Background(), "astaxie") - if v, _ := redis.Int(val, err); v != 2 { - t.Error("get err") - } + v, _ = redis.Int(val, err) + assert.Equal(t, 2, v) - if err = bm.Decr(context.Background(), "astaxie"); err != nil { - t.Error("Decr Error", err) - } + assert.Nil(t, bm.Decr(context.Background(), "astaxie")) val, _ = bm.Get(context.Background(), "astaxie") - if v, _ := redis.Int(val, err); v != 1 { - t.Error("get err") - } + v, _ = redis.Int(val, err) + assert.Equal(t, 1, v) bm.Delete(context.Background(), "astaxie") - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("delete err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) // test string - if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) val, _ = bm.Get(context.Background(), "astaxie") - if v, _ := redis.String(val, err); v != "author" { - t.Error("get err") - } + vs, _ := redis.String(val, err) + assert.Equal(t, "author", vs) // test GetMulti - if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { - t.Error("check err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[0], nil); v != "author" { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[1], nil); v != "author1" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + vs, _ = redis.String(vv[0], nil) + assert.Equal(t, "author", vs) + + vs, _ = redis.String(vv[1], nil) + assert.Equal(t, "author1", vs) vv, _ = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) - if vv[0] != nil { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[1], nil); v != "author1" { - t.Error("GetMulti ERROR") - } + assert.Nil(t, vv[0]) + + vs, _ = redis.String(vv[1], nil) + assert.Equal(t, "author1", vs) // test clear all - if err = bm.ClearAll(context.Background()); err != nil { - t.Error("clear all err") - } + assert.Nil(t, bm.ClearAll(context.Background())) } func TestCache_Scan(t *testing.T) { @@ -137,35 +115,24 @@ func TestCache_Scan(t *testing.T) { // init bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, addr)) - if err != nil { - t.Error("init err") - } + + assert.Nil(t, err) // insert all for i := 0; i < 100; i++ { - if err = bm.Put(context.Background(), fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Nil(t, bm.Put(context.Background(), fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration)) } time.Sleep(time.Second) // scan all for the first time keys, err := bm.(*Cache).Scan(DefaultKey + ":*") - if err != nil { - t.Error("scan Error", err) - } + assert.Nil(t, err) assert.Equal(t, 100, len(keys), "scan all error") // clear all - if err = bm.ClearAll(context.Background()); err != nil { - t.Error("clear all err") - } + assert.Nil(t, bm.ClearAll(context.Background())) // scan all for the second time keys, err = bm.(*Cache).Scan(DefaultKey + ":*") - if err != nil { - t.Error("scan Error", err) - } - if len(keys) != 0 { - t.Error("scan all err") - } + assert.Nil(t, err) + assert.Equal(t, 0, len(keys)) } diff --git a/client/cache/ssdb/ssdb.go b/client/cache/ssdb/ssdb.go index 93fa9feb..e715d07f 100644 --- a/client/cache/ssdb/ssdb.go +++ b/client/cache/ssdb/ssdb.go @@ -3,7 +3,6 @@ package ssdb import ( "context" "encoding/json" - "errors" "fmt" "strconv" "strings" @@ -12,6 +11,7 @@ import ( "github.com/ssdb/gossdb/ssdb" "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" ) // Cache SSDB adapter @@ -27,31 +27,21 @@ func NewSsdbCache() cache.Cache { // Get gets a key's value from memcache. func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil, err - } - } value, err := rc.conn.Get(key) if err == nil { return value, nil } - return nil, err + return nil, berror.Wrapf(err, cache.SsdbCacheCurdFailed, "could not get value, key: %s", key) } // GetMulti gets one or keys values from ssdb. func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { size := len(keys) values := make([]interface{}, size) - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return values, err - } - } res, err := rc.conn.Do("multi_get", keys) if err != nil { - return values, err + return values, berror.Wrapf(err, cache.SsdbCacheCurdFailed, "multi_get failed, key: %v", keys) } resSize := len(res) @@ -63,14 +53,14 @@ func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, er keysErr := make([]string, 0) for i, ki := range keys { if _, ok := keyIdx[ki]; !ok { - keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "the key isn't exist")) + keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "key not exist")) continue } values[i] = res[keyIdx[ki]+1] } if len(keysErr) != 0 { - return values, fmt.Errorf(strings.Join(keysErr, "; ")) + return values, berror.Error(cache.MultiGetFailed, strings.Join(keysErr, "; ")) } return values, nil @@ -78,26 +68,16 @@ func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, er // DelMulti deletes one or more keys from memcache func (rc *Cache) DelMulti(keys []string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Do("multi_del", keys) - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "multi_del failed: %v", keys) } // Put puts value into memcache. // value: must be of type string func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } v, ok := val.(string) if !ok { - return errors.New("value must string") + return berror.Errorf(cache.InvalidSsdbCacheValue, "value must be string: %v", val) } var resp []string var err error @@ -108,57 +88,37 @@ func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout t resp, err = rc.conn.Do("setx", key, v, ttl) } if err != nil { - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "set or setx failed, key: %s", key) } if len(resp) == 2 && resp[0] == "ok" { return nil } - return errors.New("bad response") + return berror.Errorf(cache.SsdbBadResponse, "the response from SSDB server is invalid: %v", resp) } // Delete deletes a value in memcache. func (rc *Cache) Delete(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Del(key) - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "del failed: %s", key) } // Incr increases a key's counter. func (rc *Cache) Incr(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Do("incr", key, 1) - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "increase failed: %s", key) } // Decr decrements a key's counter. func (rc *Cache) Decr(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Do("incr", key, -1) - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "decrease failed: %s", key) } // IsExist checks if a key exists in memcache. func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false, err - } - } resp, err := rc.conn.Do("exists", key) if err != nil { - return false, err + return false, berror.Wrapf(err, cache.SsdbCacheCurdFailed, "exists failed: %s", key) } if len(resp) == 2 && resp[1] == "1" { return true, nil @@ -167,13 +127,9 @@ func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { } -// ClearAll clears all cached items in memcache. +// ClearAll clears all cached items in ssdb. +// If there are many keys, this method may spent much time. func (rc *Cache) ClearAll(context.Context) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } keyStart, keyEnd, limit := "", "", 50 resp, err := rc.Scan(keyStart, keyEnd, limit) for err == nil { @@ -187,21 +143,16 @@ func (rc *Cache) ClearAll(context.Context) error { } _, e := rc.conn.Do("multi_del", keys) if e != nil { - return e + return berror.Wrapf(e, cache.SsdbCacheCurdFailed, "multi_del failed: %v", keys) } keyStart = resp[size-2] resp, err = rc.Scan(keyStart, keyEnd, limit) } - return err + return berror.Wrap(err, cache.SsdbCacheCurdFailed, "scan failed") } // Scan key all cached in ssdb. func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil, err - } - } resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) if err != nil { return nil, err @@ -214,30 +165,36 @@ func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, erro // If an error occurs during connection, an error is returned func (rc *Cache) StartAndGC(config string) error { var cf map[string]string - json.Unmarshal([]byte(config), &cf) + err := json.Unmarshal([]byte(config), &cf) + if err != nil { + return berror.Wrapf(err, cache.InvalidSsdbCacheCfg, + "unmarshal this config failed, it must be a valid json string: %s", config) + } if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") + return berror.Wrapf(err, cache.InvalidSsdbCacheCfg, + "Missing conn field: %s", config) } rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil + return rc.connectInit() } // connect to memcache and keep the connection. func (rc *Cache) connectInit() error { conninfoArray := strings.Split(rc.conninfo[0], ":") + if len(conninfoArray) < 2 { + return berror.Errorf(cache.InvalidSsdbCacheCfg, "The value of conn should be host:port: %s", rc.conninfo[0]) + } host := conninfoArray[0] port, e := strconv.Atoi(conninfoArray[1]) if e != nil { - return e + return berror.Errorf(cache.InvalidSsdbCacheCfg, "Port is invalid. It must be integer, %s", rc.conninfo[0]) } var err error - rc.conn, err = ssdb.Connect(host, port) - return err + if rc.conn, err = ssdb.Connect(host, port); err != nil { + return berror.Wrapf(err, cache.InvalidConnection, + "could not connect to SSDB, please check your connection info, network and firewall: %s", rc.conninfo[0]) + } + return nil } func init() { diff --git a/client/cache/ssdb/ssdb_test.go b/client/cache/ssdb/ssdb_test.go index 8ac1efd6..fea755f4 100644 --- a/client/cache/ssdb/ssdb_test.go +++ b/client/cache/ssdb/ssdb_test.go @@ -5,9 +5,12 @@ import ( "fmt" "os" "strconv" + "strings" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/beego/beego/v2/client/cache" ) @@ -19,114 +22,80 @@ func TestSsdbcacheCache(t *testing.T) { } ssdb, err := cache.NewCache("ssdb", fmt.Sprintf(`{"conn": "%s"}`, ssdbAddr)) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) // test put and exist - if res, _ := ssdb.IsExist(context.Background(), "ssdb"); res { - t.Error("check err") - } - timeoutDuration := 10 * time.Second + res, _ := ssdb.IsExist(context.Background(), "ssdb") + assert.False(t, res) + timeoutDuration := 3 * time.Second // timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent - if err = ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := ssdb.IsExist(context.Background(), "ssdb"); !res { - t.Error("check err") - } + + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration)) + + res, _ = ssdb.IsExist(context.Background(), "ssdb") + assert.True(t, res) // Get test done - if err = ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration)) - if v, _ := ssdb.Get(context.Background(), "ssdb"); v != "ssdb" { - t.Error("get Error") - } + v, _ := ssdb.Get(context.Background(), "ssdb") + assert.Equal(t, "ssdb", v) // inc/dec test done - if err = ssdb.Put(context.Background(), "ssdb", "2", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if err = ssdb.Incr(context.Background(), "ssdb"); err != nil { - t.Error("incr Error", err) - } + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "2", timeoutDuration)) + + assert.Nil(t, ssdb.Incr(context.Background(), "ssdb")) val, _ := ssdb.Get(context.Background(), "ssdb") - if v, err := strconv.Atoi(val.(string)); err != nil || v != 3 { - t.Error("get err") - } + v, err = strconv.Atoi(val.(string)) + assert.Nil(t, err) + assert.Equal(t, 3, v) - if err = ssdb.Decr(context.Background(), "ssdb"); err != nil { - t.Error("decr error") - } + assert.Nil(t, ssdb.Decr(context.Background(), "ssdb")) // test del - if err = ssdb.Put(context.Background(), "ssdb", "3", timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "3", timeoutDuration)) val, _ = ssdb.Get(context.Background(), "ssdb") - if v, err := strconv.Atoi(val.(string)); err != nil || v != 3 { - t.Error("get err") - } - if err := ssdb.Delete(context.Background(), "ssdb"); err == nil { - if e, _ := ssdb.IsExist(context.Background(), "ssdb"); e { - t.Error("delete err") - } - } + v, err = strconv.Atoi(val.(string)) + assert.Equal(t, 3, v) + assert.Nil(t, err) + assert.Nil(t, ssdb.Delete(context.Background(), "ssdb")) + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "ssdb", -10*time.Second)) // test string - if err = ssdb.Put(context.Background(), "ssdb", "ssdb", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if res, _ := ssdb.IsExist(context.Background(), "ssdb"); !res { - t.Error("check err") - } - if v, _ := ssdb.Get(context.Background(), "ssdb"); v.(string) != "ssdb" { - t.Error("get err") - } + + res, _ = ssdb.IsExist(context.Background(), "ssdb") + assert.True(t, res) + + v, _ = ssdb.Get(context.Background(), "ssdb") + assert.Equal(t, "ssdb", v.(string)) // test GetMulti done - if err = ssdb.Put(context.Background(), "ssdb1", "ssdb1", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if res, _ := ssdb.IsExist(context.Background(), "ssdb1"); !res { - t.Error("check err") - } + assert.Nil(t, ssdb.Put(context.Background(), "ssdb1", "ssdb1", -10*time.Second)) + + res, _ = ssdb.IsExist(context.Background(), "ssdb1") + assert.True(t, res) vv, _ := ssdb.GetMulti(context.Background(), []string{"ssdb", "ssdb1"}) - if len(vv) != 2 { - t.Error("getmulti error") - } - if vv[0].(string) != "ssdb" { - t.Error("getmulti error") - } - if vv[1].(string) != "ssdb1" { - t.Error("getmulti error") - } + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "ssdb", vv[0]) + assert.Equal(t, "ssdb1", vv[1]) vv, err = ssdb.GetMulti(context.Background(), []string{"ssdb", "ssdb11"}) - if len(vv) != 2 { - t.Error("getmulti error") - } - if vv[0].(string) != "ssdb" { - t.Error("getmulti error") - } - if vv[1] != nil { - t.Error("getmulti error") - } - if err != nil && err.Error() != "key [ssdb11] error: the key isn't exist" { - t.Error("getmulti error") - } + + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "ssdb", vv[0]) + assert.Nil(t, vv[1]) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key not exist")) // test clear all done - if err = ssdb.ClearAll(context.Background()); err != nil { - t.Error("clear all err") - } + assert.Nil(t, ssdb.ClearAll(context.Background())) e1, _ := ssdb.IsExist(context.Background(), "ssdb") e2, _ := ssdb.IsExist(context.Background(), "ssdb1") - if e1 || e2 { - t.Error("check err") - } + assert.False(t, e1) + assert.False(t, e2) } diff --git a/client/httplib/README.md b/client/httplib/README.md index 90a6c505..1d22f341 100644 --- a/client/httplib/README.md +++ b/client/httplib/README.md @@ -8,7 +8,7 @@ httplib is an libs help you to curl remote url. you can use Get to crawl data. - import "github.com/beego/beego/v2/httplib" + import "github.com/beego/beego/v2/client/httplib" str, err := httplib.Get("http://beego.me/").String() if err != nil { @@ -95,4 +95,4 @@ httplib support mutil file upload, use `req.PostFile()` See godoc for further documentation and examples. -* [godoc.org/github.com/beego/beego/v2/httplib](https://godoc.org/github.com/beego/beego/v2/httplib) +* [godoc.org/github.com/beego/beego/v2/client/httplib](https://godoc.org/github.com/beego/beego/v2/client/httplib) diff --git a/client/httplib/error_code.go b/client/httplib/error_code.go new file mode 100644 index 00000000..bd349a34 --- /dev/null +++ b/client/httplib/error_code.go @@ -0,0 +1,126 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "github.com/beego/beego/v2/core/berror" +) + +var InvalidUrl = berror.DefineCode(4001001, moduleName, "InvalidUrl", ` +You pass a invalid url to httplib module. Please check your url, be careful about special character. +`) + +var InvalidUrlProtocolVersion = berror.DefineCode(4001002, moduleName, "InvalidUrlProtocolVersion", ` +You pass a invalid protocol version. In practice, we use HTTP/1.0, HTTP/1.1, HTTP/1.2 +But something like HTTP/3.2 is valid for client, and the major version is 3, minor version is 2. +but you must confirm that server support those abnormal protocol version. +`) + +var UnsupportedBodyType = berror.DefineCode(4001003, moduleName, "UnsupportedBodyType", ` +You use a invalid data as request body. +For now, we only support type string and byte[]. +`) + +var InvalidXMLBody = berror.DefineCode(4001004, moduleName, "InvalidXMLBody", ` +You pass invalid data which could not be converted to XML documents. In general, if you pass structure, it works well. +Sometimes you got XML document and you want to make it as request body. So you call XMLBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +var InvalidYAMLBody = berror.DefineCode(4001005, moduleName, "InvalidYAMLBody", ` +You pass invalid data which could not be converted to YAML documents. In general, if you pass structure, it works well. +Sometimes you got YAML document and you want to make it as request body. So you call YAMLBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +var InvalidJSONBody = berror.DefineCode(4001006, moduleName, "InvalidJSONBody", ` +You pass invalid data which could not be converted to JSON documents. In general, if you pass structure, it works well. +Sometimes you got JSON document and you want to make it as request body. So you call JSONBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +// start with 5 -------------------------------------------------------------------------- + +var CreateFormFileFailed = berror.DefineCode(5001001, moduleName, "CreateFormFileFailed", ` +In normal case than handling files with BeegoRequest, you should not see this error code. +Unexpected EOF, invalid characters, bad file descriptor may cause this error. +`) + +var ReadFileFailed = berror.DefineCode(5001002, moduleName, "ReadFileFailed", ` +There are several cases that cause this error: +1. file not found. Please check the file name; +2. file not found, but file name is correct. If you use relative file path, it's very possible for you to see this code. +make sure that this file is in correct directory which Beego looks for; +3. Beego don't have the privilege to read this file, please change file mode; +`) + +var CopyFileFailed = berror.DefineCode(5001003, moduleName, "CopyFileFailed", ` +When we try to read file content and then copy it to another writer, and failed. +1. Unexpected EOF; +2. Bad file descriptor; +3. Write conflict; + +Please check your file content, and confirm that file is not processed by other process (or by user manually). +`) + +var CloseFileFailed = berror.DefineCode(5001004, moduleName, "CloseFileFailed", ` +After handling files, Beego try to close file but failed. Usually it was caused by bad file descriptor. +`) + +var SendRequestFailed = berror.DefineCode(5001005, moduleName, "SendRequestRetryExhausted", ` +Beego send HTTP request, but it failed. +If you config retry times, it means that Beego had retried and failed. +When you got this error, there are vary kind of reason: +1. Network unstable and timeout. In this case, sometimes server has received the request. +2. Server error. Make sure that server works well. +3. The request is invalid, which means that you pass some invalid parameter. +`) + +var ReadGzipBodyFailed = berror.DefineCode(5001006, moduleName, "BuildGzipReaderFailed", ` +Beego parse gzip-encode body failed. Usually Beego got invalid response. +Please confirm that server returns gzip data. +`) + +var CreateFileIfNotExistFailed = berror.DefineCode(5001007, moduleName, "CreateFileIfNotExist", ` +Beego want to create file if not exist and failed. +In most cases, it means that Beego doesn't have the privilege to create this file. +Please change file mode to ensure that Beego is able to create files on specific directory. +Or you can run Beego with higher authority. +In some cases, you pass invalid filename. Make sure that the file name is valid on your system. +`) + +var UnmarshalJSONResponseToObjectFailed = berror.DefineCode(5001008, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid json document +`) + +var UnmarshalXMLResponseToObjectFailed = berror.DefineCode(5001009, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid XML document +`) + +var UnmarshalYAMLResponseToObjectFailed = berror.DefineCode(5001010, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid YAML document +`) diff --git a/client/httplib/filter/log/filter.go b/client/httplib/filter/log/filter.go new file mode 100644 index 00000000..9d2e09d3 --- /dev/null +++ b/client/httplib/filter/log/filter.go @@ -0,0 +1,130 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +import ( + "context" + "io" + "net/http" + "net/http/httputil" + + "github.com/beego/beego/v2/client/httplib" + "github.com/beego/beego/v2/core/logs" +) + +// FilterChainBuilder can build a log filter +type FilterChainBuilder struct { + printableContentTypes []string // only print the body of included mime types of request and response + log func(f interface{}, v ...interface{}) // custom log function +} + +// BuilderOption option constructor +type BuilderOption func(*FilterChainBuilder) + +type logInfo struct { + req []byte + resp []byte + err error +} + +var defaultprintableContentTypes = []string{ + "text/plain", "text/xml", "text/html", "text/csv", + "text/calendar", "text/javascript", "text/javascript", + "text/css", +} + +// NewFilterChainBuilder initialize a filterChainBuilder, pass options to customize +func NewFilterChainBuilder(opts ...BuilderOption) *FilterChainBuilder { + res := &FilterChainBuilder{ + printableContentTypes: defaultprintableContentTypes, + log: logs.Debug, + } + for _, o := range opts { + o(res) + } + + return res +} + +// WithLog return option constructor modify log function +func WithLog(f func(f interface{}, v ...interface{})) BuilderOption { + return func(h *FilterChainBuilder) { + h.log = f + } +} + +// WithprintableContentTypes return option constructor modify printableContentTypes +func WithprintableContentTypes(types []string) BuilderOption { + return func(h *FilterChainBuilder) { + h.printableContentTypes = types + } +} + +// FilterChain can print the request after FilterChain processing and response before processsing +func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + info := &logInfo{} + defer info.print(builder.log) + resp, err := next(ctx, req) + info.err = err + contentType := req.GetRequest().Header.Get("Content-Type") + shouldPrintBody := builder.shouldPrintBody(contentType, req.GetRequest().Body) + dump, err := httputil.DumpRequest(req.GetRequest(), shouldPrintBody) + info.req = dump + if err != nil { + logs.Error(err) + } + if resp != nil { + contentType = resp.Header.Get("Content-Type") + shouldPrintBody = builder.shouldPrintBody(contentType, resp.Body) + dump, err = httputil.DumpResponse(resp, shouldPrintBody) + info.resp = dump + if err != nil { + logs.Error(err) + } + } + return resp, err + } +} + +func (builder *FilterChainBuilder) shouldPrintBody(contentType string, body io.ReadCloser) bool { + if contains(builder.printableContentTypes, contentType) { + return true + } + if body != nil { + logs.Warn("printableContentTypes do not contain %s, if you want to print request and response body please add it.", contentType) + } + return false +} + +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} + +func (info *logInfo) print(log func(f interface{}, v ...interface{})) { + log("Request: ====================") + log("%q", info.req) + log("Response: ===================") + log("%q", info.resp) + if info.err != nil { + log("Error: ======================") + log("%q", info.err) + } +} diff --git a/client/httplib/filter/log/filter_test.go b/client/httplib/filter/log/filter_test.go new file mode 100644 index 00000000..4ee94a0d --- /dev/null +++ b/client/httplib/filter/log/filter_test.go @@ -0,0 +1,62 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/beego/beego/v2/client/httplib" + "github.com/stretchr/testify/assert" +) + +func TestFilterChain(t *testing.T) { + next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + time.Sleep(100 * time.Millisecond) + return &http.Response{ + StatusCode: 404, + }, nil + } + builder := NewFilterChainBuilder() + filter := builder.FilterChain(next) + req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") + resp, err := filter(context.Background(), req) + assert.NotNil(t, resp) + assert.Nil(t, err) +} + +func TestContains(t *testing.T) { + jsonType := "application/json" + cases := []struct { + Name string + Types []string + ContentType string + Expected bool + }{ + {"case1", []string{jsonType}, jsonType, true}, + {"case2", []string{"text/plain"}, jsonType, false}, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + if ans := contains(c.Types, c.ContentType); ans != c.Expected { + t.Fatalf("Types: %v, ContentType: %v, expected %v, but %v got", + c.Types, c.ContentType, c.Expected, ans) + } + }) + } +} diff --git a/client/httplib/filter/opentracing/filter.go b/client/httplib/filter/opentracing/filter.go index cde50261..a46effc8 100644 --- a/client/httplib/filter/opentracing/filter.go +++ b/client/httplib/filter/opentracing/filter.go @@ -21,11 +21,14 @@ import ( logKit "github.com/go-kit/kit/log" opentracingKit "github.com/go-kit/kit/tracing/opentracing" "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/log" "github.com/beego/beego/v2/client/httplib" ) type FilterChainBuilder struct { + // TagURL true will tag span with url + TagURL bool // CustomSpanFunc users are able to custom their span CustomSpanFunc func(span opentracing.Span, ctx context.Context, req *httplib.BeegoHTTPRequest, resp *http.Response, err error) @@ -50,13 +53,19 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt } span.SetTag("http.method", method) span.SetTag("peer.hostname", req.GetRequest().URL.Host) - span.SetTag("http.url", req.GetRequest().URL.String()) + span.SetTag("http.scheme", req.GetRequest().URL.Scheme) span.SetTag("span.kind", "client") span.SetTag("component", "beego") + + if builder.TagURL { + span.SetTag("http.url", req.GetRequest().URL.String()) + } + span.LogFields(log.String("http.url", req.GetRequest().URL.String())) + if err != nil { span.SetTag("error", true) - span.SetTag("message", err.Error()) + span.LogFields(log.String("message", err.Error())) } else if resp != nil && !(resp.StatusCode < 300 && resp.StatusCode >= 200) { span.SetTag("error", true) } diff --git a/client/httplib/filter/opentracing/filter_test.go b/client/httplib/filter/opentracing/filter_test.go index b9c1e1e2..a9b9cbb0 100644 --- a/client/httplib/filter/opentracing/filter_test.go +++ b/client/httplib/filter/opentracing/filter_test.go @@ -26,14 +26,16 @@ import ( "github.com/beego/beego/v2/client/httplib" ) -func TestFilterChainBuilder_FilterChain(t *testing.T) { +func TestFilterChainBuilderFilterChain(t *testing.T) { next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { time.Sleep(100 * time.Millisecond) return &http.Response{ StatusCode: 404, }, errors.New("hello") } - builder := &FilterChainBuilder{} + builder := &FilterChainBuilder{ + TagURL: true, + } filter := builder.FilterChain(next) req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") resp, err := filter(context.Background(), req) diff --git a/client/httplib/filter/prometheus/filter_test.go b/client/httplib/filter/prometheus/filter_test.go index 1e7935d0..4a7b29f2 100644 --- a/client/httplib/filter/prometheus/filter_test.go +++ b/client/httplib/filter/prometheus/filter_test.go @@ -25,7 +25,7 @@ import ( "github.com/beego/beego/v2/client/httplib" ) -func TestFilterChainBuilder_FilterChain(t *testing.T) { +func TestFilterChainBuilderFilterChain(t *testing.T) { next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { time.Sleep(100 * time.Millisecond) return &http.Response{ diff --git a/client/httplib/http_response.go b/client/httplib/http_response.go new file mode 100644 index 00000000..89930cb1 --- /dev/null +++ b/client/httplib/http_response.go @@ -0,0 +1,39 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" +) + +// NewHttpResponseWithJsonBody will try to convert the data to json format +// usually you only use this when you want to mock http Response +func NewHttpResponseWithJsonBody(data interface{}) *http.Response { + var body []byte + if str, ok := data.(string); ok { + body = []byte(str) + } else if bts, ok := data.([]byte); ok { + body = bts + } else { + body, _ = json.Marshal(data) + } + return &http.Response{ + ContentLength: int64(len(body)), + Body: ioutil.NopCloser(bytes.NewReader(body)), + } +} diff --git a/client/httplib/http_response_test.go b/client/httplib/http_response_test.go new file mode 100644 index 00000000..a62dd42c --- /dev/null +++ b/client/httplib/http_response_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewHttpResponseWithJsonBody(t *testing.T) { + // string + resp := NewHttpResponseWithJsonBody("{}") + assert.Equal(t, int64(2), resp.ContentLength) + + resp = NewHttpResponseWithJsonBody([]byte("{}")) + assert.Equal(t, int64(2), resp.ContentLength) + + resp = NewHttpResponseWithJsonBody(&user{ + Name: "Tom", + }) + assert.True(t, resp.ContentLength > 0) +} diff --git a/client/httplib/httplib.go b/client/httplib/httplib.go index 56c50cd2..cef2294c 100644 --- a/client/httplib/httplib.go +++ b/client/httplib/httplib.go @@ -15,7 +15,7 @@ // Package httplib is used as http.Client // Usage: // -// import "github.com/beego/beego/v2/httplib" +// import "github.com/beego/beego/v2/client/httplib" // // b := httplib.Post("http://beego.me/") // b.Param("username","astaxie") @@ -40,58 +40,36 @@ import ( "encoding/xml" "io" "io/ioutil" - "log" "mime/multipart" "net" "net/http" - "net/http/cookiejar" - "net/http/httputil" "net/url" "os" "path" "strings" - "sync" "time" "gopkg.in/yaml.v2" + + "github.com/beego/beego/v2/core/berror" + "github.com/beego/beego/v2/core/logs" ) -var defaultSetting = BeegoHTTPSettings{ - UserAgent: "beegoServer", - ConnectTimeout: 60 * time.Second, - ReadWriteTimeout: 60 * time.Second, - Gzip: true, - DumpBody: true, -} - -var defaultCookieJar http.CookieJar -var settingMutex sync.Mutex - +const contentTypeKey = "Content-Type" // it will be the last filter and execute request.Do var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { return req.doRequest(ctx) } -// createDefaultCookie creates a global cookiejar to store cookies. -func createDefaultCookie() { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultCookieJar, _ = cookiejar.New(nil) -} - -// SetDefaultSetting overwrites default settings -func SetDefaultSetting(setting BeegoHTTPSettings) { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultSetting = setting -} - // NewBeegoRequest returns *BeegoHttpRequest with specific method +// TODO add error as return value +// I think if we don't return error +// users are hard to check whether we create Beego request successfully func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { var resp http.Response u, err := url.Parse(rawurl) if err != nil { - log.Println("Httplib:", err) + logs.Error("%+v", berror.Wrapf(err, InvalidUrl, "invalid raw url: %s", rawurl)) } req := http.Request{ URL: u, @@ -136,24 +114,6 @@ func Head(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "HEAD") } -// BeegoHTTPSettings is the http.Client setting -type BeegoHTTPSettings struct { - ShowDebug bool - UserAgent string - ConnectTimeout time.Duration - ReadWriteTimeout time.Duration - TLSClientConfig *tls.Config - Proxy func(*http.Request) (*url.URL, error) - Transport http.RoundTripper - CheckRedirect func(req *http.Request, via []*http.Request) error - EnableCookie bool - Gzip bool - DumpBody bool - Retries int // if set to -1 means will retry forever - RetryDelay time.Duration - FilterChains []FilterChain -} - // BeegoHTTPRequest provides more useful methods than http.Request for requesting a url. type BeegoHTTPRequest struct { url string @@ -195,12 +155,6 @@ func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { return b } -// Debug sets show debug or not when executing request. -func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { - b.setting.ShowDebug = isdebug - return b -} - // Retries sets Retries times. // default is 0 (never retry) // -1 retry indefinitely (forever) @@ -216,17 +170,6 @@ func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { return b } -// DumpBody sets the DumbBody field -func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { - b.setting.DumpBody = isdump - return b -} - -// DumpRequest returns the DumpRequest -func (b *BeegoHTTPRequest) DumpRequest() []byte { - return b.dump -} - // SetTimeout sets connect time out and read-write time out for BeegoRequest. func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { b.setting.ConnectTimeout = connectTimeout @@ -253,7 +196,7 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { } // SetProtocolVersion sets the protocol version for incoming requests. -// Client requests always use HTTP/1.1. +// Client requests always use HTTP/1.1 func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { if len(vers) == 0 { vers = "HTTP/1.1" @@ -264,8 +207,9 @@ func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { b.req.Proto = vers b.req.ProtoMajor = major b.req.ProtoMinor = minor + return b } - + logs.Error("%+v", berror.Errorf(InvalidUrlProtocolVersion, "invalid protocol: %s", vers)) return b } @@ -333,16 +277,25 @@ func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest // Body adds request raw body. // Supports string and []byte. +// TODO return error if data is invalid func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { switch t := data.(type) { case string: bf := bytes.NewBufferString(t) b.req.Body = ioutil.NopCloser(bf) + b.req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bf), nil + } b.req.ContentLength = int64(len(t)) case []byte: bf := bytes.NewBuffer(t) b.req.Body = ioutil.NopCloser(bf) + b.req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bf), nil + } b.req.ContentLength = int64(len(t)) + default: + logs.Error("%+v", berror.Errorf(UnsupportedBodyType, "unsupported body data type: %s", t)) } return b } @@ -352,11 +305,14 @@ func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := xml.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidXMLBody, "obj could not be converted to XML data") } b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader(byts)), nil + } b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/xml") + b.req.Header.Set(contentTypeKey, "application/xml") } return b, nil } @@ -366,11 +322,11 @@ func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) if b.req.Body == nil && obj != nil { byts, err := yaml.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidYAMLBody, "obj could not be converted to YAML data") } b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/x+yaml") + b.req.Header.Set(contentTypeKey, "application/x+yaml") } return b, nil } @@ -380,11 +336,11 @@ func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) if b.req.Body == nil && obj != nil { byts, err := json.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidJSONBody, "obj could not be converted to JSON body") } b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/json") + b.req.Header.Set(contentTypeKey, "application/json") } return b, nil } @@ -404,47 +360,61 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) { if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { // with files if len(b.files) > 0 { - pr, pw := io.Pipe() - bodyWriter := multipart.NewWriter(pw) - go func() { - for formname, filename := range b.files { - fileWriter, err := bodyWriter.CreateFormFile(formname, filename) - if err != nil { - log.Println("Httplib:", err) - } - fh, err := os.Open(filename) - if err != nil { - log.Println("Httplib:", err) - } - // iocopy - _, err = io.Copy(fileWriter, fh) - fh.Close() - if err != nil { - log.Println("Httplib:", err) - } - } - for k, v := range b.params { - for _, vv := range v { - bodyWriter.WriteField(k, vv) - } - } - bodyWriter.Close() - pw.Close() - }() - b.Header("Content-Type", bodyWriter.FormDataContentType()) - b.req.Body = ioutil.NopCloser(pr) - b.Header("Transfer-Encoding", "chunked") + b.handleFiles() return } // with params if len(paramBody) > 0 { - b.Header("Content-Type", "application/x-www-form-urlencoded") + b.Header(contentTypeKey, "application/x-www-form-urlencoded") b.Body(paramBody) } } } +func (b *BeegoHTTPRequest) handleFiles() { + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) + go func() { + for formname, filename := range b.files { + b.handleFileToBody(bodyWriter, formname, filename) + } + for k, v := range b.params { + for _, vv := range v { + _ = bodyWriter.WriteField(k, vv) + } + } + _ = bodyWriter.Close() + _ = pw.Close() + }() + b.Header(contentTypeKey, bodyWriter.FormDataContentType()) + b.req.Body = ioutil.NopCloser(pr) + b.Header("Transfer-Encoding", "chunked") +} + +func (b *BeegoHTTPRequest) handleFileToBody(bodyWriter *multipart.Writer, formname string, filename string) { + fileWriter, err := bodyWriter.CreateFormFile(formname, filename) + const errFmt = "Httplib: %+v" + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CreateFormFileFailed, + "could not create form file, formname: %s, filename: %s", formname, filename)) + } + fh, err := os.Open(filename) + + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, ReadFileFailed, "could not open this file %s", filename)) + } + // iocopy + _, err = io.Copy(fileWriter, fh) + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CopyFileFailed, "could not copy this file %s", filename)) + } + err = fh.Close() + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CloseFileFailed, "could not close this file %s", filename)) + } +} + func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { if b.resp.StatusCode != 0 { return b.resp, nil @@ -463,7 +433,6 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { } func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Response, err error) { - root := doRequestFilter if len(b.setting.FilterChains) > 0 { for i := len(b.setting.FilterChains) - 1; i >= 0; i-- { @@ -473,62 +442,20 @@ func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Res return root(ctx, b) } -func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, err error) { - var paramBody string - if len(b.params) > 0 { - var buf bytes.Buffer - for k, v := range b.params { - for _, vv := range v { - buf.WriteString(url.QueryEscape(k)) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(vv)) - buf.WriteByte('&') - } - } - paramBody = buf.String() - paramBody = paramBody[0 : len(paramBody)-1] - } +func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (*http.Response, error) { + paramBody := b.buildParamBody() b.buildURL(paramBody) urlParsed, err := url.Parse(b.url) if err != nil { - return nil, err + return nil, berror.Wrapf(err, InvalidUrl, "parse url failed, the url is %s", b.url) } b.req.URL = urlParsed - trans := b.setting.Transport + trans := b.buildTrans() - if trans == nil { - // create default transport - trans = &http.Transport{ - TLSClientConfig: b.setting.TLSClientConfig, - Proxy: b.setting.Proxy, - Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), - MaxIdleConnsPerHost: 100, - } - } else { - // if b.transport is *http.Transport then set the settings. - if t, ok := trans.(*http.Transport); ok { - if t.TLSClientConfig == nil { - t.TLSClientConfig = b.setting.TLSClientConfig - } - if t.Proxy == nil { - t.Proxy = b.setting.Proxy - } - if t.Dial == nil { - t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) - } - } - } - - var jar http.CookieJar - if b.setting.EnableCookie { - if defaultCookieJar == nil { - createDefaultCookie() - } - jar = defaultCookieJar - } + jar := b.buildCookieJar() client := &http.Client{ Transport: trans, @@ -543,13 +470,10 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, client.CheckRedirect = b.setting.CheckRedirect } - if b.setting.ShowDebug { - dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) - if err != nil { - log.Println(err.Error()) - } - b.dump = dump - } + return b.sendRequest(client) +} + +func (b *BeegoHTTPRequest) sendRequest(client *http.Client) (resp *http.Response, err error) { // retries default value is 0, it will run once. // retries equal to -1, it will run forever until success // retries is setted, it will retries fixed times. @@ -557,11 +481,68 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { resp, err = client.Do(b.req) if err == nil { - break + return } time.Sleep(b.setting.RetryDelay) } - return resp, err + return nil, berror.Wrap(err, SendRequestFailed, "sending request fail") +} + +func (b *BeegoHTTPRequest) buildCookieJar() http.CookieJar { + var jar http.CookieJar + if b.setting.EnableCookie { + if defaultCookieJar == nil { + createDefaultCookie() + } + jar = defaultCookieJar + } + return jar +} + +func (b *BeegoHTTPRequest) buildTrans() http.RoundTripper { + trans := b.setting.Transport + + if trans == nil { + // create default transport + trans = &http.Transport{ + TLSClientConfig: b.setting.TLSClientConfig, + Proxy: b.setting.Proxy, + DialContext: TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), + MaxIdleConnsPerHost: 100, + } + } else { + // if b.transport is *http.Transport then set the settings. + if t, ok := trans.(*http.Transport); ok { + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.setting.TLSClientConfig + } + if t.Proxy == nil { + t.Proxy = b.setting.Proxy + } + if t.DialContext == nil { + t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) + } + } + } + return trans +} + +func (b *BeegoHTTPRequest) buildParamBody() string { + var paramBody string + if len(b.params) > 0 { + var buf bytes.Buffer + for k, v := range b.params { + for _, vv := range v { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(vv)) + buf.WriteByte('&') + } + } + paramBody = buf.String() + paramBody = paramBody[0 : len(paramBody)-1] + } + return paramBody } // String returns the body string in response. @@ -592,10 +573,10 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { reader, err := gzip.NewReader(resp.Body) if err != nil { - return nil, err + return nil, berror.Wrap(err, ReadGzipBodyFailed, "building gzip reader failed") } b.body, err = ioutil.ReadAll(reader) - return b.body, err + return b.body, berror.Wrap(err, ReadGzipBodyFailed, "reading gzip data failed") } b.body, err = ioutil.ReadAll(resp.Body) return b.body, err @@ -638,7 +619,7 @@ func pathExistAndMkdir(filename string) (err error) { return nil } } - return err + return berror.Wrapf(err, CreateFileIfNotExistFailed, "try to create(if not exist) failed: %s", filename) } // ToJSON returns the map that marshals from the body bytes as json in response. @@ -648,7 +629,8 @@ func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { if err != nil { return err } - return json.Unmarshal(data, v) + return berror.Wrap(json.Unmarshal(data, v), + UnmarshalJSONResponseToObjectFailed, "unmarshal json body to object failed.") } // ToXML returns the map that marshals from the body bytes as xml in response . @@ -658,7 +640,8 @@ func (b *BeegoHTTPRequest) ToXML(v interface{}) error { if err != nil { return err } - return xml.Unmarshal(data, v) + return berror.Wrap(xml.Unmarshal(data, v), + UnmarshalXMLResponseToObjectFailed, "unmarshal xml body to object failed.") } // ToYAML returns the map that marshals from the body bytes as yaml in response . @@ -668,7 +651,8 @@ func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { if err != nil { return err } - return yaml.Unmarshal(data, v) + return berror.Wrap(yaml.Unmarshal(data, v), + UnmarshalYAMLResponseToObjectFailed, "unmarshal yaml body to object failed.") } // Response executes request client gets response manually. @@ -677,8 +661,18 @@ func (b *BeegoHTTPRequest) Response() (*http.Response, error) { } // TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +// Deprecated +// we will move this at the end of 2021 +// please use TimeoutDialerCtx func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { return func(netw, addr string) (net.Conn, error) { + return TimeoutDialerCtx(cTimeout, rwTimeout)(context.Background(), netw, addr) + } +} + +func TimeoutDialerCtx(cTimeout time.Duration, + rwTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) { + return func(ctx context.Context, netw, addr string) (net.Conn, error) { conn, err := net.DialTimeout(netw, addr, cTimeout) if err != nil { return nil, err diff --git a/client/httplib/httplib_test.go b/client/httplib/httplib_test.go index d0f826cb..9133ad5f 100644 --- a/client/httplib/httplib_test.go +++ b/client/httplib/httplib_test.go @@ -15,6 +15,7 @@ package httplib import ( + "bytes" "context" "errors" "io/ioutil" @@ -259,7 +260,7 @@ func TestToFile(t *testing.T) { } defer os.Remove(f) b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { + if n := bytes.Index(b, []byte("origin")); n == -1 { t.Fatal(err) } } @@ -273,7 +274,7 @@ func TestToFileDir(t *testing.T) { } defer os.RemoveAll("./files") b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { + if n := bytes.Index(b, []byte("origin")); n == -1 { t.Fatal(err) } } @@ -300,3 +301,135 @@ func TestAddFilter(t *testing.T) { r := Get("http://beego.me") assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains)) } + +func TestFilterChainOrder(t *testing.T) { + req := Get("http://beego.me") + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("first"), nil + } + }) + + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("second"), nil + } + }) + + resp, err := req.DoRequestWithCtx(context.Background()) + assert.Nil(t, err) + data := make([]byte, 5) + _, _ = resp.Body.Read(data) + assert.Equal(t, "first", string(data)) +} + +func TestHead(t *testing.T) { + req := Head("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "HEAD", req.req.Method) +} + +func TestDelete(t *testing.T) { + req := Delete("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "DELETE", req.req.Method) +} + +func TestPost(t *testing.T) { + req := Post("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "POST", req.req.Method) +} + +func TestNewBeegoRequest(t *testing.T) { + req := NewBeegoRequest("http://beego.me", "GET") + assert.NotNil(t, req) + assert.Equal(t, "GET", req.req.Method) + + // invalid case but still go request + req = NewBeegoRequest("httpa\ta://beego.me", "GET") + assert.NotNil(t, req) +} + +func TestBeegoHTTPRequest_SetProtocolVersion(t *testing.T) { + req := NewBeegoRequest("http://beego.me", "GET") + req.SetProtocolVersion("HTTP/3.10") + assert.Equal(t, "HTTP/3.10", req.req.Proto) + assert.Equal(t, 3, req.req.ProtoMajor) + assert.Equal(t, 10, req.req.ProtoMinor) + + req.SetProtocolVersion("") + assert.Equal(t, "HTTP/1.1", req.req.Proto) + assert.Equal(t, 1, req.req.ProtoMajor) + assert.Equal(t, 1, req.req.ProtoMinor) + + // invalid case + req.SetProtocolVersion("HTTP/aaa1.1") + assert.Equal(t, "HTTP/1.1", req.req.Proto) + assert.Equal(t, 1, req.req.ProtoMajor) + assert.Equal(t, 1, req.req.ProtoMinor) +} + +func TestPut(t *testing.T) { + req := Put("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "PUT", req.req.Method) +} + +func TestBeegoHTTPRequest_Header(t *testing.T) { + req := Post("http://beego.me") + key, value := "test-header", "test-header-value" + req.Header(key, value) + assert.Equal(t, value, req.req.Header.Get(key)) +} + +func TestBeegoHTTPRequest_SetHost(t *testing.T) { + req := Post("http://beego.me") + host := "test-hose" + req.SetHost(host) + assert.Equal(t, host, req.req.Host) +} + +func TestBeegoHTTPRequest_Param(t *testing.T) { + req := Post("http://beego.me") + key, value := "test-param", "test-param-value" + req.Param(key, value) + assert.Equal(t, value, req.params[key][0]) + + value1 := "test-param-value-1" + req.Param(key, value1) + assert.Equal(t, value1, req.params[key][1]) +} + +func TestBeegoHTTPRequest_Body(t *testing.T) { + req := Post("http://beego.me") + body := `hello, world` + req.Body([]byte(body)) + assert.Equal(t, int64(len(body)), req.req.ContentLength) + assert.NotNil(t, req.req.GetBody) + assert.NotNil(t, req.req.Body) + + body = "hhhh, i am test" + req.Body(body) + assert.Equal(t, int64(len(body)), req.req.ContentLength) + assert.NotNil(t, req.req.GetBody) + assert.NotNil(t, req.req.Body) + + // invalid case + req.Body(13) +} + +type user struct { + Name string `xml:"name"` +} + +func TestBeegoHTTPRequest_XMLBody(t *testing.T) { + req := Post("http://beego.me") + body := &user{ + Name: "Tom", + } + _, err := req.XMLBody(body) + assert.True(t, req.req.ContentLength > 0) + assert.Nil(t, err) + assert.NotNil(t, req.req.GetBody) +} diff --git a/client/httplib/mock/mock.go b/client/httplib/mock/mock.go new file mode 100644 index 00000000..421a7a45 --- /dev/null +++ b/client/httplib/mock/mock.go @@ -0,0 +1,78 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "net/http" + + "github.com/beego/beego/v2/client/httplib" + "github.com/beego/beego/v2/core/logs" +) + +const mockCtxKey = "beego-httplib-mock" + +func init() { + InitMockSetting() +} + +type Stub interface { + Mock(cond RequestCondition, resp *http.Response, err error) + Clear() + MockByPath(path string, resp *http.Response, err error) +} + +var mockFilter = &MockResponseFilter{} + +func InitMockSetting() { + httplib.AddDefaultFilter(mockFilter.FilterChain) +} + +func StartMock() Stub { + return mockFilter +} + +func CtxWithMock(ctx context.Context, mock ...*Mock) context.Context { + return context.WithValue(ctx, mockCtxKey, mock) +} + +func mockFromCtx(ctx context.Context) []*Mock { + ms := ctx.Value(mockCtxKey) + if ms != nil { + if res, ok := ms.([]*Mock); ok { + return res + } + logs.Error("mockCtxKey found in context, but value is not type []*Mock") + } + return nil +} + +type Mock struct { + cond RequestCondition + resp *http.Response + err error +} + +func NewMockByPath(path string, resp *http.Response, err error) *Mock { + return NewMock(NewSimpleCondition(path), resp, err) +} + +func NewMock(con RequestCondition, resp *http.Response, err error) *Mock { + return &Mock{ + cond: con, + resp: resp, + err: err, + } +} diff --git a/client/httplib/mock/mock_condition.go b/client/httplib/mock/mock_condition.go new file mode 100644 index 00000000..53d3d703 --- /dev/null +++ b/client/httplib/mock/mock_condition.go @@ -0,0 +1,176 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "encoding/json" + "net/textproto" + "regexp" + + "github.com/beego/beego/v2/client/httplib" +) + +type RequestCondition interface { + Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool +} + +// reqCondition create condition +// - path: same path +// - pathReg: request path match pathReg +// - method: same method +// - Query parameters (key, value) +// - header (key, value) +// - Body json format, contains specific (key, value). +type SimpleCondition struct { + pathReg string + path string + method string + query map[string]string + header map[string]string + body map[string]interface{} +} + +func NewSimpleCondition(path string, opts ...simpleConditionOption) *SimpleCondition { + sc := &SimpleCondition{ + path: path, + query: make(map[string]string), + header: make(map[string]string), + body: map[string]interface{}{}, + } + for _, o := range opts { + o(sc) + } + return sc +} + +func (sc *SimpleCondition) Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + res := true + if len(sc.path) > 0 { + res = sc.matchPath(ctx, req) + } else if len(sc.pathReg) > 0 { + res = sc.matchPathReg(ctx, req) + } else { + return false + } + return res && + sc.matchMethod(ctx, req) && + sc.matchQuery(ctx, req) && + sc.matchHeader(ctx, req) && + sc.matchBodyFields(ctx, req) +} + +func (sc *SimpleCondition) matchPath(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + path := req.GetRequest().URL.Path + return path == sc.path +} + +func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + path := req.GetRequest().URL.Path + if b, err := regexp.Match(sc.pathReg, []byte(path)); err == nil { + return b + } + return false +} + +func (sc *SimpleCondition) matchQuery(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + qs := req.GetRequest().URL.Query() + for k, v := range sc.query { + if uv, ok := qs[k]; !ok || uv[0] != v { + return false + } + } + return true +} + +func (sc *SimpleCondition) matchHeader(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + headers := req.GetRequest().Header + for k, v := range sc.header { + if uv, ok := headers[k]; !ok || uv[0] != v { + return false + } + } + return true +} + +func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + if len(sc.body) == 0 { + return true + } + getBody := req.GetRequest().GetBody + body, err := getBody() + if err != nil { + return false + } + bytes := make([]byte, req.GetRequest().ContentLength) + _, err = body.Read(bytes) + if err != nil { + return false + } + + m := make(map[string]interface{}) + + err = json.Unmarshal(bytes, &m) + + if err != nil { + return false + } + + for k, v := range sc.body { + if uv, ok := m[k]; !ok || uv != v { + return false + } + } + return true +} + +func (sc *SimpleCondition) matchMethod(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + if len(sc.method) > 0 { + return sc.method == req.GetRequest().Method + } + return true +} + +type simpleConditionOption func(sc *SimpleCondition) + +func WithPathReg(pathReg string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.pathReg = pathReg + } +} + +func WithQuery(key, value string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.query[key] = value + } +} + +func WithHeader(key, value string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.header[textproto.CanonicalMIMEHeaderKey(key)] = value + } +} + +func WithJsonBodyFields(field string, value interface{}) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.body[field] = value + } +} + +func WithMethod(method string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.method = method + } +} diff --git a/client/httplib/mock/mock_condition_test.go b/client/httplib/mock/mock_condition_test.go new file mode 100644 index 00000000..4fc6d377 --- /dev/null +++ b/client/httplib/mock/mock_condition_test.go @@ -0,0 +1,124 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func init() { + +} + +func TestSimpleCondition_MatchPath(t *testing.T) { + sc := NewSimpleCondition("/abc/s") + res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s")) + assert.True(t, res) +} + +func TestSimpleCondition_MatchQuery(t *testing.T) { + k, v := "my-key", "my-value" + sc := NewSimpleCondition("/abc/s") + res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value")) + assert.True(t, res) + + sc = NewSimpleCondition("/abc/s", WithQuery(k, v)) + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value")) + assert.True(t, res) + + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-valuesss")) + assert.False(t, res) + + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key-a=my-value")) + assert.False(t, res) + + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value&abc=hello")) + assert.True(t, res) +} + +func TestSimpleCondition_MatchHeader(t *testing.T) { + k, v := "my-header", "my-header-value" + sc := NewSimpleCondition("/abc/s") + req := httplib.Get("http://localhost:8080/abc/s") + assert.True(t, sc.Match(context.Background(), req)) + + req = httplib.Get("http://localhost:8080/abc/s") + req.Header(k, v) + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithHeader(k, v)) + req.Header(k, v) + assert.True(t, sc.Match(context.Background(), req)) + + req.Header(k, "invalid") + assert.False(t, sc.Match(context.Background(), req)) +} + +func TestSimpleCondition_MatchBodyField(t *testing.T) { + + sc := NewSimpleCondition("/abc/s") + req := httplib.Post("http://localhost:8080/abc/s") + + assert.True(t, sc.Match(context.Background(), req)) + + req.Body(`{ + "body-field": 123 +}`) + assert.True(t, sc.Match(context.Background(), req)) + + k := "body-field" + v := float64(123) + sc = NewSimpleCondition("/abc/s", WithJsonBodyFields(k, v)) + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithJsonBodyFields(k, v)) + req.Body(`{ + "body-field": abc +}`) + assert.False(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithJsonBodyFields("body-field", "abc")) + req.Body(`{ + "body-field": "abc" +}`) + assert.True(t, sc.Match(context.Background(), req)) +} + +func TestSimpleCondition_Match(t *testing.T) { + sc := NewSimpleCondition("/abc/s") + req := httplib.Post("http://localhost:8080/abc/s") + + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithMethod("POST")) + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithMethod("GET")) + assert.False(t, sc.Match(context.Background(), req)) +} + +func TestSimpleCondition_MatchPathReg(t *testing.T) { + sc := NewSimpleCondition("", WithPathReg(`\/abc\/.*`)) + req := httplib.Post("http://localhost:8080/abc/s") + assert.True(t, sc.Match(context.Background(), req)) + + req = httplib.Post("http://localhost:8080/abcd/s") + assert.False(t, sc.Match(context.Background(), req)) +} diff --git a/client/httplib/mock/mock_filter.go b/client/httplib/mock/mock_filter.go new file mode 100644 index 00000000..225d65f3 --- /dev/null +++ b/client/httplib/mock/mock_filter.go @@ -0,0 +1,61 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "net/http" + + "github.com/beego/beego/v2/client/httplib" +) + +// MockResponse will return mock response if find any suitable mock data +// if you want to test your code using httplib, you need this. +type MockResponseFilter struct { + ms []*Mock +} + +func NewMockResponseFilter() *MockResponseFilter { + return &MockResponseFilter{ + ms: make([]*Mock, 0, 1), + } +} + +func (m *MockResponseFilter) FilterChain(next httplib.Filter) httplib.Filter { + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + ms := mockFromCtx(ctx) + ms = append(ms, m.ms...) + for _, mock := range ms { + if mock.cond.Match(ctx, req) { + return mock.resp, mock.err + } + } + return next(ctx, req) + } +} + +func (m *MockResponseFilter) MockByPath(path string, resp *http.Response, err error) { + m.Mock(NewSimpleCondition(path), resp, err) +} + +func (m *MockResponseFilter) Clear() { + m.ms = make([]*Mock, 0, 1) +} + +// Mock add mock data +// If the cond.Match(...) = true, the resp and err will be returned +func (m *MockResponseFilter) Mock(cond RequestCondition, resp *http.Response, err error) { + m.ms = append(m.ms, NewMock(cond, resp, err)) +} diff --git a/client/httplib/mock/mock_filter_test.go b/client/httplib/mock/mock_filter_test.go new file mode 100644 index 00000000..b27e772e --- /dev/null +++ b/client/httplib/mock/mock_filter_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestMockResponseFilter_FilterChain(t *testing.T) { + req := httplib.Get("http://localhost:8080/abc/s") + ft := NewMockResponseFilter() + + expectedResp := httplib.NewHttpResponseWithJsonBody(`{}`) + expectedErr := errors.New("expected error") + ft.Mock(NewSimpleCondition("/abc/s"), expectedResp, expectedErr) + + req.AddFilters(ft.FilterChain) + + resp, err := req.DoRequest() + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) + + req = httplib.Get("http://localhost:8080/abcd/s") + req.AddFilters(ft.FilterChain) + + resp, err = req.DoRequest() + assert.NotEqual(t, expectedErr, err) + assert.NotEqual(t, expectedResp, resp) + + req = httplib.Get("http://localhost:8080/abc/s") + req.AddFilters(ft.FilterChain) + expectedResp1 := httplib.NewHttpResponseWithJsonBody(map[string]string{}) + expectedErr1 := errors.New("expected error") + ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1) + + resp, err = req.DoRequest() + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) + + req = httplib.Get("http://localhost:8080/abc/abs/bbc") + req.AddFilters(ft.FilterChain) + ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1) + resp, err = req.DoRequest() + assert.Equal(t, expectedErr1, err) + assert.Equal(t, expectedResp1, resp) +} diff --git a/client/httplib/mock/mock_test.go b/client/httplib/mock/mock_test.go new file mode 100644 index 00000000..2972cf8f --- /dev/null +++ b/client/httplib/mock/mock_test.go @@ -0,0 +1,77 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestStartMock(t *testing.T) { + + // httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain} + + stub := StartMock() + // defer stub.Clear() + + expectedResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`)) + expectedErr := errors.New("expected err") + + stub.Mock(NewSimpleCondition("/abc"), expectedResp, expectedErr) + + resp, err := OriginalCodeUsingHttplib() + + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) + +} + +// TestStartMock_Isolation Test StartMock that +// mock only work for this request +func TestStartMock_Isolation(t *testing.T) { + // httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain} + // setup global stub + stub := StartMock() + globalMockResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`)) + globalMockErr := errors.New("expected err") + stub.Mock(NewSimpleCondition("/abc"), globalMockResp, globalMockErr) + + expectedResp := httplib.NewHttpResponseWithJsonBody(struct { + A string `json:"a"` + }{ + A: "aaa", + }) + expectedErr := errors.New("expected err aa") + m := NewMockByPath("/abc", expectedResp, expectedErr) + ctx := CtxWithMock(context.Background(), m) + + resp, err := OriginnalCodeUsingHttplibPassCtx(ctx) + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) +} + +func OriginnalCodeUsingHttplibPassCtx(ctx context.Context) (*http.Response, error) { + return httplib.Get("http://localhost:7777/abc").DoRequestWithCtx(ctx) +} + +func OriginalCodeUsingHttplib() (*http.Response, error) { + return httplib.Get("http://localhost:7777/abc").DoRequest() +} diff --git a/client/httplib/module.go b/client/httplib/module.go new file mode 100644 index 00000000..8503133c --- /dev/null +++ b/client/httplib/module.go @@ -0,0 +1,17 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +const moduleName = "httplib" diff --git a/client/httplib/setting.go b/client/httplib/setting.go new file mode 100644 index 00000000..df8eff4b --- /dev/null +++ b/client/httplib/setting.go @@ -0,0 +1,78 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "crypto/tls" + "net/http" + "net/http/cookiejar" + "net/url" + "sync" + "time" +) + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings struct { + UserAgent string + ConnectTimeout time.Duration + ReadWriteTimeout time.Duration + TLSClientConfig *tls.Config + Proxy func(*http.Request) (*url.URL, error) + Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error + EnableCookie bool + Gzip bool + Retries int // if set to -1 means will retry forever + RetryDelay time.Duration + FilterChains []FilterChain +} + +// createDefaultCookie creates a global cookiejar to store cookies. +func createDefaultCookie() { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultCookieJar, _ = cookiejar.New(nil) +} + +// SetDefaultSetting overwrites default settings +// Keep in mind that when you invoke the SetDefaultSetting +// some methods invoked before SetDefaultSetting +func SetDefaultSetting(setting BeegoHTTPSettings) { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultSetting = setting +} + +var defaultSetting = BeegoHTTPSettings{ + UserAgent: "beegoServer", + ConnectTimeout: 60 * time.Second, + ReadWriteTimeout: 60 * time.Second, + Gzip: true, + FilterChains: make([]FilterChain, 0, 4), +} + +var defaultCookieJar http.CookieJar +var settingMutex sync.Mutex + +// AddDefaultFilter add a new filter into defaultSetting +// Be careful about using this method if you invoke SetDefaultSetting somewhere +func AddDefaultFilter(fc FilterChain) { + settingMutex.Lock() + defer settingMutex.Unlock() + if defaultSetting.FilterChains == nil { + defaultSetting.FilterChains = make([]FilterChain, 0, 4) + } + defaultSetting.FilterChains = append(defaultSetting.FilterChains, fc) +} diff --git a/client/orm/README.md b/client/orm/README.md index 58669e1f..bb11d1c6 100644 --- a/client/orm/README.md +++ b/client/orm/README.md @@ -27,7 +27,7 @@ more features please read the docs **Install:** - go get github.com/beego/beego/v2/orm + go get github.com/beego/beego/v2/client/orm ## Changelog @@ -45,7 +45,7 @@ package main import ( "fmt" - "github.com/beego/beego/v2/orm" + "github.com/beego/beego/v2/client/orm" _ "github.com/go-sql-driver/mysql" // import your used driver ) diff --git a/client/orm/clauses/const.go b/client/orm/clauses/const.go new file mode 100644 index 00000000..38a6d556 --- /dev/null +++ b/client/orm/clauses/const.go @@ -0,0 +1,6 @@ +package clauses + +const ( + ExprSep = "__" + ExprDot = "." +) diff --git a/client/orm/clauses/order_clause/order.go b/client/orm/clauses/order_clause/order.go new file mode 100644 index 00000000..bdb2d1ca --- /dev/null +++ b/client/orm/clauses/order_clause/order.go @@ -0,0 +1,104 @@ +package order_clause + +import ( + "strings" + + "github.com/beego/beego/v2/client/orm/clauses" +) + +type Sort int8 + +const ( + None Sort = 0 + Ascending Sort = 1 + Descending Sort = 2 +) + +type Option func(order *Order) + +type Order struct { + column string + sort Sort + isRaw bool +} + +func Clause(options ...Option) *Order { + o := &Order{} + for _, option := range options { + option(o) + } + + return o +} + +func (o *Order) GetColumn() string { + return o.column +} + +func (o *Order) GetSort() Sort { + return o.sort +} + +func (o *Order) SortString() string { + switch o.GetSort() { + case Ascending: + return "ASC" + case Descending: + return "DESC" + } + + return `` +} + +func (o *Order) IsRaw() bool { + return o.isRaw +} + +func ParseOrder(expressions ...string) []*Order { + var orders []*Order + for _, expression := range expressions { + sort := Ascending + column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot) + if column[0] == '-' { + sort = Descending + column = column[1:] + } + + orders = append(orders, &Order{ + column: column, + sort: sort, + }) + } + + return orders +} + +func Column(column string) Option { + return func(order *Order) { + order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot) + } +} + +func sort(sort Sort) Option { + return func(order *Order) { + order.sort = sort + } +} + +func SortAscending() Option { + return sort(Ascending) +} + +func SortDescending() Option { + return sort(Descending) +} + +func SortNone() Option { + return sort(None) +} + +func Raw() Option { + return func(order *Order) { + order.isRaw = true + } +} diff --git a/client/orm/clauses/order_clause/order_test.go b/client/orm/clauses/order_clause/order_test.go new file mode 100644 index 00000000..172e7492 --- /dev/null +++ b/client/orm/clauses/order_clause/order_test.go @@ -0,0 +1,144 @@ +package order_clause + +import ( + "testing" +) + +func TestClause(t *testing.T) { + var ( + column = `a` + ) + + o := Clause( + Column(column), + ) + + if o.GetColumn() != column { + t.Error() + } +} + +func TestSortAscending(t *testing.T) { + o := Clause( + SortAscending(), + ) + + if o.GetSort() != Ascending { + t.Error() + } +} + +func TestSortDescending(t *testing.T) { + o := Clause( + SortDescending(), + ) + + if o.GetSort() != Descending { + t.Error() + } +} + +func TestSortNone(t *testing.T) { + o1 := Clause( + SortNone(), + ) + + if o1.GetSort() != None { + t.Error() + } + + o2 := Clause() + + if o2.GetSort() != None { + t.Error() + } +} + +func TestRaw(t *testing.T) { + o1 := Clause() + + if o1.IsRaw() { + t.Error() + } + + o2 := Clause( + Raw(), + ) + + if !o2.IsRaw() { + t.Error() + } +} + +func TestColumn(t *testing.T) { + o1 := Clause( + Column(`aaa`), + ) + + if o1.GetColumn() != `aaa` { + t.Error() + } +} + +func TestParseOrder(t *testing.T) { + orders := ParseOrder( + `-user__status`, + `status`, + `user__status`, + ) + + t.Log(orders) + + if orders[0].GetSort() != Descending { + t.Error() + } + + if orders[0].GetColumn() != `user.status` { + t.Error() + } + + if orders[1].GetColumn() != `status` { + t.Error() + } + + if orders[1].GetSort() != Ascending { + t.Error() + } + + if orders[2].GetColumn() != `user.status` { + t.Error() + } + +} + +func TestOrder_GetColumn(t *testing.T) { + o := Clause( + Column(`user__id`), + ) + if o.GetColumn() != `user.id` { + t.Error() + } +} + +func TestOrder_GetSort(t *testing.T) { + o := Clause( + SortDescending(), + ) + if o.GetSort() != Descending { + t.Error() + } +} + +func TestOrder_IsRaw(t *testing.T) { + o1 := Clause() + if o1.IsRaw() { + t.Error() + } + + o2 := Clause( + Raw(), + ) + if !o2.IsRaw() { + t.Error() + } +} diff --git a/client/orm/cmd.go b/client/orm/cmd.go index b0661971..b377a5f2 100644 --- a/client/orm/cmd.go +++ b/client/orm/cmd.go @@ -15,6 +15,7 @@ package orm import ( + "context" "flag" "fmt" "os" @@ -141,6 +142,7 @@ func (d *commandSyncDb) Run() error { fmt.Printf(" %s\n", err.Error()) } + ctx := context.Background() for i, mi := range modelCache.allOrdered() { if !isApplicableTableForDB(mi.addrField, d.al.Name) { @@ -154,7 +156,7 @@ func (d *commandSyncDb) Run() error { } var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(db, mi.table) + columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table) if err != nil { if d.rtOnError { return err @@ -188,7 +190,7 @@ func (d *commandSyncDb) Run() error { } for _, idx := range indexes[mi.table] { - if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { + if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } diff --git a/client/orm/db.go b/client/orm/db.go index 4080f292..a49d6df7 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "errors" "fmt" @@ -268,7 +269,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } // create insert sql preparation statement object. -func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { +func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() dbcols := make([]string, 0, len(mi.fields.dbcols)) @@ -289,12 +290,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, d.ins.HasReturningID(mi, &query) - stmt, err := q.Prepare(query) + stmt, err := q.PrepareContext(ctx, query) return stmt, query, err } // insert struct with prepared statement and given struct reflect value. -func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err @@ -306,7 +307,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, err := row.Scan(&id) return id, err } - res, err := stmt.Exec(values...) + res, err := stmt.ExecContext(ctx, values...) if err == nil { return res.LastInsertId() } @@ -314,7 +315,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } // query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { var whereCols []string var args []interface{} @@ -360,7 +361,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo d.ins.ReplaceMarks(&query) - row := q.QueryRow(query, args...) + row := q.QueryRowContext(ctx, query, args...) if err := row.Scan(refs...); err != nil { if err == sql.ErrNoRows { return ErrNoRows @@ -375,26 +376,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } // execute insert sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { names := make([]string, 0, len(mi.fields.dbcols)) values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return 0, err } - id, err := d.InsertValue(q, mi, false, names, values) + id, err := d.InsertValue(ctx, q, mi, false, names, values) if err != nil { return 0, err } if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return id, err } // multi-insert sql with given slice struct reflect.Value. -func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { +func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 nums int @@ -440,7 +441,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul } if i > 1 && i%bulk == 0 || length == i { - num, err := d.InsertValue(q, mi, true, names, values[:nums]) + num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums]) if err != nil { return cnt, err } @@ -451,7 +452,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul var err error if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return cnt, err @@ -459,7 +460,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -482,7 +483,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -498,7 +499,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err @@ -507,7 +508,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s // InsertOrUpdate a row // If your primary key or unique column conflict will update // If no will insert -func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { args0 := "" iouStr := "" argsMap := map[string]string{} @@ -590,7 +591,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -607,7 +608,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) if err != nil && err.Error() == `pq: syntax error at or near "ON"` { @@ -617,7 +618,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a } // execute update sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if !ok { return 0, ErrMissPK @@ -674,7 +675,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, setValues...) + res, err := q.ExecContext(ctx, query, setValues...) if err == nil { return res.RowsAffected() } @@ -683,7 +684,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // execute delete sql dbQuerier with given struct reflect.Value. // delete index is pk. -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { var whereCols []string var args []interface{} // if specify cols length > 0, then use it for where condition. @@ -712,7 +713,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, args...) + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { @@ -726,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) } } - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -738,7 +739,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // update table-related record by querySet. // need querySet not struct reflect.Value to update related records. -func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { +func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { @@ -819,13 +820,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } d.ins.ReplaceMarks(&query) - var err error - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, values...) - } else { - res, err = q.Exec(query, values...) - } + res, err := q.ExecContext(ctx, query, values...) if err == nil { return res.RowsAffected() } @@ -834,13 +829,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con // delete related records. // do UpdateBanch or DeleteBanch by condition of tables' relationship. -func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { +func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo switch fi.onDelete { case odCascade: cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) + _, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz) if err != nil { return err } @@ -850,7 +845,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * if fi.onDelete == odSetDefault { params[fi.column] = fi.initial.String() } - _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) + _, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz) if err != nil { return err } @@ -861,7 +856,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * } // delete table-related records. -func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { +func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true @@ -886,7 +881,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con d.ins.ReplaceMarks(&query) var rs *sql.Rows - r, err := q.Query(query, args...) + r, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -920,19 +915,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) d.ins.ReplaceMarks(&query) - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, args...) - } else { - res, err = q.Exec(query, args...) - } + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if num > 0 { - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -943,14 +933,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } // read related records. -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) - errTyp := true + unregister := true one := true isPtr := true + name := "" if val.Kind() == reflect.Ptr { fn := "" @@ -963,19 +954,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi case reflect.Struct: isPtr = false fn = getFullName(typ) + name = getTableName(reflect.New(typ)) } } else { fn = getFullName(ind.Type()) + name = getTableName(ind) } - errTyp = fn != mi.fullName + unregister = fn != mi.fullName } - if errTyp { - if one { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) - } else { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) - } + if unregister { + RegisterModel(container) } rlimit := qs.limit @@ -1040,6 +1029,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi if qs.distinct { sqlSelect += " DISTINCT" } + if qs.aggregate != "" { + sels = qs.aggregate + } query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, specifyIndexes, join, where, groupBy, orderBy, limit) @@ -1050,18 +1042,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi d.ins.ReplaceMarks(&query) - var rs *sql.Rows - var err error - if qs != nil && qs.forContext { - rs, err = q.QueryContext(qs.ctx, query, args...) - if err != nil { - return 0, err - } - } else { - rs, err = q.Query(query, args...) - if err != nil { - return 0, err - } + rs, err := q.QueryContext(ctx, query, args...) + if err != nil { + return 0, err + } + + defer rs.Close() + + slice := ind + if unregister { + mi, _ = modelCache.get(name) + tCols = mi.fields.dbcols + colsNum = len(tCols) } refs := make([]interface{}, colsNum) @@ -1069,11 +1061,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi var ref interface{} refs[i] = &ref } - - defer rs.Close() - - slice := ind - var cnt int64 for rs.Next() { if one && cnt == 0 || !one { @@ -1172,7 +1159,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } // excute count sql and return count result int64. -func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { +func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -1194,12 +1181,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition d.ins.ReplaceMarks(&query) - var row *sql.Row - if qs != nil && qs.forContext { - row = q.QueryRowContext(qs.ctx, query, args...) - } else { - row = q.QueryRow(query, args...) - } + row := q.QueryRowContext(ctx, query, args...) err = row.Scan(&cnt) return } @@ -1649,7 +1631,7 @@ setValue: } // query sql, read values , save to *[]ParamList. -func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { +func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( maps []Params @@ -1732,7 +1714,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond d.ins.ReplaceMarks(&query) - rs, err := q.Query(query, args...) + rs, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -1847,7 +1829,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool { } // sync auto key -func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { return nil } @@ -1892,10 +1874,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { } // get all cloumns in table. -func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { columns := make(map[string][3]string) query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return columns, err } @@ -1934,7 +1916,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string { } // not implement. -func (d *dbBase) IndexExists(dbQuerier, string, string) bool { +func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool { panic(ErrNotImplement) } diff --git a/client/orm/db_alias.go b/client/orm/db_alias.go index 29e0904c..72c447b3 100644 --- a/client/orm/db_alias.go +++ b/client/orm/db_alias.go @@ -232,6 +232,14 @@ func (t *TxDB) Rollback() error { return t.tx.Rollback() } +func (t *TxDB) RollbackUnlessCommit() error { + err := t.tx.Rollback() + if err != sql.ErrTxDone { + return err + } + return nil +} + var _ dbQuerier = new(TxDB) var _ txEnder = new(TxDB) diff --git a/client/orm/db_mysql.go b/client/orm/db_mysql.go index ee68baf7..c89b1e52 100644 --- a/client/orm/db_mysql.go +++ b/client/orm/db_mysql.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "reflect" "strings" @@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) @@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool // If your primary key or unique column conflict will update // If no will insert // Add "`" for mysql sql building -func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { var iouStr string argsMap := map[string]string{} @@ -161,7 +162,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -178,7 +179,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) return id, err diff --git a/client/orm/db_oracle.go b/client/orm/db_oracle.go index 1de440b6..a3b93ff3 100644 --- a/client/orm/db_oracle.go +++ b/client/orm/db_oracle.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strings" @@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string { } // check index is exist -func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ +func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) @@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err diff --git a/client/orm/db_postgres.go b/client/orm/db_postgres.go index 12431d6e..b2f321db 100644 --- a/client/orm/db_postgres.go +++ b/client/orm/db_postgres.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strconv" ) @@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { } // sync auto key -func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { if len(autoFields) == 0 { return nil } @@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string mi.table, name, Q, name, Q, Q, mi.table, Q) - if _, err := db.Exec(query); err != nil { + if _, err := db.ExecContext(ctx, query); err != nil { return err } } @@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string { } // check index exist in postgresql. -func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) - row := db.QueryRow(query) + row := db.QueryRowContext(ctx, query) var cnt int row.Scan(&cnt) return cnt > 0 diff --git a/client/orm/db_sqlite.go b/client/orm/db_sqlite.go index aff713a5..6a4b3131 100644 --- a/client/orm/db_sqlite.go +++ b/client/orm/db_sqlite.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "fmt" "reflect" @@ -73,11 +74,11 @@ type dbBaseSqlite struct { var _ dbBaser = new(dbBaseSqlite) // override base db read for update behavior as SQlite does not support syntax -func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { if isForUpdate { DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") } - return d.dbBase.Read(q, mi, ind, tz, cols, false) + return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false) } // get sqlite operator. @@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string { } // get columns in sqlite. -func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return nil, err } @@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { } // check index exist in sqlite. -func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("PRAGMA index_list('%s')", table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { panic(err) } diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 5fd472d1..f81651ff 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -18,6 +18,9 @@ import ( "fmt" "strings" "time" + + "github.com/beego/beego/v2/client/orm/clauses" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" ) // table info struct. @@ -421,7 +424,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { } // generate order sql. -func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { +func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) { if len(orders) == 0 { return } @@ -430,19 +433,25 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { orderSqls := make([]string, 0, len(orders)) for _, order := range orders { - asc := "ASC" - if order[0] == '-' { - asc = "DESC" - order = order[1:] - } - exprs := strings.Split(order, ExprSep) + column := order.GetColumn() + clause := strings.Split(column, clauses.ExprDot) - index, _, fi, suc := t.parseExprs(t.mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) - } + if order.IsRaw() { + if len(clause) == 2 { + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString())) + } else if len(clause) == 1 { + orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString())) + } else { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } + } else { + index, _, fi, suc := t.parseExprs(t.mi, clause) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString())) + } } orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) diff --git a/client/orm/db_tidb.go b/client/orm/db_tidb.go index 6020a488..48c5b4e7 100644 --- a/client/orm/db_tidb.go +++ b/client/orm/db_tidb.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" ) @@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) diff --git a/client/orm/do_nothing_orm.go b/client/orm/do_nothing_orm.go index c6da420d..59ffe877 100644 --- a/client/orm/do_nothing_orm.go +++ b/client/orm/do_nothing_orm.go @@ -66,6 +66,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { return nil } @@ -74,6 +75,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { return nil } diff --git a/client/orm/do_nothing_orm_test.go b/client/orm/do_nothing_orm_test.go index 4d477353..e10f70af 100644 --- a/client/orm/do_nothing_orm_test.go +++ b/client/orm/do_nothing_orm_test.go @@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, o.Driver()) - assert.Nil(t, o.QueryM2MWithCtx(nil, nil, "")) assert.Nil(t, o.QueryM2M(nil, "")) assert.Nil(t, o.ReadWithCtx(nil, nil)) assert.Nil(t, o.Read(nil)) @@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int64(0), i) - assert.Nil(t, o.QueryTableWithCtx(nil, nil)) assert.Nil(t, o.QueryTable(nil)) assert.Nil(t, o.Read(nil)) diff --git a/client/orm/filter/opentracing/filter.go b/client/orm/filter/opentracing/filter.go index 75852c63..7afa07f1 100644 --- a/client/orm/filter/opentracing/filter.go +++ b/client/orm/filter/opentracing/filter.go @@ -27,7 +27,7 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to trace QuerySetter -// actually we trace invoking "QueryTable" and "QueryTableWithCtx" +// actually we trace invoking "QueryTable" // the method Begin*, Commit and Rollback are ignored. // When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. type FilterChainBuilder struct { diff --git a/client/orm/filter/prometheus/filter.go b/client/orm/filter/prometheus/filter.go index 7563f51e..b2c83dcf 100644 --- a/client/orm/filter/prometheus/filter.go +++ b/client/orm/filter/prometheus/filter.go @@ -32,7 +32,7 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to records the metrics of QuerySetter -// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx" +// actually we only records metrics of invoking "QueryTable" type FilterChainBuilder struct { AppName string ServerName string @@ -85,7 +85,7 @@ func (builder *FilterChainBuilder) report(ctx context.Context, inv *orm.Invocati } func (builder *FilterChainBuilder) reportTxn(ctx context.Context, inv *orm.Invocation) { - dur := time.Now().Sub(inv.TxStartTime) / time.Millisecond + dur := time.Since(inv.TxStartTime) / time.Millisecond summaryVec.WithLabelValues(inv.Method, inv.TxName, strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur)) } diff --git a/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go index a60390a1..6a9ecc53 100644 --- a/client/orm/filter_orm_decorator.go +++ b/client/orm/filter_orm_decorator.go @@ -20,6 +20,7 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/core/utils" ) @@ -161,36 +162,34 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac } func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { - return f.QueryM2MWithCtx(context.Background(), md, name) -} - -func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, _ := modelCache.getByMd(md) inv := &Invocation{ - Method: "QueryM2MWithCtx", + Method: "QueryM2M", Args: []interface{}{md, name}, Md: md, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, f: func(c context.Context) []interface{} { - res := f.ormer.QueryM2MWithCtx(c, md, name) + res := f.ormer.QueryM2M(md, name) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil } return res[0].(QueryM2Mer) } -func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { - return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName) +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.") + return f.QueryM2M(md, name) } -func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { +func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { var ( name string md interface{} @@ -209,18 +208,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT } inv := &Invocation{ - Method: "QueryTableWithCtx", + Method: "QueryTable", Args: []interface{}{ptrStructOrTableName}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, Md: md, mi: mi, f: func(c context.Context) []interface{} { - res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) + res := f.ormer.QueryTable(ptrStructOrTableName) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil @@ -228,6 +227,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT return res[0].(QuerySeter) } +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.") + return f.QueryTable(ptrStructOrTableName) +} + func (f *filterOrmDecorator) DBStats() *sql.DBStats { inv := &Invocation{ Method: "DBStats", @@ -498,6 +503,22 @@ func (f *filterOrmDecorator) Rollback() error { return f.convertError(res[0]) } +func (f *filterOrmDecorator) RollbackUnlessCommit() error { + inv := &Invocation{ + Method: "RollbackUnlessCommit", + Args: []interface{}{}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: f.txName, + f: func(c context.Context) []interface{} { + err := f.TxCommitter.RollbackUnlessCommit() + return []interface{}{err} + }, + } + res := f.root(context.Background(), inv) + return f.convertError(res[0]) +} + func (f *filterOrmDecorator) convertError(v interface{}) error { if v == nil { return nil diff --git a/client/orm/filter_orm_decorator_test.go b/client/orm/filter_orm_decorator_test.go index 9e223358..6c3bc72b 100644 --- a/client/orm/filter_orm_decorator_test.go +++ b/client/orm/filter_orm_decorator_test.go @@ -268,7 +268,7 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryM2MWithCtx", inv.Method) + assert.Equal(t, "QueryM2M", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) @@ -284,7 +284,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryTableWithCtx", inv.Method) + assert.Equal(t, "QueryTable", inv.Method) assert.Equal(t, 1, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) @@ -402,6 +402,10 @@ func (f *filterMockOrm) Rollback() error { return errors.New("rollback") } +func (f *filterMockOrm) RollbackUnlessCommit() error { + return errors.New("rollback unless commit") +} + func (f *filterMockOrm) DBStats() *sql.DBStats { return &sql.DBStats{ MaxOpenConnections: -1, diff --git a/client/orm/mock/condition.go b/client/orm/mock/condition.go new file mode 100644 index 00000000..eda88824 --- /dev/null +++ b/client/orm/mock/condition.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +type Mock struct { + cond Condition + resp []interface{} + cb func(inv *orm.Invocation) +} + +func NewMock(cond Condition, resp []interface{}, cb func(inv *orm.Invocation)) *Mock { + return &Mock{ + cond: cond, + resp: resp, + cb: cb, + } +} + +type Condition interface { + Match(ctx context.Context, inv *orm.Invocation) bool +} + +type SimpleCondition struct { + tableName string + method string +} + +func NewSimpleCondition(tableName string, methodName string) Condition { + return &SimpleCondition{ + tableName: tableName, + method: methodName, + } +} + +func (s *SimpleCondition) Match(ctx context.Context, inv *orm.Invocation) bool { + res := true + if len(s.tableName) != 0 { + res = res && (s.tableName == inv.GetTableName()) + } + + if len(s.method) != 0 { + res = res && (s.method == inv.Method) + } + return res +} diff --git a/client/orm/mock/condition_test.go b/client/orm/mock/condition_test.go new file mode 100644 index 00000000..7f646e70 --- /dev/null +++ b/client/orm/mock/condition_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestSimpleCondition_Match(t *testing.T) { + cond := NewSimpleCondition("", "") + res := cond.Match(context.Background(), &orm.Invocation{}) + assert.True(t, res) + cond = NewSimpleCondition("hello", "") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{})) + + cond = NewSimpleCondition("", "A") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{ + Method: "B", + })) + + assert.True(t, cond.Match(context.Background(), &orm.Invocation{ + Method: "A", + })) +} diff --git a/client/orm/mock/context.go b/client/orm/mock/context.go new file mode 100644 index 00000000..ca251c5d --- /dev/null +++ b/client/orm/mock/context.go @@ -0,0 +1,40 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/core/logs" +) + +type mockCtxKeyType string + +const mockCtxKey = mockCtxKeyType("beego-orm-mock") + +func CtxWithMock(ctx context.Context, mock ...*Mock) context.Context { + return context.WithValue(ctx, mockCtxKey, mock) +} + +func mockFromCtx(ctx context.Context) []*Mock { + ms := ctx.Value(mockCtxKey) + if ms != nil { + if res, ok := ms.([]*Mock); ok { + return res + } + logs.Error("mockCtxKey found in context, but value is not type []*Mock") + } + return nil +} diff --git a/client/orm/mock/context_test.go b/client/orm/mock/context_test.go new file mode 100644 index 00000000..a3ed1e90 --- /dev/null +++ b/client/orm/mock/context_test.go @@ -0,0 +1,29 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCtx(t *testing.T) { + ms := make([]*Mock, 0, 4) + ctx := CtxWithMock(context.Background(), ms...) + res := mockFromCtx(ctx) + assert.Equal(t, ms, res) +} diff --git a/client/orm/mock/mock.go b/client/orm/mock/mock.go new file mode 100644 index 00000000..072488b2 --- /dev/null +++ b/client/orm/mock/mock.go @@ -0,0 +1,72 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +var stub = newOrmStub() + +func init() { + orm.AddGlobalFilterChain(stub.FilterChain) +} + +type Stub interface { + Mock(m *Mock) + Clear() +} + +type OrmStub struct { + ms []*Mock +} + +func StartMock() Stub { + return stub +} + +func newOrmStub() *OrmStub { + return &OrmStub{ + ms: make([]*Mock, 0, 4), + } +} + +func (o *OrmStub) Mock(m *Mock) { + o.ms = append(o.ms, m) +} + +func (o *OrmStub) Clear() { + o.ms = make([]*Mock, 0, 4) +} + +func (o *OrmStub) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { + + ms := mockFromCtx(ctx) + ms = append(ms, o.ms...) + + for _, mock := range ms { + if mock.cond.Match(ctx, inv) { + if mock.cb != nil { + mock.cb(inv) + } + return mock.resp + } + } + return next(ctx, inv) + } +} diff --git a/client/orm/mock/mock_orm.go b/client/orm/mock/mock_orm.go new file mode 100644 index 00000000..853a4213 --- /dev/null +++ b/client/orm/mock/mock_orm.go @@ -0,0 +1,167 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "database/sql" + "os" + "path/filepath" + + _ "github.com/mattn/go-sqlite3" + + "github.com/beego/beego/v2/client/orm" +) + +func init() { + RegisterMockDB("default") +} + +// RegisterMockDB create an "virtual DB" by using sqllite +// you should not +func RegisterMockDB(name string) { + source := filepath.Join(os.TempDir(), name+".db") + _ = orm.RegisterDataBase(name, "sqlite3", source) +} + +// MockTable only check table name +func MockTable(tableName string, resp ...interface{}) *Mock { + return NewMock(NewSimpleCondition(tableName, ""), resp, nil) +} + +// MockMethod only check method name +func MockMethod(method string, resp ...interface{}) *Mock { + return NewMock(NewSimpleCondition("", method), resp, nil) +} + +// MockOrmRead support orm.Read and orm.ReadWithCtx +// cb is used to mock read data from DB +func MockRead(tableName string, cb func(data interface{}), err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadWithCtx"), []interface{}{err}, func(inv *orm.Invocation) { + if cb != nil { + cb(inv.Args[0]) + } + }) +} + +// MockReadForUpdateWithCtx support ReadForUpdate and ReadForUpdateWithCtx +// cb is used to mock read data from DB +func MockReadForUpdateWithCtx(tableName string, cb func(data interface{}), err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadForUpdateWithCtx"), + []interface{}{err}, + func(inv *orm.Invocation) { + cb(inv.Args[0]) + }) +} + +// MockReadOrCreateWithCtx support ReadOrCreate and ReadOrCreateWithCtx +// cb is used to mock read data from DB +func MockReadOrCreateWithCtx(tableName string, + cb func(data interface{}), + insert bool, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadOrCreateWithCtx"), + []interface{}{insert, id, err}, + func(inv *orm.Invocation) { + cb(inv.Args[0]) + }) +} + +// MockInsertWithCtx support Insert and InsertWithCtx +func MockInsertWithCtx(tableName string, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertWithCtx"), []interface{}{id, err}, nil) +} + +// MockInsertMultiWithCtx support InsertMulti and InsertMultiWithCtx +func MockInsertMultiWithCtx(tableName string, cnt int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertMultiWithCtx"), []interface{}{cnt, err}, nil) +} + +// MockInsertOrUpdateWithCtx support InsertOrUpdate and InsertOrUpdateWithCtx +func MockInsertOrUpdateWithCtx(tableName string, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertOrUpdateWithCtx"), []interface{}{id, err}, nil) +} + +// MockUpdateWithCtx support UpdateWithCtx and Update +func MockUpdateWithCtx(tableName string, affectedRow int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "UpdateWithCtx"), []interface{}{affectedRow, err}, nil) +} + +// MockDeleteWithCtx support Delete and DeleteWithCtx +func MockDeleteWithCtx(tableName string, affectedRow int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "DeleteWithCtx"), []interface{}{affectedRow, err}, nil) +} + +// MockQueryM2MWithCtx support QueryM2MWithCtx and QueryM2M +// Now you may be need to use golang/mock to generate QueryM2M mock instance +// Or use DoNothingQueryM2Mer +// for example: +// post := Post{Id: 4} +// m2m := Ormer.QueryM2M(&post, "Tags") +// when you write test code: +// MockQueryM2MWithCtx("post", "Tags", mockM2Mer) +// "post" is the table name of model Post structure +// TODO provide orm.QueryM2Mer +func MockQueryM2MWithCtx(tableName string, name string, res orm.QueryM2Mer) *Mock { + return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{res}, nil) +} + +// MockLoadRelatedWithCtx support LoadRelatedWithCtx and LoadRelated +func MockLoadRelatedWithCtx(tableName string, name string, rows int64, err error) *Mock { + return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{rows, err}, nil) +} + +// MockQueryTableWithCtx support QueryTableWithCtx and QueryTable +func MockQueryTableWithCtx(tableName string, qs orm.QuerySeter) *Mock { + return NewMock(NewSimpleCondition(tableName, "QueryTable"), []interface{}{qs}, nil) +} + +// MockRawWithCtx support RawWithCtx and Raw +func MockRawWithCtx(rs orm.RawSeter) *Mock { + return NewMock(NewSimpleCondition("", "RawWithCtx"), []interface{}{rs}, nil) +} + +// MockDriver support Driver +// func MockDriver(driver orm.Driver) *Mock { +// return NewMock(NewSimpleCondition("", "Driver"), []interface{}{driver}) +// } + +// MockDBStats support DBStats +func MockDBStats(stats *sql.DBStats) *Mock { + return NewMock(NewSimpleCondition("", "DBStats"), []interface{}{stats}, nil) +} + +// MockBeginWithCtxAndOpts support Begin, BeginWithCtx, BeginWithOpts, BeginWithCtxAndOpts +// func MockBeginWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock { +// return NewMock(NewSimpleCondition("", "BeginWithCtxAndOpts"), []interface{}{txOrm, err}) +// } + +// MockDoTxWithCtxAndOpts support DoTx, DoTxWithCtx, DoTxWithOpts, DoTxWithCtxAndOpts +// func MockDoTxWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock { +// return MockBeginWithCtxAndOpts(txOrm, err) +// } + +// MockCommit support Commit +func MockCommit(err error) *Mock { + return NewMock(NewSimpleCondition("", "Commit"), []interface{}{err}, nil) +} + +// MockRollback support Rollback +func MockRollback(err error) *Mock { + return NewMock(NewSimpleCondition("", "Rollback"), []interface{}{err}, nil) +} + +// MockRollbackUnlessCommit support RollbackUnlessCommit +func MockRollbackUnlessCommit(err error) *Mock { + return NewMock(NewSimpleCondition("", "RollbackUnlessCommit"), []interface{}{err}, nil) +} \ No newline at end of file diff --git a/client/orm/mock/mock_orm_test.go b/client/orm/mock/mock_orm_test.go new file mode 100644 index 00000000..d34774d0 --- /dev/null +++ b/client/orm/mock/mock_orm_test.go @@ -0,0 +1,312 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "database/sql" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +const mockErrorMsg = "mock error" + +func init() { + orm.RegisterModel(&User{}) +} + +func TestMockDBStats(t *testing.T) { + s := StartMock() + defer s.Clear() + stats := &sql.DBStats{} + s.Mock(MockDBStats(stats)) + + o := orm.NewOrm() + + res := o.DBStats() + + assert.Equal(t, stats, res) +} + +func TestMockDeleteWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + s.Mock(MockDeleteWithCtx((&User{}).TableName(), 12, nil)) + o := orm.NewOrm() + rows, err := o.Delete(&User{}) + assert.Equal(t, int64(12), rows) + assert.Nil(t, err) +} + +func TestMockInsertOrUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + s.Mock(MockInsertOrUpdateWithCtx((&User{}).TableName(), 12, nil)) + o := orm.NewOrm() + id, err := o.InsertOrUpdate(&User{}) + assert.Equal(t, int64(12), id) + assert.Nil(t, err) +} + +func TestMockRead(t *testing.T) { + s := StartMock() + defer s.Clear() + err := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, err)) + o := orm.NewOrm() + u := &User{} + e := o.Read(u) + assert.Equal(t, err, e) + assert.Equal(t, "Tom", u.Name) +} + +func TestMockQueryM2MWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingQueryM2Mer{} + s.Mock(MockQueryM2MWithCtx((&User{}).TableName(), "Tags", mock)) + o := orm.NewOrm() + res := o.QueryM2M(&User{}, "Tags") + assert.Equal(t, mock, res) +} + +func TestMockQueryTableWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingQuerySetter{} + s.Mock(MockQueryTableWithCtx((&User{}).TableName(), mock)) + o := orm.NewOrm() + res := o.QueryTable(&User{}) + assert.Equal(t, mock, res) +} + +func TestMockTable(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockTable((&User{}).TableName(), mock)) + o := orm.NewOrm() + res := o.Read(&User{}) + assert.Equal(t, mock, res) +} + +func TestMockInsertMultiWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockInsertMultiWithCtx((&User{}).TableName(), 12, mock)) + o := orm.NewOrm() + res, err := o.InsertMulti(11, []interface{}{&User{}}) + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockInsertWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockInsertWithCtx((&User{}).TableName(), 13, mock)) + o := orm.NewOrm() + res, err := o.Insert(&User{}) + assert.Equal(t, int64(13), res) + assert.Equal(t, mock, err) +} + +func TestMockUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockUpdateWithCtx((&User{}).TableName(), 12, mock)) + o := orm.NewOrm() + res, err := o.Update(&User{}) + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockLoadRelatedWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockLoadRelatedWithCtx((&User{}).TableName(), "T", 12, mock)) + o := orm.NewOrm() + res, err := o.LoadRelated(&User{}, "T") + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockMethod(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockMethod("ReadWithCtx", mock)) + o := orm.NewOrm() + err := o.Read(&User{}) + assert.Equal(t, mock, err) +} + +func TestMockReadForUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockReadForUpdateWithCtx((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + o := orm.NewOrm() + u := &User{} + err := o.ReadForUpdate(u) + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestMockRawWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingRawSetter{} + s.Mock(MockRawWithCtx(mock)) + o := orm.NewOrm() + res := o.Raw("") + assert.Equal(t, mock, res) +} + +func TestMockReadOrCreateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockReadOrCreateWithCtx((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, false, 12, mock)) + o := orm.NewOrm() + u := &User{} + inserted, id, err := o.ReadOrCreate(u, "") + assert.Equal(t, mock, err) + assert.Equal(t, int64(12), id) + assert.False(t, inserted) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionClosure(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + u, err := originalTxUsingClosure() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionManually(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + u, err := originalTxManually() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionRollback(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), nil, errors.New("read error"))) + s.Mock(MockRollback(mock)) + _, err := originalTx() + assert.Equal(t, mock, err) +} + +func TestTransactionRollbackUnlessCommit(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRollbackUnlessCommit(mock)) + + //u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.RollbackUnlessCommit() + assert.Equal(t, mock, err) +} + +func TestTransactionCommit(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, nil)) + s.Mock(MockCommit(mock)) + u, err := originalTx() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func originalTx() (*User, error) { + u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.Read(u) + if err == nil { + err = txOrm.Commit() + return u, err + } else { + err = txOrm.Rollback() + return nil, err + } +} + +func originalTxManually() (*User, error) { + u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.Read(u) + _ = txOrm.Commit() + return u, err +} + +func originalTxUsingClosure() (*User, error) { + u := &User{} + var err error + o := orm.NewOrm() + _ = o.DoTx(func(ctx context.Context, txOrm orm.TxOrmer) error { + err = txOrm.Read(u) + return nil + }) + return u, err +} + +type User struct { + Id int + Name string +} + +func (u *User) TableName() string { + return "user" +} diff --git a/client/orm/mock/mock_queryM2Mer.go b/client/orm/mock/mock_queryM2Mer.go new file mode 100644 index 00000000..a58f10ae --- /dev/null +++ b/client/orm/mock/mock_queryM2Mer.go @@ -0,0 +1,89 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +// DoNothingQueryM2Mer do nothing +// use it to build mock orm.QueryM2Mer +type DoNothingQueryM2Mer struct { +} + +func (d *DoNothingQueryM2Mer) AddWithCtx(ctx context.Context, i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) RemoveWithCtx(ctx context.Context, i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) ExistWithCtx(ctx context.Context, i interface{}) bool { + return true +} + +func (d *DoNothingQueryM2Mer) ClearWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) CountWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Add(i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Remove(i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Exist(i interface{}) bool { + return true +} + +func (d *DoNothingQueryM2Mer) Clear() (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Count() (int64, error) { + return 0, nil +} + +type QueryM2MerCondition struct { + tableName string + name string +} + +func NewQueryM2MerCondition(tableName string, name string) *QueryM2MerCondition { + return &QueryM2MerCondition{ + tableName: tableName, + name: name, + } +} + +func (q *QueryM2MerCondition) Match(ctx context.Context, inv *orm.Invocation) bool { + res := true + if len(q.tableName) > 0 { + res = res && (q.tableName == inv.GetTableName()) + } + if len(q.name) > 0 { + res = res && (len(inv.Args) > 1) && (q.name == inv.Args[1].(string)) + } + return res +} diff --git a/client/orm/mock/mock_queryM2Mer_test.go b/client/orm/mock/mock_queryM2Mer_test.go new file mode 100644 index 00000000..ef754092 --- /dev/null +++ b/client/orm/mock/mock_queryM2Mer_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestDoNothingQueryM2Mer(t *testing.T) { + m2m := &DoNothingQueryM2Mer{} + + i, err := m2m.Clear() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Count() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Add() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Remove() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + assert.True(t, m2m.Exist(nil)) +} + +func TestNewQueryM2MerCondition(t *testing.T) { + cond := NewQueryM2MerCondition("", "") + res := cond.Match(context.Background(), &orm.Invocation{}) + assert.True(t, res) + cond = NewQueryM2MerCondition("hello", "") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{})) + + cond = NewQueryM2MerCondition("", "A") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{ + Args: []interface{}{0, "B"}, + })) + + assert.True(t, cond.Match(context.Background(), &orm.Invocation{ + Args: []interface{}{0, "A"}, + })) +} diff --git a/client/orm/mock/mock_querySetter.go b/client/orm/mock/mock_querySetter.go new file mode 100644 index 00000000..074b6211 --- /dev/null +++ b/client/orm/mock/mock_querySetter.go @@ -0,0 +1,183 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" +) + +// DoNothingQuerySetter do nothing +// usually you use this to build your mock QuerySetter +type DoNothingQuerySetter struct { +} + +func (d *DoNothingQuerySetter) OrderClauses(orders ...*order_clause.Order) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) CountWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ExistWithCtx(ctx context.Context) bool { + return true +} + +func (d *DoNothingQuerySetter) UpdateWithCtx(ctx context.Context, values orm.Params) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) DeleteWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) PrepareInsertWithCtx(ctx context.Context) (orm.Inserter, error) { + return nil, nil +} + +func (d *DoNothingQuerySetter) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingQuerySetter) ValuesWithCtx(ctx context.Context, results *[]orm.Params, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesListWithCtx(ctx context.Context, results *[]orm.ParamsList, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesFlatWithCtx(ctx context.Context, result *orm.ParamsList, expr string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Aggregate(s string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Filter(s string, i ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) FilterRaw(s string, s2 string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Exclude(s string, i ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) SetCond(condition *orm.Condition) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) GetCond() *orm.Condition { + return orm.NewCondition() +} + +func (d *DoNothingQuerySetter) Limit(limit interface{}, args ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Offset(offset interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) GroupBy(exprs ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) OrderBy(exprs ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) RelatedSel(params ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Distinct() orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) ForUpdate() orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Count() (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Exist() bool { + return true +} + +func (d *DoNothingQuerySetter) Update(values orm.Params) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Delete() (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) PrepareInsert() (orm.Inserter, error) { + return nil, nil +} + +func (d *DoNothingQuerySetter) All(container interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) One(container interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingQuerySetter) Values(results *[]orm.Params, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesList(results *[]orm.ParamsList, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesFlat(result *orm.ParamsList, expr string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return 0, nil +} diff --git a/client/orm/mock/mock_querySetter_test.go b/client/orm/mock/mock_querySetter_test.go new file mode 100644 index 00000000..09e5ad8c --- /dev/null +++ b/client/orm/mock/mock_querySetter_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingQuerySetter(t *testing.T) { + setter := &DoNothingQuerySetter{} + setter.GroupBy().Filter("").Limit(10). + Distinct().Exclude("a").FilterRaw("", ""). + ForceIndex().ForUpdate().IgnoreIndex(). + Offset(11).OrderBy().RelatedSel().SetCond(nil).UseIndex() + + assert.True(t, setter.Exist()) + err := setter.One(nil) + assert.Nil(t, err) + i, err := setter.Count() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Delete() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.All(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Update(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.RowsToMap(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.RowsToStruct(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Values(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.ValuesFlat(nil, "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.ValuesList(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + ins, err := setter.PrepareInsert() + assert.Nil(t, err) + assert.Nil(t, ins) + + assert.NotNil(t, setter.GetCond()) +} diff --git a/client/orm/mock/mock_rawSetter.go b/client/orm/mock/mock_rawSetter.go new file mode 100644 index 00000000..00311e80 --- /dev/null +++ b/client/orm/mock/mock_rawSetter.go @@ -0,0 +1,64 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "database/sql" + + "github.com/beego/beego/v2/client/orm" +) + +type DoNothingRawSetter struct { +} + +func (d *DoNothingRawSetter) Exec() (sql.Result, error) { + return nil, nil +} + +func (d *DoNothingRawSetter) QueryRow(containers ...interface{}) error { + return nil +} + +func (d *DoNothingRawSetter) QueryRows(containers ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) SetArgs(i ...interface{}) orm.RawSeter { + return d +} + +func (d *DoNothingRawSetter) Values(container *[]orm.Params, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) ValuesList(container *[]orm.ParamsList, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) ValuesFlat(container *orm.ParamsList, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) Prepare() (orm.RawPreparer, error) { + return nil, nil +} diff --git a/client/orm/mock/mock_rawSetter_test.go b/client/orm/mock/mock_rawSetter_test.go new file mode 100644 index 00000000..dd98edbd --- /dev/null +++ b/client/orm/mock/mock_rawSetter_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingRawSetter(t *testing.T) { + rs := &DoNothingRawSetter{} + i, err := rs.ValuesList(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.Values(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.ValuesFlat(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.RowsToStruct(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.RowsToMap(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.QueryRows() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + err = rs.QueryRow() + // assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + s, err := rs.Exec() + assert.Nil(t, err) + assert.Nil(t, s) + + p, err := rs.Prepare() + assert.Nil(t, err) + assert.Nil(t, p) + + rrs := rs.SetArgs() + assert.Equal(t, rrs, rs) +} diff --git a/client/orm/mock/mock_test.go b/client/orm/mock/mock_test.go new file mode 100644 index 00000000..73bce4e5 --- /dev/null +++ b/client/orm/mock/mock_test.go @@ -0,0 +1,58 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestOrmStub_FilterChain(t *testing.T) { + os := newOrmStub() + inv := &orm.Invocation{ + Args: []interface{}{10}, + } + i := 1 + os.FilterChain(func(ctx context.Context, inv *orm.Invocation) []interface{} { + i++ + return nil + })(context.Background(), inv) + + assert.Equal(t, 2, i) + + m := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) { + arg := inv.Args[0] + j := arg.(int) + inv.Args[0] = j + 1 + }) + os.Mock(m) + + os.FilterChain(nil)(context.Background(), inv) + assert.Equal(t, 11, inv.Args[0]) + + inv.Args[0] = 10 + ctxMock := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) { + arg := inv.Args[0] + j := arg.(int) + inv.Args[0] = j + 3 + }) + + os.FilterChain(nil)(CtxWithMock(context.Background(), ctxMock), inv) + assert.Equal(t, 13, inv.Args[0]) +} diff --git a/client/orm/models.go b/client/orm/models.go index 64dfab09..31cdc4a1 100644 --- a/client/orm/models.go +++ b/client/orm/models.go @@ -332,10 +332,6 @@ end: // register register models to model cache func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { - if mc.done { - err = fmt.Errorf("register must be run before BootStrap") - return - } for _, model := range models { val := reflect.ValueOf(model) @@ -352,7 +348,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m err = fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) return } - + if val.Elem().Kind() == reflect.Slice { + val = reflect.New(val.Elem().Type().Elem()) + } table := getTableName(val) if prefixOrSuffixStr != "" { @@ -371,8 +369,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m } if _, ok := mc.get(table); ok { - err = fmt.Errorf(" table name `%s` repeat register, must be unique\n", table) - return + return nil } mi := newModelInfo(val) @@ -389,12 +386,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m } } } - - if mi.fields.pk == nil { - err = fmt.Errorf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - return - } - } mi.table = table diff --git a/client/orm/models_info_f.go b/client/orm/models_info_f.go index c7ad4801..6d1263e2 100644 --- a/client/orm/models_info_f.go +++ b/client/orm/models_info_f.go @@ -101,29 +101,30 @@ func newFields() *fields { // single field info type fieldInfo struct { - mi *modelInfo - fieldIndex []int - fieldType int dbcol bool // table column fk and onetoone inModel bool - name string - fullName string - column string - addrValue reflect.Value - sf reflect.StructField auto bool pk bool null bool index bool unique bool - colDefault bool // whether has default tag - initial StrTo // store the default value - size int + colDefault bool // whether has default tag toText bool autoNow bool autoNowAdd bool rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true reverse bool + isFielder bool // implement Fielder interface + mi *modelInfo + fieldIndex []int + fieldType int + name string + fullName string + column string + addrValue reflect.Value + sf reflect.StructField + initial StrTo // store the default value + size int reverseField string reverseFieldInfo *fieldInfo reverseFieldInfoTwo *fieldInfo @@ -134,7 +135,6 @@ type fieldInfo struct { relModelInfo *modelInfo digits int decimals int - isFielder bool // implement Fielder interface onDelete string description string timePrecision *int diff --git a/client/orm/models_info_m.go b/client/orm/models_info_m.go index c9a979af..b94480ca 100644 --- a/client/orm/models_info_m.go +++ b/client/orm/models_info_m.go @@ -22,16 +22,16 @@ import ( // single model info type modelInfo struct { + manual bool + isThrough bool pkg string name string fullName string table string model interface{} fields *fields - manual bool addrField reflect.Value // store the original struct value uniques []string - isThrough bool } // new model info diff --git a/client/orm/models_test.go b/client/orm/models_test.go index e3f74c0b..0051c126 100644 --- a/client/orm/models_test.go +++ b/client/orm/models_test.go @@ -118,6 +118,10 @@ var _ Fielder = new(JSONFieldTest) type Data struct { ID int `orm:"column(id)"` Boolean bool + Byte byte + Int8 int8 + Uint8 uint8 + Rune rune Char string `orm:"size(50)"` Text string `orm:"type(text)"` JSON string `orm:"type(json);default({\"name\":\"json\"})"` @@ -125,26 +129,21 @@ type Data struct { Time time.Time `orm:"type(time)"` Date time.Time `orm:"type(date)"` DateTime time.Time `orm:"column(datetime)"` - Byte byte - Rune rune Int int - Int8 int8 + Uint uint Int16 int16 + Uint16 uint16 Int32 int32 Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 Uint32 uint32 - Uint64 uint64 Float32 float32 + Uint64 uint64 Float64 float64 Decimal float64 `orm:"digits(8);decimals(4)"` } type DataNull struct { ID int `orm:"column(id)"` - Boolean bool `orm:"null"` Char string `orm:"null;size(50)"` Text string `orm:"null;type(text)"` JSON string `orm:"type(json);null"` @@ -153,19 +152,20 @@ type DataNull struct { Date time.Time `orm:"null;type(date)"` DateTime time.Time `orm:"null;column(datetime)"` DateTimePrecision time.Time `orm:"null;type(datetime);precision(4)"` + Boolean bool `orm:"null"` Byte byte `orm:"null"` + Int8 int8 `orm:"null"` + Uint8 uint8 `orm:"null"` Rune rune `orm:"null"` Int int `orm:"null"` - Int8 int8 `orm:"null"` + Uint uint `orm:"null"` Int16 int16 `orm:"null"` + Uint16 uint16 `orm:"null"` Int32 int32 `orm:"null"` Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` Float32 float32 `orm:"null"` + Uint64 uint64 `orm:"null"` Float64 float64 `orm:"null"` Decimal float64 `orm:"digits(8);decimals(4);null"` NullString sql.NullString `orm:"null"` @@ -215,21 +215,21 @@ type Float64 float64 type DataCustom struct { ID int `orm:"column(id)"` Boolean Boolean + Byte Byte + Int8 Int8 + Uint8 Uint8 + Rune Rune Char string `orm:"size(50)"` Text string `orm:"type(text)"` - Byte Byte - Rune Rune Int Int - Int8 Int8 + Uint Uint Int16 Int16 + Uint16 Uint16 Int32 Int32 Int64 Int64 - Uint Uint - Uint8 Uint8 - Uint16 Uint16 Uint32 Uint32 - Uint64 Uint64 Float32 Float32 + Uint64 Uint64 Float64 Float64 Decimal Float64 `orm:"digits(8);decimals(4)"` } @@ -255,24 +255,40 @@ func NewTM() *TM { return obj } +type DeptInfo struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + +type UnregisterModel struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + type User struct { - ID int `orm:"column(id)"` - UserName string `orm:"size(30);unique"` - Email string `orm:"size(100)"` - Password string `orm:"size(100)"` - Status int16 `orm:"column(Status)"` - IsStaff bool - IsActive bool `orm:"default(true)"` - Created time.Time `orm:"auto_now_add;type(date)"` - Updated time.Time `orm:"auto_now"` - Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - ShouldSkip string `orm:"-"` - Nums int - Langs SliceStringField `orm:"size(100)"` - Extra JSONFieldTest `orm:"type(text)"` - unexport bool `orm:"-"` - unexportBool bool + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 `orm:"column(Status)"` + IsStaff bool + IsActive bool `orm:"default(true)"` + Unexported bool `orm:"-"` + UnexportedBool bool + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` + Nums int + Langs SliceStringField `orm:"size(100)"` + Extra JSONFieldTest `orm:"type(text)"` } func (u *User) TableIndex() [][]string { @@ -476,45 +492,45 @@ var ( helpinfo = `need driver and source! Default DB Drivers. - + driver: url mysql: https://github.com/go-sql-driver/mysql sqlite3: https://github.com/mattn/go-sqlite3 postgres: https://github.com/lib/pq tidb: https://github.com/pingcap/tidb - + usage: - + go get -u github.com/beego/beego/v2/client/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq go get -u github.com/pingcap/tidb - + #### MySQL mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" go test -v github.com/beego/beego/v2/client/orm - - + + #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' go test -v github.com/beego/beego/v2/client/orm - - + + #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" go test -v github.com/beego/beego/v2/client/orm - + #### TiDB export ORM_DRIVER=tidb export ORM_SOURCE='memory://test/test' go test -v github.com/beego/beego/v2/pgk/orm - + ` ) diff --git a/client/orm/orm.go b/client/orm/orm.go index 1adf84e2..fa96de4f 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -62,6 +62,8 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/hints" "github.com/beego/beego/v2/core/utils" @@ -135,7 +137,7 @@ func (o *ormBase) Read(md interface{}, cols ...string) error { } func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form @@ -144,7 +146,7 @@ func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { } func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist @@ -154,7 +156,7 @@ func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (boo func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create id, err := o.InsertWithCtx(ctx, md) @@ -179,7 +181,7 @@ func (o *ormBase) Insert(md interface{}) (int64, error) { } func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return id, err } @@ -222,7 +224,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac for i := 0; i < sind.Len(); i++ { ind := reflect.Indirect(sind.Index(i)) mi, _ := o.getMiInd(ind.Interface(), false) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return cnt, err } @@ -233,7 +235,7 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac } } else { mi, _ := o.getMiInd(sind.Index(0).Interface(), false) - return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ) } return cnt, nil } @@ -244,7 +246,7 @@ func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) ( } func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) + id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { return id, err } @@ -261,7 +263,7 @@ func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) + return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols) } // delete model in database @@ -271,7 +273,7 @@ func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { } func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) + num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols) if err != nil { return num, err } @@ -283,9 +285,6 @@ func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...str // create a models to models queryer func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { - return o.QueryM2MWithCtx(context.Background(), md, name) -} -func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { mi, ind := o.getMiInd(md, true) fi := o.getFieldInfo(mi, name) @@ -299,6 +298,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri return newQueryM2M(md, o, mi, fi, ind) } +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.") + return o.QueryM2M(md, name) +} + // load related models to md model. // args are limit, offset int and order string. // @@ -351,7 +356,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s qs.relDepth = relDepth if len(order) > 0 { - qs.orders = []string{order} + qs.orders = order_clause.ParseOrder(order) } find := ind.FieldByIndex(fi.fieldIndex) @@ -451,9 +456,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { - return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName) -} -func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { name = nameStrategyMap[defaultNameStrategy](table) @@ -469,7 +471,13 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in if qs == nil { panic(fmt.Errorf(" table name: `%s` not exists", name)) } - return + return qs +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.") + return o.QueryTable(ptrStructOrTableName) } // return a raw query seter for raw sql string. @@ -585,6 +593,10 @@ func (t *txOrm) Rollback() error { return t.db.(txEnder).Rollback() } +func (t *txOrm) RollbackUnlessCommit() error { + return t.db.(txEnder).RollbackUnlessCommit() +} + // NewOrm create new orm func NewOrm() Ormer { BootStrap() // execute only once @@ -595,9 +607,8 @@ func NewOrm() Ormer { func NewOrmUsingDB(aliasName string) Ormer { if al, ok := dataBaseCache.get(aliasName); ok { return newDBWithAlias(al) - } else { - panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } + panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } // NewOrmWithDB create a new ormer object with specify *sql.DB for query diff --git a/client/orm/orm_conds.go b/client/orm/orm_conds.go index 5409406e..eeb5538a 100644 --- a/client/orm/orm_conds.go +++ b/client/orm/orm_conds.go @@ -17,11 +17,13 @@ package orm import ( "fmt" "strings" + + "github.com/beego/beego/v2/client/orm/clauses" ) // ExprSep define the expression separation const ( - ExprSep = "__" + ExprSep = clauses.ExprSep ) type condValue struct { diff --git a/client/orm/orm_log.go b/client/orm/orm_log.go index 61addeb5..da3ef732 100644 --- a/client/orm/orm_log.go +++ b/client/orm/orm_log.go @@ -41,7 +41,7 @@ func NewLog(out io.Writer) *Log { func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { var logMap = make(map[string]interface{}) - sub := time.Now().Sub(t) / 1e5 + sub := time.Since(t) / 1e5 elsp := float64(int(sub)) / 10.0 logMap["cost_time"] = elsp flag := " OK" @@ -85,20 +85,31 @@ func (d *stmtQueryLog) Close() error { } func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { + return d.ExecContext(context.Background(), args...) +} + +func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { a := time.Now() - res, err := d.stmt.Exec(args...) + res, err := d.stmt.ExecContext(ctx, args...) debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) return res, err } - func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { + return d.QueryContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { a := time.Now() - res, err := d.stmt.Query(args...) + res, err := d.stmt.QueryContext(ctx, args...) debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) return res, err } func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { + return d.QueryRowContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { a := time.Now() res := d.stmt.QueryRow(args...) debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) @@ -195,6 +206,13 @@ func (d *dbQueryLog) Rollback() error { return err } +func (d *dbQueryLog) RollbackUnlessCommit() error { + a := time.Now() + err := d.db.(txEnder).RollbackUnlessCommit() + debugLogQueies(d.alias, "tx.RollbackUnlessCommit", "ROLLBACK UNLESS COMMIT", a, err) + return err +} + func (d *dbQueryLog) SetDB(db dbQuerier) { d.db = db } diff --git a/client/orm/orm_object.go b/client/orm/orm_object.go index 6f9798d3..50c1ca41 100644 --- a/client/orm/orm_object.go +++ b/client/orm/orm_object.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "reflect" ) @@ -31,6 +32,10 @@ var _ Inserter = new(insertSet) // insert model ignore it's registered or not. func (o *insertSet) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} + +func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { if o.closed { return 0, ErrStmtClosed } @@ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { if name != o.mi.fullName { panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) } - id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) + id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) if err != nil { return id, err } @@ -70,11 +75,11 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { +func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm bi.mi = mi - st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi) if err != nil { return nil, err } diff --git a/client/orm/orm_querym2m.go b/client/orm/orm_querym2m.go index 17e1b5d1..9da49bba 100644 --- a/client/orm/orm_querym2m.go +++ b/client/orm/orm_querym2m.go @@ -14,7 +14,10 @@ package orm -import "reflect" +import ( + "context" + "reflect" +) // model to model struct type queryM2M struct { @@ -33,6 +36,10 @@ type queryM2M struct { // // make sure the relation is defined in post model struct tag. func (o *queryM2M) Add(mds ...interface{}) (int64, error) { + return o.AddWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi mi := fi.relThroughModelInfo mfi := fi.reverseFieldInfo @@ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } names = append(names, otherNames...) values = append(values, otherValues...) - return dbase.InsertValue(orm.db, mi, true, names, values) + return dbase.InsertValue(ctx, orm.db, mi, true, names, values) } // remove models following the origin model relationship func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { + return o.RemoveWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) @@ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { // check model is existed in relationship of origin model func (o *queryM2M) Exist(md interface{}) bool { + return o.ExistWithCtx(context.Background(), md) +} + +func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md). - Filter(fi.reverseFieldInfoTwo.name, md).Exist() + Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx) } // clean all models in related of origin model func (o *queryM2M) Clear() (int64, error) { + return o.ClearWithCtx(context.Background()) +} + +func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx) } // count all related models of origin model func (o *queryM2M) Count() (int64, error) { + return o.CountWithCtx(context.Background()) +} + +func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx) } var _ QueryM2Mer = new(queryM2M) diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 177cfc3a..9f7b8441 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -18,6 +18,7 @@ import ( "context" "fmt" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/hints" ) @@ -64,21 +65,20 @@ func ColValue(opt operator, value interface{}) interface{} { // real query struct type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []string - distinct bool - forUpdate bool - useIndex int - indexes []string - orm *ormBase - ctx context.Context - forContext bool + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []*order_clause.Order + distinct bool + forUpdate bool + useIndex int + indexes []string + orm *ormBase + aggregate string } var _ QuerySeter = new(querySet) @@ -139,8 +139,20 @@ func (o querySet) GroupBy(exprs ...string) QuerySeter { // add ORDER expression. // "column" means ASC, "-column" means DESC. -func (o querySet) OrderBy(exprs ...string) QuerySeter { - o.orders = exprs +func (o querySet) OrderBy(expressions ...string) QuerySeter { + if len(expressions) <= 0 { + return &o + } + o.orders = order_clause.ParseOrder(expressions...) + return &o +} + +// add ORDER expression. +func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter { + if len(orders) <= 0 { + return &o + } + o.orders = orders return &o } @@ -210,23 +222,39 @@ func (o querySet) GetCond() *Condition { // return QuerySeter execution result number func (o *querySet) Count() (int64, error) { - return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.CountWithCtx(context.Background()) +} + +func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // check result empty or not after QuerySeter executed func (o *querySet) Exist() bool { - cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.ExistWithCtx(context.Background()) +} + +func (o *querySet) ExistWithCtx(ctx context.Context) bool { + cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return cnt > 0 } // execute update with parameters func (o *querySet) Update(values Params) (int64, error) { - return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) + return o.UpdateWithCtx(context.Background(), values) +} + +func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) { + return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) } // execute delete func (o *querySet) Delete() (int64, error) { - return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.DeleteWithCtx(context.Background()) +} + +func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // return a insert queryer. @@ -235,20 +263,32 @@ func (o *querySet) Delete() (int64, error) { // i,err := sq.PrepareInsert() // i.Add(&user1{},&user2{}) func (o *querySet) PrepareInsert() (Inserter, error) { - return newInsertSet(o.orm, o.mi) + return o.PrepareInsertWithCtx(context.Background()) +} + +func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) { + return newInsertSet(ctx, o.orm, o.mi) } // query all data and map to containers. // cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + return o.AllWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } // query one row data and map to containers. // cols means the columns when querying. func (o *querySet) One(container interface{}, cols ...string) error { + return o.OneWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { o.limit = 1 - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) if err != nil { return err } @@ -266,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error { // expres means condition expression. // it converts data to []map[column]value. func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + return o.ValuesWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to [][]interface // it converts data to [][column_index]value func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + return o.ValuesListWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to []interface. // it's designed for one row record set, auto change to []value, not [][column]value. func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) + return o.ValuesFlatWithCtx(context.Background(), result, expr) +} + +func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } // query all rows into map[string]interface with specify key and value column name. @@ -309,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) panic(ErrNotImplement) } -// set context to QuerySeter. -func (o querySet) WithContext(ctx context.Context) QuerySeter { - o.ctx = ctx - o.forContext = true - return &o -} - // create new QuerySeter. func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o := new(querySet) @@ -323,3 +368,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o.orm = orm return o } + +// aggregate func +func (o querySet) Aggregate(s string) QuerySeter { + o.aggregate = s + return &o +} diff --git a/client/orm/orm_raw.go b/client/orm/orm_raw.go index e11e97fa..af9c00cc 100644 --- a/client/orm/orm_raw.go +++ b/client/orm/orm_raw.go @@ -181,6 +181,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { if err == nil { ind.Set(reflect.ValueOf(t)) } + } else if len(str) >= 8 { + str = str[:8] + t, err := time.ParseInLocation(formatTime, str, DefaultTimeLoc) + if err == nil { + ind.Set(reflect.ValueOf(t)) + } } } case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index 08997177..3254a01b 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -31,6 +31,8 @@ import ( "testing" "time" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/hints" "github.com/stretchr/testify/assert" @@ -84,11 +86,7 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er } ok = is && ok || !is && !ok if !ok { - if is { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } else { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) } wrongArg: @@ -161,6 +159,7 @@ func throwFail(t *testing.T, err error, args ...interface{}) { } } +// deprecated using assert.XXX func throwFailNow(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) @@ -205,6 +204,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(Index)) RegisterModel(new(StrPk)) RegisterModel(new(TM)) + RegisterModel(new(DeptInfo)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -232,6 +232,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(Index)) RegisterModel(new(StrPk)) RegisterModel(new(TM)) + RegisterModel(new(DeptInfo)) BootStrap() @@ -333,6 +334,73 @@ func TestTM(t *testing.T) { throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC")) } +func TestUnregisterModel(t *testing.T) { + data := []*DeptInfo{ + { + DeptName: "A", + EmployeeName: "A1", + Salary: 1000, + }, + { + DeptName: "A", + EmployeeName: "A2", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B1", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B2", + Salary: 4000, + }, + { + DeptName: "B", + EmployeeName: "B3", + Salary: 3000, + }, + } + qs := dORM.QueryTable("dept_info") + i, _ := qs.PrepareInsert() + for _, d := range data { + _, err := i.Insert(d) + if err != nil { + throwFail(t, err) + } + } + + f := func() { + var res []UnregisterModel + n, err := dORM.QueryTable("dept_info").All(&res) + throwFail(t, err) + throwFail(t, AssertIs(n, 5)) + throwFail(t, AssertIs(res[0].EmployeeName, "A1")) + + type Sum struct { + DeptName string + Total int + } + var sun []Sum + qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun) + throwFail(t, AssertIs(sun[0].DeptName, "A")) + throwFail(t, AssertIs(sun[0].Total, 3000)) + + type Max struct { + DeptName string + Max float64 + } + var max []Max + qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max) + throwFail(t, AssertIs(max[1].DeptName, "B")) + throwFail(t, AssertIs(max[1].Max, 4000)) + } + for i := 0; i < 5; i++ { + f() + } +} + func TestNullDataTypes(t *testing.T) { d := DataNull{} @@ -1077,6 +1145,26 @@ func TestOrderBy(t *testing.T) { num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`profile__age`), + order_clause.SortDescending(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsMysql { + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`rand()`), + order_clause.Raw(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + } } func TestAll(t *testing.T) { @@ -1163,6 +1251,19 @@ func TestValues(t *testing.T) { throwFail(t, AssertIs(maps[2]["Profile"], nil)) } + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column("Id"), + order_clause.SortAscending(), + ), + ).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") throwFail(t, err) throwFail(t, AssertIs(num, 3)) @@ -1185,8 +1286,8 @@ func TestValuesList(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 3)) if num == 3 { - throwFail(t, AssertIs(list[0][1], "slene")) - throwFail(t, AssertIs(list[2][9], nil)) + throwFail(t, AssertIs(list[0][1], "slene")) //username + throwFail(t, AssertIs(list[2][10], nil)) //profile } num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") @@ -1746,6 +1847,10 @@ func TestRawQueryRow(t *testing.T) { throwFail(t, AssertIs(id, 1)) break case "time": + v = v.(time.Time).In(DefaultTimeLoc) + value := dataValues[col].(time.Time).In(DefaultTimeLoc) + assert.True(t, v.(time.Time).Sub(value) <= time.Second) + break case "date": case "datetime": v = v.(time.Time).In(DefaultTimeLoc) @@ -2144,27 +2249,64 @@ func TestTransaction(t *testing.T) { } err = to.Rollback() - throwFail(t, err) - + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name__in", names).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) + assert.Nil(t, err) + assert.Equal(t, int64(0), num) to, err = o.Begin() - throwFail(t, err) + assert.Nil(t, err) tag.Name = "commit" id, err = to.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) + assert.Nil(t, err) + assert.True(t, id > 0) - to.Commit() - throwFail(t, err) + err = to.Commit() + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name", "commit").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) + assert.Nil(t, err) + assert.Equal(t, int64(1), num) + +} + +func TestTxOrmRollbackUnlessCommit(t *testing.T) { + o := NewOrm() + var tag Tag + + // test not commited and call RollbackUnlessCommit + to, err := o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err := to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + num, err := o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(0), num) + + // test commit and call RollbackUnlessCommit + + to, err = o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err = to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + + err = to.Commit() + assert.Nil(t, err) + + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + + num, err = o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(1), num) } func TestTransactionIsolationLevel(t *testing.T) { @@ -2713,3 +2855,23 @@ func TestCondition(t *testing.T) { throwFail(t, AssertIs(!cycleFlag, true)) return } + +func TestContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + user := User{UserName: "slene"} + + err := dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, err) + + cancel() + err = dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, AssertIs(err, context.Canceled)) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + + qs := dORM.QueryTable(user) + _, err = qs.Filter("UserName", "slene").CountWithCtx(ctx) + throwFail(t, AssertIs(err, context.Canceled)) +} diff --git a/client/orm/types.go b/client/orm/types.go index cb735ac8..f9f74652 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -20,6 +20,7 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/core/utils" ) @@ -109,10 +110,35 @@ type TxBeginner interface { } type TxCommitter interface { + txEnder +} + +// transaction beginner +type txer interface { + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// transaction ending +type txEnder interface { Commit() error Rollback() error + + // RollbackUnlessCommit if the transaction has been committed, do nothing, or transaction will be rollback + // For example: + // ```go + // txOrm := orm.Begin() + // defer txOrm.RollbackUnlessCommit() + // err := txOrm.Insert() // do something + // if err != nil { + // return err + // } + // txOrm.Commit() + // ``` + RollbackUnlessCommit() error } + // Data Manipulation Language type DML interface { // insert model data to database @@ -196,12 +222,16 @@ type DQL interface { // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") QueryM2M(md interface{}, name string) QueryM2Mer + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), QueryTable(ptrStructOrTableName interface{}) QuerySeter + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter DBStats() *sql.DBStats @@ -211,25 +241,32 @@ type DriverGetter interface { Driver() Driver } + type ormer interface { DQL DML DriverGetter } -type Ormer interface { +//QueryExecutor wrapping for ormer +type QueryExecutor interface { ormer +} + +type Ormer interface { + QueryExecutor TxBeginner } type TxOrmer interface { - ormer + QueryExecutor TxCommitter } // Inserter insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) + InsertWithCtx(context.Context, interface{}) (int64, error) Close() error } @@ -289,6 +326,28 @@ type QuerySeter interface { // for example: // qs.OrderBy("-status") OrderBy(exprs ...string) QuerySeter + // add ORDER expression by order clauses + // for example: + // OrderClauses( + // order_clause.Clause( + // order.Column("Id"), + // order.SortAscending(), + // ), + // order_clause.Clause( + // order.Column("status"), + // order.SortDescending(), + // ), + // ) + // OrderClauses(order_clause.Clause( + // order_clause.Column(`user__status`), + // order_clause.SortDescending(),//default None + // )) + // OrderClauses(order_clause.Clause( + // order_clause.Column(`random()`), + // order_clause.SortNone(),//default None + // order_clause.Raw(),//default false.if true, do not check field is valid or not + // )) + OrderClauses(orders ...*order_clause.Order) QuerySeter // add FORCE INDEX expression. // for example: // qs.ForceIndex(`idx_name1`,`idx_name2`) @@ -327,9 +386,11 @@ type QuerySeter interface { // for example: // num, err = qs.Filter("profile__age__gt", 28).Count() Count() (int64, error) + CountWithCtx(context.Context) (int64, error) // check result empty or not after QuerySeter executed // the same as QuerySeter.Count > 0 Exist() bool + ExistWithCtx(context.Context) bool // execute update with parameters // for example: // num, err = qs.Filter("user_name", "slene").Update(Params{ @@ -339,11 +400,13 @@ type QuerySeter interface { // "user_name": "slene2" // }) // user slene's name will change to slene2 Update(values Params) (int64, error) + UpdateWithCtx(ctx context.Context, values Params) (int64, error) // delete from table // for example: // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // //delete two user who's name is testing1 or testing2 Delete() (int64, error) + DeleteWithCtx(context.Context) (int64, error) // return a insert queryer. // it can be used in times. // example: @@ -352,18 +415,21 @@ type QuerySeter interface { // num, err = i.Insert(&user2) // user table will add one record user2 at once // err = i.Close() //don't forget call Close PrepareInsert() (Inserter, error) + PrepareInsertWithCtx(context.Context) (Inserter, error) // query all data and map to containers. // cols means the columns when querying. // for example: // var users []*User // qs.All(&users) // users[0],users[1],users[2] ... All(container interface{}, cols ...string) (int64, error) + AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) // query one row data and map to containers. // cols means the columns when querying. // for example: // var user User // qs.One(&user) //user.UserName == "slene" One(container interface{}, cols ...string) error + OneWithCtx(ctx context.Context, container interface{}, cols ...string) error // query all data and map to []map[string]interface. // expres means condition expression. // it converts data to []map[column]value. @@ -371,18 +437,21 @@ type QuerySeter interface { // var maps []Params // qs.Values(&maps) //maps[0]["UserName"]=="slene" Values(results *[]Params, exprs ...string) (int64, error) + ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) // query all data and map to [][]interface // it converts data to [][column_index]value // for example: // var list []ParamsList // qs.ValuesList(&list) // list[0][1] == "slene" ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) // query all data and map to []interface. // it's designed for one column record set, auto change to []value, not [][column]value. // for example: // var list ParamsList // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" ValuesFlat(result *ParamsList, expr string) (int64, error) + ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) // query all rows into map[string]interface with specify key and value column name. // keyCol = "name", valueCol = "value" // table data @@ -405,6 +474,15 @@ type QuerySeter interface { // Found int // } RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + // aggregate func. + // for example: + // type result struct { + // DeptName string + // Total int + // } + // var res []result + // o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res) + Aggregate(s string) QuerySeter } // QueryM2Mer model to model query struct @@ -422,18 +500,23 @@ type QueryM2Mer interface { // insert one or more rows to m2m table // make sure the relation is defined in post model struct tag. Add(...interface{}) (int64, error) + AddWithCtx(context.Context, ...interface{}) (int64, error) // remove models following the origin model relationship // only delete rows from m2m table // for example: // tag3 := &Tag{Id:5,Name: "TestTag3"} // num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) + RemoveWithCtx(context.Context, ...interface{}) (int64, error) // check model is existed in relationship of origin model Exist(interface{}) bool + ExistWithCtx(context.Context, interface{}) bool // clean all models in related of origin model Clear() (int64, error) + ClearWithCtx(context.Context) (int64, error) // count all related models of origin model Count() (int64, error) + CountWithCtx(context.Context) (int64, error) } // RawPreparer raw query statement @@ -507,11 +590,11 @@ type RawSeter interface { type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) - // ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) - // QueryContext(args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) QueryRow(args ...interface{}) *sql.Row - // QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row } // db querier @@ -534,42 +617,30 @@ type dbQuerier interface { // QueryRow(query string, args ...interface{}) *sql.Row // } -// transaction beginner -type txer interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -// transaction ending -type txEnder interface { - Commit() error - Rollback() error -} - // base database struct type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) - InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) - InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) + InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) SupportUpdateJoin() bool OperatorSQL(string) string GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorLeftCol(*fieldInfo, string, *string) - PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) + PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) @@ -578,12 +649,12 @@ type dbBaser interface { TimeToDB(*time.Time, *time.Location) DbTypes() map[string]string GetTables(dbQuerier) (map[string]bool, error) - GetColumns(dbQuerier, string) (map[string][3]string, error) + GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error) ShowTablesQuery() string ShowColumnsQuery(string) string - IndexExists(dbQuerier, string, string) bool + IndexExists(context.Context, dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) - setval(dbQuerier, *modelInfo, []string) error + setval(context.Context, dbQuerier, *modelInfo, []string) error GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string } diff --git a/client/orm/utils.go b/client/orm/utils.go index d6c0a8e8..8d05c080 100644 --- a/client/orm/utils.go +++ b/client/orm/utils.go @@ -228,7 +228,7 @@ func snakeStringWithAcronym(s string) string { } data = append(data, d) } - return strings.ToLower(string(data[:])) + return strings.ToLower(string(data)) } // snake string, XxYy to xx_yy , XxYY to xx_y_y @@ -246,7 +246,7 @@ func snakeString(s string) string { } data = append(data, d) } - return strings.ToLower(string(data[:])) + return strings.ToLower(string(data)) } // SetNameStrategy set different name strategy @@ -274,7 +274,7 @@ func camelString(s string) string { } data = append(data, d) } - return string(data[:]) + return string(data) } type argString []string diff --git a/core/admin/profile.go b/core/admin/profile.go index 5b3fdb21..6162a2d4 100644 --- a/core/admin/profile.go +++ b/core/admin/profile.go @@ -108,7 +108,7 @@ func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { if gcstats.NumGC > 0 { lastPause := gcstats.Pause[0] - elapsed := time.Now().Sub(startTime) + elapsed := time.Since(startTime) overhead := float64(gcstats.PauseTotal) / float64(elapsed) * 100 allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() @@ -125,7 +125,7 @@ func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { utils.ToShortTimeFormat(gcstats.PauseQuantiles[99])) } else { // while GC has disabled - elapsed := time.Now().Sub(startTime) + elapsed := time.Since(startTime) allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() fmt.Fprintf(w, "Alloc:%s Sys:%s Alloc(Rate):%s/s\n", diff --git a/core/bean/tag_auto_wire_bean_factory_test.go b/core/bean/tag_auto_wire_bean_factory_test.go index bcdada67..b5744af7 100644 --- a/core/bean/tag_auto_wire_bean_factory_test.go +++ b/core/bean/tag_auto_wire_bean_factory_test.go @@ -51,24 +51,25 @@ func TestTagAutoWireBeanFactory_AutoWire(t *testing.T) { } type ComplicateStruct struct { - IntValue int `default:"12"` - StrValue string `default:"hello, strValue"` - Int8Value int8 `default:"8"` - Int16Value int16 `default:"16"` - Int32Value int32 `default:"32"` - Int64Value int64 `default:"64"` + BoolValue bool `default:"true"` + Int8Value int8 `default:"8"` + Uint8Value uint8 `default:"88"` - UintValue uint `default:"13"` - Uint8Value uint8 `default:"88"` + Int16Value int16 `default:"16"` Uint16Value uint16 `default:"1616"` + Int32Value int32 `default:"32"` Uint32Value uint32 `default:"3232"` + + IntValue int `default:"12"` + UintValue uint `default:"13"` + Int64Value int64 `default:"64"` Uint64Value uint64 `default:"6464"` + StrValue string `default:"hello, strValue"` + Float32Value float32 `default:"32.32"` Float64Value float64 `default:"64.64"` - BoolValue bool `default:"true"` - ignoreInt int `default:"11"` TimeValue time.Time `default:"2018-02-03 12:13:14.000"` diff --git a/core/berror/codes.go b/core/berror/codes.go new file mode 100644 index 00000000..b6712a84 --- /dev/null +++ b/core/berror/codes.go @@ -0,0 +1,86 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" + "sync" +) + +// A Code is an unsigned 32-bit error code as defined in the beego spec. +type Code interface { + Code() uint32 + Module() string + Desc() string + Name() string +} + +var defaultCodeRegistry = &codeRegistry{ + codes: make(map[uint32]*codeDefinition, 127), +} + +// DefineCode defining a new Code +// Before defining a new code, please read Beego specification. +// desc could be markdown doc +func DefineCode(code uint32, module string, name string, desc string) Code { + res := &codeDefinition{ + code: code, + module: module, + desc: desc, + } + defaultCodeRegistry.lock.Lock() + defer defaultCodeRegistry.lock.Unlock() + + if _, ok := defaultCodeRegistry.codes[code]; ok { + panic(fmt.Sprintf("duplicate code, code %d has been registered", code)) + } + defaultCodeRegistry.codes[code] = res + return res +} + +type codeRegistry struct { + lock sync.RWMutex + codes map[uint32]*codeDefinition +} + +func (cr *codeRegistry) Get(code uint32) (Code, bool) { + cr.lock.RLock() + defer cr.lock.RUnlock() + c, ok := cr.codes[code] + return c, ok +} + +type codeDefinition struct { + code uint32 + module string + desc string + name string +} + +func (c *codeDefinition) Name() string { + return c.name +} + +func (c *codeDefinition) Code() uint32 { + return c.code +} + +func (c *codeDefinition) Module() string { + return c.module +} + +func (c *codeDefinition) Desc() string { + return c.desc +} diff --git a/core/berror/error.go b/core/berror/error.go new file mode 100644 index 00000000..ca09798a --- /dev/null +++ b/core/berror/error.go @@ -0,0 +1,69 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +// code, msg +const errFmt = "ERROR-%d, %s" + +// Err returns an error representing c and msg. If c is OK, returns nil. +func Error(c Code, msg string) error { + return fmt.Errorf(errFmt, c.Code(), msg) +} + +// Errorf returns error +func Errorf(c Code, format string, a ...interface{}) error { + return Error(c, fmt.Sprintf(format, a...)) +} + +func Wrap(err error, c Code, msg string) error { + if err == nil { + return nil + } + return errors.Wrap(err, fmt.Sprintf(errFmt, c.Code(), msg)) +} + +func Wrapf(err error, c Code, format string, a ...interface{}) error { + return Wrap(err, c, fmt.Sprintf(format, a...)) +} + +// FromError is very simple. It just parse error msg and check whether code has been register +// if code not being register, return unknown +// if err.Error() is not valid beego error code, return unknown +func FromError(err error) (Code, bool) { + msg := err.Error() + codeSeg := strings.SplitN(msg, ",", 2) + if strings.HasPrefix(codeSeg[0], "ERROR-") { + codeStr := strings.SplitN(codeSeg[0], "-", 2) + if len(codeStr) < 2 { + return Unknown, false + } + codeInt, e := strconv.ParseUint(codeStr[1], 10, 32) + if e != nil { + return Unknown, false + } + if code, ok := defaultCodeRegistry.Get(uint32(codeInt)); ok { + return code, true + } + } + return Unknown, false +} diff --git a/core/berror/error_test.go b/core/berror/error_test.go new file mode 100644 index 00000000..7a14d933 --- /dev/null +++ b/core/berror/error_test.go @@ -0,0 +1,77 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +var testCode1 = DefineCode(1, "unit_test", "TestError", "Hello, test code1") + +var testErr = errors.New("hello, this is error") + +func TestErrorf(t *testing.T) { + msg := Errorf(testCode1, "errorf %s", "aaaa") + assert.NotNil(t, msg) + assert.Equal(t, "ERROR-1, errorf aaaa", msg.Error()) +} + +func TestWrapf(t *testing.T) { + err := Wrapf(testErr, testCode1, "Wrapf %s", "aaaa") + assert.NotNil(t, err) + assert.True(t, errors.Is(err, testErr)) +} + +func TestFromError(t *testing.T) { + err := errors.New("ERROR-1, errorf aaaa") + code, ok := FromError(err) + assert.True(t, ok) + assert.Equal(t, testCode1, code) + assert.Equal(t, "unit_test", code.Module()) + assert.Equal(t, "Hello, test code1", code.Desc()) + + err = errors.New("not beego error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-2, not register") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-aaa, invalid code") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("aaaaaaaaaaaaaa") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-2-3, invalid error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR, invalid error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) +} diff --git a/core/berror/pre_define_code.go b/core/berror/pre_define_code.go new file mode 100644 index 00000000..275f86c1 --- /dev/null +++ b/core/berror/pre_define_code.go @@ -0,0 +1,51 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" +) + +// pre define code + +// Unknown indicates got some error which is not defined +var Unknown = DefineCode(5000001, "error", "Unknown", fmt.Sprintf(` +Unknown error code. Usually you will see this code in three cases: +1. You forget to define Code or function DefineCode not being executed; +2. This is not Beego's error but you call FromError(); +3. Beego got unexpected error and don't know how to handle it, and then return Unknown error + +A common practice to DefineCode looks like: +%s + +In this way, you may forget to import this package, and got Unknown error. + +Sometimes, you believe you got Beego error, but actually you don't, and then you call FromError(err) + +`, goCodeBlock(` +import your_package + +func init() { + DefineCode(5100100, "your_module", "detail") + // ... +} +`))) + +func goCodeBlock(code string) string { + return codeBlock("go", code) +} +func codeBlock(lan string, code string) string { + return fmt.Sprintf("```%s\n%s\n```", lan, code) +} diff --git a/core/config/config.go b/core/config/config.go index d0add317..98080fe3 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -14,7 +14,7 @@ // Package config is used to parse config. // Usage: -// import "github.com/beego/beego/v2/config" +// import "github.com/beego/beego/v2/core/config" // Examples. // // cnf, err := config.NewConfig("ini", "config.conf") diff --git a/core/config/xml/xml.go b/core/config/xml/xml.go index 059ada5c..56bbd428 100644 --- a/core/config/xml/xml.go +++ b/core/config/xml/xml.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/config/xml" -// "github.com/beego/beego/v2/config" +// _ "github.com/beego/beego/v2/core/config/xml" +// "github.com/beego/beego/v2/core/config" // ) // // cnf, err := config.NewConfig("xml", "config.xml") diff --git a/core/config/yaml/yaml.go b/core/config/yaml/yaml.go index 778a4eb1..10335123 100644 --- a/core/config/yaml/yaml.go +++ b/core/config/yaml/yaml.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/config/yaml" -// "github.com/beego/beego/v2/config" +// _ "github.com/beego/beego/v2/core/config/yaml" +// "github.com/beego/beego/v2/core/config" // ) // // cnf, err := config.NewConfig("yaml", "config.yaml") diff --git a/core/logs/README.md b/core/logs/README.md index c7c82110..b2c405ff 100644 --- a/core/logs/README.md +++ b/core/logs/README.md @@ -4,7 +4,7 @@ logs is a Go logs manager. It can use many logs adapters. The repo is inspired b ## How to install? - go get github.com/beego/beego/v2/logs + go get github.com/beego/beego/v2/core/logs ## What adapters are supported? @@ -16,7 +16,7 @@ First you must import it ```golang import ( - "github.com/beego/beego/v2/logs" + "github.com/beego/beego/v2/core/logs" ) ``` diff --git a/core/logs/es/es.go b/core/logs/es/es.go index 2e592ffd..1140da97 100644 --- a/core/logs/es/es.go +++ b/core/logs/es/es.go @@ -29,7 +29,7 @@ func NewES() logs.Logger { // please import this package // usually means that you can import this package in your main package // for example, anonymous: -// import _ "github.com/beego/beego/v2/logs/es" +// import _ "github.com/beego/beego/v2/core/logs/es" type esLogger struct { *elasticsearch.Client DSN string `json:"dsn"` diff --git a/core/logs/file.go b/core/logs/file.go index b01be357..97c4a72d 100644 --- a/core/logs/file.go +++ b/core/logs/file.go @@ -33,6 +33,11 @@ import ( // Writes messages by lines limit, file size limit, or time frequency. type fileLogWriter struct { sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize + + Rotate bool `json:"rotate"` + Daily bool `json:"daily"` + Hourly bool `json:"hourly"` + // The opened file Filename string `json:"filename"` fileWriter *os.File @@ -49,19 +54,15 @@ type fileLogWriter struct { maxSizeCurSize int // Rotate daily - Daily bool `json:"daily"` MaxDays int64 `json:"maxdays"` dailyOpenDate int dailyOpenTime time.Time // Rotate hourly - Hourly bool `json:"hourly"` MaxHours int64 `json:"maxhours"` hourlyOpenDate int hourlyOpenTime time.Time - Rotate bool `json:"rotate"` - Level int `json:"level"` Perm string `json:"perm"` diff --git a/core/logs/log.go b/core/logs/log.go index 52d3007f..043b0f61 100644 --- a/core/logs/log.go +++ b/core/logs/log.go @@ -15,7 +15,7 @@ // Package logs provide a general log interface // Usage: // -// import "github.com/beego/beego/v2/logs" +// import "github.com/beego/beego/v2/core/logs" // // log := NewLogger(10000) // log.SetLogger("console", "") @@ -112,17 +112,17 @@ func Register(name string, log newLoggerFunc) { // Can contain several providers and log message into all providers. type BeeLogger struct { lock sync.Mutex - level int init bool enableFuncCallDepth bool - loggerFuncCallDepth int enableFullFilePath bool asynchronous bool + wg sync.WaitGroup + level int + loggerFuncCallDepth int prefix string msgChanLen int64 msgChan chan *LogMsg signalChan chan string - wg sync.WaitGroup outputs []*nameLogger globalFormatter string } diff --git a/core/validation/README.md b/core/validation/README.md index 46d7c935..dee5a7b1 100644 --- a/core/validation/README.md +++ b/core/validation/README.md @@ -7,18 +7,18 @@ validation is a form validation for a data validation and error collecting using Install: - go get github.com/beego/beego/v2/validation + go get github.com/beego/beego/v2/core/validation Test: - go test github.com/beego/beego/v2/validation + go test github.com/beego/beego/v2/core/validation ## Example Direct Use: import ( - "github.com/beego/beego/v2/validation" + "github.com/beego/beego/v2/core/validation" "log" ) @@ -49,7 +49,7 @@ Direct Use: Struct Tag Use: import ( - "github.com/beego/beego/v2/validation" + "github.com/beego/beego/v2/core/validation" ) // validation function follow with "valid" tag @@ -81,7 +81,7 @@ Struct Tag Use: Use custom function: import ( - "github.com/beego/beego/v2/validation" + "github.com/beego/beego/v2/core/validation" ) type user struct { diff --git a/core/validation/validation.go b/core/validation/validation.go index eb3a1042..6be10ef3 100644 --- a/core/validation/validation.go +++ b/core/validation/validation.go @@ -15,7 +15,7 @@ // Package validation for validations // // import ( -// "github.com/beego/beego/v2/validation" +// "github.com/beego/beego/v2/core/validation" // "log" // ) // @@ -121,7 +121,7 @@ func (v *Validation) Clear() { v.ErrorsMap = nil } -// HasErrors Has ValidationError nor not. +// HasErrors Has ValidationError or not. func (v *Validation) HasErrors() bool { return len(v.Errors) > 0 } @@ -158,7 +158,7 @@ func (v *Validation) Max(obj interface{}, max int, key string) *Result { return v.apply(Max{max, key}, obj) } -// Range Test that the obj is between mni and max if obj's type is int +// Range Test that the obj is between min and max if obj's type is int func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj) } diff --git a/go.mod b/go.mod index a6d30e37..7a933f8d 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-cmp v0.5.0 // indirect - github.com/google/uuid v1.1.1 // indirect + github.com/google/uuid v1.1.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/hashicorp/golang-lru v0.5.4 github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 diff --git a/go.sum b/go.sum index 888160b2..06ea96c6 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,10 @@ github.com/couchbase/go-couchbase v0.0.0-20210126152612-8e416c37c8ef h1:pXh08kdO github.com/couchbase/go-couchbase v0.0.0-20210126152612-8e416c37c8ef/go.mod h1:+/bddYDxXsf9qt0xpDUtRR47A2GjaXmGGAqQ/k3GJ8A= github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808 h1:8s2l8TVUwMXl6tZMe3+hPCRJ25nQXiA3d1x622JtOqc= github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808/go.mod h1:srVSlQLB8iXBVXHgnqemxUXqN6FCvClgCMPCsjBDR7c= +github.com/couchbase/gomemcached v0.1.0 h1:whUde87n8CScx8ckMp2En5liqAlcuG3aKy/BQeBPu84= +github.com/couchbase/gomemcached v0.1.0/go.mod h1:srVSlQLB8iXBVXHgnqemxUXqN6FCvClgCMPCsjBDR7c= +github.com/couchbase/gomemcached v0.1.1 h1:xCS8ZglJDhrlQg3jmK7Rn1V8f7bPjXABLC05CgLQauc= +github.com/couchbase/gomemcached v0.1.1/go.mod h1:mxliKQxOv84gQ0bJWbI+w9Wxdpt9HjDvgW9MjCym5Vo= github.com/couchbase/gomemcached v0.1.2-0.20201215185628-3bc3f73e68cb h1:ZCFku0K/3Xvl7rXkGGM+ioT76Rxko8V9wDEWa0GFp14= github.com/couchbase/gomemcached v0.1.2-0.20201215185628-3bc3f73e68cb/go.mod h1:mxliKQxOv84gQ0bJWbI+w9Wxdpt9HjDvgW9MjCym5Vo= github.com/couchbase/gomemcached v0.1.2-0.20210126151728-840240974836 h1:ZxgtUfduO/Fk2NY1e1YhlgN6tRl0TMdXK9ElddO7uZY= diff --git a/scripts/prepare_etcd.sh b/scripts/prepare_etcd.sh new file mode 100644 index 00000000..d34c05a3 --- /dev/null +++ b/scripts/prepare_etcd.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +etcdctl put current.float 1.23 +etcdctl put current.bool true +etcdctl put current.int 11 +etcdctl put current.string hello +etcdctl put current.serialize.name test +etcdctl put sub.sub.key1 sub.sub.key \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 00000000..977055a3 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" up -d + +export ORM_DRIVER=mysql +export TZ=UTC +export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" + +# wait for services in images ready +sleep 5 + +go test "$(pwd)/..." + +# clear all container +docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" down + + diff --git a/scripts/test_docker_compose.yaml b/scripts/test_docker_compose.yaml new file mode 100644 index 00000000..f22b6deb --- /dev/null +++ b/scripts/test_docker_compose.yaml @@ -0,0 +1,55 @@ +version: "3.8" +services: + redis: + container_name: "beego-redis" + image: redis + environment: + - ALLOW_EMPTY_PASSWORD=yes + ports: + - "6379:6379" + + mysql: + container_name: "beego-mysql" + image: mysql:5.7.30 + ports: + - "13306:3306" + environment: + - MYSQL_ROOT_PASSWORD=1q2w3e + - MYSQL_DATABASE=orm_test + - MYSQL_USER=beego + - MYSQL_PASSWORD=test + + postgresql: + container_name: "beego-postgresql" + image: bitnami/postgresql:latest + ports: + - "5432:5432" + environment: + - ALLOW_EMPTY_PASSWORD=yes + ssdb: + container_name: "beego-ssdb" + image: wendal/ssdb + ports: + - "8888:8888" + memcache: + container_name: "beego-memcache" + image: memcached + ports: + - "11211:11211" + etcd: + command: > + sh -c " + etcdctl put current.float 1.23 + && etcdctl put current.bool true + && etcdctl put current.int 11 + && etcdctl put current.string hello + && etcdctl put current.serialize.name test + " + container_name: "beego-etcd" + environment: + - ALLOW_NONE_AUTHENTICATION=yes +# - ETCD_ADVERTISE_CLIENT_URLS=http://etcd:2379 + image: bitnami/etcd + ports: + - "2379:2379" + - "2380:2380" \ No newline at end of file diff --git a/server/web/beego.go b/server/web/beego.go index 14e51a94..17a7ea7b 100644 --- a/server/web/beego.go +++ b/server/web/beego.go @@ -75,7 +75,7 @@ func initBeforeHTTPRun() { registerTemplate, registerAdmin, registerGzip, - registerCommentRouter, + // registerCommentRouter, ) for _, hk := range hooks { diff --git a/server/web/captcha/README.md b/server/web/captcha/README.md index 74e1cf82..07a4dc4d 100644 --- a/server/web/captcha/README.md +++ b/server/web/captcha/README.md @@ -7,8 +7,8 @@ package controllers import ( "github.com/beego/beego/v2" - "github.com/beego/beego/v2/cache" - "github.com/beego/beego/v2/utils/captcha" + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/server/web/captcha" ) var cpt *captcha.Captcha diff --git a/server/web/captcha/captcha.go b/server/web/captcha/captcha.go index d052af13..e0a9a6ed 100644 --- a/server/web/captcha/captcha.go +++ b/server/web/captcha/captcha.go @@ -20,8 +20,8 @@ // // import ( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/cache" -// "github.com/beego/beego/v2/utils/captcha" +// "github.com/beego/beego/v2/client/cache" +// "github.com/beego/beego/v2/server/web/captcha" // ) // // var cpt *captcha.Captcha diff --git a/server/web/config.go b/server/web/config.go index d89c59cb..bef92cfa 100644 --- a/server/web/config.go +++ b/server/web/config.go @@ -17,6 +17,7 @@ package web import ( "crypto/tls" "fmt" + "net/http" "os" "path/filepath" "reflect" @@ -38,46 +39,46 @@ type Config struct { AppName string // Application name RunMode string // Running Mode: dev | prod RouterCaseSensitive bool - ServerName string RecoverPanic bool - RecoverFunc func(*context.Context, *Config) CopyRequestBody bool EnableGzip bool + EnableErrorsShow bool + EnableErrorsRender bool + ServerName string + RecoverFunc func(*context.Context, *Config) // MaxMemory and MaxUploadSize are used to limit the request body // if the request is not uploading file, MaxMemory is the max size of request body // if the request is uploading file, MaxUploadSize is the max size of request body - MaxMemory int64 - MaxUploadSize int64 - EnableErrorsShow bool - EnableErrorsRender bool - Listen Listen - WebConfig WebConfig - Log LogConfig + MaxMemory int64 + MaxUploadSize int64 + Listen Listen + WebConfig WebConfig + Log LogConfig } // Listen holds for http and https related config type Listen struct { Graceful bool // Graceful means use graceful module to start the server - ServerTimeOut int64 ListenTCP4 bool EnableHTTP bool - HTTPAddr string - HTTPPort int AutoTLS bool - Domains []string - TLSCacheDir string EnableHTTPS bool EnableMutualHTTPS bool + EnableAdmin bool + EnableFcgi bool + EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O + ServerTimeOut int64 + HTTPAddr string + HTTPPort int + Domains []string + TLSCacheDir string HTTPSAddr string HTTPSPort int HTTPSCertFile string HTTPSKeyFile string TrustCaFile string - EnableAdmin bool AdminAddr string AdminPort int - EnableFcgi bool - EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O ClientAuth int } @@ -85,9 +86,10 @@ type Listen struct { type WebConfig struct { AutoRender bool EnableDocs bool + EnableXSRF bool + DirectoryIndex bool FlashName string FlashSeparator string - DirectoryIndex bool StaticDir map[string]string StaticExtensionsToGzip []string StaticCacheFileSize int @@ -96,7 +98,6 @@ type WebConfig struct { TemplateRight string ViewsPath string CommentRouterPath string - EnableXSRF bool XSRFKey string XSRFExpire int Session SessionConfig @@ -105,25 +106,26 @@ type WebConfig struct { // SessionConfig holds session related config type SessionConfig struct { SessionOn bool + SessionAutoSetCookie bool + SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. + SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers + SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params SessionProvider string SessionName string SessionGCMaxLifetime int64 SessionProviderConfig string SessionCookieLifeTime int - SessionAutoSetCookie bool SessionDomain string - SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. - SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers SessionNameInHTTPHeader string - SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params + SessionCookieSameSite http.SameSite } // LogConfig holds Log related config type LogConfig struct { AccessLogs bool - EnableStaticLogs bool // log static files requests default: false - AccessLogsFormat string // access log format: JSON_FORMAT, APACHE_FORMAT or empty string + EnableStaticLogs bool // log static files requests default: false FileLineNum bool + AccessLogsFormat string // access log format: JSON_FORMAT, APACHE_FORMAT or empty string Outputs map[string]string // Store Adaptor : config } @@ -274,6 +276,7 @@ func newBConfig() *Config { SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers SessionNameInHTTPHeader: "Beegosessionid", SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params + SessionCookieSameSite: http.SameSiteDefaultMode, }, }, Log: LogConfig{ diff --git a/server/web/context/context.go b/server/web/context/context.go index 6070c996..bf6da3d2 100644 --- a/server/web/context/context.go +++ b/server/web/context/context.go @@ -15,7 +15,7 @@ // Package context provide the context utils // Usage: // -// import "github.com/beego/beego/v2/context" +// import "github.com/beego/beego/v2/server/web/context" // // ctx := context.Context{Request:req,ResponseWriter:rw} // @@ -35,6 +35,8 @@ import ( "strings" "time" + "github.com/beego/beego/v2/server/web/session" + "github.com/beego/beego/v2/core/utils" ) @@ -195,6 +197,22 @@ func (ctx *Context) RenderMethodResult(result interface{}) { } } +// Session return session store of this context of request +func (ctx *Context) Session() (store session.Store, err error) { + if ctx.Input != nil { + if ctx.Input.CruSession != nil { + store = ctx.Input.CruSession + return + } else { + err = errors.New(`no valid session store(please initialize session)`) + return + } + } else { + err = errors.New(`no valid input`) + return + } +} + // Response is a wrapper for the http.ResponseWriter // Started: if true, response was already written to so the other handler will not be executed type Response struct { diff --git a/server/web/context/context_test.go b/server/web/context/context_test.go index 7c0535e0..3915a853 100644 --- a/server/web/context/context_test.go +++ b/server/web/context/context_test.go @@ -18,6 +18,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/beego/beego/v2/server/web/session" ) func TestXsrfReset_01(t *testing.T) { @@ -45,3 +47,26 @@ func TestXsrfReset_01(t *testing.T) { t.FailNow() } } + +func TestContext_Session(t *testing.T) { + c := NewContext() + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() + } +} + +func TestContext_Session1(t *testing.T) { + c := Context{} + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() + } +} + +func TestContext_Session2(t *testing.T) { + c := NewContext() + c.Input.CruSession = &session.MemSessionStore{} + + if store, err := c.Session(); store == nil || err != nil { + t.FailNow() + } +} diff --git a/server/web/controller.go b/server/web/controller.go index 5983cfbd..32378829 100644 --- a/server/web/controller.go +++ b/server/web/controller.go @@ -28,11 +28,11 @@ import ( "reflect" "strconv" "strings" - - "github.com/beego/beego/v2/server/web/session" + "sync" "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/context/param" + "github.com/beego/beego/v2/server/web/session" ) var ( @@ -40,8 +40,21 @@ var ( ErrAbort = errors.New("user stop run") // GlobalControllerRouter store comments with controller. pkgpath+controller:comments GlobalControllerRouter = make(map[string][]ControllerComments) + copyBufferPool sync.Pool ) +const ( + bytePerKb = 1024 + copyBufferKb = 32 + filePerm = 0666 +) + +func init() { + copyBufferPool.New = func() interface{} { + return make([]byte, bytePerKb*copyBufferKb) + } +} + // ControllerFilter store the filter for controller type ControllerFilter struct { Pattern string @@ -108,9 +121,9 @@ type Controller struct { EnableRender bool // xsrf data + EnableXSRF bool _xsrfToken string XSRFExpire int - EnableXSRF bool // session CruSession session.Store @@ -605,19 +618,31 @@ func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { // SaveToFile saves uploaded file to new path. // it only operates the first one of mutil-upload form file field. -func (c *Controller) SaveToFile(fromfile, tofile string) error { - file, _, err := c.Ctx.Request.FormFile(fromfile) +func (c *Controller) SaveToFile(fromFile, toFile string) error { + buf := copyBufferPool.Get().([]byte) + defer copyBufferPool.Put(buf) + return c.SaveToFileWithBuffer(fromFile, toFile, buf) +} + +type onlyWriter struct { + io.Writer +} + +func (c *Controller) SaveToFileWithBuffer(fromFile string, toFile string, buf []byte) error { + src, _, err := c.Ctx.Request.FormFile(fromFile) if err != nil { return err } - defer file.Close() - f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + defer src.Close() + + dst, err := os.OpenFile(toFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, filePerm) if err != nil { return err } - defer f.Close() - io.Copy(f, file) - return nil + defer dst.Close() + + _, err = io.CopyBuffer(onlyWriter{dst}, src, buf) + return err } // StartSession starts session and load old session data info this controller. diff --git a/server/web/controller_test.go b/server/web/controller_test.go index 4dd203f6..4f8b6d1c 100644 --- a/server/web/controller_test.go +++ b/server/web/controller_test.go @@ -21,8 +21,6 @@ import ( "strconv" "testing" - "github.com/stretchr/testify/assert" - "github.com/beego/beego/v2/server/web/context" ) @@ -127,10 +125,9 @@ func TestGetUint64(t *testing.T) { } func TestAdditionalViewPaths(t *testing.T) { - wkdir, err := os.Getwd() - assert.Nil(t, err) - dir1 := filepath.Join(wkdir, "_beeTmp", "TestAdditionalViewPaths") - dir2 := filepath.Join(wkdir, "_beeTmp2", "TestAdditionalViewPaths") + tmpDir := os.TempDir() + dir1 := filepath.Join(tmpDir, "_beeTmp", "TestAdditionalViewPaths") + dir2 := filepath.Join(tmpDir, "_beeTmp2", "TestAdditionalViewPaths") defer os.RemoveAll(dir1) defer os.RemoveAll(dir2) diff --git a/server/web/filter/apiauth/apiauth.go b/server/web/filter/apiauth/apiauth.go index 9e6c30dc..55a914a2 100644 --- a/server/web/filter/apiauth/apiauth.go +++ b/server/web/filter/apiauth/apiauth.go @@ -17,12 +17,12 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/apiauth" +// "github.com/beego/beego/v2/server/web/filter/apiauth" // ) // // func main(){ // // apiauth every request -// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBasicAuth("appid","appkey")) // beego.Run() // } // diff --git a/server/web/filter/auth/basic.go b/server/web/filter/auth/basic.go index 5a01f260..403d4b2c 100644 --- a/server/web/filter/auth/basic.go +++ b/server/web/filter/auth/basic.go @@ -16,7 +16,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/auth" +// "github.com/beego/beego/v2/server/web/filter/auth" // ) // // func main(){ diff --git a/server/web/filter/authz/authz.go b/server/web/filter/authz/authz.go index 8009c976..4ff2d6bf 100644 --- a/server/web/filter/authz/authz.go +++ b/server/web/filter/authz/authz.go @@ -16,7 +16,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/authz" +// "github.com/beego/beego/v2/server/web/filter/authz" // "github.com/casbin/casbin" // ) // diff --git a/server/web/filter/cors/cors.go b/server/web/filter/cors/cors.go index 0eb9aa30..f6c68ca0 100644 --- a/server/web/filter/cors/cors.go +++ b/server/web/filter/cors/cors.go @@ -16,7 +16,7 @@ // Usage // import ( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/cors" +// "github.com/beego/beego/v2/server/web/filter/cors" // ) // // func main() { @@ -69,10 +69,10 @@ var ( type Options struct { // If set, all origins are allowed. AllowAllOrigins bool - // A list of allowed origins. Wild cards and FQDNs are supported. - AllowOrigins []string // If set, allows to share auth credentials such as cookies. AllowCredentials bool + // A list of allowed origins. Wild cards and FQDNs are supported. + AllowOrigins []string // A list of allowed HTTP methods. AllowMethods []string // A list of allowed HTTP headers. @@ -143,7 +143,7 @@ func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map rHeader = strings.TrimSpace(rHeader) lookupLoop: for _, allowedHeader := range o.AllowHeaders { - if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { + if strings.EqualFold(rHeader, allowedHeader) { allowed = append(allowed, rHeader) break lookupLoop } diff --git a/server/web/filter/prometheus/filter_test.go b/server/web/filter/prometheus/filter_test.go index 8da63df9..618ce5af 100644 --- a/server/web/filter/prometheus/filter_test.go +++ b/server/web/filter/prometheus/filter_test.go @@ -38,6 +38,7 @@ func TestFilterChain(t *testing.T) { ctx.Input.SetData("RouterPattern", "my-route") filter(ctx) assert.True(t, ctx.Input.GetData("invocation").(bool)) + time.Sleep(1 * time.Second) } func TestFilterChainBuilder_report(t *testing.T) { @@ -52,4 +53,4 @@ func TestFilterChainBuilder_report(t *testing.T) { ctx.Input.SetData("RouterPattern", "my-route") report(time.Second, ctx, fb.buildVec()) -} \ No newline at end of file +} diff --git a/server/web/filter/ratelimit/bucket.go b/server/web/filter/ratelimit/bucket.go new file mode 100644 index 00000000..67a3907e --- /dev/null +++ b/server/web/filter/ratelimit/bucket.go @@ -0,0 +1,14 @@ +package ratelimit + +import "time" + +// bucket is an interface store ratelimit info +type bucket interface { + take(amount uint) bool + getCapacity() uint + getRemaining() uint + getRate() time.Duration +} + +// bucketOption is constructor option +type bucketOption func(bucket) diff --git a/server/web/filter/ratelimit/limiter.go b/server/web/filter/ratelimit/limiter.go new file mode 100644 index 00000000..c7f156bf --- /dev/null +++ b/server/web/filter/ratelimit/limiter.go @@ -0,0 +1,169 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "net/http" + "sync" + "time" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" +) + +// Limiter is an interface used to ratelimit +type Limiter interface { + take(amount uint, r *http.Request) bool +} + +// limiterOption is constructor option +type limiterOption func(l *limiter) + +type limiter struct { + sync.RWMutex + capacity uint + rate time.Duration + buckets map[string]bucket + bucketFactory func(opts ...bucketOption) bucket + sessionKey func(r *http.Request) string + resp RejectionResponse +} + +// RejectionResponse stores response information +// for the request rejected by limiter +type RejectionResponse struct { + code int + body string +} + +const perRequestConsumedAmount = 1 + +var defaultRejectionResponse = RejectionResponse{ + code: 429, + body: "too many requests", +} + +// NewLimiter return FilterFunc, the limiter enables rate limit +// according to the configuration. +func NewLimiter(opts ...limiterOption) web.FilterFunc { + l := &limiter{ + buckets: make(map[string]bucket), + sessionKey: func(r *http.Request) string { + return defaultSessionKey(r) + }, + bucketFactory: NewTokenBucket, + resp: defaultRejectionResponse, + } + for _, o := range opts { + o(l) + } + + return func(ctx *context.Context) { + if !l.take(perRequestConsumedAmount, ctx.Request) { + ctx.ResponseWriter.WriteHeader(l.resp.code) + ctx.WriteString(l.resp.body) + } + } +} + +// WithSessionKey return limiterOption. WithSessionKey config func +// which defines the request characteristic againstthe limit is applied +func WithSessionKey(f func(r *http.Request) string) limiterOption { + return func(l *limiter) { + l.sessionKey = f + } +} + +// WithRate return limiterOption. WithRate config how long it takes to +// generate a token. +func WithRate(r time.Duration) limiterOption { + return func(l *limiter) { + l.rate = r + } +} + +// WithCapacity return limiterOption. WithCapacity config the capacity size. +// The bucket with a capacity of n has n tokens after initialization. The capacity +// defines how many requests a client can make in excess of the rate. +func WithCapacity(c uint) limiterOption { + return func(l *limiter) { + l.capacity = c + } +} + +// WithBucketFactory return limiterOption. WithBucketFactory customize the +// implementation of Bucket. +func WithBucketFactory(f func(opts ...bucketOption) bucket) limiterOption { + return func(l *limiter) { + l.bucketFactory = f + } +} + +// WithRejectionResponse return limiterOption. WithRejectionResponse +// customize the response for the request rejected by the limiter. +func WithRejectionResponse(resp RejectionResponse) limiterOption { + return func(l *limiter) { + l.resp = resp + } +} + +func (l *limiter) take(amount uint, r *http.Request) bool { + bucket := l.getBucket(r) + if bucket == nil { + return true + } + return bucket.take(amount) +} + +func (l *limiter) getBucket(r *http.Request) bucket { + key := l.sessionKey(r) + l.RLock() + b, ok := l.buckets[key] + l.RUnlock() + if !ok { + b = l.createBucket(key) + } + + return b +} + +func (l *limiter) createBucket(key string) bucket { + l.Lock() + defer l.Unlock() + // double check avoid overwriting + b, ok := l.buckets[key] + if ok { + return b + } + b = l.bucketFactory(withCapacity(l.capacity), withRate(l.rate)) + l.buckets[key] = b + return b +} + + +func defaultSessionKey(r *http.Request) string { + return "" +} + +func RemoteIPSessionKey(r *http.Request) string { + IPAddress := r.Header.Get("X-Real-Ip") + if IPAddress == "" { + IPAddress = r.Header.Get("X-Forwarded-For") + } + if IPAddress == "" { + IPAddress = r.RemoteAddr + } + return IPAddress +} diff --git a/server/web/filter/ratelimit/limiter_test.go b/server/web/filter/ratelimit/limiter_test.go new file mode 100644 index 00000000..cafede5e --- /dev/null +++ b/server/web/filter/ratelimit/limiter_test.go @@ -0,0 +1,76 @@ +package ratelimit + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, requestIP, method, path string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.Header.Set("X-Real-Ip", requestIP) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", requestIP, method, path, w.Code, code) + } +} + +func TestLimiter(t *testing.T) { + handler := web.NewControllerRegister() + err := handler.InsertFilter("/foo/*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(1), WithSessionKey(RemoteIPSessionKey))) + if err != nil { + t.Error(err) + } + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + route := "/foo/1" + ip := "127.0.0.1" + testRequest(t, handler, ip, "GET", route, 200) + testRequest(t, handler, ip, "GET", route, 429) + testRequest(t, handler, "127.0.0.2", "GET", route, 200) + time.Sleep(1 * time.Millisecond) + testRequest(t, handler, ip, "GET", route, 200) +} + +func BenchmarkWithoutLimiter(b *testing.B) { + recorder := httptest.NewRecorder() + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + handler.ServeHTTP(recorder, r) + } + }) +} + +func BenchmarkWithLimiter(b *testing.B) { + recorder := httptest.NewRecorder() + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + err := handler.InsertFilter("*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(100))) + if err != nil { + b.Error(err) + } + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + handler.ServeHTTP(recorder, r) + } + }) +} diff --git a/server/web/filter/ratelimit/token_bucket.go b/server/web/filter/ratelimit/token_bucket.go new file mode 100644 index 00000000..5906ee9e --- /dev/null +++ b/server/web/filter/ratelimit/token_bucket.go @@ -0,0 +1,76 @@ +package ratelimit + +import ( + "sync" + "time" +) + +type tokenBucket struct { + sync.RWMutex + remaining uint + capacity uint + lastCheckAt time.Time + rate time.Duration +} + +// NewTokenBucket return an bucket that implements token bucket +func NewTokenBucket(opts ...bucketOption) bucket { + b := &tokenBucket{lastCheckAt: time.Now()} + for _, o := range opts { + o(b) + } + return b +} + +func withCapacity(capacity uint) bucketOption { + return func(b bucket) { + bucket := b.(*tokenBucket) + bucket.capacity = capacity + bucket.remaining = capacity + } +} + +func withRate(rate time.Duration) bucketOption { + return func(b bucket) { + bucket := b.(*tokenBucket) + bucket.rate = rate + } +} + +func (b *tokenBucket) getRemaining() uint { + b.RLock() + defer b.RUnlock() + return b.remaining +} + +func (b *tokenBucket) getRate() time.Duration { + b.RLock() + defer b.RUnlock() + return b.rate +} + +func (b *tokenBucket) getCapacity() uint { + b.RLock() + defer b.RUnlock() + return b.capacity +} + +func (b *tokenBucket) take(amount uint) bool { + if b.rate <= 0 { + return true + } + b.Lock() + defer b.Unlock() + now := time.Now() + times := uint(now.Sub(b.lastCheckAt) / b.rate) + b.lastCheckAt = b.lastCheckAt.Add(time.Duration(times) * b.rate) + b.remaining += times + if b.remaining < amount { + return false + } + b.remaining -= amount + if b.remaining > b.capacity { + b.remaining = b.capacity + } + return true +} diff --git a/server/web/filter/ratelimit/token_bucket_test.go b/server/web/filter/ratelimit/token_bucket_test.go new file mode 100644 index 00000000..93a1b3bd --- /dev/null +++ b/server/web/filter/ratelimit/token_bucket_test.go @@ -0,0 +1,32 @@ +package ratelimit + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetRate(t *testing.T) { + b := NewTokenBucket(withRate(1 * time.Second)).(*tokenBucket) + assert.Equal(t, b.getRate(), 1*time.Second) +} + +func TestGetRemainingAndCapacity(t *testing.T) { + b := NewTokenBucket(withCapacity(10)) + assert.Equal(t, b.getRemaining(), uint(10)) + assert.Equal(t, b.getCapacity(), uint(10)) +} + +func TestTake(t *testing.T) { + b := NewTokenBucket(withCapacity(10), withRate(10*time.Millisecond)).(*tokenBucket) + for i := 0; i < 10; i++ { + assert.True(t, b.take(1)) + } + assert.False(t, b.take(1)) + assert.Equal(t, b.getRemaining(), uint(0)) + b = NewTokenBucket(withCapacity(1), withRate(1*time.Millisecond)).(*tokenBucket) + assert.True(t, b.take(1)) + time.Sleep(2 * time.Millisecond) + assert.True(t, b.take(1)) +} diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go new file mode 100644 index 00000000..b26e4d53 --- /dev/null +++ b/server/web/filter/session/filter.go @@ -0,0 +1,36 @@ +package session + +import ( + "context" + + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" +) + +//Session maintain session for web service +//Session new a session storage and store it into webContext.Context +func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { + sessionConfig := session.NewManagerConfig(options...) + sessionManager, _ := session.NewManager(string(providerType), sessionConfig) + go sessionManager.GC() + + return func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if ctx.Input.CruSession != nil { + return + } + + if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil { + logs.Error(`init session error:%s`, err.Error()) + } else { + //release session at the end of request + defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) + ctx.Input.CruSession = sess + } + + next(ctx) + } + } +} diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go new file mode 100644 index 00000000..43046bf3 --- /dev/null +++ b/server/web/filter/session/filter_test.go @@ -0,0 +1,87 @@ +package session + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" + "github.com/google/uuid" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code) + } +} + +func TestSession(t *testing.T) { + storeKey := uuid.New().String() + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store := ctx.Input.GetData(storeKey); store == nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} + +func TestSession1(t *testing.T) { + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store, err := ctx.Session(); store == nil || err != nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} diff --git a/server/web/filter_chain_test.go b/server/web/filter_chain_test.go index 2a428b78..d76d8cbf 100644 --- a/server/web/filter_chain_test.go +++ b/server/web/filter_chain_test.go @@ -15,9 +15,12 @@ package web import ( + "fmt" "net/http" "net/http/httptest" + "strconv" "testing" + "time" "github.com/stretchr/testify/assert" @@ -36,13 +39,45 @@ func TestControllerRegister_InsertFilterChain(t *testing.T) { ns := NewNamespace("/chain") ns.Get("/*", func(ctx *context.Context) { - ctx.Output.Body([]byte("hello")) + _ = ctx.Output.Body([]byte("hello")) }) r, _ := http.NewRequest("GET", "/chain/user", nil) w := httptest.NewRecorder() + BeeApp.Handlers.Init() BeeApp.Handlers.ServeHTTP(w, r) assert.Equal(t, "filter-chain", w.Header().Get("filter")) } + +func TestControllerRegister_InsertFilterChain_Order(t *testing.T) { + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + r, _ := http.NewRequest("GET", "/abc", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.Init() + BeeApp.Handlers.ServeHTTP(w, r) + first := w.Header().Get("first") + second := w.Header().Get("second") + + ft, _ := strconv.ParseInt(first, 10, 64) + st, _ := strconv.ParseInt(second, 10, 64) + + assert.True(t, st > ft) +} diff --git a/server/web/flash_test.go b/server/web/flash_test.go index 2deef54e..3e20c8fb 100644 --- a/server/web/flash_test.go +++ b/server/web/flash_test.go @@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) { // setup the handler handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) handler.ServeHTTP(w, r) // get the Set-Cookie value diff --git a/server/web/grace/grace.go b/server/web/grace/grace.go index 0adc8654..96ae10ef 100644 --- a/server/web/grace/grace.go +++ b/server/web/grace/grace.go @@ -22,7 +22,7 @@ // "net/http" // "os" // -// "github.com/beego/beego/v2/grace" +// "github.com/beego/beego/v2/server/web/grace" // ) // // func handler(w http.ResponseWriter, r *http.Request) { diff --git a/server/web/hooks.go b/server/web/hooks.go index 438496a0..0f72e711 100644 --- a/server/web/hooks.go +++ b/server/web/hooks.go @@ -6,8 +6,6 @@ import ( "net/http" "path/filepath" - "github.com/coreos/etcd/pkg/fileutil" - "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/session" @@ -63,6 +61,7 @@ func registerSession() error { conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery + conf.CookieSameSite = BConfig.WebConfig.Session.SessionCookieSameSite } else { if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { return err @@ -97,18 +96,3 @@ func registerGzip() error { } return nil } - -func registerCommentRouter() error { - if BConfig.RunMode == DEV { - ctrlDir := filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath) - if !fileutil.Exist(ctrlDir) { - logs.Warn("controller package not found, won't generate router: ", ctrlDir) - return nil - } - if err := parserPkg(ctrlDir); err != nil { - return err - } - } - - return nil -} diff --git a/server/web/namespace.go b/server/web/namespace.go index 4e0c3b85..96037b4d 100644 --- a/server/web/namespace.go +++ b/server/web/namespace.go @@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { // Router same as beego.Rourer // refer: https://godoc.org/github.com/beego/beego/v2#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, mappingMethods...) + n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...)) return n } @@ -187,6 +187,54 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { return n } +// RouterGet same as beego.RouterGet +func (n *Namespace) RouterGet(rootpath string, f interface{}) *Namespace { + n.handlers.RouterGet(rootpath, f) + return n +} + +// RouterPost same as beego.RouterPost +func (n *Namespace) RouterPost(rootpath string, f interface{}) *Namespace { + n.handlers.RouterPost(rootpath, f) + return n +} + +// RouterDelete same as beego.RouterDelete +func (n *Namespace) RouterDelete(rootpath string, f interface{}) *Namespace { + n.handlers.RouterDelete(rootpath, f) + return n +} + +// RouterPut same as beego.RouterPut +func (n *Namespace) RouterPut(rootpath string, f interface{}) *Namespace { + n.handlers.RouterPut(rootpath, f) + return n +} + +// RouterHead same as beego.RouterHead +func (n *Namespace) RouterHead(rootpath string, f interface{}) *Namespace { + n.handlers.RouterHead(rootpath, f) + return n +} + +// RouterOptions same as beego.RouterOptions +func (n *Namespace) RouterOptions(rootpath string, f interface{}) *Namespace { + n.handlers.RouterOptions(rootpath, f) + return n +} + +// RouterPatch same as beego.RouterPatch +func (n *Namespace) RouterPatch(rootpath string, f interface{}) *Namespace { + n.handlers.RouterPatch(rootpath, f) + return n +} + +// Any same as beego.RouterAny +func (n *Namespace) RouterAny(rootpath string, f interface{}) *Namespace { + n.handlers.RouterAny(rootpath, f) + return n +} + // Namespace add nest Namespace // usage: // ns := beego.NewNamespace(“/v1”). @@ -366,6 +414,62 @@ func NSPatch(rootpath string, f FilterFunc) LinkNamespace { } } +// NSRouterGet call Namespace RouterGet +func NSRouterGet(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterGet(rootpath, f) + } +} + +// NSRouterPost call Namespace RouterPost +func NSRouterPost(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterPost(rootpath, f) + } +} + +// NSRouterHead call Namespace RouterHead +func NSRouterHead(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterHead(rootpath, f) + } +} + +// NSRouterPut call Namespace RouterPut +func NSRouterPut(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterPut(rootpath, f) + } +} + +// NSRouterDelete call Namespace RouterDelete +func NSRouterDelete(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterDelete(rootpath, f) + } +} + +// NSRouterAny call Namespace RouterAny +func NSRouterAny(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterAny(rootpath, f) + } +} + +// NSRouterOptions call Namespace RouterOptions +func NSRouterOptions(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterOptions(rootpath, f) + } +} + +// NSRouterPatch call Namespace RouterPatch +func NSRouterPatch(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.RouterPatch(rootpath, f) + } +} + // NSAutoRouter call Namespace AutoRouter func NSAutoRouter(c ControllerInterface) LinkNamespace { return func(ns *Namespace) { diff --git a/server/web/namespace_test.go b/server/web/namespace_test.go index 05042c96..30d17cb2 100644 --- a/server/web/namespace_test.go +++ b/server/web/namespace_test.go @@ -15,6 +15,7 @@ package web import ( + "fmt" "net/http" "net/http/httptest" "strconv" @@ -23,6 +24,40 @@ import ( "github.com/beego/beego/v2/server/web/context" ) +const ( + exampleBody = "hello world" + examplePointerBody = "hello world pointer" + + nsNamespace = "/router" + nsPath = "/user" + nsNamespacePath = "/router/user" +) + +type ExampleController struct { + Controller +} + +func (m ExampleController) Ping() { + err := m.Ctx.Output.Body([]byte(exampleBody)) + if err != nil { + fmt.Println(err) + } +} + +func (m *ExampleController) PingPointer() { + err := m.Ctx.Output.Body([]byte(examplePointerBody)) + if err != nil { + fmt.Println(err) + } +} + +func (m ExampleController) ping() { + err := m.Ctx.Output.Body([]byte("ping method")) + if err != nil { + fmt.Println(err) + } +} + func TestNamespaceGet(t *testing.T) { r, _ := http.NewRequest("GET", "/v1/user", nil) w := httptest.NewRecorder() @@ -166,3 +201,215 @@ func TestNamespaceInside(t *testing.T) { t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) } } + +func TestNamespaceRouterGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterGet(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterPost(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterPost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterDelete(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterDelete can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterPut(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterPut can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterHead(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterHead can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterOptions(t *testing.T) { + r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterOptions(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterOptions can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.RouterPatch(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterPatch can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouterAny(t *testing.T) { + ns := NewNamespace(nsNamespace) + ns.RouterAny(nsPath, ExampleController.Ping) + AddNamespace(ns) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, nsNamespacePath, nil) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceRouterAny can't run, get the response is " + w.Body.String()) + } + } +} + +func TestNamespaceNSRouterGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterGet(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/router") + NSRouterPost(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterPost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterDelete(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterDelete can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterPut(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterPut can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterHead(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterHead can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterOptions(t *testing.T) { + r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterOptions(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterOptions can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSRouterPatch("/user", ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterPatch can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSRouterAny(t *testing.T) { + ns := NewNamespace(nsNamespace) + NSRouterAny(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, nsNamespacePath, nil) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSRouterAny can't run, get the response is " + w.Body.String()) + } + } +} diff --git a/server/web/parser.go b/server/web/parser.go deleted file mode 100644 index 803880fe..00000000 --- a/server/web/parser.go +++ /dev/null @@ -1,600 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package web - -import ( - "encoding/json" - "errors" - "fmt" - "go/ast" - "io/ioutil" - "os" - "path/filepath" - "regexp" - "sort" - "strconv" - "strings" - "unicode" - - "golang.org/x/tools/go/packages" - - "github.com/beego/beego/v2/core/logs" - - "github.com/beego/beego/v2/core/utils" - "github.com/beego/beego/v2/server/web/context/param" -) - -var globalRouterTemplate = `package {{.routersDir}} - -import ( - beego "github.com/beego/beego/v2/server/web" - "github.com/beego/beego/v2/server/web/context/param"{{.globalimport}} -) - -func init() { -{{.globalinfo}} -} -` - -var ( - lastupdateFilename = "lastupdate.tmp" - commentFilename string - pkgLastupdate map[string]int64 - genInfoList map[string][]ControllerComments - - routerHooks = map[string]int{ - "beego.BeforeStatic": BeforeStatic, - "beego.BeforeRouter": BeforeRouter, - "beego.BeforeExec": BeforeExec, - "beego.AfterExec": AfterExec, - "beego.FinishRouter": FinishRouter, - } - - routerHooksMapping = map[int]string{ - BeforeStatic: "beego.BeforeStatic", - BeforeRouter: "beego.BeforeRouter", - BeforeExec: "beego.BeforeExec", - AfterExec: "beego.AfterExec", - FinishRouter: "beego.FinishRouter", - } -) - -const commentPrefix = "commentsRouter_" - -func init() { - pkgLastupdate = make(map[string]int64) -} - -func parserPkg(pkgRealpath string) error { - rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") - commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) - commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go" - if !compareFile(pkgRealpath) { - logs.Info(pkgRealpath + " no changed") - return nil - } - genInfoList = make(map[string][]ControllerComments) - pkgs, err := packages.Load(&packages.Config{ - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedSyntax, - Dir: pkgRealpath, - }, "./...") - - if err != nil { - return err - } - for _, pkg := range pkgs { - for _, fl := range pkg.Syntax { - for _, d := range fl.Decls { - switch specDecl := d.(type) { - case *ast.FuncDecl: - if specDecl.Recv != nil { - exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser - if ok { - parserComments(specDecl, fmt.Sprint(exp.X), pkg.PkgPath) - } - } - } - } - } - } - genRouterCode(pkgRealpath) - savetoFile(pkgRealpath) - return nil -} - -type parsedComment struct { - routerPath string - methods []string - params map[string]parsedParam - filters []parsedFilter - imports []parsedImport -} - -type parsedImport struct { - importPath string - importAlias string -} - -type parsedFilter struct { - pattern string - pos int - filter string - params []bool -} - -type parsedParam struct { - name string - datatype string - location string - defValue string - required bool -} - -func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { - if f.Doc != nil { - parsedComments, err := parseComment(f.Doc.List) - if err != nil { - return err - } - for _, parsedComment := range parsedComments { - if parsedComment.routerPath != "" { - key := pkgpath + ":" + controllerName - cc := ControllerComments{} - cc.Method = f.Name.String() - cc.Router = parsedComment.routerPath - cc.AllowHTTPMethods = parsedComment.methods - cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) - cc.FilterComments = buildFilters(parsedComment.filters) - cc.ImportComments = buildImports(parsedComment.imports) - genInfoList[key] = append(genInfoList[key], cc) - } - } - } - return nil -} - -func buildImports(pis []parsedImport) []*ControllerImportComments { - var importComments []*ControllerImportComments - - for _, pi := range pis { - importComments = append(importComments, &ControllerImportComments{ - ImportPath: pi.importPath, - ImportAlias: pi.importAlias, - }) - } - - return importComments -} - -func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { - var filterComments []*ControllerFilterComments - - for _, pf := range pfs { - var ( - returnOnOutput bool - resetParams bool - ) - - if len(pf.params) >= 1 { - returnOnOutput = pf.params[0] - } - - if len(pf.params) >= 2 { - resetParams = pf.params[1] - } - - filterComments = append(filterComments, &ControllerFilterComments{ - Filter: pf.filter, - Pattern: pf.pattern, - Pos: pf.pos, - ReturnOnOutput: returnOnOutput, - ResetParams: resetParams, - }) - } - - return filterComments -} - -func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { - result := make([]*param.MethodParam, 0, len(funcParams)) - for _, fparam := range funcParams { - for _, pName := range fparam.Names { - methodParam := buildMethodParam(fparam, pName.Name, pc) - result = append(result, methodParam) - } - } - return result -} - -func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam { - options := []param.MethodParamOption{} - if cparam, ok := pc.params[name]; ok { - // Build param from comment info - name = cparam.name - if cparam.required { - options = append(options, param.IsRequired) - } - switch cparam.location { - case "body": - options = append(options, param.InBody) - case "header": - options = append(options, param.InHeader) - case "path": - options = append(options, param.InPath) - } - if cparam.defValue != "" { - options = append(options, param.Default(cparam.defValue)) - } - } else { - if paramInPath(name, pc.routerPath) { - options = append(options, param.InPath) - } - } - return param.New(name, options...) -} - -func paramInPath(name, route string) bool { - return strings.HasSuffix(route, ":"+name) || - strings.Contains(route, ":"+name+"/") -} - -var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) - -func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { - pcs = []*parsedComment{} - params := map[string]parsedParam{} - filters := []parsedFilter{} - imports := []parsedImport{} - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Param") { - pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) - if len(pv) < 4 { - logs.Error("Invalid @Param format. Needs at least 4 parameters") - } - p := parsedParam{} - names := strings.SplitN(pv[0], "=>", 2) - p.name = names[0] - funcParamName := p.name - if len(names) > 1 { - funcParamName = names[1] - } - p.location = pv[1] - p.datatype = pv[2] - switch len(pv) { - case 5: - p.required, _ = strconv.ParseBool(pv[3]) - case 6: - p.defValue = pv[3] - p.required, _ = strconv.ParseBool(pv[4]) - } - params[funcParamName] = p - } - } - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Import") { - iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) - if len(iv) == 0 || len(iv) > 2 { - logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") - continue - } - - p := parsedImport{} - p.importPath = iv[0] - - if len(iv) == 2 { - p.importAlias = iv[1] - } - - imports = append(imports, p) - } - } - -filterLoop: - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Filter") { - fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) - if len(fv) < 3 { - logs.Error("Invalid @Filter format. Needs at least 3 parameters") - continue filterLoop - } - - p := parsedFilter{} - p.pattern = fv[0] - posName := fv[1] - if pos, exists := routerHooks[posName]; exists { - p.pos = pos - } else { - logs.Error("Invalid @Filter pos: ", posName) - continue filterLoop - } - - p.filter = fv[2] - fvParams := fv[3:] - for _, fvParam := range fvParams { - switch fvParam { - case "true": - p.params = append(p.params, true) - case "false": - p.params = append(p.params, false) - default: - logs.Error("Invalid @Filter param: ", fvParam) - continue filterLoop - } - } - - filters = append(filters, p) - } - } - - for _, c := range lines { - var pc = &parsedComment{} - pc.params = params - pc.filters = filters - pc.imports = imports - - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@router") { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - matches := routeRegex.FindStringSubmatch(t) - if len(matches) == 3 { - pc.routerPath = matches[1] - methods := matches[2] - if methods == "" { - pc.methods = []string{"get"} - // pc.hasGet = true - } else { - pc.methods = strings.Split(methods, ",") - // pc.hasGet = strings.Contains(methods, "get") - } - pcs = append(pcs, pc) - } else { - return nil, errors.New("Router information is missing") - } - } - } - return -} - -// direct copy from bee\g_docs.go -// analysis params return []string -// @Param query form string true "The email for login" -// [query form string true "The email for login"] -func getparams(str string) []string { - var s []rune - var j int - var start bool - var r []string - var quoted int8 - for _, c := range str { - if unicode.IsSpace(c) && quoted == 0 { - if !start { - continue - } else { - start = false - j++ - r = append(r, string(s)) - s = make([]rune, 0) - continue - } - } - - start = true - if c == '"' { - quoted ^= 1 - continue - } - s = append(s, c) - } - if len(s) > 0 { - r = append(r, string(s)) - } - return r -} - -func genRouterCode(pkgRealpath string) { - os.Mkdir(getRouterDir(pkgRealpath), 0755) - logs.Info("generate router from comments") - var ( - globalinfo string - globalimport string - sortKey []string - ) - for k := range genInfoList { - sortKey = append(sortKey, k) - } - sort.Strings(sortKey) - for _, k := range sortKey { - cList := genInfoList[k] - sort.Sort(ControllerCommentsSlice(cList)) - for _, c := range cList { - allmethod := "nil" - if len(c.AllowHTTPMethods) > 0 { - allmethod = "[]string{" - for _, m := range c.AllowHTTPMethods { - allmethod += `"` + m + `",` - } - allmethod = strings.TrimRight(allmethod, ",") + "}" - } - - params := "nil" - if len(c.Params) > 0 { - params = "[]map[string]string{" - for _, p := range c.Params { - for k, v := range p { - params = params + `map[string]string{` + k + `:"` + v + `"},` - } - } - params = strings.TrimRight(params, ",") + "}" - } - - methodParams := "param.Make(" - if len(c.MethodParams) > 0 { - lines := make([]string, 0, len(c.MethodParams)) - for _, m := range c.MethodParams { - lines = append(lines, fmt.Sprint(m)) - } - methodParams += "\n " + - strings.Join(lines, ",\n ") + - ",\n " - } - methodParams += ")" - - imports := "" - if len(c.ImportComments) > 0 { - for _, i := range c.ImportComments { - var s string - if i.ImportAlias != "" { - s = fmt.Sprintf(` - %s "%s"`, i.ImportAlias, i.ImportPath) - } else { - s = fmt.Sprintf(` - "%s"`, i.ImportPath) - } - if !strings.Contains(globalimport, s) { - imports += s - } - } - } - - filters := "" - if len(c.FilterComments) > 0 { - for _, f := range c.FilterComments { - filters += fmt.Sprintf(` &beego.ControllerFilter{ - Pattern: "%s", - Pos: %s, - Filter: %s, - ReturnOnOutput: %v, - ResetParams: %v, - },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) - } - } - - if filters == "" { - filters = "nil" - } else { - filters = fmt.Sprintf(`[]*beego.ControllerFilter{ -%s - }`, filters) - } - - globalimport += imports - - globalinfo = globalinfo + ` - beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], - beego.ControllerComments{ - Method: "` + strings.TrimSpace(c.Method) + `", - ` + `Router: "` + c.Router + `"` + `, - AllowHTTPMethods: ` + allmethod + `, - MethodParams: ` + methodParams + `, - Filters: ` + filters + `, - Params: ` + params + `}) -` - } - } - - if globalinfo != "" { - f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) - if err != nil { - panic(err) - } - defer f.Close() - - routersDir := AppConfig.DefaultString("routersdir", "routers") - content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) - content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) - content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) - f.WriteString(content) - } -} - -func compareFile(pkgRealpath string) bool { - if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) { - return true - } - if utils.FileExists(lastupdateFilename) { - content, err := ioutil.ReadFile(lastupdateFilename) - if err != nil { - return true - } - json.Unmarshal(content, &pkgLastupdate) - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return true - } - if v, ok := pkgLastupdate[pkgRealpath]; ok { - if lastupdate <= v { - return false - } - } - } - return true -} - -func savetoFile(pkgRealpath string) { - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return - } - pkgLastupdate[pkgRealpath] = lastupdate - d, err := json.Marshal(pkgLastupdate) - if err != nil { - return - } - ioutil.WriteFile(lastupdateFilename, d, os.ModePerm) -} - -func getpathTime(pkgRealpath string) (lastupdate int64, err error) { - fl, err := ioutil.ReadDir(pkgRealpath) - if err != nil { - return lastupdate, err - } - for _, f := range fl { - var t int64 - if f.IsDir() { - t, err = getpathTime(filepath.Join(pkgRealpath, f.Name())) - if err != nil { - return lastupdate, err - } - } else { - t = f.ModTime().UnixNano() - } - if lastupdate < t { - lastupdate = t - } - } - return lastupdate, nil -} - -func getRouterDir(pkgRealpath string) string { - dir := filepath.Dir(pkgRealpath) - for { - routersDir := AppConfig.DefaultString("routersdir", "routers") - d := filepath.Join(dir, routersDir) - if utils.FileExists(d) { - return d - } - - if r, _ := filepath.Rel(dir, AppPath); r == "." { - return d - } - // Parent dir. - dir = filepath.Dir(dir) - } -} diff --git a/server/web/router.go b/server/web/router.go index 5a663386..f9f8b322 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -20,6 +20,7 @@ import ( "net/http" "path" "reflect" + "runtime" "strconv" "strings" "sync" @@ -118,24 +119,48 @@ type ControllerInfo struct { routerType int initialize func() ControllerInterface methodParams []*param.MethodParam + sessionOn bool } +type ControllerOption func(*ControllerInfo) + func (c *ControllerInfo) GetPattern() string { return c.pattern } +func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { + return func(c *ControllerInfo) { + c.methods = parseMappingMethods(ctrlInterface, mappingMethod) + } +} + +func WithRouterSessionOn(sessionOn bool) ControllerOption { + return func(c *ControllerInfo) { + c.sessionOn = sessionOn + } +} + +type filterChainConfig struct { + pattern string + chain FilterChain + opts []FilterOpt +} + // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree enablePolicy bool - policies map[string]*Tree enableFilter bool + policies map[string]*Tree filters [FinishRouter + 1][]*FilterRouter pool sync.Pool // the filter created by FilterChain chainRoot *FilterRouter + // keep registered chain and build it when serve http + filterChains []filterChainConfig + cfg *Config } @@ -155,12 +180,24 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { return beecontext.NewContext() }, }, - cfg: cfg, + cfg: cfg, + filterChains: make([]filterChainConfig, 0, 4), } res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res } +// Init will be executed when HttpServer start running +func (p *ControllerRegister) Init() { + for i := len(p.filterChains) - 1; i >= 0; i-- { + fc := p.filterChains[i] + root := p.chainRoot + filterFunc := fc.chain(root.filterFunc) + p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...) + p.chainRoot.next = root + } +} + // Add controller handler and pattern rules to ControllerRegister. // usage: // default methods is the same name as method @@ -171,41 +208,64 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { // Add("/api/delete",&RestController{},"delete:DeleteFood") // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - p.addWithMethodParams(pattern, c, nil, mappingMethods...) +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) { + p.addWithMethodParams(pattern, c, nil, opts...) } -func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { +func parseMappingMethods(c ControllerInterface, mappingMethods []string) map[string]string { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() methods := make(map[string]string) - if len(mappingMethods) > 0 { - semi := strings.Split(mappingMethods[0], ";") - for _, v := range semi { - colon := strings.Split(v, ":") - if len(colon) != 2 { - panic("method mapping format is invalid") + + if len(mappingMethods) == 0 { + return methods + } + + semi := strings.Split(mappingMethods[0], ";") + for _, v := range semi { + colon := strings.Split(v, ":") + if len(colon) != 2 { + panic("method mapping format is invalid") + } + comma := strings.Split(colon[0], ",") + for _, m := range comma { + if m != "*" && !HTTPMETHOD[strings.ToUpper(m)] { + panic(v + " is an invalid method mapping. Method doesn't exist " + m) } - comma := strings.Split(colon[0], ",") - for _, m := range comma { - if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { - if val := reflectVal.MethodByName(colon[1]); val.IsValid() { - methods[strings.ToUpper(m)] = colon[1] - } else { - panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) - } - } else { - panic(v + " is an invalid method mapping. Method doesn't exist " + m) - } + if val := reflectVal.MethodByName(colon[1]); val.IsValid() { + methods[strings.ToUpper(m)] = colon[1] + continue } + panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) } } - route := &ControllerInfo{} - route.pattern = pattern - route.methods = methods - route.routerType = routerTypeBeego - route.controllerType = t + return methods +} + +func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) { + if len(route.methods) == 0 { + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + return + } + for k := range route.methods { + if k != "*" { + p.addToRouter(k, route.pattern, route) + continue + } + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + } +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + + route := p.createBeegoRouter(t, pattern) route.initialize = func() ControllerInterface { vc := reflect.New(route.controllerType) execController, ok := vc.Interface().(ControllerInterface) @@ -229,23 +289,18 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt return execController } - route.methodParams = methodParams - if len(methods) == 0 { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } + for i := range opts { + opts[i](route) } + + globalSessionOn := p.cfg.WebConfig.Session.SessionOn + if !globalSessionOn && route.sessionOn { + logs.Warn("global sessionOn is false, sessionOn of router [%s] can't be set to true", route.pattern) + route.sessionOn = globalSessionOn + } + + p.addRouterForMethod(route) } func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { @@ -273,7 +328,8 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { for _, f := range a.Filters { p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } - p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + + p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) } } } @@ -294,6 +350,261 @@ func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { p.pool.Put(ctx) } +// RouterGet add get method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterGet("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterGet(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodGet, pattern, f) +} + +// RouterPost add post method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPost("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterPost(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPost, pattern, f) +} + +// RouterHead add head method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterHead("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterHead(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodHead, pattern, f) +} + +// RouterPut add put method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPut("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterPut(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPut, pattern, f) +} + +// RouterPatch add patch method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPatch("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterPatch(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPatch, pattern, f) +} + +// RouterDelete add delete method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterDelete("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterDelete(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodDelete, pattern, f) +} + +// RouterOptions add options method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterOptions("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterOptions(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodOptions, pattern, f) +} + +// RouterAny add all method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterAny("/api/:id", MyController.Ping) +func (p *ControllerRegister) RouterAny(pattern string, f interface{}) { + p.AddRouterMethod("*", pattern, f) +} + +// AddRouterMethod add http method router +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// AddRouterMethod("get","/api/:id", MyController.Ping) +func (p *ControllerRegister) AddRouterMethod(httpMethod, pattern string, f interface{}) { + httpMethod = p.getUpperMethodString(httpMethod) + ct, methodName := getReflectTypeAndMethod(f) + + p.addBeegoTypeRouter(ct, methodName, httpMethod, pattern) +} + +// addBeegoTypeRouter add beego type router +func (p *ControllerRegister) addBeegoTypeRouter(ct reflect.Type, ctMethod, httpMethod, pattern string) { + route := p.createBeegoRouter(ct, pattern) + methods := p.getHttpMethodMapMethod(httpMethod, ctMethod) + route.methods = methods + + p.addRouterForMethod(route) +} + +// createBeegoRouter create beego router base on reflect type and pattern +func (p *ControllerRegister) createBeegoRouter(ct reflect.Type, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeBeego + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.controllerType = ct + return route +} + +// createRestfulRouter create restful router with filter function and pattern +func (p *ControllerRegister) createRestfulRouter(f FilterFunc, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeRESTFul + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.runFunction = f + return route +} + +// createHandlerRouter create handler router with handler and pattern +func (p *ControllerRegister) createHandlerRouter(h http.Handler, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeHandler + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.handler = h + return route +} + +// getHttpMethodMapMethod based on http method and controller method, if ctMethod is empty, then it will +// use http method as the controller method +func (p *ControllerRegister) getHttpMethodMapMethod(httpMethod, ctMethod string) map[string]string { + methods := make(map[string]string) + // not match-all sign, only add for the http method + if httpMethod != "*" { + + if ctMethod == "" { + ctMethod = httpMethod + } + methods[httpMethod] = ctMethod + return methods + } + + // add all http method + for val := range HTTPMETHOD { + if ctMethod == "" { + methods[val] = val + } else { + methods[val] = ctMethod + } + } + return methods +} + +// getUpperMethodString get upper string of method, and panic if the method +// is not valid +func (p *ControllerRegister) getUpperMethodString(method string) string { + method = strings.ToUpper(method) + if method != "*" && !HTTPMETHOD[method] { + panic("not support http method: " + method) + } + return method +} + +// get reflect controller type and method by controller method expression +func getReflectTypeAndMethod(f interface{}) (controllerType reflect.Type, method string) { + // check f is a function + funcType := reflect.TypeOf(f) + if funcType.Kind() != reflect.Func { + panic("not a method") + } + + // get function name + funcObj := runtime.FuncForPC(reflect.ValueOf(f).Pointer()) + if funcObj == nil { + panic("cannot find the method") + } + funcNameSli := strings.Split(funcObj.Name(), ".") + lFuncSli := len(funcNameSli) + if lFuncSli == 0 { + panic("invalid method full name: " + funcObj.Name()) + } + + method = funcNameSli[lFuncSli-1] + if len(method) == 0 { + panic("method name is empty") + } else if method[0] > 96 || method[0] < 65 { + panic(fmt.Sprintf("%s is not a public method", method)) + } + + // check only one param which is the method receiver + if numIn := funcType.NumIn(); numIn != 1 { + panic("invalid number of param in") + } + + controllerType = funcType.In(0) + + // check controller has the method + _, exists := controllerType.MethodByName(method) + if !exists { + panic(controllerType.String() + " has no method " + method) + } + + // check the receiver implement ControllerInterface + if controllerType.Kind() == reflect.Ptr { + controllerType = controllerType.Elem() + } + controller := reflect.New(controllerType) + _, ok := controller.Interface().(ControllerInterface) + if !ok { + panic(controllerType.String() + " is not implemented ControllerInterface") + } + + return +} + // Get add get method // usage: // Get("/", func(ctx *context.Context){ @@ -372,40 +683,18 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) { // ctx.Output.Body("hello world") // }) func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { - method = strings.ToUpper(method) - if method != "*" && !HTTPMETHOD[method] { - panic("not support http method: " + method) - } - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeRESTFul - route.runFunction = f - methods := make(map[string]string) - if method == "*" { - for val := range HTTPMETHOD { - methods[val] = val - } - } else { - methods[method] = method - } + method = p.getUpperMethodString(method) + + route := p.createRestfulRouter(f, pattern) + methods := p.getHttpMethodMapMethod(method, "") route.methods = methods - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } + + p.addRouterForMethod(route) } // Handler add user defined Handler func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeHandler - route.handler = h + route := p.createHandlerRouter(h, pattern) if len(options) > 0 { if _, ok := options[0].(bool); ok { pattern = path.Join(pattern, "?:all(.*)") @@ -437,15 +726,13 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) controllerName := strings.TrimSuffix(ct.Name(), "Controller") for i := 0; i < rt.NumMethod(); i++ { if !utils.InSlice(rt.Method(i).Name, exceptMethod) { - route := &ControllerInfo{} - route.routerType = routerTypeBeego - route.methods = map[string]string{"*": rt.Method(i).Name} - route.controllerType = ct pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) - route.pattern = pattern + + route := p.createBeegoRouter(ct, pattern) + route.methods = map[string]string{"*": rt.Method(i).Name} for m := range HTTPMETHOD { p.addToRouter(m, pattern, route) p.addToRouter(m, patternInit, route) @@ -478,12 +765,13 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // } // } func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { - root := p.chainRoot - filterFunc := chain(root.filterFunc) - opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) - p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) - p.chainRoot.next = root + opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) + p.filterChains = append(p.filterChains, filterChainConfig{ + pattern: pattern, + chain: chain, + opts: opts, + }) } // add Filter into @@ -548,7 +836,7 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str for _, l := range t.leaves { if c, ok := l.runObject.(*ControllerInfo); ok { if c.routerType == routerTypeBeego && - strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) { find := false if HTTPMETHOD[strings.ToUpper(methodName)] { if len(c.methods) == 0 { @@ -670,12 +958,15 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { r := ctx.Request rw := ctx.ResponseWriter.ResponseWriter var ( - runRouter reflect.Type - findRouter bool - runMethod string - methodParams []*param.MethodParam - routerInfo *ControllerInfo - isRunnable bool + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool + currentSessionOn bool + originRouterInfo *ControllerInfo + originFindRouter bool ) if p.cfg.RecoverFunc != nil { @@ -741,7 +1032,12 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } // session init - if p.cfg.WebConfig.Session.SessionOn { + currentSessionOn = p.cfg.WebConfig.Session.SessionOn + originRouterInfo, originFindRouter = p.FindRouter(ctx) + if originFindRouter { + currentSessionOn = originRouterInfo.sessionOn + } + if currentSessionOn { ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { logs.Error(err) diff --git a/server/web/router_test.go b/server/web/router_test.go index 87997322..3633aee7 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -16,6 +16,7 @@ package web import ( "bytes" + "fmt" "net/http" "net/http/httptest" "strings" @@ -26,6 +27,25 @@ import ( "github.com/beego/beego/v2/server/web/context" ) +type PrefixTestController struct { + Controller +} + +func (ptc *PrefixTestController) PrefixList() { + ptc.Ctx.Output.Body([]byte("i am list in prefix test")) +} + +type TestControllerWithInterface struct { +} + +func (m TestControllerWithInterface) Ping() { + fmt.Println("pong") +} + +func (m *TestControllerWithInterface) PingPointer() { + fmt.Println("pong pointer") +} + type TestController struct { Controller } @@ -87,10 +107,24 @@ func (jc *JSONController) Get() { jc.Ctx.Output.Body([]byte("ok")) } +func TestPrefixUrlFor(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList")) + + if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` { + logs.Info(a) + t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list") + } + if a := handler.URLFor(`TestController.PrefixList`); a != `` { + logs.Info(a) + t.Errorf("TestController.PrefixList must equal to empty string") + } +} + func TestUrlFor(t *testing.T) { handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.Add("/person/:last/:first", &TestController{}, "*:Param") + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) if a := handler.URLFor("TestController.List"); a != "/api/list" { logs.Info(a) t.Errorf("TestController.List must equal to /api/list") @@ -113,9 +147,9 @@ func TestUrlFor3(t *testing.T) { func TestUrlFor2(t *testing.T) { handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") - handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL")) + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { logs.Info(handler.URLFor("TestController.GetURL")) @@ -145,7 +179,7 @@ func TestUserFunc(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) handler.ServeHTTP(w, r) if w.Body.String() != "i am list" { t.Errorf("user define func can't run") @@ -235,7 +269,7 @@ func TestRouteOk(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams")) handler.ServeHTTP(w, r) body := w.Body.String() if body != "anderson+thomas+kungfu" { @@ -249,7 +283,7 @@ func TestManyRoute(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter")) handler.ServeHTTP(w, r) body := w.Body.String() @@ -266,7 +300,7 @@ func TestEmptyResponse(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") + handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody")) handler.ServeHTTP(w, r) if body := w.Body.String(); body != "" { @@ -750,3 +784,321 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) { t.Errorf("TestRouterRequestEntityTooLarge can't run") } } + +func TestRouterSessionSet(t *testing.T) { + oldGlobalSessionOn := BConfig.WebConfig.Session.SessionOn + defer func() { + BConfig.WebConfig.Session.SessionOn = oldGlobalSessionOn + }() + + // global sessionOn = false, router sessionOn = false + r, _ := http.NewRequest("GET", "/user", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = false, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + BConfig.WebConfig.Session.SessionOn = true + if err := registerSession(); err != nil { + t.Errorf("register session failed, error: %s", err.Error()) + } + // global sessionOn = true, router sessionOn = false + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = true, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") == "" { + t.Errorf("TestRotuerSessionSet failed") + } + +} + +func TestRouterRouterGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterGet("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterGet can't run") + } +} + +func TestRouterRouterPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPost("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterPost can't run") + } +} + +func TestRouterRouterHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterHead("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterHead can't run") + } +} + +func TestRouterRouterPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPut("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterPut can't run") + } +} + +func TestRouterRouterPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPatch("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterPatch can't run") + } +} + +func TestRouterRouterDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterDelete("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterDelete can't run") + } +} + +func TestRouterRouterAny(t *testing.T) { + handler := NewControllerRegister() + handler.RouterAny("/user", ExampleController.Ping) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/user", nil) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterRouterAny can't run, get the response is " + w.Body.String()) + } + } +} + +func TestRouterRouterGetPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterGet("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterGetPointerMethod can't run") + } +} + +func TestRouterRouterPostPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPost("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterPostPointerMethod can't run") + } +} + +func TestRouterRouterHeadPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterHead("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterHeadPointerMethod can't run") + } +} + +func TestRouterRouterPutPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPut("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterPutPointerMethod can't run") + } +} + +func TestRouterRouterPatchPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterPatch("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterPatchPointerMethod can't run") + } +} + +func TestRouterRouterDeletePointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.RouterDelete("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterDeletePointerMethod can't run") + } +} + +func TestRouterRouterAnyPointerMethod(t *testing.T) { + handler := NewControllerRegister() + handler.RouterAny("/user", (*ExampleController).PingPointer) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/user", nil) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterRouterAnyPointerMethod can't run, get the response is " + w.Body.String()) + } + } +} + +func TestRouterAddRouterMethodPanicInvalidMethod(t *testing.T) { + method := "some random method" + message := "not support http method: " + strings.ToUpper(method) + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicInvalidMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController.Ping) +} + +func TestRouterAddRouterMethodPanicNotAMethod(t *testing.T) { + method := http.MethodGet + message := "not a method" + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotAMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController{}) +} + +func TestRouterAddRouterMethodPanicNotPublicMethod(t *testing.T) { + method := http.MethodGet + message := "ping is not a public method" + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotPublicMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController.ping) +} + +func TestRouterAddRouterMethodPanicNotImplementInterface(t *testing.T) { + method := http.MethodGet + message := "web.TestControllerWithInterface is not implemented ControllerInterface" + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotImplementInterface failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", TestControllerWithInterface.Ping) +} + +func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) { + method := http.MethodGet + message := "web.TestControllerWithInterface is not implemented ControllerInterface" + defer func() { + err := recover() + if err != nil { //产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterPointerMethodPanicNotImplementInterface failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer) +} diff --git a/server/web/server.go b/server/web/server.go index f0a4f4ea..e1ef1a03 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -84,7 +84,9 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { initBeforeHTTPRun() + // init... app.initAddr(addr) + app.Handlers.Init() addr = app.Cfg.Listen.HTTPAddr @@ -267,7 +269,11 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { // Router see HttpServer.Router func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - return BeeApp.Router(rootpath, c, mappingMethods...) + return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...)) +} + +func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { + return BeeApp.RouterWithOpts(rootpath, c, opts...) } // Router adds a patterned controller handler to BeeApp. @@ -287,7 +293,11 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *H // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - app.Handlers.Add(rootPath, c, mappingMethods...) + return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...)) +} + +func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { + app.Handlers.Add(rootPath, c, opts...) return app } @@ -453,6 +463,166 @@ func (app *HttpServer) AutoPrefix(prefix string, c ControllerInterface) *HttpSer return app } +// RouterGet see HttpServer.RouterGet +func RouterGet(rootpath string, f interface{}) { + BeeApp.RouterGet(rootpath, f) +} + +// RouterGet used to register router for RouterGet method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterGet("/api/:id", MyController.Ping) +func (app *HttpServer) RouterGet(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterGet(rootpath, f) + return app +} + +// RouterPost see HttpServer.RouterGet +func RouterPost(rootpath string, f interface{}) { + BeeApp.RouterPost(rootpath, f) +} + +// RouterPost used to register router for RouterPost method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPost("/api/:id", MyController.Ping) +func (app *HttpServer) RouterPost(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterPost(rootpath, f) + return app +} + +// RouterHead see HttpServer.RouterHead +func RouterHead(rootpath string, f interface{}) { + BeeApp.RouterHead(rootpath, f) +} + +// RouterHead used to register router for RouterHead method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterHead("/api/:id", MyController.Ping) +func (app *HttpServer) RouterHead(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterHead(rootpath, f) + return app +} + +// RouterPut see HttpServer.RouterPut +func RouterPut(rootpath string, f interface{}) { + BeeApp.RouterPut(rootpath, f) +} + +// RouterPut used to register router for RouterPut method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPut("/api/:id", MyController.Ping) +func (app *HttpServer) RouterPut(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterPut(rootpath, f) + return app +} + +// RouterPatch see HttpServer.RouterPatch +func RouterPatch(rootpath string, f interface{}) { + BeeApp.RouterPatch(rootpath, f) +} + +// RouterPatch used to register router for RouterPatch method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterPatch("/api/:id", MyController.Ping) +func (app *HttpServer) RouterPatch(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterPatch(rootpath, f) + return app +} + +// RouterDelete see HttpServer.RouterDelete +func RouterDelete(rootpath string, f interface{}) { + BeeApp.RouterDelete(rootpath, f) +} + +// RouterDelete used to register router for RouterDelete method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterDelete("/api/:id", MyController.Ping) +func (app *HttpServer) RouterDelete(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterDelete(rootpath, f) + return app +} + +// RouterOptions see HttpServer.RouterOptions +func RouterOptions(rootpath string, f interface{}) { + BeeApp.RouterOptions(rootpath, f) +} + +// RouterOptions used to register router for RouterOptions method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterOptions("/api/:id", MyController.Ping) +func (app *HttpServer) RouterOptions(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterOptions(rootpath, f) + return app +} + +// RouterAny see HttpServer.RouterAny +func RouterAny(rootpath string, f interface{}) { + BeeApp.RouterAny(rootpath, f) +} + +// RouterAny used to register router for RouterAny method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// RouterAny("/api/:id", MyController.Ping) +func (app *HttpServer) RouterAny(rootpath string, f interface{}) *HttpServer { + app.Handlers.RouterAny(rootpath, f) + return app +} + // Get see HttpServer.Get func Get(rootpath string, f FilterFunc) *HttpServer { return BeeApp.Get(rootpath, f) diff --git a/server/web/server_test.go b/server/web/server_test.go index 0b0c601c..0734be77 100644 --- a/server/web/server_test.go +++ b/server/web/server_test.go @@ -15,6 +15,8 @@ package web import ( + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -28,3 +30,82 @@ func TestNewHttpServerWithCfg(t *testing.T) { assert.Equal(t, "hello", BConfig.AppName) } + +func TestServerRouterGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + RouterGet("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterGet can't run") + } +} + +func TestServerRouterPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + RouterPost("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterPost can't run") + } +} + +func TestServerRouterHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + RouterHead("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterHead can't run") + } +} + +func TestServerRouterPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + RouterPut("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterPut can't run") + } +} + +func TestServerRouterPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + RouterPatch("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterPatch can't run") + } +} + +func TestServerRouterDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + RouterDelete("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterDelete can't run") + } +} + +func TestServerRouterAny(t *testing.T) { + RouterAny("/user", ExampleController.Ping) + + for method := range HTTPMETHOD { + r, _ := http.NewRequest(method, "/user", nil) + w := httptest.NewRecorder() + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerRouterAny can't run") + } + } +} diff --git a/server/web/session/README.md b/server/web/session/README.md index 854fb590..8dd70f67 100644 --- a/server/web/session/README.md +++ b/server/web/session/README.md @@ -6,7 +6,7 @@ and `database/sql/driver`. ## How to install? - go get github.com/beego/beego/v2/session + go get github.com/beego/beego/v2/server/web/session ## What providers are supported? @@ -17,7 +17,7 @@ As of now this session manager support memory, file, Redis and MySQL. First you must import it import ( - "github.com/beego/beego/v2/session" + "github.com/beego/beego/v2/server/web/session" ) Then in you web app init the global session manager diff --git a/server/web/session/couchbase/sess_couchbase.go b/server/web/session/couchbase/sess_couchbase.go index ea94f501..b9075040 100644 --- a/server/web/session/couchbase/sess_couchbase.go +++ b/server/web/session/couchbase/sess_couchbase.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/couchbase" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/couchbase" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/memcache/sess_memcache.go b/server/web/session/memcache/sess_memcache.go index 3f4c9842..cf191fdf 100644 --- a/server/web/session/memcache/sess_memcache.go +++ b/server/web/session/memcache/sess_memcache.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/memcache" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/memcache" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/mysql/sess_mysql.go b/server/web/session/mysql/sess_mysql.go index d76ec287..6df2737d 100644 --- a/server/web/session/mysql/sess_mysql.go +++ b/server/web/session/mysql/sess_mysql.go @@ -28,8 +28,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/mysql" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/mysql" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -150,6 +150,8 @@ func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, if err == sql.ErrNoRows { c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) + } else if err != nil { + return nil, err } var kv map[interface{}]interface{} if len(sessiondata) == 0 { @@ -189,7 +191,10 @@ func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) ( if err == sql.ErrNoRows { c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) } - c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) + _, err = c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) + if err != nil { + return nil, err + } var kv map[interface{}]interface{} if len(sessiondata) == 0 { kv = make(map[interface{}]interface{}) diff --git a/server/web/session/postgres/sess_postgresql.go b/server/web/session/postgres/sess_postgresql.go index 7745ff5f..b26b82ce 100644 --- a/server/web/session/postgres/sess_postgresql.go +++ b/server/web/session/postgres/sess_postgresql.go @@ -38,8 +38,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/postgresql" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/postgresql" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/redis/sess_redis.go b/server/web/session/redis/sess_redis.go index e3d38be3..acc25f78 100644 --- a/server/web/session/redis/sess_redis.go +++ b/server/web/session/redis/sess_redis.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/redis/sess_redis_test.go b/server/web/session/redis/sess_redis_test.go index fe5c363b..2b15eef1 100644 --- a/server/web/session/redis/sess_redis_test.go +++ b/server/web/session/redis/sess_redis_test.go @@ -15,21 +15,22 @@ import ( ) func TestRedis(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - } - redisAddr := os.Getenv("REDIS_ADDR") if redisAddr == "" { redisAddr = "127.0.0.1:6379" } + redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr) + + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig(redisConfig), + ) - sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr) globalSession, err := session.NewManager("redis", sessionConfig) if err != nil { t.Fatal("could not create manager:", err) diff --git a/server/web/session/redis_cluster/redis_cluster.go b/server/web/session/redis_cluster/redis_cluster.go index e94dccc3..4db3bbe9 100644 --- a/server/web/session/redis_cluster/redis_cluster.go +++ b/server/web/session/redis_cluster/redis_cluster.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis_cluster" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis_cluster" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel.go b/server/web/session/redis_sentinel/sess_redis_sentinel.go index 2d64c6b4..d18a3773 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis_sentinel" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis_sentinel" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go index 0a8030ce..489e8998 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go @@ -13,15 +13,15 @@ import ( ) func TestRedisSentinel(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - ProviderConfig: "127.0.0.1:6379,100,,0,master", - } + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"), + ) globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) if e != nil { t.Log(e) diff --git a/server/web/session/sess_file.go b/server/web/session/sess_file.go index 90de9a79..a96bacb8 100644 --- a/server/web/session/sess_file.go +++ b/server/web/session/sess_file.go @@ -211,9 +211,7 @@ func (fp *FileProvider) SessionGC(context.Context) { // it walks save path to count files. func (fp *FileProvider) SessionAll(context.Context) int { a := &activeSession{} - err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { - return a.visit(path, f, err) - }) + err := filepath.Walk(fp.savePath, a.visit) if err != nil { SLogger.Printf("filepath.Walk() returned %v\n", err) return 0 diff --git a/server/web/session/session.go b/server/web/session/session.go index de63ed75..f0b7e292 100644 --- a/server/web/session/session.go +++ b/server/web/session/session.go @@ -16,7 +16,7 @@ // // Usage: // import( -// "github.com/beego/beego/v2/session" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -91,24 +91,6 @@ func GetProvider(name string) (Provider, error) { return provider, nil } -// ManagerConfig define the session config -type ManagerConfig struct { - CookieName string `json:"cookieName"` - EnableSetCookie bool `json:"enableSetCookie,omitempty"` - Gclifetime int64 `json:"gclifetime"` - Maxlifetime int64 `json:"maxLifetime"` - DisableHTTPOnly bool `json:"disableHTTPOnly"` - Secure bool `json:"secure"` - CookieLifeTime int `json:"cookieLifeTime"` - ProviderConfig string `json:"providerConfig"` - Domain string `json:"domain"` - SessionIDLength int64 `json:"sessionIDLength"` - EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` - SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` - EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` - SessionIDPrefix string `json:"sessionIDPrefix"` -} - // Manager contains Provider and its configuration. type Manager struct { provider Provider @@ -239,6 +221,7 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se HttpOnly: !manager.config.DisableHTTPOnly, Secure: manager.isSecure(r), Domain: manager.config.Domain, + SameSite: manager.config.CookieSameSite, } if manager.config.CookieLifeTime > 0 { cookie.MaxAge = manager.config.CookieLifeTime @@ -278,7 +261,9 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { HttpOnly: !manager.config.DisableHTTPOnly, Expires: expiration, MaxAge: -1, - Domain: manager.config.Domain} + Domain: manager.config.Domain, + SameSite: manager.config.CookieSameSite, + } http.SetCookie(w, cookie) } @@ -319,6 +304,7 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque HttpOnly: !manager.config.DisableHTTPOnly, Secure: manager.isSecure(r), Domain: manager.config.Domain, + SameSite: manager.config.CookieSameSite, } } else { oldsid, err := url.QueryUnescape(cookie.Value) diff --git a/server/web/session/session_config.go b/server/web/session/session_config.go new file mode 100644 index 00000000..aedfc559 --- /dev/null +++ b/server/web/session/session_config.go @@ -0,0 +1,143 @@ +package session + +import "net/http" + +// ManagerConfig define the session config +type ManagerConfig struct { + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` + Secure bool `json:"secure"` + EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` + EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` + CookieName string `json:"cookieName"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIDLength int64 `json:"sessionIDLength"` + SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` + SessionIDPrefix string `json:"sessionIDPrefix"` + CookieSameSite http.SameSite `json:"cookieSameSite"` +} + +func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) { + for _, opt := range opts { + opt(c) + } +} + +type ManagerConfigOpt func(config *ManagerConfig) + +func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig { + config := &ManagerConfig{} + for _, opt := range opts { + opt(config) + } + return config +} + +// CfgCookieName set key of session id +func CfgCookieName(cookieName string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieName = cookieName + } +} + +// CfgCookieName set len of session id +func CfgSessionIdLength(len int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDLength = len + } +} + +// CfgSessionIdPrefix set prefix of session id +func CfgSessionIdPrefix(prefix string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDPrefix = prefix + } +} + +//CfgSetCookie whether set `Set-Cookie` header in HTTP response +func CfgSetCookie(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSetCookie = enable + } +} + +//CfgGcLifeTime set session gc lift time +func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Gclifetime = lifeTime + } +} + +//CfgMaxLifeTime set session lift time +func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Maxlifetime = lifeTime + } +} + +//CfgGcLifeTime set session lift time +func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieLifeTime = lifeTime + } +} + +//CfgProviderConfig configure session provider +func CfgProviderConfig(providerConfig string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.ProviderConfig = providerConfig + } +} + +//CfgDomain set cookie domain +func CfgDomain(domain string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Domain = domain + } +} + +//CfgSessionIdInHTTPHeader enable session id in http header +func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInHTTPHeader = enable + } +} + +//CfgSetSessionNameInHTTPHeader set key of session id in http header +func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionNameInHTTPHeader = name + } +} + +//EnableSidInURLQuery enable session id in query string +func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInURLQuery = enable + } +} + +//DisableHTTPOnly set HTTPOnly for http.Cookie +func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.DisableHTTPOnly = !HTTPOnly + } +} + +//CfgSecure set Secure for http.Cookie +func CfgSecure(Enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Secure = Enable + } +} + +//CfgSameSite set http.SameSite +func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieSameSite = sameSite + } +} diff --git a/server/web/session/session_config_test.go b/server/web/session/session_config_test.go new file mode 100644 index 00000000..0ea7d22b --- /dev/null +++ b/server/web/session/session_config_test.go @@ -0,0 +1,222 @@ +package session + +import ( + "net/http" + "testing" +) + +func TestCfgCookieLifeTime(t *testing.T) { + value := 8754 + c := NewManagerConfig( + CfgCookieLifeTime(value), + ) + + if c.CookieLifeTime != value { + t.Error() + } +} + +func TestCfgDomain(t *testing.T) { + value := `http://domain.com` + c := NewManagerConfig( + CfgDomain(value), + ) + + if c.Domain != value { + t.Error() + } +} + +func TestCfgSameSite(t *testing.T) { + value := http.SameSiteLaxMode + c := NewManagerConfig( + CfgSameSite(value), + ) + + if c.CookieSameSite != value { + t.Error() + } +} + +func TestCfgSecure(t *testing.T) { + c := NewManagerConfig( + CfgSecure(true), + ) + + if c.Secure != true { + t.Error() + } +} + +func TestCfgSecure1(t *testing.T) { + c := NewManagerConfig( + CfgSecure(false), + ) + + if c.Secure != false { + t.Error() + } +} + +func TestCfgSessionIdPrefix(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSessionIdPrefix(value), + ) + + if c.SessionIDPrefix != value { + t.Error() + } +} + +func TestCfgSetSessionNameInHTTPHeader(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSetSessionNameInHTTPHeader(value), + ) + + if c.SessionNameInHTTPHeader != value { + t.Error() + } +} + +func TestCfgCookieName(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgCookieName(value), + ) + + if c.CookieName != value { + t.Error() + } +} + +func TestCfgEnableSidInURLQuery(t *testing.T) { + c := NewManagerConfig( + CfgEnableSidInURLQuery(true), + ) + + if c.EnableSidInURLQuery != true { + t.Error() + } +} + +func TestCfgGcLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgGcLifeTime(value), + ) + + if c.Gclifetime != value { + t.Error() + } +} + +func TestCfgHTTPOnly(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(true), + ) + + if c.DisableHTTPOnly != false { + t.Error() + } +} + +func TestCfgHTTPOnly2(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(false), + ) + + if c.DisableHTTPOnly != true { + t.Error() + } +} + +func TestCfgMaxLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgMaxLifeTime(value), + ) + + if c.Maxlifetime != value { + t.Error() + } +} + +func TestCfgProviderConfig(t *testing.T) { + value := `asodiuasldkj12i39809as` + c := NewManagerConfig( + CfgProviderConfig(value), + ) + + if c.ProviderConfig != value { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(true), + ) + + if c.EnableSidInHTTPHeader != true { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader1(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(false), + ) + + if c.EnableSidInHTTPHeader != false { + t.Error() + } +} + +func TestCfgSessionIdLength(t *testing.T) { + value := int64(100) + c := NewManagerConfig( + CfgSessionIdLength(value), + ) + + if c.SessionIDLength != value { + t.Error() + } +} + +func TestCfgSetCookie(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(true), + ) + + if c.EnableSetCookie != true { + t.Error() + } +} + +func TestCfgSetCookie1(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(false), + ) + + if c.EnableSetCookie != false { + t.Error() + } +} + +func TestNewManagerConfig(t *testing.T) { + c := NewManagerConfig() + if c == nil { + t.Error() + } +} + +func TestManagerConfig_Opts(t *testing.T) { + c := NewManagerConfig() + c.Opts(CfgSetCookie(true)) + + if c.EnableSetCookie != true { + t.Error() + } +} diff --git a/server/web/session/session_provider_type.go b/server/web/session/session_provider_type.go new file mode 100644 index 00000000..78dc116d --- /dev/null +++ b/server/web/session/session_provider_type.go @@ -0,0 +1,18 @@ +package session + +type ProviderType string + +const ( + ProviderCookie ProviderType = `cookie` + ProviderFile ProviderType = `file` + ProviderMemory ProviderType = `memory` + ProviderCouchbase ProviderType = `couchbase` + ProviderLedis ProviderType = `ledis` + ProviderMemcache ProviderType = `memcache` + ProviderMysql ProviderType = `mysql` + ProviderPostgresql ProviderType = `postgresql` + ProviderRedis ProviderType = `redis` + ProviderRedisCluster ProviderType = `redis_cluster` + ProviderRedisSentinel ProviderType = `redis_sentinel` + ProviderSsdb ProviderType = `ssdb` +) diff --git a/server/web/template.go b/server/web/template.go index 65935ca8..78ea958a 100644 --- a/server/web/template.go +++ b/server/web/template.go @@ -202,9 +202,7 @@ func BuildTemplate(dir string, files ...string) error { root: dir, files: make(map[string][]string), } - err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error { - return self.visit(path, f, err) - }) + err = Walk(fs, dir, self.visit) if err != nil { fmt.Printf("Walk() returned %v\n", err) return err diff --git a/server/web/template_test.go b/server/web/template_test.go index 1d82c2e2..9ccacfcd 100644 --- a/server/web/template_test.go +++ b/server/web/template_test.go @@ -49,9 +49,8 @@ var block = `{{define "block"}} {{end}}` func TestTemplate(t *testing.T) { - wkdir, err := os.Getwd() - assert.Nil(t, err) - dir := filepath.Join(wkdir, "_beeTmp", "TestTemplate") + tmpDir := os.TempDir() + dir := filepath.Join(tmpDir, "_beeTmp", "TestTemplate") files := []string{ "header.tpl", "index.tpl", @@ -113,9 +112,8 @@ var user = ` ` func TestRelativeTemplate(t *testing.T) { - wkdir, err := os.Getwd() - assert.Nil(t, err) - dir := filepath.Join(wkdir, "_beeTmp") + tmpDir := os.TempDir() + dir := filepath.Join(tmpDir, "_beeTmp") // Just add dir to known viewPaths if err := AddViewPath(dir); err != nil { @@ -226,10 +224,10 @@ var output = ` ` func TestTemplateLayout(t *testing.T) { - wkdir, err := os.Getwd() + tmpDir, err := os.Getwd() assert.Nil(t, err) - dir := filepath.Join(wkdir, "_beeTmp", "TestTemplateLayout") + dir := filepath.Join(tmpDir, "_beeTmp", "TestTemplateLayout") files := []string{ "add.tpl", "layout_blog.tpl", diff --git a/server/web/tree.go b/server/web/tree.go index dc459c49..79f3da7a 100644 --- a/server/web/tree.go +++ b/server/web/tree.go @@ -210,9 +210,9 @@ func (t *Tree) AddRouter(pattern string, runObject interface{}) { func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) { if len(segments) == 0 { if reg != "" { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}) + t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}}, t.leaves...) } else { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards}) + t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards}}, t.leaves...) } } else { seg := segments[0] @@ -342,8 +342,9 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string if runObject == nil && len(t.fixrouters) > 0 { // Filter the .json .xml .html extension for _, str := range allowSuffixExt { - if strings.HasSuffix(seg, str) { + if strings.HasSuffix(seg, str) && strings.HasSuffix(treePattern, seg) { for _, subTree := range t.fixrouters { + // strings.HasSuffix(treePattern, seg) avoid cases: /aaa.html/bbb could access /aaa/bbb if subTree.prefix == seg[:len(seg)-len(str)] { runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) if runObject != nil { diff --git a/server/web/tree_test.go b/server/web/tree_test.go index 3cb39c60..0885ffe8 100644 --- a/server/web/tree_test.go +++ b/server/web/tree_test.go @@ -17,6 +17,7 @@ package web import ( "strings" "testing" + "time" "github.com/beego/beego/v2/server/web/context" ) @@ -49,7 +50,7 @@ func notMatchTestInfo(pattern, url string) testInfo { } func init() { - routers = make([]testInfo, 0) + routers = make([]testInfo, 0, 128) // match example routers = append(routers, matchTestInfo("/topic/?:auth:int", "/topic", nil)) routers = append(routers, matchTestInfo("/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"})) @@ -90,7 +91,17 @@ func init() { routers = append(routers, matchTestInfo("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"})) routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"})) routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"})) - + routers = append(routers, matchTestInfo("/?:year/?:month/?:day", "/2020/11/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"})) + routers = append(routers, matchTestInfo("/?:year/?:month/?:day", "/2020/11", map[string]string{":year": "2020", ":month": "11"})) + routers = append(routers, matchTestInfo("/?:year", "/2020", map[string]string{":year": "2020"})) + routers = append(routers, matchTestInfo("/?:year([0-9]+)/?:month([0-9]+)/mid/?:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"})) + routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/mid/10", map[string]string{":year": "2020", ":day": "10"})) + routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/11/mid", map[string]string{":year": "2020", ":month": "11"})) + routers = append(routers, matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/mid/10/24", map[string]string{":day": "10", ":hour": "24"})) + routers = append(routers, matchTestInfo("/?:year([0-9]+)/:month([0-9]+)/mid/:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"})) + routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10/24", map[string]string{":month": "11", ":day": "10"})) + routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/2020/11/mid/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"})) + routers = append(routers, matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10", map[string]string{":month": "11", ":day": "10"})) // not match example // https://github.com/beego/beego/v2/issues/3865 @@ -98,12 +109,23 @@ func init() { routers = append(routers, notMatchTestInfo("/read_:id:int\\.htm", "/read_222_htm")) routers = append(routers, notMatchTestInfo("/read_:id:int\\.htm", " /read_262shtm")) + // test .html, .json not suffix + const abcHtml = "/suffix/abc.html" + routers = append(routers, notMatchTestInfo(abcHtml, "/suffix.html/abc")) + routers = append(routers, matchTestInfo("/suffix/abc", abcHtml, nil)) + routers = append(routers, matchTestInfo("/suffix/*", abcHtml, nil)) + routers = append(routers, notMatchTestInfo("/suffix/*", "/suffix.html/a")) + const abcSuffix = "/abc/suffix/*" + routers = append(routers, notMatchTestInfo(abcSuffix, "/abc/suffix.html/a")) + routers = append(routers, matchTestInfo(abcSuffix, "/abc/suffix/a", nil)) + routers = append(routers, notMatchTestInfo(abcSuffix, "/abc.j/suffix/a")) + } func TestTreeRouters(t *testing.T) { for _, r := range routers { - shouldMatch := r.shouldMatchOrNot + shouldMatch := r.shouldMatchOrNot tr := NewTree() tr.AddRouter(r.pattern, "astaxie") ctx := context.NewContext() @@ -112,7 +134,7 @@ func TestTreeRouters(t *testing.T) { if obj != nil { t.Fatal("pattern:", r.pattern, ", should not match", r.requestUrl) } else { - return + continue } } if obj == nil || obj.(string) != "astaxie" { @@ -128,6 +150,7 @@ func TestTreeRouters(t *testing.T) { } } } + time.Sleep(time.Second) } func TestStaticPath(t *testing.T) { diff --git a/server/web/unregroute_test.go b/server/web/unregroute_test.go index c675ae7d..226cffb8 100644 --- a/server/web/unregroute_test.go +++ b/server/web/unregroute_test.go @@ -75,9 +75,9 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, "Test original root", @@ -96,7 +96,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Replace the root path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") + handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) // Test replacement root (expect change) testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) @@ -117,9 +117,9 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -146,7 +146,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Replace the "level1" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -167,9 +167,9 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { var method = "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Replace the "/level1/level2" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") + handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 00000000..1a12fb33 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,7 @@ +sonar.organization=beego +sonar.projectKey=beego_beego + +# relative paths to source directories. More details and properties are described +# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/ +sonar.sources=. +sonar.exclusions=**/*_test.go \ No newline at end of file diff --git a/task/governor_command_test.go b/task/governor_command_test.go index 00ed37f2..c3547cdf 100644 --- a/task/governor_command_test.go +++ b/task/governor_command_test.go @@ -55,6 +55,10 @@ func (c *countTask) GetPrev(ctx context.Context) time.Time { return time.Now() } +func (c *countTask) GetTimeout(ctx context.Context) time.Duration { + return 0 +} + func TestRunTaskCommand_Execute(t *testing.T) { task := &countTask{} AddTask("count", task) diff --git a/task/task.go b/task/task.go index 2ea34f24..00e67c4b 100644 --- a/task/task.go +++ b/task/task.go @@ -109,6 +109,7 @@ type Tasker interface { GetNext(ctx context.Context) time.Time SetPrev(context.Context, time.Time) GetPrev(ctx context.Context) time.Time + GetTimeout(ctx context.Context) time.Duration } // task error @@ -127,13 +128,14 @@ type Task struct { DoFunc TaskFunc Prev time.Time Next time.Time - Errlist []*taskerr // like errtime:errinfo - ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution + Timeout time.Duration // timeout duration + Errlist []*taskerr // like errtime:errinfo + ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func -func NewTask(tname string, spec string, f TaskFunc) *Task { +func NewTask(tname string, spec string, f TaskFunc, opts ...Option) *Task { task := &Task{ Taskname: tname, @@ -144,6 +146,11 @@ func NewTask(tname string, spec string, f TaskFunc) *Task { // we only store the pointer, so it won't use too many space Errlist: make([]*taskerr, 100, 100), } + + for _, opt := range opts { + opt.apply(task) + } + task.SetCron(spec) return task } @@ -196,6 +203,31 @@ func (t *Task) GetPrev(context.Context) time.Time { return t.Prev } +// GetTimeout get timeout duration of this task +func (t *Task) GetTimeout(context.Context) time.Duration { + return t.Timeout +} + +// Option interface +type Option interface { + apply(*Task) +} + +// optionFunc return a function to set task element +type optionFunc func(*Task) + +// apply option to task +func (f optionFunc) apply(t *Task) { + f(t) +} + +// TimeoutOption return a option to set timeout duration for task +func TimeoutOption(timeout time.Duration) Option { + return optionFunc(func(t *Task) { + t.Timeout = timeout + }) +} + // six columns mean: // second:0-59 // minute:0-59 @@ -455,14 +487,12 @@ func (m *taskManager) StartTask() { func (m *taskManager) run() { now := time.Now().Local() - m.taskLock.Lock() - for _, t := range m.adminTaskList { - t.SetNext(nil, now) - } - m.taskLock.Unlock() + // first run the tasks, so set all tasks next run time. + m.setTasksStartTime(now) for { // we only use RLock here because NewMapSorter copy the reference, do not change any thing + // here, we sort all task and get first task running time (effective). m.taskLock.RLock() sortList := NewMapSorter(m.adminTaskList) m.taskLock.RUnlock() @@ -475,37 +505,75 @@ func (m *taskManager) run() { } else { effective = sortList.Vals[0].GetNext(context.Background()) } + select { - case now = <-time.After(effective.Sub(now)): - // Run every entry whose next time was this effective time. - for _, e := range sortList.Vals { - if e.GetNext(context.Background()) != effective { - break - } - go e.Run(nil) - e.SetPrev(context.Background(), e.GetNext(context.Background())) - e.SetNext(nil, effective) - } + case now = <-time.After(effective.Sub(now)): // wait for effective time + runNextTasks(sortList, effective) continue - case <-m.changed: + case <-m.changed: // tasks have been changed, set all tasks run again now now = time.Now().Local() - m.taskLock.Lock() - for _, t := range m.adminTaskList { - t.SetNext(nil, now) - } - m.taskLock.Unlock() + m.setTasksStartTime(now) continue - case <-m.stop: - m.taskLock.Lock() - if m.started { - m.started = false - } - m.taskLock.Unlock() + case <-m.stop: // manager is stopped, and mark manager is stopped + m.markManagerStop() return } } } +// setTasksStartTime is set all tasks next running time +func (m *taskManager) setTasksStartTime(now time.Time) { + m.taskLock.Lock() + for _, task := range m.adminTaskList { + task.SetNext(context.Background(), now) + } + m.taskLock.Unlock() +} + +// markManagerStop it sets manager to be stopped +func (m *taskManager) markManagerStop() { + m.taskLock.Lock() + if m.started { + m.started = false + } + m.taskLock.Unlock() +} + +// runNextTasks it runs next task which next run time is equal to effective +func runNextTasks(sortList *MapSorter, effective time.Time) { + // Run every entry whose next time was this effective time. + var i = 0 + for _, e := range sortList.Vals { + i++ + if e.GetNext(context.Background()) != effective { + break + } + + // check if timeout is on, if yes passing the timeout context + ctx := context.Background() + if duration := e.GetTimeout(ctx); duration != 0 { + go func(e Tasker) { + ctx, cancelFunc := context.WithTimeout(ctx, duration) + defer cancelFunc() + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } else { + go func(e Tasker) { + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } + + e.SetPrev(context.Background(), e.GetNext(context.Background())) + e.SetNext(context.Background(), effective) + } +} + // StopTask stop all tasks func (m *taskManager) StopTask() { go func() { diff --git a/task/task_test.go b/task/task_test.go index 5e117cbd..1078aa01 100644 --- a/task/task_test.go +++ b/task/task_test.go @@ -90,6 +90,57 @@ func TestSpec(t *testing.T) { } } +func TestTimeout(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + wg := &sync.WaitGroup{} + wg.Add(2) + once1, once2 := sync.Once{}, sync.Once{} + + tk1 := NewTask("tk1", "0/10 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + once1.Do(func() { + fmt.Println("tk1 done") + wg.Done() + }) + return errors.New("timeout") + default: + } + return nil + }, TimeoutOption(3*time.Second), + ) + + tk2 := NewTask("tk2", "0/11 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + return errors.New("timeout") + default: + once2.Do(func() { + fmt.Println("tk2 done") + wg.Done() + }) + } + return nil + }, + ) + + m.AddTask("tk1", tk1) + m.AddTask("tk2", tk2) + m.StartTask() + defer m.StopTask() + + select { + case <-time.After(19 * time.Second): + t.Error("TestTimeout failed") + case <-wait(wg): + } +} + func TestTask_Run(t *testing.T) { cnt := -1 task := func(ctx context.Context) error { @@ -109,6 +160,23 @@ func TestTask_Run(t *testing.T) { assert.Equal(t, "Hello, world! 101", l[1].errinfo) } +func TestCrudTask(t *testing.T) { + m := newTaskManager() + m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.DeleteTask("my-task1") + assert.Equal(t, 1, len(m.adminTaskList)) + + m.ClearTask() + assert.Equal(t, 0, len(m.adminTaskList)) +} + func wait(wg *sync.WaitGroup) chan bool { ch := make(chan bool) go func() { diff --git a/test/bindata.go b/test/bindata.go index 6dbc08ab..120d327c 100644 --- a/test/bindata.go +++ b/test/bindata.go @@ -287,9 +287,7 @@ func _filePath(dir, name string) string { } func assetFS() *assetfs.AssetFS { - assetInfo := func(path string) (os.FileInfo, error) { - return os.Stat(path) - } + assetInfo := os.Stat for k := range _bintree.Children { return &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: assetInfo, Prefix: k} }