diff --git a/.github/ISSUE_TEMPLATE b/.github/ISSUE_TEMPLATE index 7a8d1691..b99b58a2 100644 --- a/.github/ISSUE_TEMPLATE +++ b/.github/ISSUE_TEMPLATE @@ -1,4 +1,4 @@ -**English Only**. Please use English because other could join the discussion if they got similar issue! +**English Only**. Please use English because others could join the discussion if they got similar issue! Please answer these questions before submitting your issue. Thanks! @@ -12,6 +12,7 @@ Please answer these questions before submitting your issue. Thanks! If possible, provide a recipe for reproducing the error. A complete runnable program is good. +If this is ORM issue, please provide the DB schemas. 4. What did you expect to see? diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..d85c6373 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,17 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "daily" + open-pull-requests-limit: 10 + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/linters/.golangci.yml b/.github/linters/.golangci.yml new file mode 100644 index 00000000..efed1ac4 --- /dev/null +++ b/.github/linters/.golangci.yml @@ -0,0 +1,43 @@ +run: + timeout: 5m + skip-files: + - generated.* + +issues: + new: true + +linters: + enable: + - asciicheck + - bodyclose + - deadcode + - depguard + - gci + - gocritic + - gofmt + - gofumpt + - goimports + - goprintffuncname + - gosimple + - govet + - ineffassign + - misspell + - nilerr + - rowserrcheck + - staticcheck + - structcheck + - stylecheck + - typecheck + - unconvert + - unused + - unparam + - varcheck + - whitespace + disable: + - errcheck + +linters-settings: + gci: + local-prefixes: github.com/beego/beego + goimports: + local-prefixes: github.com/beego/beego diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 00000000..50e91510 --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,34 @@ +# This action requires that any PR targeting the master branch should touch at +# least one CHANGELOG file. If a CHANGELOG entry is not required, add the "Skip +# Changelog" label to disable this action. + +name: changelog + +on: + pull_request: + types: [opened, synchronize, reopened, labeled, unlabeled] + branches: + - develop + +jobs: + changelog: + runs-on: ubuntu-latest + if: "!contains(github.event.pull_request.labels.*.name, 'Skip Changelog')" + + steps: + - uses: actions/checkout@v2 + + - name: Check for CHANGELOG changes + run: | + # Only the latest commit of the feature branch is available + # automatically. To diff with the base branch, we need to + # fetch that too (and we only need its latest commit). + git fetch origin ${{ github.base_ref }} --depth=1 + if [[ $(git diff --name-only FETCH_HEAD | grep CHANGELOG) ]] + then + echo "A CHANGELOG was modified. Looks good!" + else + echo "No CHANGELOG was modified." + echo "Please add a CHANGELOG entry, or add the \"Skip Changelog\" label if not required." + false + fi \ No newline at end of file diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml new file mode 100644 index 00000000..39901ce8 --- /dev/null +++ b/.github/workflows/golangci-lint.yml @@ -0,0 +1,26 @@ +name: golangci-lint +on: + push: + branches: + - master + paths: + - "**/*.go" + - ".github/workflows/golangci-lint.yml" + pull_request: + types: [opened, synchronize, reopened] + paths: + - "**/*.go" + - ".github/workflows/golangci-lint.yml" +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout codebase + uses: actions/checkout@v2 + + - name: golangci-lint + uses: golangci/golangci-lint-action@v2 + with: + version: latest + args: --config=.github/linters/.golangci.yml + only-new-issues: true diff --git a/.github/workflows/need-feedback.yml b/.github/workflows/need-feedback.yml new file mode 100644 index 00000000..0ee0dbd4 --- /dev/null +++ b/.github/workflows/need-feedback.yml @@ -0,0 +1,19 @@ +name: need-feeback-issues + +on: + schedule: + - cron: "0 * * * *" # pick a cron here, this is every 1h + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: luanpotter/changes-requested@master + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + # these are optional, if you want to configure: + days-until-close: 5 + trigger-label: status/need-feedback + closing-comment: This issue was closed by the need-feedback bot due to without feebacks. + dry-run: false \ No newline at end of file diff --git a/.github/workflows/need-translation.yml b/.github/workflows/need-translation.yml new file mode 100644 index 00000000..31d54c22 --- /dev/null +++ b/.github/workflows/need-translation.yml @@ -0,0 +1,19 @@ +name: need-translation-issues + +on: + schedule: + - cron: "0 * * * *" # pick a cron here, this is every 1h + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: luanpotter/changes-requested@master + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + # these are optional, if you want to configure: + days-until-close: 5 + trigger-label: status/need-translation + closing-comment: This issue was closed by the need-translation bot. Please transalate your issue to English so it could help others. + dry-run: false \ No newline at end of file diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 412274a3..66c8ce4a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/stale@v1 + - uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: 'This issue is inactive for a long time.' diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..7ebaddd1 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,124 @@ +name: Test +on: + push: + branches: + - master + - develop + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".github/workflows/test.yml" + pull_request: + types: [opened, synchronize, reopened] + branches: + - master + - develop + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".github/workflows/test.yml" + +jobs: + test: + strategy: + fail-fast: false + matrix: + go-version: [1.14, 1.15, 1.16] + runs-on: ubuntu-latest + services: + redis: + image: redis:latest + ports: + - 6379:6379 + memcached: + image: memcached:latest + ports: + - 11211:11211 + ssdb: + image: wendal/ssdb:latest + ports: + - 8888:8888 + postgres: + image: postgres:latest + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: orm_test + ports: + - 5432/tcp + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + + steps: + - name: Set up Go + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go-version }} + + - name: Checkout codebase + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Run etcd + env: + ETCD_VERSION: v3.4.16 + run: | + rm -rf /tmp/etcd-data.tmp + mkdir -p /tmp/etcd-data.tmp + docker rmi gcr.io/etcd-development/etcd:${ETCD_VERSION} || true && \ + docker run -d \ + -p 2379:2379 \ + -p 2380:2380 \ + --mount type=bind,source=/tmp/etcd-data.tmp,destination=/etcd-data \ + --name etcd-gcr-${ETCD_VERSION} \ + gcr.io/etcd-development/etcd:${ETCD_VERSION} \ + /usr/local/bin/etcd \ + --name s1 \ + --data-dir /etcd-data \ + --listen-client-urls http://0.0.0.0:2379 \ + --advertise-client-urls http://0.0.0.0:2379 \ + --listen-peer-urls http://0.0.0.0:2380 \ + --initial-advertise-peer-urls http://0.0.0.0:2380 \ + --initial-cluster s1=http://0.0.0.0:2380 \ + --initial-cluster-token tkn \ + --initial-cluster-state new + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.float 1.23" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.bool true" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.int 11" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.string hello" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.serialize.name test" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put sub.sub.key1 sub.sub.key" + + - name: Run ORM tests on sqlite3 + env: + GOPATH: /home/runner/go + ORM_DRIVER: sqlite3 + ORM_SOURCE: /tmp/sqlite3/orm_test.db + run: | + mkdir -p /tmp/sqlite3 && touch /tmp/sqlite3/orm_test.db + go test -coverprofile=coverage_sqlite3.txt -covermode=atomic $(go list ./... | grep client/orm) + + - name: Run ORM tests on postgres + env: + GOPATH: /home/runner/go + ORM_DRIVER: postgres + ORM_SOURCE: host=localhost port=${{ job.services.postgres.ports[5432] }} user=postgres password=postgres dbname=orm_test sslmode=disable + run: | + go test -coverprofile=coverage_postgres.txt -covermode=atomic $(go list ./... | grep client/orm) + + - name: Run tests on mysql + env: + GOPATH: /home/runner/go + ORM_DRIVER: mysql + ORM_SOURCE: root:root@/orm_test?charset=utf8 + run: | + sudo systemctl start mysql + mysql -u root -proot -e 'create database orm_test;' + go test -coverprofile=coverage.txt -covermode=atomic ./... + + - name: Upload codecov + env: + CODECOV_TOKEN: 4f4bc484-32a8-43b7-9f48-20966bd48ceb + run: bash <(curl -s https://codecov.io/bash) diff --git a/.gitignore b/.gitignore index 304c4b73..9a2af3cf 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,8 @@ _beeTmp2/ pkg/_beeTmp/ pkg/_beeTmp2/ test/tmp/ +core/config/env/pkg/ + +my save path/ + +profile.out diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 8e495e66..00000000 --- a/.travis.yml +++ /dev/null @@ -1,105 +0,0 @@ -language: go - -go: - - "1.14.x" -services: - - redis-server - - mysql - - postgresql - - memcached - - docker -env: - global: - - GO_REPO_FULLNAME="github.com/beego/beego/v2" - matrix: - - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db - - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - - ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" -before_install: - # link the local repo with ${GOPATH}/src// - - GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*} - # relies on GOPATH to contain only one directory... - - mkdir -p ${GOPATH}/src/${GO_REPO_NAMESPACE} - - ln -sv ${TRAVIS_BUILD_DIR} ${GOPATH}/src/${GO_REPO_FULLNAME} - - cd ${GOPATH}/src/${GO_REPO_FULLNAME} - # get and build ssdb - - git clone git://github.com/ideawu/ssdb.git - - cd ssdb - - make - - cd .. - # - prepare etcd - # - prepare for etcd unit tests - - rm -rf /tmp/etcd-data.tmp - - mkdir -p /tmp/etcd-data.tmp - - docker rmi gcr.io/etcd-development/etcd:v3.3.25 || true && - docker run -d - -p 2379:2379 - -p 2380:2380 - --mount type=bind,source=/tmp/etcd-data.tmp,destination=/etcd-data - --name etcd-gcr-v3.3.25 - gcr.io/etcd-development/etcd:v3.3.25 - /usr/local/bin/etcd - --name s1 - --data-dir /etcd-data - --listen-client-urls http://0.0.0.0:2379 - --advertise-client-urls http://0.0.0.0:2379 - --listen-peer-urls http://0.0.0.0:2380 - --initial-advertise-peer-urls http://0.0.0.0:2380 - --initial-cluster s1=http://0.0.0.0:2380 - --initial-cluster-token tkn - --initial-cluster-state new - - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.float 1.23" - - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.bool true" - - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.int 11" - - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.string hello" - - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.serialize.name test" - - docker exec etcd-gcr-v3.3.25 /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put sub.sub.key1 sub.sub.key" -install: - - go get github.com/lib/pq - - go get github.com/go-sql-driver/mysql - - go get github.com/mattn/go-sqlite3 - - go get github.com/bradfitz/gomemcache/memcache - - go get github.com/gomodule/redigo/redis - - go get github.com/beego/x2j - - go get github.com/couchbase/go-couchbase - - go get github.com/beego/goyaml2 - - go get gopkg.in/yaml.v2 - - go get github.com/belogik/goes - - go get github.com/ledisdb/ledisdb - - go get github.com/ssdb/gossdb/ssdb - - go get github.com/cloudflare/golz4 - - go get github.com/gogo/protobuf/proto - - go get github.com/Knetic/govaluate - - go get github.com/casbin/casbin - - go get github.com/elazarl/go-bindata-assetfs - - go get github.com/OwnLocal/goes - - go get github.com/shiena/ansicolor - - go get -u honnef.co/go/tools/cmd/staticcheck - - go get -u github.com/mdempsky/unconvert - - go get -u github.com/gordonklaus/ineffassign - - go get -u golang.org/x/lint/golint - - go get -u github.com/go-redis/redis -before_script: - - # - - - psql --version - # - prepare for orm unit tests - - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" - - sh -c "go get github.com/golang/lint/golint; golint ./...;" - - sh -c "go list ./... | grep -v vendor | xargs go vet -v" - - mkdir -p res/var - - ./ssdb/ssdb-server ./ssdb/ssdb.conf -d -after_script: - - killall -w ssdb-server - - rm -rf ./res/var/* -script: - - go test ./... - - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./ - - unconvert $(go list ./... | grep -v /vendor/) - - ineffassign . - - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s - - golint ./... -addons: - postgresql: "9.6" diff --git a/CHANGELOG.md b/CHANGELOG.md index 989bcede..a5201e3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,28 @@ # developing + +- Add a custom option for whether to escape HTML special characters when processing http request parameters. [4701](https://github.com/beego/beego/pull/4701) +- Always set the response status in the CustomAbort function. [4686](https://github.com/beego/beego/pull/4686) +- Add template functions eq,lt to support uint and int compare. [4607](https://github.com/beego/beego/pull/4607) +- Migrate tests to GitHub Actions. [4663](https://github.com/beego/beego/issues/4663) +- Add http client and option func. [4455](https://github.com/beego/beego/issues/4455) +- Add: Convenient way to generate mock object [4620](https://github.com/beego/beego/issues/4620) +- Infra: use dependabot to update dependencies. [4623](https://github.com/beego/beego/pull/4623) +- Lint: use golangci-lint. [4619](https://github.com/beego/beego/pull/4619) +- Chore: format code. [4615](https://github.com/beego/beego/pull/4615) +- Test on Go v1.15.x & v1.16.x. [4614](https://github.com/beego/beego/pull/4614) +- Env: non-empty GOBIN & GOPATH. [4613](https://github.com/beego/beego/pull/4613) +- Chore: update dependencies. [4611](https://github.com/beego/beego/pull/4611) +- Update orm_test.go/TestInsertOrUpdate with table-driven. [4609](https://github.com/beego/beego/pull/4609) +- Add: Resp() method for web.Controller. [4588](https://github.com/beego/beego/pull/4588) +- Web mock and test support. [4565](https://github.com/beego/beego/pull/4565) [4574](https://github.com/beego/beego/pull/4574) +- Error codes definition of cache module. [4493](https://github.com/beego/beego/pull/4493) +- Remove generateCommentRoute http hook. Using `bee generate routers` commands instead.[4486](https://github.com/beego/beego/pull/4486) [bee PR 762](https://github.com/beego/bee/pull/762) +- Fix: /abc.html/aaa match /abc/aaa. [4459](https://github.com/beego/beego/pull/4459) +- ORM mock. [4407](https://github.com/beego/beego/pull/4407) +- Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433) +- Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427) +- Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398) +- Support `RollbackUnlessCommit` API. [4542](https://github.com/beego/beego/pull/4542) - Fix 4503 and 4504: Add `when` to `Write([]byte)` method and add `prefix` to `writeMsg`. [4507](https://github.com/beego/beego/pull/4507) - Fix 4480: log format incorrect. [4482](https://github.com/beego/beego/pull/4482) - Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) @@ -7,6 +31,65 @@ - Fix 4727: CSS when request URI is invalid. [4729](https://github.com/beego/beego/pull/4729) - Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385) - Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) +- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294) +- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) +- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) +- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424) +- Finish timeout option for tasks #4441 [4441](https://github.com/beego/beego/pull/4441) +- Error Module brief design & using httplib module to validate this design. [4453](https://github.com/beego/beego/pull/4453) - Fix 4444: panic when 404 not found. [4446](https://github.com/beego/beego/pull/4446) - Fix 4435: fix panic when controller dir not found. [4452](https://github.com/beego/beego/pull/4452) -- Hotfix:reflect.ValueOf(nil) in getFlatParams [4716](https://github.com/beego/beego/issues/4716) \ No newline at end of file +- Hotfix:reflect.ValueOf(nil) in getFlatParams [4716](https://github.com/beego/beego/issues/4716) +- Fix 4456: Fix router method expression [4456](https://github.com/beego/beego/pull/4456) +- Remove some `go get` lines in `.travis.yml` file [4469](https://github.com/beego/beego/pull/4469) +- Fix 4451: support QueryExecutor interface. [4461](https://github.com/beego/beego/pull/4461) +- Add some testing scripts [4461](https://github.com/beego/beego/pull/4461) +- Refactor httplib: Move debug code to a filter [4440](https://github.com/beego/beego/issues/4440) +- fix: code quality issues [4513](https://github.com/beego/beego/pull/4513) +- Optimize maligned structs to reduce memory foot-print [4525](https://github.com/beego/beego/pull/4525) +- Feat: add token bucket ratelimit filter [4508](https://github.com/beego/beego/pull/4508) +- Improve: Avoid ignoring mistakes that need attention [4548](https://github.com/beego/beego/pull/4548) +- Integration: DeepSource [4560](https://github.com/beego/beego/pull/4560) +- Integration: Remove unnecessary function call [4577](https://github.com/beego/beego/pull/4577) +- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) +- Proposal: Add Bind() method for `web.Controller` [4491](https://github.com/beego/beego/issues/4579) +- Optimize AddAutoPrefix: only register one router in case-insensitive mode. [4582](https://github.com/beego/beego/pull/4582) +- Init exceptMethod by using reflection. [4583](https://github.com/beego/beego/pull/4583) +- Deprecated BeeMap and replace all usage with `sync.map` [4616](https://github.com/beego/beego/pull/4616) +- TaskManager support graceful shutdown [4635](https://github.com/beego/beego/pull/4635) +- Add comments to `web.Config`, rename `RouterXXX` to `CtrlXXX`, define `HandleFunc` [4714](https://github.com/beego/beego/pull/4714) +- Refactor: Move `BindXXX` and `XXXResp` methods to `context.Context`. [4718](https://github.com/beego/beego/pull/4718) +- Fix 4728: Print wrong file name. [4737](https://github.com/beego/beego/pull/4737) +- fix bug:reflect.ValueOf(nil) in getFlatParams [4715](https://github.com/beego/beego/pull/4715) +- Fix 4736: set a fixed value "/" to the "Path" of "_xsrf" cookie. [4736](https://github.com/beego/beego/issues/4735) [4739](https://github.com/beego/beego/issues/4739) +- Fix 4734: do not reset id in Delete function. [4738](https://github.com/beego/beego/pull/4738) [4742](https://github.com/beego/beego/pull/4742) +- Fix 4699: Remove Remove goyaml2 dependency. [4755](https://github.com/beego/beego/pull/4755) +- Fix 4698: Prompt error when config format is incorrect. [4757](https://github.com/beego/beego/pull/4757) +- Fix 4674: Tx Orm missing debug log [4756](https://github.com/beego/beego/pull/4756) +- Fix 4759: fix numeric notation of permissions [4759](https://github.com/beego/beego/pull/4759) +## Fix Sonar + +- [4677](https://github.com/beego/beego/pull/4677) +- [4624](https://github.com/beego/beego/pull/4624) +- [4608](https://github.com/beego/beego/pull/4608) +- [4473](https://github.com/beego/beego/pull/4473) +- [4474](https://github.com/beego/beego/pull/4474) +- [4479](https://github.com/beego/beego/pull/4479) +- [4639](https://github.com/beego/beego/pull/4639) +- [4668](https://github.com/beego/beego/pull/4668) + +## Fix lint and format code + +- [4644](https://github.com/beego/beego/pull/4644) +- [4645](https://github.com/beego/beego/pull/4645) +- [4646](https://github.com/beego/beego/pull/4646) +- [4647](https://github.com/beego/beego/pull/4647) +- [4648](https://github.com/beego/beego/pull/4648) +- [4649](https://github.com/beego/beego/pull/4649) +- [4651](https://github.com/beego/beego/pull/4651) +- [4652](https://github.com/beego/beego/pull/4652) +- [4653](https://github.com/beego/beego/pull/4653) +- [4654](https://github.com/beego/beego/pull/4654) +- [4655](https://github.com/beego/beego/pull/4655) +- [4656](https://github.com/beego/beego/pull/4656) +- [4660](https://github.com/beego/beego/pull/4660) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2f9189e2..83a7eaea 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,7 +36,7 @@ We provide docker compose file to start all middlewares. You can run: ```shell script -docker-compose -f scripts/test_docker_compose.yml up -d +docker-compose -f scripts/test_docker_compose.yaml up -d ``` Unit tests read addresses from environment, here is an example: @@ -53,7 +53,7 @@ export SSDB_ADDR="192.168.0.105:8888" ### Pull requests -First of all. beego follow the gitflow. So please send you pull request to **develop-2** branch. We will close the pull +First of all. beego follow the gitflow. So please send you pull request to **develop** branch. We will close the pull request to master branch. We are always happy to receive pull requests, and do our best to review them as fast as possible. Not sure if that typo diff --git a/ERROR_SPECIFICATION.md b/ERROR_SPECIFICATION.md new file mode 100644 index 00000000..68a04bd1 --- /dev/null +++ b/ERROR_SPECIFICATION.md @@ -0,0 +1,5 @@ +# Error Module + +## Module code +- httplib 1 +- cache 2 \ No newline at end of file diff --git a/README.md b/README.md index 85b3df2a..0133510f 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,12 @@ -# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/beego/beego/v2?status.svg)](http://godoc.org/github.com/beego/beego/v2) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) [![Go Report Card](https://goreportcard.com/badge/github.com/beego/beego/v2)](https://goreportcard.com/report/github.com/beego/beego/v2) +# Beego [![Test](https://github.com/beego/beego/actions/workflows/test.yml/badge.svg?branch=develop)](https://github.com/beego/beego/actions/workflows/test.yml) [![Go Report Card](https://goreportcard.com/badge/github.com/beego/beego)](https://goreportcard.com/report/github.com/beego/beego) [![Go Reference](https://pkg.go.dev/badge/github.com/beego/beego/v2.svg)](https://pkg.go.dev/github.com/beego/beego/v2) -Beego is used for rapid development of enterprise application in Go, including RESTful APIs, web apps and backend -services. +Beego is used for rapid development of enterprise application in Go, including RESTful APIs, web apps and backend services. -It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct -embedding. +It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. ![architecture](https://cdn.nlark.com/yuque/0/2020/png/755700/1607857489109-1e267fce-d65f-4c5e-b915-5c475df33c58.png) -Beego is compos of four parts: +Beego is composed of four parts: 1. Base modules: including log module, config module, governor module; 2. Task: is used for running timed tasks or periodic tasks; @@ -19,7 +17,7 @@ Beego is compos of four parts: ## Quick Start -[Officail website](http://beego.me) +[Official website](http://beego.me) [Example](https://github.com/beego/beego-example) @@ -40,7 +38,7 @@ Beego is compos of four parts: #### Download and install - go get github.com/beego/beego/v2@v2.0.0 + go get github.com/beego/beego/v2@latest #### Create file `hello.go` @@ -90,12 +88,11 @@ Congratulations! You've just built your first **beego** app. ## Community * [http://beego.me/community](http://beego.me/community) -* Welcome to join us in Slack: [https://beego.slack.com](https://beego.slack.com), you can get invited - from [here](https://github.com/beego/beedoc/issues/232) +* Welcome to join us in Slack: [https://beego.slack.com invite](https://join.slack.com/t/beego/shared_invite/zt-fqlfjaxs-_CRmiITCSbEqQG9NeBqXKA), * QQ Group Group ID:523992905 * [Contribution Guide](https://github.com/beego/beedoc/blob/master/en-US/intro/contributing.md). ## License beego source code is licensed under the Apache Licence, Version 2.0 -(http://www.apache.org/licenses/LICENSE-2.0.html). +([https://www.apache.org/licenses/LICENSE-2.0.html](https://www.apache.org/licenses/LICENSE-2.0.html)). diff --git a/adapter/app.go b/adapter/app.go index 8502256b..35570616 100644 --- a/adapter/app.go +++ b/adapter/app.go @@ -22,10 +22,8 @@ import ( "github.com/beego/beego/v2/server/web/context" ) -var ( - // BeeApp is an application instance - BeeApp *App -) +// BeeApp is an application instance +var BeeApp *App func init() { // create beego application @@ -139,7 +137,7 @@ func RESTRouter(rootpath string, c ControllerInterface) *App { // AutoRouter adds defined controller handler to BeeApp. // it's same to HttpServer.AutoRouter. -// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// if beego.AddAuto(&MainController{}) and MainController has methods List and Page, // visit the url /main/list to exec List function or /main/page to exec Page function. func AutoRouter(c ControllerInterface) *App { return (*App)(web.AutoRouter(c)) @@ -147,7 +145,7 @@ func AutoRouter(c ControllerInterface) *App { // AutoPrefix adds controller handler to BeeApp with prefix. // it's same to HttpServer.AutoRouterWithPrefix. -// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// if beego.AutoPrefix("/admin",&MainController{}) and MainController has methods List and Page, // visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. func AutoPrefix(prefix string, c ControllerInterface) *App { return (*App)(web.AutoPrefix(prefix, c)) diff --git a/adapter/beego.go b/adapter/beego.go index 48eed9bc..0c914241 100644 --- a/adapter/beego.go +++ b/adapter/beego.go @@ -36,10 +36,6 @@ type M web.M // Hook function to run type hookfunc func() error -var ( - hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc -) - // AddAPPStartHook is used to register the hookfunc // The hookfuncs will run in beego.Run() // such as initiating session , starting middleware , building template, starting admin control and so on. diff --git a/adapter/cache/cache.go b/adapter/cache/cache.go index f615b26f..1291568c 100644 --- a/adapter/cache/cache.go +++ b/adapter/cache/cache.go @@ -16,7 +16,7 @@ // Usage: // // import( -// "github.com/beego/beego/v2/cache" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("memory", `{"interval":60}`) diff --git a/adapter/cache/cache_test.go b/adapter/cache/cache_test.go index f6217e1a..bdb7e41f 100644 --- a/adapter/cache/cache_test.go +++ b/adapter/cache/cache_test.go @@ -19,12 +19,16 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) +const initError = "init err" + func TestCacheIncr(t *testing.T) { bm, err := NewCache("memory", `{"interval":20}`) if err != nil { - t.Error("init err") + t.Error(initError) } // timeoutDuration := 10 * time.Second @@ -45,147 +49,94 @@ func TestCacheIncr(t *testing.T) { func TestCache(t *testing.T) { bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } + assert.Nil(t, err) - time.Sleep(30 * time.Second) + timeoutDuration := 5 * time.Second + err = bm.Put("astaxie", 1, timeoutDuration) + assert.Nil(t, err) - if bm.IsExist("astaxie") { - t.Error("check err") - } + assert.True(t, bm.IsExist("astaxie")) - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Equal(t, 1, bm.Get("astaxie")) - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } + time.Sleep(10 * time.Second) - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } + assert.False(t, bm.IsExist("astaxie")) - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } + err = bm.Put("astaxie", 1, timeoutDuration) + assert.Nil(t, err) - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } + err = bm.Incr("astaxie") + assert.Nil(t, err) - // test GetMulti - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } + assert.Equal(t, 2, bm.Get("astaxie")) - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } + assert.Nil(t, bm.Decr("astaxie")) + + assert.Equal(t, 1, bm.Get("astaxie")) + + assert.Nil(t, bm.Delete("astaxie")) + + assert.False(t, bm.IsExist("astaxie")) + + assert.Nil(t, bm.Put("astaxie", "author", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie")) + + assert.Equal(t, "author", bm.Get("astaxie")) + + assert.Nil(t, bm.Put("astaxie1", "author1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie1")) vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } + + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "author", vv[0]) + + assert.Equal(t, "author1", vv[1]) } func TestFileCache(t *testing.T) { bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } + assert.Nil(t, err) + timeoutDuration := 5 * time.Second - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, bm.Put("astaxie", 1, timeoutDuration)) - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } + assert.True(t, bm.IsExist("astaxie")) - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } + assert.Equal(t, 1, bm.Get("astaxie")) - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } + assert.Nil(t, bm.Incr("astaxie")) - // test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } + assert.Equal(t, 2, bm.Get("astaxie")) - // test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } + assert.Nil(t, bm.Decr("astaxie")) + + assert.Equal(t, 1, bm.Get("astaxie")) + assert.Nil(t, bm.Delete("astaxie")) + + assert.False(t, bm.IsExist("astaxie")) + + assert.Nil(t, bm.Put("astaxie", "author", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie")) + + assert.Equal(t, "author", bm.Get("astaxie")) + + assert.Nil(t, bm.Put("astaxie1", "author1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie1")) vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - os.RemoveAll("cache") + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "author", vv[0]) + assert.Equal(t, "author1", vv[1]) + assert.Nil(t, os.RemoveAll("cache")) } diff --git a/adapter/cache/conv_test.go b/adapter/cache/conv_test.go index b90e224a..af49e92c 100644 --- a/adapter/cache/conv_test.go +++ b/adapter/cache/conv_test.go @@ -19,15 +19,15 @@ import ( ) func TestGetString(t *testing.T) { - var t1 = "test1" + t1 := "test1" if "test1" != GetString(t1) { t.Error("get string from string error") } - var t2 = []byte("test2") + t2 := []byte("test2") if "test2" != GetString(t2) { t.Error("get string from byte array error") } - var t3 = 1 + t3 := 1 if "1" != GetString(t3) { t.Error("get string from int error") } @@ -35,7 +35,7 @@ func TestGetString(t *testing.T) { if "1" != GetString(t4) { t.Error("get string from int64 error") } - var t5 = 1.1 + t5 := 1.1 if "1.1" != GetString(t5) { t.Error("get string from float64 error") } @@ -46,7 +46,7 @@ func TestGetString(t *testing.T) { } func TestGetInt(t *testing.T) { - var t1 = 1 + t1 := 1 if 1 != GetInt(t1) { t.Error("get int from int error") } @@ -58,7 +58,7 @@ func TestGetInt(t *testing.T) { if 64 != GetInt(t3) { t.Error("get int from int64 error") } - var t4 = "128" + t4 := "128" if 128 != GetInt(t4) { t.Error("get int from num string error") } @@ -69,7 +69,7 @@ func TestGetInt(t *testing.T) { func TestGetInt64(t *testing.T) { var i int64 = 1 - var t1 = 1 + t1 := 1 if i != GetInt64(t1) { t.Error("get int64 from int error") } @@ -81,7 +81,7 @@ func TestGetInt64(t *testing.T) { if i != GetInt64(t3) { t.Error("get int64 from int64 error") } - var t4 = "1" + t4 := "1" if i != GetInt64(t4) { t.Error("get int64 from num string error") } @@ -91,22 +91,22 @@ func TestGetInt64(t *testing.T) { } func TestGetFloat64(t *testing.T) { - var f = 1.11 + f := 1.11 var t1 float32 = 1.11 if f != GetFloat64(t1) { t.Error("get float64 from float32 error") } - var t2 = 1.11 + t2 := 1.11 if f != GetFloat64(t2) { t.Error("get float64 from float64 error") } - var t3 = "1.11" + t3 := "1.11" if f != GetFloat64(t3) { t.Error("get float64 from string error") } var f2 float64 = 1 - var t4 = 1 + t4 := 1 if f2 != GetFloat64(t4) { t.Error("get float64 from int error") } @@ -117,11 +117,11 @@ func TestGetFloat64(t *testing.T) { } func TestGetBool(t *testing.T) { - var t1 = true + t1 := true if !GetBool(t1) { t.Error("get bool from bool error") } - var t2 = "true" + t2 := "true" if !GetBool(t2) { t.Error("get bool from string error") } diff --git a/adapter/cache/memcache/memcache.go b/adapter/cache/memcache/memcache.go index 16948f65..37b4b282 100644 --- a/adapter/cache/memcache/memcache.go +++ b/adapter/cache/memcache/memcache.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/cache/memcache" -// "github.com/beego/beego/v2/cache" +// _ "github.com/beego/beego/v2/client/cache/memcache" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) diff --git a/adapter/cache/memcache/memcache_test.go b/adapter/cache/memcache/memcache_test.go index 6382543e..13663907 100644 --- a/adapter/cache/memcache/memcache_test.go +++ b/adapter/cache/memcache/memcache_test.go @@ -21,94 +21,64 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/beego/beego/v2/adapter/cache" ) func TestMemcacheCache(t *testing.T) { - addr := os.Getenv("MEMCACHE_ADDR") if addr == "" { addr = "127.0.0.1:11211" } bm, err := cache.NewCache("memcache", fmt.Sprintf(`{"conn": "%s"}`, addr)) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } + assert.Nil(t, err) + timeoutDuration := 5 * time.Second + + assert.Nil(t, bm.Put("astaxie", "1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie")) time.Sleep(11 * time.Second) - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.False(t, bm.IsExist("astaxie")) - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } + assert.Nil(t, bm.Put("astaxie", "1", timeoutDuration)) + v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, bm.Incr("astaxie")) - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { - t.Error("get err") - } + v, err = strconv.Atoi(string(bm.Get("astaxie").([]byte))) + assert.Nil(t, err) + assert.Equal(t, 2, v) - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } + assert.Nil(t, bm.Decr("astaxie")) - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } + v, err = strconv.Atoi(string(bm.Get("astaxie").([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) - // test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } + assert.Nil(t, bm.Delete("astaxie")) - if v := bm.Get("astaxie").([]byte); string(v) != "author" { - t.Error("get err") - } + assert.False(t, bm.IsExist("astaxie")) - // test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } + assert.Nil(t, bm.Put("astaxie", "author", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie")) + + assert.Equal(t, []byte("author"), bm.Get("astaxie")) + + assert.Nil(t, bm.Put("astaxie1", "author1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie1")) vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { - t.Error("GetMulti ERROR") - } - if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Equal(t, []byte("author"), vv[0]) + assert.Equal(t, []byte("author1"), vv[1]) - // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } + assert.Nil(t, bm.ClearAll()) } diff --git a/adapter/cache/redis/redis.go b/adapter/cache/redis/redis.go index bfbeeb9c..a511f53a 100644 --- a/adapter/cache/redis/redis.go +++ b/adapter/cache/redis/redis.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/cache/redis" -// "github.com/beego/beego/v2/cache" +// _ "github.com/beego/beego/v2/client/cache/redis" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) @@ -34,10 +34,8 @@ import ( redis2 "github.com/beego/beego/v2/client/cache/redis" ) -var ( - // DefaultKey the collection name of redis for cache adapter. - DefaultKey = "beecacheRedis" -) +// DefaultKey the collection name of redis for cache adapter. +var DefaultKey = "beecacheRedis" // NewRedisCache create new redis cache with default collection name. func NewRedisCache() cache.Cache { diff --git a/adapter/cache/redis/redis_test.go b/adapter/cache/redis/redis_test.go index 39a30e8e..a4200fab 100644 --- a/adapter/cache/redis/redis_test.go +++ b/adapter/cache/redis/redis_test.go @@ -21,10 +21,16 @@ import ( "time" "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" "github.com/beego/beego/v2/adapter/cache" ) +const ( + initError = "init err" + setError = "set Error" +) + func TestRedisCache(t *testing.T) { redisAddr := os.Getenv("REDIS_ADDR") if redisAddr == "" { @@ -32,98 +38,79 @@ func TestRedisCache(t *testing.T) { } bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, redisAddr)) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } + assert.Nil(t, err) + timeoutDuration := 5 * time.Second - time.Sleep(11 * time.Second) + assert.Nil(t, bm.Put("astaxie", 1, timeoutDuration)) - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.True(t, bm.IsExist("astaxie")) - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } + time.Sleep(7 * time.Second) - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.False(t, bm.IsExist("astaxie")) - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { - t.Error("get err") - } + assert.Nil(t, bm.Put("astaxie", 1, timeoutDuration)) - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } + v, err := redis.Int(bm.Get("astaxie"), err) + assert.Nil(t, err) + assert.Equal(t, 1, v) - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } + assert.Nil(t, bm.Incr("astaxie")) - // test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } + v, err = redis.Int(bm.Get("astaxie"), err) + assert.Nil(t, err) + assert.Equal(t, 2, v) - if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { - t.Error("get err") - } + assert.Nil(t, bm.Decr("astaxie")) - // test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } + v, err = redis.Int(bm.Get("astaxie"), err) + assert.Nil(t, err) + assert.Equal(t, 1, v) + + assert.Nil(t, bm.Delete("astaxie")) + + assert.False(t, bm.IsExist("astaxie")) + + assert.Nil(t, bm.Put("astaxie", "author", timeoutDuration)) + assert.True(t, bm.IsExist("astaxie")) + + vs, err := redis.String(bm.Get("astaxie"), err) + assert.Nil(t, err) + assert.Equal(t, "author", vs) + + assert.Nil(t, bm.Put("astaxie1", "author1", timeoutDuration)) + + assert.True(t, bm.IsExist("astaxie1")) vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[0], nil); v != "author" { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[1], nil); v != "author1" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + + vs, err = redis.String(vv[0], nil) + + assert.Nil(t, err) + assert.Equal(t, "author", vs) + + vs, err = redis.String(vv[1], nil) + + assert.Nil(t, err) + assert.Equal(t, "author1", vs) + + assert.Nil(t, bm.ClearAll()) // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } } -func TestCache_Scan(t *testing.T) { +func TestCacheScan(t *testing.T) { timeoutDuration := 10 * time.Second // init bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) if err != nil { - t.Error("init err") + t.Error(initError) } // insert all for i := 0; i < 10000; i++ { if err = bm.Put(fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { - t.Error("set Error", err) + t.Error(setError, err) } } @@ -131,5 +118,4 @@ func TestCache_Scan(t *testing.T) { if err = bm.ClearAll(); err != nil { t.Error("clear all err") } - } diff --git a/adapter/cache/ssdb/ssdb_test.go b/adapter/cache/ssdb/ssdb_test.go index 98e805d1..9f00e5c7 100644 --- a/adapter/cache/ssdb/ssdb_test.go +++ b/adapter/cache/ssdb/ssdb_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/beego/beego/v2/adapter/cache" ) @@ -17,95 +19,59 @@ func TestSsdbcacheCache(t *testing.T) { } ssdb, err := cache.NewCache("ssdb", fmt.Sprintf(`{"conn": "%s"}`, ssdbAddr)) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) + + assert.False(t, ssdb.IsExist("ssdb")) // test put and exist - if ssdb.IsExist("ssdb") { - t.Error("check err") - } - timeoutDuration := 10 * time.Second + timeoutDuration := 3 * time.Second // timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } + assert.Nil(t, ssdb.Put("ssdb", "ssdb", timeoutDuration)) + assert.True(t, ssdb.IsExist("ssdb")) - // Get test done - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Nil(t, ssdb.Put("ssdb", "ssdb", timeoutDuration)) - if v := ssdb.Get("ssdb"); v != "ssdb" { - t.Error("get Error") - } + assert.Equal(t, "ssdb", ssdb.Get("ssdb")) // inc/dec test done - if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if err = ssdb.Incr("ssdb"); err != nil { - t.Error("incr Error", err) - } + assert.Nil(t, ssdb.Put("ssdb", "2", timeoutDuration)) - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } + assert.Nil(t, ssdb.Incr("ssdb")) - if err = ssdb.Decr("ssdb"); err != nil { - t.Error("decr error") - } + v, err := strconv.Atoi(ssdb.Get("ssdb").(string)) + assert.Nil(t, err) + assert.Equal(t, 3, v) + + assert.Nil(t, ssdb.Decr("ssdb")) + + assert.Nil(t, ssdb.Put("ssdb", "3", timeoutDuration)) // test del - if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } - if err := ssdb.Delete("ssdb"); err == nil { - if ssdb.IsExist("ssdb") { - t.Error("delete err") - } - } + v, err = strconv.Atoi(ssdb.Get("ssdb").(string)) + assert.Nil(t, err) + assert.Equal(t, 3, v) + + assert.Nil(t, ssdb.Delete("ssdb")) + assert.False(t, ssdb.IsExist("ssdb")) // test string - if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } - if v := ssdb.Get("ssdb").(string); v != "ssdb" { - t.Error("get err") - } + assert.Nil(t, ssdb.Put("ssdb", "ssdb", -10*time.Second)) + + assert.True(t, ssdb.IsExist("ssdb")) + assert.Equal(t, "ssdb", ssdb.Get("ssdb")) // test GetMulti done - if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb1") { - t.Error("check err") - } - vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) - if len(vv) != 2 { - t.Error("getmulti error") - } - if vv[0].(string) != "ssdb" { - t.Error("getmulti error") - } - if vv[1].(string) != "ssdb1" { - t.Error("getmulti error") - } + assert.Nil(t, ssdb.Put("ssdb1", "ssdb1", -10*time.Second)) + assert.True(t, ssdb.IsExist("ssdb1")) + vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "ssdb", vv[0]) + assert.Equal(t, "ssdb1", vv[1]) + + assert.Nil(t, ssdb.ClearAll()) + assert.False(t, ssdb.IsExist("ssdb")) + assert.False(t, ssdb.IsExist("ssdb1")) // test clear all done - if err = ssdb.ClearAll(); err != nil { - t.Error("clear all err") - } - if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { - t.Error("check err") - } } diff --git a/adapter/config/config.go b/adapter/config/config.go index a935e281..bf2496fc 100644 --- a/adapter/config/config.go +++ b/adapter/config/config.go @@ -14,7 +14,7 @@ // Package config is used to parse config. // Usage: -// import "github.com/beego/beego/v2/config" +// import "github.com/beego/beego/v2/core/config" // Examples. // // cnf, err := config.NewConfig("ini", "config.conf") diff --git a/adapter/config/config_test.go b/adapter/config/config_test.go index 15d6ffa6..86d3a2c5 100644 --- a/adapter/config/config_test.go +++ b/adapter/config/config_test.go @@ -20,7 +20,6 @@ import ( ) func TestExpandValueEnv(t *testing.T) { - testCases := []struct { item string want string @@ -51,5 +50,4 @@ func TestExpandValueEnv(t *testing.T) { t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) } } - } diff --git a/adapter/config/fake.go b/adapter/config/fake.go index a2b53167..b87ead34 100644 --- a/adapter/config/fake.go +++ b/adapter/config/fake.go @@ -20,6 +20,6 @@ import ( // NewFakeConfig return a fake Configer func NewFakeConfig() Configer { - new := config.NewFakeConfig() - return &newToOldConfigerAdapter{delegate: new} + config := config.NewFakeConfig() + return &newToOldConfigerAdapter{delegate: config} } diff --git a/adapter/config/ini_test.go b/adapter/config/ini_test.go index 60f1febd..997d3f68 100644 --- a/adapter/config/ini_test.go +++ b/adapter/config/ini_test.go @@ -23,7 +23,6 @@ import ( ) func TestIni(t *testing.T) { - var ( inicontext = ` ;comment one @@ -81,7 +80,8 @@ password = ${GOPATH} } ) - f, err := os.Create("testini.conf") + cfgFile := "testini.conf" + f, err := os.Create(cfgFile) if err != nil { t.Fatal(err) } @@ -91,8 +91,8 @@ password = ${GOPATH} t.Fatal(err) } f.Close() - defer os.Remove("testini.conf") - iniconf, err := NewConfig("ini", "testini.conf") + defer os.Remove(cfgFile) + iniconf, err := NewConfig("ini", cfgFile) if err != nil { t.Fatal(err) } @@ -128,11 +128,9 @@ password = ${GOPATH} if iniconf.String("name") != "astaxie" { t.Fatal("get name error") } - } func TestIniSave(t *testing.T) { - const ( inicontext = ` app = app diff --git a/adapter/config/json_test.go b/adapter/config/json_test.go index 16f42409..2f2c27c3 100644 --- a/adapter/config/json_test.go +++ b/adapter/config/json_test.go @@ -18,10 +18,11 @@ import ( "fmt" "os" "testing" + + "github.com/stretchr/testify/assert" ) func TestJsonStartsWithArray(t *testing.T) { - const jsoncontextwitharray = `[ { "url": "user", @@ -32,7 +33,8 @@ func TestJsonStartsWithArray(t *testing.T) { "serviceAPI": "http://www.test.com/employee" } ]` - f, err := os.Create("testjsonWithArray.conf") + cfgFileName := "testjsonWithArray.conf" + f, err := os.Create(cfgFileName) if err != nil { t.Fatal(err) } @@ -42,8 +44,8 @@ func TestJsonStartsWithArray(t *testing.T) { t.Fatal(err) } f.Close() - defer os.Remove("testjsonWithArray.conf") - jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + defer os.Remove(cfgFileName) + jsonconf, err := NewConfig("json", cfgFileName) if err != nil { t.Fatal(err) } @@ -68,7 +70,6 @@ func TestJsonStartsWithArray(t *testing.T) { } func TestJson(t *testing.T) { - var ( jsoncontext = `{ "appname": "beeapi", @@ -132,7 +133,8 @@ func TestJson(t *testing.T) { } ) - f, err := os.Create("testjson.conf") + cfgFileName := "testjson.conf" + f, err := os.Create(cfgFileName) if err != nil { t.Fatal(err) } @@ -142,8 +144,8 @@ func TestJson(t *testing.T) { t.Fatal(err) } f.Close() - defer os.Remove("testjson.conf") - jsonconf, err := NewConfig("json", "testjson.conf") + defer os.Remove(cfgFileName) + jsonconf, err := NewConfig("json", cfgFileName) if err != nil { t.Fatal(err) } @@ -167,56 +169,39 @@ func TestJson(t *testing.T) { default: value, err = jsonconf.DIY(k) } - if err != nil { - t.Fatalf("get key %q value fatal,%v err %s", k, v, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Fatalf("get key %q value, want %v got %v .", k, v, value) - } - } - if err = jsonconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if jsonconf.String("name") != "astaxie" { - t.Fatal("get name error") + assert.Nil(t, err) + assert.Equal(t, fmt.Sprintf("%v", v), fmt.Sprintf("%v", value)) } - if db, err := jsonconf.DIY("database"); err != nil { - t.Fatal(err) - } else if m, ok := db.(map[string]interface{}); !ok { - t.Log(db) - t.Fatal("db not map[string]interface{}") - } else { - if m["host"].(string) != "host" { - t.Fatal("get host err") - } - } + assert.Nil(t, jsonconf.Set("name", "astaxie")) - if _, err := jsonconf.Int("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an Int") - } + assert.Equal(t, "astaxie", jsonconf.String("name")) - if _, err := jsonconf.Int64("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an Int64") - } + db, err := jsonconf.DIY("database") + assert.Nil(t, err) - if _, err := jsonconf.Float("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting a Float") - } + m, ok := db.(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t, "host", m["host"]) - if _, err := jsonconf.DIY("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting an interface{}") - } + _, err = jsonconf.Int("unknown") + assert.NotNil(t, err) - if val := jsonconf.String("unknown"); val != "" { - t.Error("unknown keys should return an empty string when expecting a String") - } + _, err = jsonconf.Int64("unknown") + assert.NotNil(t, err) - if _, err := jsonconf.Bool("unknown"); err == nil { - t.Error("unknown keys should return an error when expecting a Bool") - } + _, err = jsonconf.Float("unknown") + assert.NotNil(t, err) - if !jsonconf.DefaultBool("unknown", true) { - t.Error("unknown keys with default value wrong") - } + _, err = jsonconf.DIY("unknown") + assert.NotNil(t, err) + + val := jsonconf.String("unknown") + assert.Equal(t, "", val) + + _, err = jsonconf.Bool("unknown") + assert.NotNil(t, err) + + assert.True(t, jsonconf.DefaultBool("unknown", true)) } diff --git a/adapter/config/xml/xml.go b/adapter/config/xml/xml.go index 190cee97..8c623033 100644 --- a/adapter/config/xml/xml.go +++ b/adapter/config/xml/xml.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/config/xml" -// "github.com/beego/beego/v2/config" +// _ "github.com/beego/beego/v2/core/config/xml" +// "github.com/beego/beego/v2/core/config" // ) // // cnf, err := config.NewConfig("xml", "config.xml") diff --git a/adapter/config/xml/xml_test.go b/adapter/config/xml/xml_test.go index 5e43ca0f..48424ef9 100644 --- a/adapter/config/xml/xml_test.go +++ b/adapter/config/xml/xml_test.go @@ -23,7 +23,6 @@ import ( ) func TestXML(t *testing.T) { - var ( // xml parse should incluce in tags xmlcontext = ` @@ -58,7 +57,8 @@ func TestXML(t *testing.T) { } ) - f, err := os.Create("testxml.conf") + cfgFileName := "testxml.conf" + f, err := os.Create(cfgFileName) if err != nil { t.Fatal(err) } @@ -68,9 +68,9 @@ func TestXML(t *testing.T) { t.Fatal(err) } f.Close() - defer os.Remove("testxml.conf") + defer os.Remove(cfgFileName) - xmlconf, err := config.NewConfig("xml", "testxml.conf") + xmlconf, err := config.NewConfig("xml", cfgFileName) if err != nil { t.Fatal(err) } diff --git a/adapter/config/yaml/yaml.go b/adapter/config/yaml/yaml.go index 8d0bb697..538f1178 100644 --- a/adapter/config/yaml/yaml.go +++ b/adapter/config/yaml/yaml.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/config/yaml" -// "github.com/beego/beego/v2/config" +// _ "github.com/beego/beego/v2/core/config/yaml" +// "github.com/beego/beego/v2/core/config" // ) // // cnf, err := config.NewConfig("yaml", "config.yaml") diff --git a/adapter/config/yaml/yaml_test.go b/adapter/config/yaml/yaml_test.go index d567b554..ac0245dd 100644 --- a/adapter/config/yaml/yaml_test.go +++ b/adapter/config/yaml/yaml_test.go @@ -23,7 +23,6 @@ import ( ) func TestYaml(t *testing.T) { - var ( yamlcontext = ` "appname": beeapi @@ -54,7 +53,8 @@ func TestYaml(t *testing.T) { "emptystrings": []string{}, } ) - f, err := os.Create("testyaml.conf") + cfgFileName := "testyaml.conf" + f, err := os.Create(cfgFileName) if err != nil { t.Fatal(err) } @@ -64,8 +64,8 @@ func TestYaml(t *testing.T) { t.Fatal(err) } f.Close() - defer os.Remove("testyaml.conf") - yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + defer os.Remove(cfgFileName) + yamlconf, err := config.NewConfig("yaml", cfgFileName) if err != nil { t.Fatal(err) } @@ -111,5 +111,4 @@ func TestYaml(t *testing.T) { if yamlconf.String("name") != "astaxie" { t.Fatal("get name error") } - } diff --git a/adapter/context/context.go b/adapter/context/context.go index 16e631fc..bb8f7cd9 100644 --- a/adapter/context/context.go +++ b/adapter/context/context.go @@ -15,7 +15,7 @@ // Package context provide the context utils // Usage: // -// import "github.com/beego/beego/v2/context" +// import "github.com/beego/beego/v2/server/web/context" // // ctx := context.Context{Request:req,ResponseWriter:rw} // diff --git a/adapter/context/param/conv.go b/adapter/context/param/conv.go new file mode 100644 index 00000000..ec4c6b7e --- /dev/null +++ b/adapter/context/param/conv.go @@ -0,0 +1,18 @@ +package param + +import ( + "reflect" + + beecontext "github.com/beego/beego/v2/adapter/context" + "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/context/param" +) + +// ConvertParams converts http method params to values that will be passed to the method controller as arguments +func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) { + nps := make([]*param.MethodParam, 0, len(methodParams)) + for _, mp := range methodParams { + nps = append(nps, (*param.MethodParam)(mp)) + } + return param.ConvertParams(nps, methodType, (*context.Context)(ctx)) +} diff --git a/adapter/context/param/conv_test.go b/adapter/context/param/conv_test.go new file mode 100644 index 00000000..b31a8afc --- /dev/null +++ b/adapter/context/param/conv_test.go @@ -0,0 +1,39 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package param + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/adapter/context" +) + +// Demo is used to test, it's empty +func Demo(i int) { +} + +func TestConvertParams(t *testing.T) { + res := ConvertParams(nil, reflect.TypeOf(Demo), context.NewContext()) + assert.Equal(t, 0, len(res)) + ctx := context.NewContext() + ctx.Input.RequestBody = []byte("11") + res = ConvertParams([]*MethodParam{ + New("A", InBody), + }, reflect.TypeOf(Demo), ctx) + assert.Equal(t, int64(11), res[0].Int()) +} diff --git a/adapter/context/param/methodparams.go b/adapter/context/param/methodparams.go new file mode 100644 index 00000000..000539db --- /dev/null +++ b/adapter/context/param/methodparams.go @@ -0,0 +1,29 @@ +package param + +import ( + "github.com/beego/beego/v2/server/web/context/param" +) + +// MethodParam keeps param information to be auto passed to controller methods +type MethodParam param.MethodParam + +// New creates a new MethodParam with name and specific options +func New(name string, opts ...MethodParamOption) *MethodParam { + newOps := make([]param.MethodParamOption, 0, len(opts)) + for _, o := range opts { + newOps = append(newOps, oldMpoToNew(o)) + } + return (*MethodParam)(param.New(name, newOps...)) +} + +// Make creates an array of MethodParmas or an empty array +func Make(list ...*MethodParam) []*MethodParam { + if len(list) > 0 { + return list + } + return nil +} + +func (mp *MethodParam) String() string { + return (*param.MethodParam)(mp).String() +} diff --git a/adapter/context/param/methodparams_test.go b/adapter/context/param/methodparams_test.go new file mode 100644 index 00000000..9d5155bf --- /dev/null +++ b/adapter/context/param/methodparams_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package param + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMethodParamString(t *testing.T) { + method := New("myName", IsRequired, InHeader, Default("abc")) + s := method.String() + assert.Equal(t, `param.New("myName", param.IsRequired, param.InHeader, param.Default("abc"))`, s) +} + +func TestMake(t *testing.T) { + res := Make() + assert.Equal(t, 0, len(res)) + res = Make(New("myName", InBody)) + assert.Equal(t, 1, len(res)) +} diff --git a/adapter/context/param/options.go b/adapter/context/param/options.go new file mode 100644 index 00000000..1d9364c2 --- /dev/null +++ b/adapter/context/param/options.go @@ -0,0 +1,45 @@ +package param + +import ( + "github.com/beego/beego/v2/server/web/context/param" +) + +// MethodParamOption defines a func which apply options on a MethodParam +type MethodParamOption func(*MethodParam) + +// IsRequired indicates that this param is required and can not be omitted from the http request +var IsRequired MethodParamOption = func(p *MethodParam) { + param.IsRequired((*param.MethodParam)(p)) +} + +// InHeader indicates that this param is passed via an http header +var InHeader MethodParamOption = func(p *MethodParam) { + param.InHeader((*param.MethodParam)(p)) +} + +// InPath indicates that this param is part of the URL path +var InPath MethodParamOption = func(p *MethodParam) { + param.InPath((*param.MethodParam)(p)) +} + +// InBody indicates that this param is passed as an http request body +var InBody MethodParamOption = func(p *MethodParam) { + param.InBody((*param.MethodParam)(p)) +} + +// Default provides a default value for the http param +func Default(defaultValue interface{}) MethodParamOption { + return newMpoToOld(param.Default(defaultValue)) +} + +func newMpoToOld(n param.MethodParamOption) MethodParamOption { + return func(methodParam *MethodParam) { + n((*param.MethodParam)(methodParam)) + } +} + +func oldMpoToNew(old MethodParamOption) param.MethodParamOption { + return func(methodParam *param.MethodParam) { + old((*MethodParam)(methodParam)) + } +} diff --git a/adapter/controller.go b/adapter/controller.go index 15d813ab..6d2d5f64 100644 --- a/adapter/controller.go +++ b/adapter/controller.go @@ -19,9 +19,8 @@ import ( "net/url" "github.com/beego/beego/v2/adapter/session" - webContext "github.com/beego/beego/v2/server/web/context" - "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" ) var ( @@ -49,9 +48,11 @@ type ControllerCommentsSlice web.ControllerCommentsSlice func (p ControllerCommentsSlice) Len() int { return (web.ControllerCommentsSlice)(p).Len() } + func (p ControllerCommentsSlice) Less(i, j int) bool { return (web.ControllerCommentsSlice)(p).Less(i, j) } + func (p ControllerCommentsSlice) Swap(i, j int) { (web.ControllerCommentsSlice)(p).Swap(i, j) } diff --git a/adapter/error.go b/adapter/error.go index a4d0bc00..67f2ab82 100644 --- a/adapter/error.go +++ b/adapter/error.go @@ -18,9 +18,8 @@ import ( "net/http" "github.com/beego/beego/v2/adapter/context" - beecontext "github.com/beego/beego/v2/server/web/context" - "github.com/beego/beego/v2/server/web" + beecontext "github.com/beego/beego/v2/server/web/context" ) const ( diff --git a/adapter/grace/grace.go b/adapter/grace/grace.go index 6e582bac..de047eb1 100644 --- a/adapter/grace/grace.go +++ b/adapter/grace/grace.go @@ -22,7 +22,7 @@ // "net/http" // "os" // -// "github.com/beego/beego/v2/grace" +// "github.com/beego/beego/v2/server/web/grace" // ) // // func handler(w http.ResponseWriter, r *http.Request) { diff --git a/adapter/httplib/httplib.go b/adapter/httplib/httplib.go index 0a182cae..005eee0f 100644 --- a/adapter/httplib/httplib.go +++ b/adapter/httplib/httplib.go @@ -15,7 +15,7 @@ // Package httplib is used as http.Client // Usage: // -// import "github.com/beego/beego/v2/httplib" +// import "github.com/beego/beego/v2/client/httplib" // // b := httplib.Post("http://beego.me/") // b.Param("username","astaxie") @@ -115,12 +115,6 @@ func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { return b } -// Debug sets show debug or not when executing request. -func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { - b.delegate.Debug(isdebug) - return b -} - // Retries sets Retries times. // default is 0 means no retried. // -1 means retried forever. @@ -135,17 +129,6 @@ func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { return b } -// DumpBody setting whether need to Dump the Body. -func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { - b.delegate.DumpBody(isdump) - return b -} - -// DumpRequest return the DumpRequest -func (b *BeegoHTTPRequest) DumpRequest() []byte { - return b.delegate.DumpRequest() -} - // SetTimeout sets connect time out and read-write time out for BeegoRequest. func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { b.delegate.SetTimeout(connectTimeout, readWriteTimeout) diff --git a/adapter/httplib/httplib_test.go b/adapter/httplib/httplib_test.go index e7605c87..41018c19 100644 --- a/adapter/httplib/httplib_test.go +++ b/adapter/httplib/httplib_test.go @@ -15,6 +15,7 @@ package httplib import ( + "bytes" "errors" "io/ioutil" "net" @@ -25,8 +26,13 @@ import ( "time" ) +const ( + getURL = "http://httpbin.org/get" + ipURL = "http://httpbin.org/ip" +) + func TestResponse(t *testing.T) { - req := Get("http://httpbin.org/get") + req := Get(getURL) resp, err := req.Response() if err != nil { t.Fatal(err) @@ -59,11 +65,10 @@ func TestDoRequest(t *testing.T) { if elapsedTime < delayedTime { t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) } - } func TestGet(t *testing.T) { - req := Get("http://httpbin.org/get") + req := Get(getURL) b, err := req.Bytes() if err != nil { t.Fatal(err) @@ -205,7 +210,7 @@ func TestWithSetting(t *testing.T) { setting.ReadWriteTimeout = 5 * time.Second SetDefaultSetting(setting) - str, err := Get("http://httpbin.org/get").String() + str, err := Get(getURL).String() if err != nil { t.Fatal(err) } @@ -218,7 +223,7 @@ func TestWithSetting(t *testing.T) { } func TestToJson(t *testing.T) { - req := Get("http://httpbin.org/ip") + req := Get(ipURL) resp, err := req.Response() if err != nil { t.Fatal(err) @@ -244,33 +249,32 @@ func TestToJson(t *testing.T) { t.Fatal("response is not valid ip") } } - } func TestToFile(t *testing.T) { f := "beego_testfile" - req := Get("http://httpbin.org/ip") + req := Get(ipURL) err := req.ToFile(f) if err != nil { t.Fatal(err) } defer os.Remove(f) b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { + if n := bytes.Index(b, []byte("origin")); n == -1 { t.Fatal(err) } } func TestToFileDir(t *testing.T) { f := "./files/beego_testfile" - req := Get("http://httpbin.org/ip") + req := Get(ipURL) err := req.ToFile(f) if err != nil { t.Fatal(err) } defer os.RemoveAll("./files") b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { + if n := bytes.Index(b, []byte("origin")); n == -1 { t.Fatal(err) } } diff --git a/adapter/log.go b/adapter/log.go index 25e82d26..9fc3a551 100644 --- a/adapter/log.go +++ b/adapter/log.go @@ -18,108 +18,106 @@ import ( "strings" "github.com/beego/beego/v2/core/logs" - - webLog "github.com/beego/beego/v2/core/logs" ) // Log levels to control the logging output. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. const ( - LevelEmergency = webLog.LevelEmergency - LevelAlert = webLog.LevelAlert - LevelCritical = webLog.LevelCritical - LevelError = webLog.LevelError - LevelWarning = webLog.LevelWarning - LevelNotice = webLog.LevelNotice - LevelInformational = webLog.LevelInformational - LevelDebug = webLog.LevelDebug + LevelEmergency = logs.LevelEmergency + LevelAlert = logs.LevelAlert + LevelCritical = logs.LevelCritical + LevelError = logs.LevelError + LevelWarning = logs.LevelWarning + LevelNotice = logs.LevelNotice + LevelInformational = logs.LevelInformational + LevelDebug = logs.LevelDebug ) // BeeLogger references the used application logger. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. var BeeLogger = logs.GetBeeLogger() // SetLevel sets the global log level used by the simple logger. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func SetLevel(l int) { logs.SetLevel(l) } // SetLogFuncCall set the CallDepth, default is 3 -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func SetLogFuncCall(b bool) { logs.SetLogFuncCall(b) } // SetLogger sets a new logger. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func SetLogger(adaptername string, config string) error { return logs.SetLogger(adaptername, config) } // Emergency logs a message at emergency level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Emergency(v ...interface{}) { logs.Emergency(generateFmtStr(len(v)), v...) } // Alert logs a message at alert level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Alert(v ...interface{}) { logs.Alert(generateFmtStr(len(v)), v...) } // Critical logs a message at critical level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Critical(v ...interface{}) { logs.Critical(generateFmtStr(len(v)), v...) } // Error logs a message at error level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Error(v ...interface{}) { logs.Error(generateFmtStr(len(v)), v...) } // Warning logs a message at warning level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Warning(v ...interface{}) { logs.Warning(generateFmtStr(len(v)), v...) } // Warn compatibility alias for Warning() -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Warn(v ...interface{}) { logs.Warn(generateFmtStr(len(v)), v...) } // Notice logs a message at notice level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Notice(v ...interface{}) { logs.Notice(generateFmtStr(len(v)), v...) } // Informational logs a message at info level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Informational(v ...interface{}) { logs.Informational(generateFmtStr(len(v)), v...) } // Info compatibility alias for Warning() -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Info(v ...interface{}) { logs.Info(generateFmtStr(len(v)), v...) } // Debug logs a message at debug level. -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Debug(v ...interface{}) { logs.Debug(generateFmtStr(len(v)), v...) } // Trace logs a message at trace level. // compatibility alias for Warning() -// Deprecated: use github.com/beego/beego/v2/logs instead. +// Deprecated: use github.com/beego/beego/v2/core/logs instead. func Trace(v ...interface{}) { logs.Trace(generateFmtStr(len(v)), v...) } diff --git a/adapter/logs/log.go b/adapter/logs/log.go index 9d098d8f..d53cc2ce 100644 --- a/adapter/logs/log.go +++ b/adapter/logs/log.go @@ -15,7 +15,7 @@ // Package logs provide a general log interface // Usage: // -// import "github.com/beego/beego/v2/logs" +// import "github.com/beego/beego/v2/core/logs" // // log := NewLogger(10000) // log.SetLogger("console", "") diff --git a/adapter/logs/logger_test.go b/adapter/logs/logger_test.go index 9f2cc5a5..42708fa5 100644 --- a/adapter/logs/logger_test.go +++ b/adapter/logs/logger_test.go @@ -18,7 +18,7 @@ import ( "testing" ) -func TestBeeLogger_Info(t *testing.T) { +func TestBeeLoggerInfo(t *testing.T) { log := NewLogger(1000) log.SetLogger("file", `{"net":"tcp","addr":":7020"}`) } diff --git a/adapter/metric/prometheus_test.go b/adapter/metric/prometheus_test.go index 53984845..72212dd4 100644 --- a/adapter/metric/prometheus_test.go +++ b/adapter/metric/prometheus_test.go @@ -15,6 +15,7 @@ package metric import ( + "fmt" "net/http" "net/url" "testing" @@ -26,7 +27,9 @@ import ( ) func TestPrometheusMiddleWare(t *testing.T) { - middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) + middleware := PrometheusMiddleWare(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + fmt.Print("you are coming") + })) writer := &context.Response{} request := &http.Request{ URL: &url.URL{ diff --git a/adapter/namespace.go b/adapter/namespace.go index 943aace2..4beda252 100644 --- a/adapter/namespace.go +++ b/adapter/namespace.go @@ -18,9 +18,8 @@ import ( "net/http" adtContext "github.com/beego/beego/v2/adapter/context" - "github.com/beego/beego/v2/server/web/context" - "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" ) type namespaceCond func(*adtContext.Context) bool @@ -240,141 +239,158 @@ func AddNamespace(nl ...*Namespace) { // NSCond is Namespace Condition func NSCond(cond namespaceCond) LinkNamespace { + wc := web.NSCond(func(b *context.Context) bool { + return cond((*adtContext.Context)(b)) + }) return func(namespace *Namespace) { - web.NSCond(func(b *context.Context) bool { - return cond((*adtContext.Context)(b)) - }) + wc((*web.Namespace)(namespace)) } } // NSBefore Namespace BeforeRouter filter func NSBefore(filterList ...FilterFunc) LinkNamespace { + nfs := oldToNewFilter(filterList) + wf := web.NSBefore(nfs...) return func(namespace *Namespace) { - nfs := oldToNewFilter(filterList) - web.NSBefore(nfs...) + wf((*web.Namespace)(namespace)) } } // NSAfter add Namespace FinishRouter filter func NSAfter(filterList ...FilterFunc) LinkNamespace { + nfs := oldToNewFilter(filterList) + wf := web.NSAfter(nfs...) return func(namespace *Namespace) { - nfs := oldToNewFilter(filterList) - web.NSAfter(nfs...) + wf((*web.Namespace)(namespace)) } } // NSInclude Namespace Include ControllerInterface func NSInclude(cList ...ControllerInterface) LinkNamespace { + nfs := oldToNewCtrlIntfs(cList) + wi := web.NSInclude(nfs...) return func(namespace *Namespace) { - nfs := oldToNewCtrlIntfs(cList) - web.NSInclude(nfs...) + wi((*web.Namespace)(namespace)) } } // NSRouter call Namespace Router func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) LinkNamespace { + wn := web.NSRouter(rootpath, c, mappingMethods...) return func(namespace *Namespace) { - web.Router(rootpath, c, mappingMethods...) + wn((*web.Namespace)(namespace)) } } // NSGet call Namespace Get func NSGet(rootpath string, f FilterFunc) LinkNamespace { + ln := web.NSGet(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSGet(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + ln((*web.Namespace)(ns)) } } // NSPost call Namespace Post func NSPost(rootpath string, f FilterFunc) LinkNamespace { + wp := web.NSPost(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.Post(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wp((*web.Namespace)(ns)) } } // NSHead call Namespace Head func NSHead(rootpath string, f FilterFunc) LinkNamespace { + wb := web.NSHead(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSHead(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wb((*web.Namespace)(ns)) } } // NSPut call Namespace Put func NSPut(rootpath string, f FilterFunc) LinkNamespace { + wn := web.NSPut(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSPut(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wn((*web.Namespace)(ns)) } } // NSDelete call Namespace Delete func NSDelete(rootpath string, f FilterFunc) LinkNamespace { + wn := web.NSDelete(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSDelete(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wn((*web.Namespace)(ns)) } } // NSAny call Namespace Any func NSAny(rootpath string, f FilterFunc) LinkNamespace { + wn := web.NSAny(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSAny(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wn((*web.Namespace)(ns)) } } // NSOptions call Namespace Options func NSOptions(rootpath string, f FilterFunc) LinkNamespace { + wo := web.NSOptions(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSOptions(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wo((*web.Namespace)(ns)) } } // NSPatch call Namespace Patch func NSPatch(rootpath string, f FilterFunc) LinkNamespace { + wn := web.NSPatch(rootpath, func(ctx *context.Context) { + f((*adtContext.Context)(ctx)) + }) return func(ns *Namespace) { - web.NSPatch(rootpath, func(ctx *context.Context) { - f((*adtContext.Context)(ctx)) - }) + wn((*web.Namespace)(ns)) } } // NSAutoRouter call Namespace AutoRouter func NSAutoRouter(c ControllerInterface) LinkNamespace { + wn := web.NSAutoRouter(c) return func(ns *Namespace) { - web.NSAutoRouter(c) + wn((*web.Namespace)(ns)) } } // NSAutoPrefix call Namespace AutoPrefix func NSAutoPrefix(prefix string, c ControllerInterface) LinkNamespace { + wn := web.NSAutoPrefix(prefix, c) return func(ns *Namespace) { - web.NSAutoPrefix(prefix, c) + wn((*web.Namespace)(ns)) } } // NSNamespace add sub Namespace func NSNamespace(prefix string, params ...LinkNamespace) LinkNamespace { + nps := oldToNewLinkNs(params) + wn := web.NSNamespace(prefix, nps...) return func(ns *Namespace) { - nps := oldToNewLinkNs(params) - web.NSNamespace(prefix, nps...) + wn((*web.Namespace)(ns)) } } // NSHandler add handler func NSHandler(rootpath string, h http.Handler) LinkNamespace { + wn := web.NSHandler(rootpath, h) return func(ns *Namespace) { - web.NSHandler(rootpath, h) + wn((*web.Namespace)(ns)) } } diff --git a/adapter/orm/db.go b/adapter/orm/db.go index 3cdd33cd..c1d1fe92 100644 --- a/adapter/orm/db.go +++ b/adapter/orm/db.go @@ -18,7 +18,5 @@ import ( "github.com/beego/beego/v2/client/orm" ) -var ( - // ErrMissPK missing pk error - ErrMissPK = orm.ErrMissPK -) +// ErrMissPK missing pk error +var ErrMissPK = orm.ErrMissPK diff --git a/adapter/orm/models_boot.go b/adapter/orm/models_boot.go index e004f35a..678b86e6 100644 --- a/adapter/orm/models_boot.go +++ b/adapter/orm/models_boot.go @@ -25,7 +25,7 @@ func RegisterModel(models ...interface{}) { // RegisterModelWithPrefix register models with a prefix func RegisterModelWithPrefix(prefix string, models ...interface{}) { - orm.RegisterModelWithPrefix(prefix, models) + orm.RegisterModelWithPrefix(prefix, models...) } // RegisterModelWithSuffix register models with a suffix diff --git a/adapter/orm/query_setter_adapter.go b/adapter/orm/models_boot_test.go similarity index 56% rename from adapter/orm/query_setter_adapter.go rename to adapter/orm/models_boot_test.go index 7f506759..5471885b 100644 --- a/adapter/orm/query_setter_adapter.go +++ b/adapter/orm/models_boot_test.go @@ -1,4 +1,4 @@ -// Copyright 2020 +// Copyright 2020 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,20 +15,17 @@ package orm import ( - "github.com/beego/beego/v2/client/orm" + "testing" ) -type baseQuerySetter struct { +type User struct { + Id int } -func (b *baseQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { - panic("you should not invoke this method.") +type Seller struct { + Id int } -func (b *baseQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { - panic("you should not invoke this method.") -} - -func (b *baseQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { - panic("you should not invoke this method.") +func TestRegisterModelWithPrefix(t *testing.T) { + RegisterModelWithPrefix("test", &User{}, &Seller{}) } diff --git a/adapter/orm/orm.go b/adapter/orm/orm.go index c603de2f..f3283fd4 100644 --- a/adapter/orm/orm.go +++ b/adapter/orm/orm.go @@ -21,7 +21,7 @@ // // import ( // "fmt" -// "github.com/beego/beego/v2/orm" +// "github.com/beego/beego/v2/client/orm" // _ "github.com/go-sql-driver/mysql" // import your used driver // ) // diff --git a/adapter/orm/utils.go b/adapter/orm/utils.go index 22bf8d63..cd54f867 100644 --- a/adapter/orm/utils.go +++ b/adapter/orm/utils.go @@ -195,7 +195,7 @@ func snakeStringWithAcronym(s string) string { } data = append(data, d) } - return strings.ToLower(string(data[:])) + return strings.ToLower(string(data)) } // snake string, XxYy to xx_yy , XxYY to xx_y_y @@ -213,7 +213,7 @@ func snakeString(s string) string { } data = append(data, d) } - return strings.ToLower(string(data[:])) + return strings.ToLower(string(data)) } // SetNameStrategy set different name strategy @@ -241,7 +241,7 @@ func camelString(s string) string { } data = append(data, d) } - return string(data[:]) + return string(data) } type argString []string diff --git a/adapter/orm/utils_test.go b/adapter/orm/utils_test.go index 7d94cada..fbf8663e 100644 --- a/adapter/orm/utils_test.go +++ b/adapter/orm/utils_test.go @@ -16,6 +16,8 @@ package orm import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestCamelString(t *testing.T) { @@ -29,9 +31,7 @@ func TestCamelString(t *testing.T) { for _, v := range snake { res := camelString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } + assert.Equal(t, answer[v], res) } } @@ -46,9 +46,7 @@ func TestSnakeString(t *testing.T) { for _, v := range camel { res := snakeString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } + assert.Equal(t, answer[v], res) } } @@ -63,8 +61,6 @@ func TestSnakeStringWithAcronym(t *testing.T) { for _, v := range camel { res := snakeStringWithAcronym(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } + assert.Equal(t, answer[v], res) } } diff --git a/adapter/plugins/apiauth/apiauth.go b/adapter/plugins/apiauth/apiauth.go index fd0c7ff4..d5511427 100644 --- a/adapter/plugins/apiauth/apiauth.go +++ b/adapter/plugins/apiauth/apiauth.go @@ -17,7 +17,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/apiauth" +// "github.com/beego/beego/v2/server/web/filter/apiauth" // ) // // func main(){ diff --git a/adapter/plugins/auth/basic.go b/adapter/plugins/auth/basic.go index 4ef3343f..173252ca 100644 --- a/adapter/plugins/auth/basic.go +++ b/adapter/plugins/auth/basic.go @@ -16,7 +16,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/auth" +// "github.com/beego/beego/v2/server/web/filter/auth" // ) // // func main(){ diff --git a/adapter/plugins/authz/authz.go b/adapter/plugins/authz/authz.go index 114c8c9a..096d7efb 100644 --- a/adapter/plugins/authz/authz.go +++ b/adapter/plugins/authz/authz.go @@ -16,7 +16,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/authz" +// "github.com/beego/beego/v2/server/web/filter/authz" // "github.com/casbin/casbin" // ) // diff --git a/adapter/plugins/authz/authz_test.go b/adapter/plugins/authz/authz_test.go index fa5410ca..4963ceab 100644 --- a/adapter/plugins/authz/authz_test.go +++ b/adapter/plugins/authz/authz_test.go @@ -26,6 +26,11 @@ import ( "github.com/beego/beego/v2/adapter/plugins/auth" ) +const ( + authCfg = "authz_model.conf" + authCsv = "authz_policy.csv" +) + func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { r, _ := http.NewRequest(method, path, nil) r.SetBasicAuth(user, "123") @@ -40,70 +45,79 @@ func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, p func TestBasic(t *testing.T) { handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + _ = handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) + + _ = handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer(authCfg, authCsv))) handler.Any("*", func(ctx *context.Context) { ctx.Output.SetStatus(200) }) - testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200) - testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200) - testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200) - testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403) + const d1r1 = "/dataset1/resource1" + testRequest(t, handler, "alice", d1r1, "GET", 200) + testRequest(t, handler, "alice", d1r1, "POST", 200) + const d1r2 = "/dataset1/resource2" + testRequest(t, handler, "alice", d1r2, "GET", 200) + testRequest(t, handler, "alice", d1r2, "POST", 403) } func TestPathWildcard(t *testing.T) { handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + _ = handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) + _ = handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer(authCfg, authCsv))) handler.Any("*", func(ctx *context.Context) { ctx.Output.SetStatus(200) }) - testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200) - testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200) - testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200) - testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403) - testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403) + const d2r1 = "/dataset2/resource1" + testRequest(t, handler, "bob", d2r1, "GET", 200) + testRequest(t, handler, "bob", d2r1, "POST", 200) + testRequest(t, handler, "bob", d2r1, "DELETE", 200) + const d2r2 = "/dataset2/resource2" + testRequest(t, handler, "bob", d2r2, "GET", 200) + testRequest(t, handler, "bob", d2r2, "POST", 403) + testRequest(t, handler, "bob", d2r2, "DELETE", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200) - testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403) + const item1 = "/dataset2/folder1/item1" + testRequest(t, handler, "bob", item1, "GET", 403) + testRequest(t, handler, "bob", item1, "POST", 200) + testRequest(t, handler, "bob", item1, "DELETE", 403) + const item2 = "/dataset2/folder1/item2" + testRequest(t, handler, "bob", item2, "GET", 403) + testRequest(t, handler, "bob", item2, "POST", 200) + testRequest(t, handler, "bob", item2, "DELETE", 403) } func TestRBAC(t *testing.T) { handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) - e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + _ = handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + e := casbin.NewEnforcer(authCfg, authCsv) + _ = handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) handler.Any("*", func(ctx *context.Context) { ctx.Output.SetStatus(200) }) // cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role. - testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200) - testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200) - testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200) - testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + const dataSet1 = "/dataset1/item" + testRequest(t, handler, "cathy", dataSet1, "GET", 200) + testRequest(t, handler, "cathy", dataSet1, "POST", 200) + testRequest(t, handler, "cathy", dataSet1, "DELETE", 200) + const dataSet2 = "/dataset2/item" + testRequest(t, handler, "cathy", dataSet2, "GET", 403) + testRequest(t, handler, "cathy", dataSet2, "POST", 403) + testRequest(t, handler, "cathy", dataSet2, "DELETE", 403) // delete all roles on user cathy, so cathy cannot access any resources now. e.DeleteRolesForUser("cathy") - testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403) - testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403) + testRequest(t, handler, "cathy", dataSet1, "GET", 403) + testRequest(t, handler, "cathy", dataSet1, "POST", 403) + testRequest(t, handler, "cathy", dataSet1, "DELETE", 403) + testRequest(t, handler, "cathy", dataSet2, "GET", 403) + testRequest(t, handler, "cathy", dataSet2, "POST", 403) + testRequest(t, handler, "cathy", dataSet2, "DELETE", 403) } diff --git a/adapter/plugins/cors/cors.go b/adapter/plugins/cors/cors.go index 6a836585..c02ab877 100644 --- a/adapter/plugins/cors/cors.go +++ b/adapter/plugins/cors/cors.go @@ -16,7 +16,7 @@ // Usage // import ( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/cors" +// "github.com/beego/beego/v2/server/web/filter/cors" // ) // // func main() { @@ -37,10 +37,9 @@ package cors import ( beego "github.com/beego/beego/v2/adapter" + "github.com/beego/beego/v2/adapter/context" beecontext "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/filter/cors" - - "github.com/beego/beego/v2/adapter/context" ) // Options represents Access Control options. diff --git a/adapter/router.go b/adapter/router.go index 900e3eb7..a0add1fe 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -19,9 +19,8 @@ import ( "time" beecontext "github.com/beego/beego/v2/adapter/context" - "github.com/beego/beego/v2/server/web/context" - "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" ) // default filter execution points @@ -87,7 +86,7 @@ func NewControllerRegister() *ControllerRegister { // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - (*web.ControllerRegister)(p).Add(pattern, c, mappingMethods...) + (*web.ControllerRegister)(p).Add(pattern, c, web.WithRouterMethods(c, mappingMethods...)) } // Include only when the Runmode is dev will generate router file in the router/auto.go from the controller @@ -217,7 +216,7 @@ func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ... } // AddAuto router to ControllerRegister. -// example beego.AddAuto(&MainContorlller{}), +// example beego.AddAuto(&MainController{}), // MainController has method List and Page. // visit the url /main/list to execute List function // /main/page to execute Page function. @@ -226,7 +225,7 @@ func (p *ControllerRegister) AddAuto(c ControllerInterface) { } // AddAutoPrefix Add auto router to ControllerRegister with prefix. -// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// example beego.AddAutoPrefix("/admin",&MainController{}), // MainController has method List and Page. // visit the url /admin/main/list to execute List function // /admin/main/page to execute Page function. diff --git a/adapter/session/couchbase/sess_couchbase.go b/adapter/session/couchbase/sess_couchbase.go index 4ce2d69d..2dc4ce18 100644 --- a/adapter/session/couchbase/sess_couchbase.go +++ b/adapter/session/couchbase/sess_couchbase.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/couchbase" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/couchbase" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/memcache/sess_memcache.go b/adapter/session/memcache/sess_memcache.go index e81d06c6..7cd8d733 100644 --- a/adapter/session/memcache/sess_memcache.go +++ b/adapter/session/memcache/sess_memcache.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/memcache" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/memcache" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -37,7 +37,6 @@ import ( "net/http" "github.com/beego/beego/v2/adapter/session" - beemem "github.com/beego/beego/v2/server/web/session/memcache" ) diff --git a/adapter/session/mysql/sess_mysql.go b/adapter/session/mysql/sess_mysql.go index d47e7496..71c2d4a7 100644 --- a/adapter/session/mysql/sess_mysql.go +++ b/adapter/session/mysql/sess_mysql.go @@ -28,8 +28,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/mysql" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/mysql" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -44,11 +44,11 @@ import ( "context" "net/http" - "github.com/beego/beego/v2/adapter/session" - "github.com/beego/beego/v2/server/web/session/mysql" - // import mysql driver _ "github.com/go-sql-driver/mysql" + + "github.com/beego/beego/v2/adapter/session" + "github.com/beego/beego/v2/server/web/session/mysql" ) var ( diff --git a/adapter/session/postgres/sess_postgresql.go b/adapter/session/postgres/sess_postgresql.go index a24794d6..5807d551 100644 --- a/adapter/session/postgres/sess_postgresql.go +++ b/adapter/session/postgres/sess_postgresql.go @@ -38,8 +38,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/postgresql" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/postgresql" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -54,10 +54,10 @@ import ( "context" "net/http" - "github.com/beego/beego/v2/adapter/session" // import postgresql Driver _ "github.com/lib/pq" + "github.com/beego/beego/v2/adapter/session" "github.com/beego/beego/v2/server/web/session/postgres" ) diff --git a/adapter/session/redis/sess_redis.go b/adapter/session/redis/sess_redis.go index a5fcedf6..220a59cd 100644 --- a/adapter/session/redis/sess_redis.go +++ b/adapter/session/redis/sess_redis.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -37,7 +37,6 @@ import ( "net/http" "github.com/beego/beego/v2/adapter/session" - beeRedis "github.com/beego/beego/v2/server/web/session/redis" ) diff --git a/adapter/session/redis_cluster/redis_cluster.go b/adapter/session/redis_cluster/redis_cluster.go index f4c8e4d1..623d72cc 100644 --- a/adapter/session/redis_cluster/redis_cluster.go +++ b/adapter/session/redis_cluster/redis_cluster.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis_cluster" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis_cluster" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/redis_sentinel/sess_redis_sentinel.go b/adapter/session/redis_sentinel/sess_redis_sentinel.go index 4498e55d..633fb5fa 100644 --- a/adapter/session/redis_sentinel/sess_redis_sentinel.go +++ b/adapter/session/redis_sentinel/sess_redis_sentinel.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis_sentinel" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis_sentinel" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -37,7 +37,6 @@ import ( "net/http" "github.com/beego/beego/v2/adapter/session" - sentinel "github.com/beego/beego/v2/server/web/session/redis_sentinel" ) diff --git a/adapter/session/redis_sentinel/sess_redis_sentinel_test.go b/adapter/session/redis_sentinel/sess_redis_sentinel_test.go index 0a6249ee..2d381af6 100644 --- a/adapter/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/adapter/session/redis_sentinel/sess_redis_sentinel_test.go @@ -5,6 +5,8 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" + "github.com/beego/beego/v2/adapter/session" ) @@ -19,72 +21,55 @@ func TestRedisSentinel(t *testing.T) { ProviderConfig: "127.0.0.1:6379,100,,0,master", } globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) + if e != nil { t.Log(e) return } - // todo test if e==nil + go globalSessions.GC() r, _ := http.NewRequest("GET", "/", nil) w := httptest.NewRecorder() sess, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("session start failed:", err) - } + assert.Nil(t, err) defer sess.SessionRelease(w) // SET AND GET err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set username failed:", err) - } + assert.Nil(t, err) username := sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } + assert.Equal(t, "astaxie", username) // DELETE err = sess.Delete("username") - if err != nil { - t.Fatal("delete username failed:", err) - } + assert.Nil(t, err) + username = sess.Get("username") - if username != nil { - t.Fatal("delete username failed") - } + assert.Nil(t, username) // FLUSH err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set failed:", err) - } + assert.Nil(t, err) + err = sess.Set("password", "1qaz2wsx") - if err != nil { - t.Fatal("set failed:", err) - } + assert.Nil(t, err) + username = sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } + assert.Equal(t, "astaxie", username) + password := sess.Get("password") - if password != "1qaz2wsx" { - t.Fatal("get password failed") - } + assert.Equal(t, "1qaz2wsx", password) + err = sess.Flush() - if err != nil { - t.Fatal("flush failed:", err) - } + assert.Nil(t, err) + username = sess.Get("username") - if username != nil { - t.Fatal("flush failed") - } + assert.Nil(t, username) + password = sess.Get("password") - if password != nil { - t.Fatal("flush failed") - } + assert.Nil(t, password) sess.SessionRelease(w) - } diff --git a/adapter/session/sess_cookie_test.go b/adapter/session/sess_cookie_test.go index b6726005..61937f56 100644 --- a/adapter/session/sess_cookie_test.go +++ b/adapter/session/sess_cookie_test.go @@ -22,6 +22,8 @@ import ( "testing" ) +const setCookieKey = "Set-Cookie" + func TestCookie(t *testing.T) { config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` conf := new(ManagerConfig) @@ -46,7 +48,8 @@ func TestCookie(t *testing.T) { t.Fatal("get username error") } sess.SessionRelease(w) - if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { + + if cookiestr := w.Header().Get(setCookieKey); cookiestr == "" { t.Fatal("setcookie error") } else { parts := strings.Split(strings.TrimSpace(cookiestr), ";") @@ -79,7 +82,7 @@ func TestDestorySessionCookie(t *testing.T) { // request again ,will get same sesssion id . r1, _ := http.NewRequest("GET", "/", nil) - r1.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + r1.Header.Set("Cookie", w.Header().Get(setCookieKey)) w = httptest.NewRecorder() newSession, err := globalSessions.SessionStart(w, r1) if err != nil { @@ -92,7 +95,7 @@ func TestDestorySessionCookie(t *testing.T) { // After destroy session , will get a new session id . globalSessions.SessionDestroy(w, r1) r2, _ := http.NewRequest("GET", "/", nil) - r2.Header.Set("Cookie", w.Header().Get("Set-Cookie")) + r2.Header.Set("Cookie", w.Header().Get(setCookieKey)) w = httptest.NewRecorder() newSession, err = globalSessions.SessionStart(w, r2) diff --git a/adapter/session/sess_file_test.go b/adapter/session/sess_file_test.go index 4c90a3ac..a3e3d0b9 100644 --- a/adapter/session/sess_file_test.go +++ b/adapter/session/sess_file_test.go @@ -22,15 +22,15 @@ import ( "time" ) -const sid = "Session_id" -const sidNew = "Session_id_new" -const sessionPath = "./_session_runtime" - -var ( - mutex sync.Mutex +const ( + sid = "Session_id" + sidNew = "Session_id_new" + sessionPath = "./_session_runtime" ) -func TestFileProvider_SessionExist(t *testing.T) { +var mutex sync.Mutex + +func TestFileProviderSessionExist(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -53,7 +53,7 @@ func TestFileProvider_SessionExist(t *testing.T) { } } -func TestFileProvider_SessionExist2(t *testing.T) { +func TestFileProviderSessionExist2(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -75,7 +75,7 @@ func TestFileProvider_SessionExist2(t *testing.T) { } } -func TestFileProvider_SessionRead(t *testing.T) { +func TestFileProviderSessionRead(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -97,7 +97,7 @@ func TestFileProvider_SessionRead(t *testing.T) { } } -func TestFileProvider_SessionRead1(t *testing.T) { +func TestFileProviderSessionRead1(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -117,7 +117,7 @@ func TestFileProvider_SessionRead1(t *testing.T) { } } -func TestFileProvider_SessionAll(t *testing.T) { +func TestFileProviderSessionAll(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -140,7 +140,7 @@ func TestFileProvider_SessionAll(t *testing.T) { } } -func TestFileProvider_SessionRegenerate(t *testing.T) { +func TestFileProviderSessionRegenerate(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -172,7 +172,7 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { } } -func TestFileProvider_SessionDestroy(t *testing.T) { +func TestFileProviderSessionDestroy(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -200,7 +200,7 @@ func TestFileProvider_SessionDestroy(t *testing.T) { } } -func TestFileProvider_SessionGC(t *testing.T) { +func TestFileProviderSessionGC(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -226,7 +226,7 @@ func TestFileProvider_SessionGC(t *testing.T) { } } -func TestFileSessionStore_Set(t *testing.T) { +func TestFileSessionStoreSet(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -245,7 +245,7 @@ func TestFileSessionStore_Set(t *testing.T) { } } -func TestFileSessionStore_Get(t *testing.T) { +func TestFileSessionStoreGet(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -266,7 +266,7 @@ func TestFileSessionStore_Get(t *testing.T) { } } -func TestFileSessionStore_Delete(t *testing.T) { +func TestFileSessionStoreDelete(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -289,7 +289,7 @@ func TestFileSessionStore_Delete(t *testing.T) { } } -func TestFileSessionStore_Flush(t *testing.T) { +func TestFileSessionStoreFlush(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -313,7 +313,7 @@ func TestFileSessionStore_Flush(t *testing.T) { } } -func TestFileSessionStore_SessionID(t *testing.T) { +func TestFileSessionStoreSessionID(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) diff --git a/adapter/session/sess_test.go b/adapter/session/sess_test.go index aba702ca..2ecd2dd9 100644 --- a/adapter/session/sess_test.go +++ b/adapter/session/sess_test.go @@ -18,7 +18,7 @@ import ( "testing" ) -func Test_gob(t *testing.T) { +func TestGob(t *testing.T) { a := make(map[interface{}]interface{}) a["username"] = "astaxie" a[12] = 234 diff --git a/adapter/session/session.go b/adapter/session/session.go index 40e947fd..703adbde 100644 --- a/adapter/session/session.go +++ b/adapter/session/session.go @@ -16,7 +16,7 @@ // // Usage: // import( -// "github.com/beego/beego/v2/session" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/adapter/session/ssdb/sess_ssdb.go b/adapter/session/ssdb/sess_ssdb.go index 73ead908..9d08b2ab 100644 --- a/adapter/session/ssdb/sess_ssdb.go +++ b/adapter/session/ssdb/sess_ssdb.go @@ -5,7 +5,6 @@ import ( "net/http" "github.com/beego/beego/v2/adapter/session" - beeSsdb "github.com/beego/beego/v2/server/web/session/ssdb" ) diff --git a/adapter/templatefunc.go b/adapter/templatefunc.go index 808539e7..c5574d32 100644 --- a/adapter/templatefunc.go +++ b/adapter/templatefunc.go @@ -118,7 +118,6 @@ func AssetsJs(text string) template.HTML { // AssetsCSS returns stylesheet link tag with src string. func AssetsCSS(text string) template.HTML { - text = "" return template.HTML(text) diff --git a/adapter/templatefunc_test.go b/adapter/templatefunc_test.go index f5113606..b3d5e968 100644 --- a/adapter/templatefunc_test.go +++ b/adapter/templatefunc_test.go @@ -19,19 +19,15 @@ import ( "net/url" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestSubstr(t *testing.T) { s := `012345` - if Substr(s, 0, 2) != "01" { - t.Error("should be equal") - } - if Substr(s, 0, 100) != "012345" { - t.Error("should be equal") - } - if Substr(s, 12, 100) != "012345" { - t.Error("should be equal") - } + assert.Equal(t, "01", Substr(s, 0, 2)) + assert.Equal(t, "012345", Substr(s, 0, 100)) + assert.Equal(t, "012345", Substr(s, 12, 100)) } func TestHtml2str(t *testing.T) { @@ -39,73 +35,50 @@ func TestHtml2str(t *testing.T) { \n` - if HTML2str(h) != "123\\n\n\\n" { - t.Error("should be equal") - } + assert.Equal(t, "123\\n\n\\n", HTML2str(h)) } func TestDateFormat(t *testing.T) { ts := "Mon, 01 Jul 2013 13:27:42 CST" tt, _ := time.Parse(time.RFC1123, ts) - if ss := DateFormat(tt, "2006-01-02 15:04:05"); ss != "2013-07-01 13:27:42" { - t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) - } + assert.Equal(t, "2013-07-01 13:27:42", DateFormat(tt, "2006-01-02 15:04:05")) } func TestDate(t *testing.T) { ts := "Mon, 01 Jul 2013 13:27:42 CST" tt, _ := time.Parse(time.RFC1123, ts) - if ss := Date(tt, "Y-m-d H:i:s"); ss != "2013-07-01 13:27:42" { - t.Errorf("2013-07-01 13:27:42 does not equal %v", ss) - } - if ss := Date(tt, "y-n-j h:i:s A"); ss != "13-7-1 01:27:42 PM" { - t.Errorf("13-7-1 01:27:42 PM does not equal %v", ss) - } - if ss := Date(tt, "D, d M Y g:i:s a"); ss != "Mon, 01 Jul 2013 1:27:42 pm" { - t.Errorf("Mon, 01 Jul 2013 1:27:42 pm does not equal %v", ss) - } - if ss := Date(tt, "l, d F Y G:i:s"); ss != "Monday, 01 July 2013 13:27:42" { - t.Errorf("Monday, 01 July 2013 13:27:42 does not equal %v", ss) - } + assert.Equal(t, "2013-07-01 13:27:42", Date(tt, "Y-m-d H:i:s")) + + assert.Equal(t, "13-7-1 01:27:42 PM", Date(tt, "y-n-j h:i:s A")) + assert.Equal(t, "Mon, 01 Jul 2013 1:27:42 pm", Date(tt, "D, d M Y g:i:s a")) + assert.Equal(t, "Monday, 01 July 2013 13:27:42", Date(tt, "l, d F Y G:i:s")) } func TestCompareRelated(t *testing.T) { - if !Compare("abc", "abc") { - t.Error("should be equal") - } - if Compare("abc", "aBc") { - t.Error("should be not equal") - } - if !Compare("1", 1) { - t.Error("should be equal") - } - if CompareNot("abc", "abc") { - t.Error("should be equal") - } - if !CompareNot("abc", "aBc") { - t.Error("should be not equal") - } - if !NotNil("a string") { - t.Error("should not be nil") - } + assert.True(t, Compare("abc", "abc")) + + assert.False(t, Compare("abc", "aBc")) + + assert.True(t, Compare("1", 1)) + + assert.False(t, CompareNot("abc", "abc")) + + assert.True(t, CompareNot("abc", "aBc")) + assert.True(t, NotNil("a string")) } func TestHtmlquote(t *testing.T) { h := `<' ”“&">` s := `<' ”“&">` - if Htmlquote(s) != h { - t.Error("should be equal") - } + assert.Equal(t, h, Htmlquote(s)) } func TestHtmlunquote(t *testing.T) { h := `<' ”“&">` s := `<' ”“&">` - if Htmlunquote(h) != s { - t.Error("should be equal") - } + assert.Equal(t, s, Htmlunquote(h)) } func TestParseForm(t *testing.T) { @@ -148,55 +121,42 @@ func TestParseForm(t *testing.T) { "hobby": []string{"", "Basketball", "Football"}, "memo": []string{"nothing"}, } - if err := ParseForm(form, u); err == nil { - t.Fatal("nothing will be changed") - } - if err := ParseForm(form, &u); err != nil { - t.Fatal(err) - } - if u.ID != 0 { - t.Errorf("ID should equal 0 but got %v", u.ID) - } - if len(u.tag) != 0 { - t.Errorf("tag's length should equal 0 but got %v", len(u.tag)) - } - if u.Name.(string) != "test" { - t.Errorf("Name should equal `test` but got `%v`", u.Name.(string)) - } - if u.Age != 40 { - t.Errorf("Age should equal 40 but got %v", u.Age) - } - if u.Email != "test@gmail.com" { - t.Errorf("Email should equal `test@gmail.com` but got `%v`", u.Email) - } - if u.Intro != "I am an engineer!" { - t.Errorf("Intro should equal `I am an engineer!` but got `%v`", u.Intro) - } - if !u.StrBool { - t.Errorf("strboll should equal `true`, but got `%v`", u.StrBool) - } + + assert.NotNil(t, ParseForm(form, u)) + + assert.Nil(t, ParseForm(form, &u)) + + assert.Equal(t, 0, u.ID) + + assert.Equal(t, 0, len(u.tag)) + + assert.Equal(t, "test", u.Name) + + assert.Equal(t, 40, u.Age) + + assert.Equal(t, "test@gmail.com", u.Email) + + assert.Equal(t, "I am an engineer!", u.Intro) + + assert.True(t, u.StrBool) + y, m, d := u.Date.Date() - if y != 2014 || m.String() != "November" || d != 12 { - t.Errorf("Date should equal `2014-11-12`, but got `%v`", u.Date.String()) - } - if u.Organization != "beego" { - t.Errorf("Organization should equal `beego`, but got `%v`", u.Organization) - } - if u.Title != "CXO" { - t.Errorf("Title should equal `CXO`, but got `%v`", u.Title) - } - if u.Hobby[0] != "" { - t.Errorf("Hobby should equal ``, but got `%v`", u.Hobby[0]) - } - if u.Hobby[1] != "Basketball" { - t.Errorf("Hobby should equal `Basketball`, but got `%v`", u.Hobby[1]) - } - if u.Hobby[2] != "Football" { - t.Errorf("Hobby should equal `Football`, but got `%v`", u.Hobby[2]) - } - if len(u.Memo) != 0 { - t.Errorf("Memo's length should equal 0 but got %v", len(u.Memo)) - } + + assert.Equal(t, 2014, y) + assert.Equal(t, "November", m.String()) + assert.Equal(t, 12, d) + + assert.Equal(t, "beego", u.Organization) + + assert.Equal(t, "CXO", u.Title) + + assert.Equal(t, "", u.Hobby[0]) + + assert.Equal(t, "Basketball", u.Hobby[1]) + + assert.Equal(t, "Football", u.Hobby[2]) + + assert.Equal(t, 0, len(u.Memo)) } func TestRenderForm(t *testing.T) { @@ -212,18 +172,14 @@ func TestRenderForm(t *testing.T) { u := user{Name: "test", Intro: "Some Text"} output := RenderForm(u) - if output != template.HTML("") { - t.Errorf("output should be empty but got %v", output) - } + assert.Equal(t, template.HTML(""), output) output = RenderForm(&u) result := template.HTML( `Name:
` + `年龄:
` + `Sex:
` + `Intro: `) - if output != result { - t.Errorf("output should equal `%v` but got `%v`", result, output) - } + assert.Equal(t, result, output) } func TestMapGet(t *testing.T) { @@ -233,29 +189,17 @@ func TestMapGet(t *testing.T) { "1": 2, } - if res, err := MapGet(m1, "a"); err == nil { - if res.(int64) != 1 { - t.Errorf("Should return 1, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err := MapGet(m1, "a") + assert.Nil(t, err) + assert.Equal(t, int64(1), res) - if res, err := MapGet(m1, "1"); err == nil { - if res.(int64) != 2 { - t.Errorf("Should return 2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err = MapGet(m1, "1") + assert.Nil(t, err) + assert.Equal(t, int64(2), res) - if res, err := MapGet(m1, 1); err == nil { - if res.(int64) != 2 { - t.Errorf("Should return 2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err = MapGet(m1, 1) + assert.Nil(t, err) + assert.Equal(t, int64(2), res) // test 2 level map m2 := M{ @@ -264,13 +208,9 @@ func TestMapGet(t *testing.T) { }, } - if res, err := MapGet(m2, 1, 2); err == nil { - if res.(float64) != 3.5 { - t.Errorf("Should return 3.5, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err = MapGet(m2, 1, 2) + assert.Nil(t, err) + assert.Equal(t, 3.5, res) // test 5 level map m5 := M{ @@ -285,20 +225,12 @@ func TestMapGet(t *testing.T) { }, } - if res, err := MapGet(m5, 1, 2, 3, 4, 5); err == nil { - if res.(float64) != 1.2 { - t.Errorf("Should return 1.2, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err = MapGet(m5, 1, 2, 3, 4, 5) + assert.Nil(t, err) + assert.Equal(t, 1.2, res) // check whether element not exists in map - if res, err := MapGet(m5, 5, 4, 3, 2, 1); err == nil { - if res != nil { - t.Errorf("Should return nil, but return %v", res) - } - } else { - t.Errorf("Error happens %v", err) - } + res, err = MapGet(m5, 5, 4, 3, 2, 1) + assert.Nil(t, err) + assert.Nil(t, res) } diff --git a/adapter/testing/client.go b/adapter/testing/client.go index 356a0f68..a773c9a6 100644 --- a/adapter/testing/client.go +++ b/adapter/testing/client.go @@ -14,12 +14,7 @@ package testing -import ( - "github.com/beego/beego/v2/client/httplib/testing" -) - -var port = "" -var baseURL = "http://localhost:" +import "github.com/beego/beego/v2/client/httplib/testing" // TestHTTPRequest beego test request client type TestHTTPRequest testing.TestHTTPRequest diff --git a/adapter/toolbox/profile.go b/adapter/toolbox/profile.go index 15d7010a..00b0eef7 100644 --- a/adapter/toolbox/profile.go +++ b/adapter/toolbox/profile.go @@ -16,19 +16,10 @@ package toolbox import ( "io" - "os" - "time" "github.com/beego/beego/v2/core/admin" ) -var startTime = time.Now() -var pid int - -func init() { - pid = os.Getpid() -} - // ProcessInput parse input command string func ProcessInput(input string, w io.Writer) { admin.ProcessInput(input, w) diff --git a/adapter/toolbox/statistics_test.go b/adapter/toolbox/statistics_test.go index ac29476c..f4371c3f 100644 --- a/adapter/toolbox/statistics_test.go +++ b/adapter/toolbox/statistics_test.go @@ -21,13 +21,16 @@ import ( ) func TestStatics(t *testing.T) { - StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(2000)) - StatisticsMap.AddStatistics("POST", "/api/user", "&admin.user", time.Duration(120000)) - StatisticsMap.AddStatistics("GET", "/api/user", "&admin.user", time.Duration(13000)) - StatisticsMap.AddStatistics("POST", "/api/admin", "&admin.user", time.Duration(14000)) - StatisticsMap.AddStatistics("POST", "/api/user/astaxie", "&admin.user", time.Duration(12000)) - StatisticsMap.AddStatistics("POST", "/api/user/xiemengjun", "&admin.user", time.Duration(13000)) - StatisticsMap.AddStatistics("DELETE", "/api/user", "&admin.user", time.Duration(1400)) + userApi := "/api/user" + post := "POST" + adminUser := "&admin.user" + StatisticsMap.AddStatistics(post, userApi, adminUser, time.Duration(2000)) + StatisticsMap.AddStatistics(post, userApi, adminUser, time.Duration(120000)) + StatisticsMap.AddStatistics("GET", userApi, adminUser, time.Duration(13000)) + StatisticsMap.AddStatistics(post, "/api/admin", adminUser, time.Duration(14000)) + StatisticsMap.AddStatistics(post, "/api/user/astaxie", adminUser, time.Duration(12000)) + StatisticsMap.AddStatistics(post, "/api/user/xiemengjun", adminUser, time.Duration(13000)) + StatisticsMap.AddStatistics("DELETE", userApi, adminUser, time.Duration(1400)) t.Log(StatisticsMap.GetMap()) data := StatisticsMap.GetMapData() diff --git a/adapter/toolbox/task.go b/adapter/toolbox/task.go index bdd6679f..199956f8 100644 --- a/adapter/toolbox/task.go +++ b/adapter/toolbox/task.go @@ -80,7 +80,6 @@ type Task struct { // NewTask add new task with name, time and func func NewTask(tname string, spec string, f TaskFunc) *Task { - task := task.NewTask(tname, spec, func(ctx context.Context) error { return f() }) @@ -98,7 +97,6 @@ func (t *Task) GetSpec() string { // GetStatus get current task status func (t *Task) GetStatus() string { - t.initDelegate() return t.delegate.GetStatus(context.Background()) @@ -222,7 +220,6 @@ type MapSorter task.MapSorter // NewMapSorter create new tasker map func NewMapSorter(m map[string]Tasker) *MapSorter { - newTaskerMap := make(map[string]task.Tasker, len(m)) for key, value := range m { @@ -249,6 +246,7 @@ func (ms *MapSorter) Less(i, j int) bool { } return ms.Vals[i].GetNext(context.Background()).Before(ms.Vals[j].GetNext(context.Background())) } + func (ms *MapSorter) Swap(i, j int) { ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] @@ -289,3 +287,7 @@ func (o *oldToNewAdapter) SetPrev(ctx context.Context, t time.Time) { func (o *oldToNewAdapter) GetPrev(ctx context.Context) time.Time { return o.delegate.GetPrev() } + +func (o *oldToNewAdapter) GetTimeout(ctx context.Context) time.Duration { + return 0 +} diff --git a/adapter/tree.go b/adapter/tree.go index fe9f6933..3d22d9ac 100644 --- a/adapter/tree.go +++ b/adapter/tree.go @@ -16,9 +16,8 @@ package adapter import ( "github.com/beego/beego/v2/adapter/context" - beecontext "github.com/beego/beego/v2/server/web/context" - "github.com/beego/beego/v2/server/web" + beecontext "github.com/beego/beego/v2/server/web/context" ) // Tree has three elements: FixRouter/wildcard/leaves diff --git a/adapter/utils/captcha/README.md b/adapter/utils/captcha/README.md index 74e1cf82..07a4dc4d 100644 --- a/adapter/utils/captcha/README.md +++ b/adapter/utils/captcha/README.md @@ -7,8 +7,8 @@ package controllers import ( "github.com/beego/beego/v2" - "github.com/beego/beego/v2/cache" - "github.com/beego/beego/v2/utils/captcha" + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/server/web/captcha" ) var cpt *captcha.Captcha diff --git a/adapter/utils/captcha/captcha.go b/adapter/utils/captcha/captcha.go index 4f5dd867..7cdcab2d 100644 --- a/adapter/utils/captcha/captcha.go +++ b/adapter/utils/captcha/captcha.go @@ -20,8 +20,8 @@ // // import ( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/cache" -// "github.com/beego/beego/v2/utils/captcha" +// "github.com/beego/beego/v2/client/cache" +// "github.com/beego/beego/v2/server/web/captcha" // ) // // var cpt *captcha.Captcha @@ -63,16 +63,13 @@ import ( "net/http" "time" - "github.com/beego/beego/v2/server/web/captcha" - beecontext "github.com/beego/beego/v2/server/web/context" - "github.com/beego/beego/v2/adapter/cache" "github.com/beego/beego/v2/adapter/context" + "github.com/beego/beego/v2/server/web/captcha" + beecontext "github.com/beego/beego/v2/server/web/context" ) -var ( - defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} -) +var defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} const ( // default captcha attributes diff --git a/adapter/utils/debug_test.go b/adapter/utils/debug_test.go index efb8924e..a748d20a 100644 --- a/adapter/utils/debug_test.go +++ b/adapter/utils/debug_test.go @@ -28,8 +28,8 @@ func TestPrint(t *testing.T) { } func TestPrintPoint(t *testing.T) { - var v1 = new(mytype) - var v2 = new(mytype) + v1 := new(mytype) + v2 := new(mytype) v1.prev = nil v1.next = v2 diff --git a/adapter/utils/pagination/doc.go b/adapter/utils/pagination/doc.go index d180b093..43ee78b6 100644 --- a/adapter/utils/pagination/doc.go +++ b/adapter/utils/pagination/doc.go @@ -8,7 +8,7 @@ In your beego.Controller: package controllers - import "github.com/beego/beego/v2/utils/pagination" + import "github.com/beego/beego/v2/server/web/pagination" type PostsController struct { beego.Controller diff --git a/adapter/utils/rand_test.go b/adapter/utils/rand_test.go index 6c238b5e..1cb26029 100644 --- a/adapter/utils/rand_test.go +++ b/adapter/utils/rand_test.go @@ -16,7 +16,7 @@ package utils import "testing" -func TestRand_01(t *testing.T) { +func TestRand01(t *testing.T) { bs0 := RandomCreateBytes(16) bs1 := RandomCreateBytes(16) diff --git a/adapter/utils/slice.go b/adapter/utils/slice.go index cdbfcca8..082b22ce 100644 --- a/adapter/utils/slice.go +++ b/adapter/utils/slice.go @@ -19,6 +19,7 @@ import ( ) type reducetype func(interface{}) interface{} + type filtertype func(interface{}) bool // InSlice checks given string in string slice or not. diff --git a/adapter/validation/util.go b/adapter/validation/util.go index 502be750..5ff43ebc 100644 --- a/adapter/validation/util.go +++ b/adapter/validation/util.go @@ -27,9 +27,7 @@ const ( LabelTag = validation.LabelTag ) -var ( - ErrInt64On32 = validation.ErrInt64On32 -) +var ErrInt64On32 = validation.ErrInt64On32 // CustomFunc is for custom validate function type CustomFunc func(v *Validation, obj interface{}, key string) diff --git a/adapter/validation/validation.go b/adapter/validation/validation.go index 8226fa20..eadd4361 100644 --- a/adapter/validation/validation.go +++ b/adapter/validation/validation.go @@ -15,7 +15,7 @@ // Package validation for validations // // import ( -// "github.com/beego/beego/v2/validation" +// "github.com/beego/beego/v2/core/validation" // "log" // ) // diff --git a/adapter/validation/validation_test.go b/adapter/validation/validation_test.go index b4b5b1b6..547e8635 100644 --- a/adapter/validation/validation_test.go +++ b/adapter/validation/validation_test.go @@ -18,131 +18,82 @@ import ( "regexp" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestRequired(t *testing.T) { valid := Validation{} - if valid.Required(nil, "nil").Ok { - t.Error("nil object should be false") - } - if !valid.Required(true, "bool").Ok { - t.Error("Bool value should always return true") - } - if !valid.Required(false, "bool").Ok { - t.Error("Bool value should always return true") - } - if valid.Required("", "string").Ok { - t.Error("\"'\" string should be false") - } - if valid.Required(" ", "string").Ok { - t.Error("\" \" string should be false") // For #2361 - } - if valid.Required("\n", "string").Ok { - t.Error("new line string should be false") // For #2361 - } - if !valid.Required("astaxie", "string").Ok { - t.Error("string should be true") - } - if valid.Required(0, "zero").Ok { - t.Error("Integer should not be equal 0") - } - if !valid.Required(1, "int").Ok { - t.Error("Integer except 0 should be true") - } - if !valid.Required(time.Now(), "time").Ok { - t.Error("time should be true") - } - if valid.Required([]string{}, "emptySlice").Ok { - t.Error("empty slice should be false") - } - if !valid.Required([]interface{}{"ok"}, "slice").Ok { - t.Error("slice should be true") - } + assert.False(t, valid.Required(nil, "nil").Ok) + assert.True(t, valid.Required(true, "bool").Ok) + + assert.True(t, valid.Required(false, "bool").Ok) + assert.False(t, valid.Required("", "string").Ok) + assert.False(t, valid.Required(" ", "string").Ok) + assert.False(t, valid.Required("\n", "string").Ok) + + assert.True(t, valid.Required("astaxie", "string").Ok) + assert.False(t, valid.Required(0, "zero").Ok) + + assert.True(t, valid.Required(1, "int").Ok) + + assert.True(t, valid.Required(time.Now(), "time").Ok) + + assert.False(t, valid.Required([]string{}, "emptySlice").Ok) + + assert.True(t, valid.Required([]interface{}{"ok"}, "slice").Ok) } func TestMin(t *testing.T) { valid := Validation{} - if valid.Min(-1, 0, "min0").Ok { - t.Error("-1 is less than the minimum value of 0 should be false") - } - if !valid.Min(1, 0, "min0").Ok { - t.Error("1 is greater or equal than the minimum value of 0 should be true") - } + assert.False(t, valid.Min(-1, 0, "min0").Ok) + assert.True(t, valid.Min(1, 0, "min0").Ok) } func TestMax(t *testing.T) { valid := Validation{} - if valid.Max(1, 0, "max0").Ok { - t.Error("1 is greater than the minimum value of 0 should be false") - } - if !valid.Max(-1, 0, "max0").Ok { - t.Error("-1 is less or equal than the maximum value of 0 should be true") - } + assert.False(t, valid.Max(1, 0, "max0").Ok) + assert.True(t, valid.Max(-1, 0, "max0").Ok) } func TestRange(t *testing.T) { valid := Validation{} - if valid.Range(-1, 0, 1, "range0_1").Ok { - t.Error("-1 is between 0 and 1 should be false") - } - if !valid.Range(1, 0, 1, "range0_1").Ok { - t.Error("1 is between 0 and 1 should be true") - } + assert.False(t, valid.Range(-1, 0, 1, "range0_1").Ok) + + assert.True(t, valid.Range(1, 0, 1, "range0_1").Ok) } func TestMinSize(t *testing.T) { valid := Validation{} - if valid.MinSize("", 1, "minSize1").Ok { - t.Error("the length of \"\" is less than the minimum value of 1 should be false") - } - if !valid.MinSize("ok", 1, "minSize1").Ok { - t.Error("the length of \"ok\" is greater or equal than the minimum value of 1 should be true") - } - if valid.MinSize([]string{}, 1, "minSize1").Ok { - t.Error("the length of empty slice is less than the minimum value of 1 should be false") - } - if !valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok { - t.Error("the length of [\"ok\"] is greater or equal than the minimum value of 1 should be true") - } + assert.False(t, valid.MinSize("", 1, "minSize1").Ok) + + assert.True(t, valid.MinSize("ok", 1, "minSize1").Ok) + assert.False(t, valid.MinSize([]string{}, 1, "minSize1").Ok) + assert.True(t, valid.MinSize([]interface{}{"ok"}, 1, "minSize1").Ok) } func TestMaxSize(t *testing.T) { valid := Validation{} - if valid.MaxSize("ok", 1, "maxSize1").Ok { - t.Error("the length of \"ok\" is greater than the maximum value of 1 should be false") - } - if !valid.MaxSize("", 1, "maxSize1").Ok { - t.Error("the length of \"\" is less or equal than the maximum value of 1 should be true") - } - if valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok { - t.Error("the length of [\"ok\", false] is greater than the maximum value of 1 should be false") - } - if !valid.MaxSize([]string{}, 1, "maxSize1").Ok { - t.Error("the length of empty slice is less or equal than the maximum value of 1 should be true") - } + assert.False(t, valid.MaxSize("ok", 1, "maxSize1").Ok) + assert.True(t, valid.MaxSize("", 1, "maxSize1").Ok) + assert.False(t, valid.MaxSize([]interface{}{"ok", false}, 1, "maxSize1").Ok) + assert.True(t, valid.MaxSize([]string{}, 1, "maxSize1").Ok) } func TestLength(t *testing.T) { valid := Validation{} - if valid.Length("", 1, "length1").Ok { - t.Error("the length of \"\" must equal 1 should be false") - } - if !valid.Length("1", 1, "length1").Ok { - t.Error("the length of \"1\" must equal 1 should be true") - } - if valid.Length([]string{}, 1, "length1").Ok { - t.Error("the length of empty slice must equal 1 should be false") - } - if !valid.Length([]interface{}{"ok"}, 1, "length1").Ok { - t.Error("the length of [\"ok\"] must equal 1 should be true") - } + assert.False(t, valid.Length("", 1, "length1").Ok) + assert.True(t, valid.Length("1", 1, "length1").Ok) + + assert.False(t, valid.Length([]string{}, 1, "length1").Ok) + assert.True(t, valid.Length([]interface{}{"ok"}, 1, "length1").Ok) } func TestAlpha(t *testing.T) { @@ -178,13 +129,16 @@ func TestAlphaNumeric(t *testing.T) { } } +const email = "suchuangji@gmail.com" + func TestMatch(t *testing.T) { valid := Validation{} if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") } - if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { + + if !valid.Match(email, regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok { t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") } } @@ -217,7 +171,7 @@ func TestEmail(t *testing.T) { if valid.Email("not@a email", "email").Ok { t.Error("\"not@a email\" is a valid email address should be false") } - if !valid.Email("suchuangji@gmail.com", "email").Ok { + if !valid.Email(email, "email").Ok { t.Error("\"suchuangji@gmail.com\" is a valid email address should be true") } if valid.Email("@suchuangji@gmail.com", "email").Ok { @@ -242,7 +196,7 @@ func TestIP(t *testing.T) { func TestBase64(t *testing.T) { valid := Validation{} - if valid.Base64("suchuangji@gmail.com", "base64").Ok { + if valid.Base64(email, "base64").Ok { t.Error("\"suchuangji@gmail.com\" are a valid base64 characters should be false") } if !valid.Base64("c3VjaHVhbmdqaUBnbWFpbC5jb20=", "base64").Ok { @@ -370,44 +324,25 @@ func TestValid(t *testing.T) { u := user{Name: "test@/test/;com", Age: 40} b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Error("validation should be passed") - } + assert.Nil(t, err) + assert.True(t, b) uptr := &user{Name: "test", Age: 40} valid.Clear() b, err = valid.Valid(uptr) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } - if len(valid.Errors) != 1 { - t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) - } - if valid.Errors[0].Key != "Name.Match" { - t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) - } + + assert.Nil(t, err) + assert.False(t, b) + assert.Equal(t, 1, len(valid.Errors)) + assert.Equal(t, "Name.Match", valid.Errors[0].Key) u = user{Name: "test@/test/;com", Age: 180} valid.Clear() b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } - if len(valid.Errors) != 1 { - t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) - } - if valid.Errors[0].Key != "Age.Range." { - t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) - } + assert.Nil(t, err) + assert.False(t, b) + assert.Equal(t, 1, len(valid.Errors)) + assert.Equal(t, "Age.Range.", valid.Errors[0].Key) } func TestRecursiveValid(t *testing.T) { @@ -432,12 +367,8 @@ func TestRecursiveValid(t *testing.T) { u := Account{Password: "abc123_", U: User{}} b, err := valid.RecursiveValid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Error("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) } func TestSkipValid(t *testing.T) { @@ -474,21 +405,13 @@ func TestSkipValid(t *testing.T) { valid := Validation{} b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) valid = Validation{RequiredFirst: true} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } + assert.Nil(t, err) + assert.True(t, b) } func TestPointer(t *testing.T) { @@ -506,12 +429,8 @@ func TestPointer(t *testing.T) { valid := Validation{} b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) validEmail := "a@a.com" u = User{ @@ -521,12 +440,8 @@ func TestPointer(t *testing.T) { valid = Validation{RequiredFirst: true} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } + assert.Nil(t, err) + assert.True(t, b) u = User{ ReqEmail: &validEmail, @@ -535,12 +450,8 @@ func TestPointer(t *testing.T) { valid = Validation{} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) invalidEmail := "a@a" u = User{ @@ -550,12 +461,8 @@ func TestPointer(t *testing.T) { valid = Validation{RequiredFirst: true} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) u = User{ ReqEmail: &validEmail, @@ -564,12 +471,8 @@ func TestPointer(t *testing.T) { valid = Validation{} b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) } func TestCanSkipAlso(t *testing.T) { @@ -589,21 +492,13 @@ func TestCanSkipAlso(t *testing.T) { valid := Validation{RequiredFirst: true} b, err := valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if b { - t.Fatal("validation should not be passed") - } + assert.Nil(t, err) + assert.False(t, b) valid = Validation{RequiredFirst: true} valid.CanSkipAlso("Range") b, err = valid.Valid(u) - if err != nil { - t.Fatal(err) - } - if !b { - t.Fatal("validation should be passed") - } + assert.Nil(t, err) + assert.True(t, b) } diff --git a/client/cache/README.md b/client/cache/README.md index 7e65cbbf..df1ea095 100644 --- a/client/cache/README.md +++ b/client/cache/README.md @@ -4,7 +4,7 @@ cache is a Go cache manager. It can use many cache adapters. The repo is inspire ## How to install? - go get github.com/beego/beego/v2/cache + go get github.com/beego/beego/v2/client/cache ## What adapters are supported? @@ -15,7 +15,7 @@ As of now this cache support memory, Memcache and Redis. First you must import it import ( - "github.com/beego/beego/v2/cache" + "github.com/beego/beego/v2/client/cache" ) Then init a Cache (example with memory adapter) diff --git a/client/cache/cache.go b/client/cache/cache.go index e73a1c1a..87f7ba62 100644 --- a/client/cache/cache.go +++ b/client/cache/cache.go @@ -16,7 +16,7 @@ // Usage: // // import( -// "github.com/beego/beego/v2/cache" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("memory", `{"interval":60}`) @@ -33,8 +33,9 @@ package cache import ( "context" - "fmt" "time" + + "github.com/beego/beego/v2/core/berror" ) // Cache interface contains all behaviors for cache adapter. @@ -55,12 +56,14 @@ type Cache interface { // Set a cached value with key and expire time. Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error // Delete cached value by key. + // Should not return error if key not found Delete(ctx context.Context, key string) error // Increment a cached int value by key, as a counter. Incr(ctx context.Context, key string) error // Decrement a cached int value by key, as a counter. Decr(ctx context.Context, key string) error // Check if a cached value exists or not. + // if key is expired, return (false, nil) IsExist(ctx context.Context, key string) (bool, error) // Clear all cache. ClearAll(ctx context.Context) error @@ -78,7 +81,7 @@ var adapters = make(map[string]Instance) // it panics. func Register(name string, adapter Instance) { if adapter == nil { - panic("cache: Register adapter is nil") + panic(berror.Error(NilCacheAdapter, "cache: Register adapter is nil").Error()) } if _, ok := adapters[name]; ok { panic("cache: Register called twice for adapter " + name) @@ -92,7 +95,7 @@ func Register(name string, adapter Instance) { func NewCache(adapterName, config string) (adapter Cache, err error) { instanceFunc, ok := adapters[adapterName] if !ok { - err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + err = berror.Errorf(UnknownAdapter, "cache: unknown adapter name %s (forgot to import?)", adapterName) return } adapter = instanceFunc() diff --git a/client/cache/cache_test.go b/client/cache/cache_test.go index 85f83fc4..f037a410 100644 --- a/client/cache/cache_test.go +++ b/client/cache/cache_test.go @@ -16,17 +16,19 @@ package cache import ( "context" + "math" "os" + "strings" "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestCacheIncr(t *testing.T) { bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) // timeoutDuration := 10 * time.Second bm.Put(context.Background(), "edwardhey", 0, time.Second*20) @@ -46,11 +48,9 @@ func TestCacheIncr(t *testing.T) { } func TestCache(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second + bm, err := NewCache("memory", `{"interval":1}`) + assert.Nil(t, err) + timeoutDuration := 5 * time.Second if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } @@ -62,7 +62,7 @@ func TestCache(t *testing.T) { t.Error("get err") } - time.Sleep(30 * time.Second) + time.Sleep(7 * time.Second) if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("check err") @@ -73,130 +73,95 @@ func TestCache(t *testing.T) { } // test different integer type for incr & decr - testMultiIncrDecr(t, bm, timeoutDuration) + testMultiTypeIncrDecr(t, bm, timeoutDuration) + + // test overflow of incr&decr + testIncrOverFlow(t, bm, timeoutDuration) + testDecrOverFlow(t, bm, timeoutDuration) bm.Delete(context.Background(), "astaxie") - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("delete err") - } + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) - // test GetMulti - if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } - if v, _ := bm.Get(context.Background(), "astaxie"); v.(string) != "author" { - t.Error("get err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) - if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { - t.Error("check err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + v, _ := bm.Get(context.Background(), "astaxie") + assert.Equal(t, "author", v) + + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Equal(t, "author", vv[0]) + assert.Equal(t, "author1", vv[1]) vv, err = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0] != nil { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - if err != nil && err.Error() != "key [astaxie0] error: the key isn't exist" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Nil(t, vv[0]) + assert.Equal(t, "author1", vv[1]) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key isn't exist")) } func TestFileCache(t *testing.T) { bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) timeoutDuration := 10 * time.Second - if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie", 1, timeoutDuration)) - if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 1 { - t.Error("get err") - } + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + v, _ := bm.Get(context.Background(), "astaxie") + assert.Equal(t, 1, v) // test different integer type for incr & decr - testMultiIncrDecr(t, bm, timeoutDuration) + testMultiTypeIncrDecr(t, bm, timeoutDuration) + + // test overflow of incr&decr + testIncrOverFlow(t, bm, timeoutDuration) + testDecrOverFlow(t, bm, timeoutDuration) bm.Delete(context.Background(), "astaxie") - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("delete err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) // test string - if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } - if v, _ := bm.Get(context.Background(), "astaxie"); v.(string) != "author" { - t.Error("get err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + v, _ = bm.Get(context.Background(), "astaxie") + assert.Equal(t, "author", v) // test GetMulti - if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { - t.Error("check err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Equal(t, "author", vv[0]) + assert.Equal(t, "author1", vv[1]) vv, err = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0] != nil { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - if err == nil { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Nil(t, vv[0]) + + assert.Equal(t, "author1", vv[1]) + assert.NotNil(t, err) os.RemoveAll("cache") } -func testMultiIncrDecr(t *testing.T, c Cache, timeout time.Duration) { +func testMultiTypeIncrDecr(t *testing.T, c Cache, timeout time.Duration) { testIncrDecr(t, c, 1, 2, timeout) testIncrDecr(t, c, int32(1), int32(2), timeout) testIncrDecr(t, c, int64(1), int64(2), timeout) @@ -206,30 +171,51 @@ func testMultiIncrDecr(t *testing.T, c Cache, timeout time.Duration) { } func testIncrDecr(t *testing.T, c Cache, beforeIncr interface{}, afterIncr interface{}, timeout time.Duration) { - var err error ctx := context.Background() key := "incDecKey" - if err = c.Put(ctx, key, beforeIncr, timeout); err != nil { - t.Error("Get Error", err) - } - if err = c.Incr(ctx, key); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, c.Put(ctx, key, beforeIncr, timeout)) + assert.Nil(t, c.Incr(ctx, key)) - if v, _ := c.Get(ctx, key); v != afterIncr { - t.Error("Get Error") - } + v, _ := c.Get(ctx, key) + assert.Equal(t, afterIncr, v) - if err = c.Decr(ctx, key); err != nil { - t.Error("Decr Error", err) - } + assert.Nil(t, c.Decr(ctx, key)) - if v, _ := c.Get(ctx, key); v != beforeIncr { - t.Error("Get Error") - } + v, _ = c.Get(ctx, key) + assert.Equal(t, v, beforeIncr) + assert.Nil(t, c.Delete(ctx, key)) +} - if err := c.Delete(ctx, key); err != nil { - t.Error("Delete Error") +func testIncrOverFlow(t *testing.T, c Cache, timeout time.Duration) { + ctx := context.Background() + key := "incKey" + + assert.Nil(t, c.Put(ctx, key, int64(math.MaxInt64), timeout)) + // int64 + defer func() { + assert.Nil(t, c.Delete(ctx, key)) + }() + assert.NotNil(t, c.Incr(ctx, key)) +} + +func testDecrOverFlow(t *testing.T, c Cache, timeout time.Duration) { + var err error + ctx := context.Background() + key := "decKey" + + // int64 + if err = c.Put(ctx, key, int64(math.MinInt64), timeout); err != nil { + t.Error("Put Error: ", err.Error()) + return + } + defer func() { + if err = c.Delete(ctx, key); err != nil { + t.Errorf("Delete error: %s", err.Error()) + } + }() + if err = c.Decr(ctx, key); err == nil { + t.Error("Decr error") + return } } diff --git a/client/cache/calc_utils.go b/client/cache/calc_utils.go new file mode 100644 index 00000000..f8b7f24a --- /dev/null +++ b/client/cache/calc_utils.go @@ -0,0 +1,95 @@ +package cache + +import ( + "math" + + "github.com/beego/beego/v2/core/berror" +) + +var ( + ErrIncrementOverflow = berror.Error(IncrementOverflow, "this incr invocation will overflow.") + ErrDecrementOverflow = berror.Error(DecrementOverflow, "this decr invocation will overflow.") + ErrNotIntegerType = berror.Error(NotIntegerType, "item val is not (u)int (u)int32 (u)int64") +) + +const ( + MinUint32 uint32 = 0 + MinUint64 uint64 = 0 +) + +func incr(originVal interface{}) (interface{}, error) { + switch val := originVal.(type) { + case int: + tmp := val + 1 + if val > 0 && tmp < 0 { + return nil, ErrIncrementOverflow + } + return tmp, nil + case int32: + if val == math.MaxInt32 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + case int64: + if val == math.MaxInt64 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + case uint: + tmp := val + 1 + if tmp < val { + return nil, ErrIncrementOverflow + } + return tmp, nil + case uint32: + if val == math.MaxUint32 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + case uint64: + if val == math.MaxUint64 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + default: + return nil, ErrNotIntegerType + } +} + +func decr(originVal interface{}) (interface{}, error) { + switch val := originVal.(type) { + case int: + tmp := val - 1 + if val < 0 && tmp > 0 { + return nil, ErrDecrementOverflow + } + return tmp, nil + case int32: + if val == math.MinInt32 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case int64: + if val == math.MinInt64 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case uint: + if val == 0 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case uint32: + if val == MinUint32 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case uint64: + if val == MinUint64 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + default: + return nil, ErrNotIntegerType + } +} diff --git a/client/cache/calc_utils_test.go b/client/cache/calc_utils_test.go new file mode 100644 index 00000000..1f8d3377 --- /dev/null +++ b/client/cache/calc_utils_test.go @@ -0,0 +1,140 @@ +package cache + +import ( + "math" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIncr(t *testing.T) { + // int + var originVal interface{} = int(1) + var updateVal interface{} = int(2) + val, err := incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(int(1<<(strconv.IntSize-1) - 1)) + assert.Equal(t, ErrIncrementOverflow, err) + + // int32 + originVal = int32(1) + updateVal = int32(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(int32(math.MaxInt32)) + assert.Equal(t, ErrIncrementOverflow, err) + + // int64 + originVal = int64(1) + updateVal = int64(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(int64(math.MaxInt64)) + assert.Equal(t, ErrIncrementOverflow, err) + + // uint + originVal = uint(1) + updateVal = uint(2) + val, err = incr(originVal) + + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(uint(1<<(strconv.IntSize) - 1)) + assert.Equal(t, ErrIncrementOverflow, err) + + // uint32 + originVal = uint32(1) + updateVal = uint32(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(uint32(math.MaxUint32)) + assert.Equal(t, ErrIncrementOverflow, err) + + // uint64 + originVal = uint64(1) + updateVal = uint64(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(uint64(math.MaxUint64)) + assert.Equal(t, ErrIncrementOverflow, err) + // other type + _, err = incr("string") + assert.Equal(t, ErrNotIntegerType, err) +} + +func TestDecr(t *testing.T) { + // int + var originVal interface{} = int(2) + var updateVal interface{} = int(1) + val, err := decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(int(-1 << (strconv.IntSize - 1))) + assert.Equal(t, ErrDecrementOverflow, err) + // int32 + originVal = int32(2) + updateVal = int32(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(int32(math.MinInt32)) + assert.Equal(t, ErrDecrementOverflow, err) + + // int64 + originVal = int64(2) + updateVal = int64(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(int64(math.MinInt64)) + assert.Equal(t, ErrDecrementOverflow, err) + + // uint + originVal = uint(2) + updateVal = uint(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(uint(0)) + assert.Equal(t, ErrDecrementOverflow, err) + + // uint32 + originVal = uint32(2) + updateVal = uint32(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(uint32(0)) + assert.Equal(t, ErrDecrementOverflow, err) + + // uint64 + originVal = uint64(2) + updateVal = uint64(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(uint64(0)) + assert.Equal(t, ErrDecrementOverflow, err) + + // other type + _, err = decr("string") + assert.Equal(t, ErrNotIntegerType, err) +} diff --git a/client/cache/conv_test.go b/client/cache/conv_test.go index b90e224a..c261c440 100644 --- a/client/cache/conv_test.go +++ b/client/cache/conv_test.go @@ -16,128 +16,74 @@ package cache import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestGetString(t *testing.T) { - var t1 = "test1" - if "test1" != GetString(t1) { - t.Error("get string from string error") - } - var t2 = []byte("test2") - if "test2" != GetString(t2) { - t.Error("get string from byte array error") - } - var t3 = 1 - if "1" != GetString(t3) { - t.Error("get string from int error") - } - var t4 int64 = 1 - if "1" != GetString(t4) { - t.Error("get string from int64 error") - } - var t5 = 1.1 - if "1.1" != GetString(t5) { - t.Error("get string from float64 error") - } + t1 := "test1" - if "" != GetString(nil) { - t.Error("get string from nil error") - } + assert.Equal(t, "test1", GetString(t1)) + t2 := []byte("test2") + assert.Equal(t, "test2", GetString(t2)) + t3 := 1 + assert.Equal(t, "1", GetString(t3)) + var t4 int64 = 1 + assert.Equal(t, "1", GetString(t4)) + t5 := 1.1 + assert.Equal(t, "1.1", GetString(t5)) + assert.Equal(t, "", GetString(nil)) } func TestGetInt(t *testing.T) { - var t1 = 1 - if 1 != GetInt(t1) { - t.Error("get int from int error") - } + t1 := 1 + assert.Equal(t, 1, GetInt(t1)) var t2 int32 = 32 - if 32 != GetInt(t2) { - t.Error("get int from int32 error") - } + assert.Equal(t, 32, GetInt(t2)) + var t3 int64 = 64 - if 64 != GetInt(t3) { - t.Error("get int from int64 error") - } - var t4 = "128" - if 128 != GetInt(t4) { - t.Error("get int from num string error") - } - if 0 != GetInt(nil) { - t.Error("get int from nil error") - } + assert.Equal(t, 64, GetInt(t3)) + t4 := "128" + + assert.Equal(t, 128, GetInt(t4)) + assert.Equal(t, 0, GetInt(nil)) } func TestGetInt64(t *testing.T) { var i int64 = 1 - var t1 = 1 - if i != GetInt64(t1) { - t.Error("get int64 from int error") - } + t1 := 1 + assert.Equal(t, i, GetInt64(t1)) var t2 int32 = 1 - if i != GetInt64(t2) { - t.Error("get int64 from int32 error") - } + + assert.Equal(t, i, GetInt64(t2)) var t3 int64 = 1 - if i != GetInt64(t3) { - t.Error("get int64 from int64 error") - } - var t4 = "1" - if i != GetInt64(t4) { - t.Error("get int64 from num string error") - } - if 0 != GetInt64(nil) { - t.Error("get int64 from nil") - } + assert.Equal(t, i, GetInt64(t3)) + t4 := "1" + assert.Equal(t, i, GetInt64(t4)) + assert.Equal(t, int64(0), GetInt64(nil)) } func TestGetFloat64(t *testing.T) { - var f = 1.11 + f := 1.11 var t1 float32 = 1.11 - if f != GetFloat64(t1) { - t.Error("get float64 from float32 error") - } - var t2 = 1.11 - if f != GetFloat64(t2) { - t.Error("get float64 from float64 error") - } - var t3 = "1.11" - if f != GetFloat64(t3) { - t.Error("get float64 from string error") - } + assert.Equal(t, f, GetFloat64(t1)) + t2 := 1.11 + assert.Equal(t, f, GetFloat64(t2)) + t3 := "1.11" + assert.Equal(t, f, GetFloat64(t3)) var f2 float64 = 1 - var t4 = 1 - if f2 != GetFloat64(t4) { - t.Error("get float64 from int error") - } + t4 := 1 + assert.Equal(t, f2, GetFloat64(t4)) - if 0 != GetFloat64(nil) { - t.Error("get float64 from nil error") - } + assert.Equal(t, float64(0), GetFloat64(nil)) } func TestGetBool(t *testing.T) { - var t1 = true - if !GetBool(t1) { - t.Error("get bool from bool error") - } - var t2 = "true" - if !GetBool(t2) { - t.Error("get bool from string error") - } - if GetBool(nil) { - t.Error("get bool from nil error") - } -} + t1 := true + assert.True(t, GetBool(t1)) + t2 := "true" + assert.True(t, GetBool(t2)) -func byteArrayEquals(a []byte, b []byte) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true + assert.False(t, GetBool(nil)) } diff --git a/client/cache/error_code.go b/client/cache/error_code.go new file mode 100644 index 00000000..5611f065 --- /dev/null +++ b/client/cache/error_code.go @@ -0,0 +1,176 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "github.com/beego/beego/v2/core/berror" +) + +var NilCacheAdapter = berror.DefineCode(4002001, moduleName, "NilCacheAdapter", ` +It means that you register cache adapter by pass nil. +A cache adapter is an instance of Cache interface. +`) + +var DuplicateAdapter = berror.DefineCode(4002002, moduleName, "DuplicateAdapter", ` +You register two adapter with same name. In beego cache module, one name one adapter. +Once you got this error, please check the error stack, search adapter +`) + +var UnknownAdapter = berror.DefineCode(4002003, moduleName, "UnknownAdapter", ` +Unknown adapter, do you forget to register the adapter? +You must register adapter before use it. For example, if you want to use redis implementation, +you must import the cache/redis package. +`) + +var IncrementOverflow = berror.DefineCode(4002004, moduleName, "IncrementOverflow", ` +The increment operation will overflow. +`) + +var DecrementOverflow = berror.DefineCode(4002005, moduleName, "DecrementOverflow", ` +The decrement operation will overflow. +`) + +var NotIntegerType = berror.DefineCode(4002006, moduleName, "NotIntegerType", ` +The type of value is not (u)int (u)int32 (u)int64. +When you want to call Incr or Decr function of Cache API, you must confirm that the value's type is one of (u)int (u)int32 (u)int64. +`) + +var InvalidFileCacheDirectoryLevelCfg = berror.DefineCode(4002007, moduleName, "InvalidFileCacheDirectoryLevelCfg", ` +You pass invalid DirectoryLevel parameter when you try to StartAndGC file cache instance. +This parameter must be a integer, and please check your input. +`) + +var InvalidFileCacheEmbedExpiryCfg = berror.DefineCode(4002008, moduleName, "InvalidFileCacheEmbedExpiryCfg", ` +You pass invalid EmbedExpiry parameter when you try to StartAndGC file cache instance. +This parameter must be a integer, and please check your input. +`) + +var CreateFileCacheDirFailed = berror.DefineCode(4002009, moduleName, "CreateFileCacheDirFailed", ` +Beego failed to create file cache directory. There are two cases: +1. You pass invalid CachePath parameter. Please check your input. +2. Beego doesn't have the permission to create this directory. Please check your file mode. +`) + +var InvalidFileCachePath = berror.DefineCode(4002010, moduleName, "InvalidFilePath", ` +The file path of FileCache is invalid. Please correct the config. +`) + +var ReadFileCacheContentFailed = berror.DefineCode(4002011, moduleName, "ReadFileCacheContentFailed", ` +Usually you won't got this error. It means that Beego cannot read the data from the file. +You need to check whether the file exist. Sometimes it may be deleted by other processes. +If the file exists, please check the permission that Beego is able to read data from the file. +`) + +var InvalidGobEncodedData = berror.DefineCode(4002012, moduleName, "InvalidEncodedData", ` +The data is invalid. When you try to decode the invalid data, you got this error. +Please confirm that the data is encoded by GOB correctly. +`) + +var GobEncodeDataFailed = berror.DefineCode(4002013, moduleName, "GobEncodeDataFailed", ` +Beego could not encode the data to GOB byte array. In general, the data type is invalid. +For example, GOB doesn't support function type. +Basic types, string, structure, structure pointer are supported. +`) + +var KeyExpired = berror.DefineCode(4002014, moduleName, "KeyExpired", ` +Cache key is expired. +You should notice that, a key is expired and then it may be deleted by GC goroutine. +So when you query a key which may be expired, you may got this code, or KeyNotExist. +`) + +var KeyNotExist = berror.DefineCode(4002015, moduleName, "KeyNotExist", ` +Key not found. +`) + +var MultiGetFailed = berror.DefineCode(4002016, moduleName, "MultiGetFailed", ` +Get multiple keys failed. Please check the detail msg to find out the root cause. +`) + +var InvalidMemoryCacheCfg = berror.DefineCode(4002017, moduleName, "InvalidMemoryCacheCfg", ` +The config is invalid. Please check your input. It must be a json string. +`) + +var InvalidMemCacheCfg = berror.DefineCode(4002018, moduleName, "InvalidMemCacheCfg", ` +The config is invalid. Please check your input, it must be json string and contains "conn" field. +`) + +var InvalidMemCacheValue = berror.DefineCode(4002019, moduleName, "InvalidMemCacheValue", ` +The value must be string or byte[], please check your input. +`) + +var InvalidRedisCacheCfg = berror.DefineCode(4002020, moduleName, "InvalidRedisCacheCfg", ` +The config must be json string, and has "conn" field. +`) + +var InvalidSsdbCacheCfg = berror.DefineCode(4002021, moduleName, "InvalidSsdbCacheCfg", ` +The config must be json string, and has "conn" field. The value of "conn" field should be "host:port". +"port" must be a valid integer. +`) + +var InvalidSsdbCacheValue = berror.DefineCode(4002022, moduleName, "InvalidSsdbCacheValue", ` +SSDB cache only accept string value. Please check your input. +`) + +var DeleteFileCacheItemFailed = berror.DefineCode(5002001, moduleName, "DeleteFileCacheItemFailed", ` +Beego try to delete file cache item failed. +Please check whether Beego generated file correctly. +And then confirm whether this file is already deleted by other processes or other people. +`) + +var MemCacheCurdFailed = berror.DefineCode(5002002, moduleName, "MemCacheError", ` +When you want to get, put, delete key-value from remote memcache servers, you may get error: +1. You pass invalid servers address, so Beego could not connect to remote server; +2. The servers address is correct, but there is some net issue. Typically there is some firewalls between application and memcache server; +3. Key is invalid. The key's length should be less than 250 and must not contains special characters; +4. The response from memcache server is invalid; +`) + +var RedisCacheCurdFailed = berror.DefineCode(5002003, moduleName, "RedisCacheCurdFailed", ` +When Beego uses client to send request to redis server, it failed. +1. The server addresses is invalid; +2. Network issue, firewall issue or network is unstable; +3. Client failed to manage connection. In extreme cases, Beego's redis client didn't maintain connections correctly, for example, Beego try to send request via closed connection; +4. The request are huge and redis server spent too much time to process it, and client is timeout; + +In general, if you always got this error whatever you do, in most cases, it was caused by network issue. +You could check your network state, and confirm that firewall rules are correct. +`) + +var InvalidConnection = berror.DefineCode(5002004, moduleName, "InvalidConnection", ` +The connection is invalid. Please check your connection info, network, firewall. +You could simply uses ping, telnet or write some simple tests to test network. +`) + +var DialFailed = berror.DefineCode(5002005, moduleName, "DialFailed", ` +When Beego try to dial to remote servers, it failed. Please check your connection info and network state, server state. +`) + +var SsdbCacheCurdFailed = berror.DefineCode(5002006, moduleName, "SsdbCacheCurdFailed", ` +When you try to use SSDB cache, it failed. There are many cases: +1. servers unavailable; +2. network issue, including network unstable, firewall; +3. connection issue; +4. request are huge and servers spent too much time to process it, got timeout; +`) + +var SsdbBadResponse = berror.DefineCode(5002007, moduleName, "SsdbBadResponse", ` +The reponse from SSDB server is invalid. +Usually it indicates something wrong on server side. +`) + +var ( + ErrKeyExpired = berror.Error(KeyExpired, "the key is expired") + ErrKeyNotExist = berror.Error(KeyNotExist, "the key isn't exist") +) diff --git a/client/cache/file.go b/client/cache/file.go index 043c4650..ae2bc7cf 100644 --- a/client/cache/file.go +++ b/client/cache/file.go @@ -30,7 +30,7 @@ import ( "strings" "time" - "github.com/pkg/errors" + "github.com/beego/beego/v2/core/berror" ) // FileCacheItem is basic unit of file cache adapter which @@ -67,44 +67,65 @@ func NewFileCache() Cache { // StartAndGC starts gc for file cache. // config must be in the format {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} func (fc *FileCache) StartAndGC(config string) error { - cfg := make(map[string]string) err := json.Unmarshal([]byte(config), &cfg) if err != nil { return err } - if _, ok := cfg["CachePath"]; !ok { - cfg["CachePath"] = FileCachePath - } - if _, ok := cfg["FileSuffix"]; !ok { - cfg["FileSuffix"] = FileCacheFileSuffix - } - if _, ok := cfg["DirectoryLevel"]; !ok { - cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) - } - if _, ok := cfg["EmbedExpiry"]; !ok { - cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) - } - fc.CachePath = cfg["CachePath"] - fc.FileSuffix = cfg["FileSuffix"] - fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) - fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) - fc.Init() - return nil + const cpKey = "CachePath" + const fsKey = "FileSuffix" + const dlKey = "DirectoryLevel" + const eeKey = "EmbedExpiry" + + if _, ok := cfg[cpKey]; !ok { + cfg[cpKey] = FileCachePath + } + + if _, ok := cfg[fsKey]; !ok { + cfg[fsKey] = FileCacheFileSuffix + } + + if _, ok := cfg[dlKey]; !ok { + cfg[dlKey] = strconv.Itoa(FileCacheDirectoryLevel) + } + + if _, ok := cfg[eeKey]; !ok { + cfg[eeKey] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) + } + fc.CachePath = cfg[cpKey] + fc.FileSuffix = cfg[fsKey] + fc.DirectoryLevel, err = strconv.Atoi(cfg[dlKey]) + if err != nil { + return berror.Wrapf(err, InvalidFileCacheDirectoryLevelCfg, + "invalid directory level config, please check your input, it must be integer: %s", cfg[dlKey]) + } + fc.EmbedExpiry, err = strconv.Atoi(cfg[eeKey]) + if err != nil { + return berror.Wrapf(err, InvalidFileCacheEmbedExpiryCfg, + "invalid embed expiry config, please check your input, it must be integer: %s", cfg[eeKey]) + } + return fc.Init() } // Init makes new a dir for file cache if it does not already exist -func (fc *FileCache) Init() { - if ok, _ := exists(fc.CachePath); !ok { // todo : error handle - _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle +func (fc *FileCache) Init() error { + ok, err := exists(fc.CachePath) + if err != nil || ok { + return err } + err = os.MkdirAll(fc.CachePath, os.ModePerm) + if err != nil { + return berror.Wrapf(err, CreateFileCacheDirFailed, + "could not create directory, please check the config [%s] and file mode.", fc.CachePath) + } + return nil } // getCachedFilename returns an md5 encoded file name. -func (fc *FileCache) getCacheFileName(key string) string { +func (fc *FileCache) getCacheFileName(key string) (string, error) { m := md5.New() - io.WriteString(m, key) + _, _ = io.WriteString(m, key) keyMd5 := hex.EncodeToString(m.Sum(nil)) cachePath := fc.CachePath switch fc.DirectoryLevel { @@ -113,18 +134,29 @@ func (fc *FileCache) getCacheFileName(key string) string { case 1: cachePath = filepath.Join(cachePath, keyMd5[0:2]) } - - if ok, _ := exists(cachePath); !ok { // todo : error handle - _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle + ok, err := exists(cachePath) + if err != nil { + return "", err + } + if !ok { + err = os.MkdirAll(cachePath, os.ModePerm) + if err != nil { + return "", berror.Wrapf(err, CreateFileCacheDirFailed, + "could not create the directory: %s", cachePath) + } } - return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) + return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)), nil } // Get value from file cache. // if nonexistent or expired return an empty string. func (fc *FileCache) Get(ctx context.Context, key string) (interface{}, error) { - fileData, err := FileGetContents(fc.getCacheFileName(key)) + fn, err := fc.getCacheFileName(key) + if err != nil { + return nil, err + } + fileData, err := FileGetContents(fn) if err != nil { return nil, err } @@ -136,7 +168,7 @@ func (fc *FileCache) Get(ctx context.Context, key string) (interface{}, error) { } if to.Expired.Before(time.Now()) { - return nil, errors.New("The key is expired") + return nil, ErrKeyExpired } return to.Data, nil } @@ -159,7 +191,7 @@ func (fc *FileCache) GetMulti(ctx context.Context, keys []string) ([]interface{} if len(keysErr) == 0 { return rc, nil } - return rc, errors.New(strings.Join(keysErr, "; ")) + return rc, berror.Error(MultiGetFailed, strings.Join(keysErr, "; ")) } // Put value into file cache. @@ -179,14 +211,26 @@ func (fc *FileCache) Put(ctx context.Context, key string, val interface{}, timeo if err != nil { return err } - return FilePutContents(fc.getCacheFileName(key), data) + + fn, err := fc.getCacheFileName(key) + if err != nil { + return err + } + return FilePutContents(fn, data) } // Delete file cache value. func (fc *FileCache) Delete(ctx context.Context, key string) error { - filename := fc.getCacheFileName(key) + filename, err := fc.getCacheFileName(key) + if err != nil { + return err + } if ok, _ := exists(filename); ok { - return os.Remove(filename) + err = os.Remove(filename) + if err != nil { + return berror.Wrapf(err, DeleteFileCacheItemFailed, + "can not delete this file cache key-value, key is %s and file name is %s", key, filename) + } } return nil } @@ -199,25 +243,12 @@ func (fc *FileCache) Incr(ctx context.Context, key string) error { return err } - var res interface{} - switch val := data.(type) { - case int: - res = val + 1 - case int32: - res = val + 1 - case int64: - res = val + 1 - case uint: - res = val + 1 - case uint32: - res = val + 1 - case uint64: - res = val + 1 - default: - return errors.Errorf("data is not (u)int (u)int32 (u)int64") + val, err := incr(data) + if err != nil { + return err } - return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) + return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry)) } // Decr decreases cached int value. @@ -227,43 +258,21 @@ func (fc *FileCache) Decr(ctx context.Context, key string) error { return err } - var res interface{} - switch val := data.(type) { - case int: - res = val - 1 - case int32: - res = val - 1 - case int64: - res = val - 1 - case uint: - if val > 0 { - res = val - 1 - } else { - return errors.New("data val is less than 0") - } - case uint32: - if val > 0 { - res = val - 1 - } else { - return errors.New("data val is less than 0") - } - case uint64: - if val > 0 { - res = val - 1 - } else { - return errors.New("data val is less than 0") - } - default: - return errors.Errorf("data is not (u)int (u)int32 (u)int64") + val, err := decr(data) + if err != nil { + return err } - return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) + return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry)) } // IsExist checks if value exists. func (fc *FileCache) IsExist(ctx context.Context, key string) (bool, error) { - ret, _ := exists(fc.getCacheFileName(key)) - return ret, nil + fn, err := fc.getCacheFileName(key) + if err != nil { + return false, err + } + return exists(fn) } // ClearAll cleans cached files (not implemented) @@ -280,13 +289,19 @@ func exists(path string) (bool, error) { if os.IsNotExist(err) { return false, nil } - return false, err + return false, berror.Wrapf(err, InvalidFileCachePath, "file cache path is invalid: %s", path) } // FileGetContents Reads bytes from a file. // if non-existent, create this file. -func FileGetContents(filename string) (data []byte, e error) { - return ioutil.ReadFile(filename) +func FileGetContents(filename string) ([]byte, error) { + data, err := ioutil.ReadFile(filename) + if err != nil { + return nil, berror.Wrapf(err, ReadFileCacheContentFailed, + "could not read the data from the file: %s, "+ + "please confirm that file exist and Beego has the permission to read the content.", filename) + } + return data, nil } // FilePutContents puts bytes into a file. @@ -301,16 +316,21 @@ func GobEncode(data interface{}) ([]byte, error) { enc := gob.NewEncoder(buf) err := enc.Encode(data) if err != nil { - return nil, err + return nil, berror.Wrap(err, GobEncodeDataFailed, "could not encode this data") } - return buf.Bytes(), err + return buf.Bytes(), nil } // GobDecode Gob decodes a file cache item. func GobDecode(data []byte, to *FileCacheItem) error { buf := bytes.NewBuffer(data) dec := gob.NewDecoder(buf) - return dec.Decode(&to) + err := dec.Decode(&to) + if err != nil { + return berror.Wrap(err, InvalidGobEncodedData, + "could not decode this data to FileCacheItem. Make sure that the data is encoded by GOB.") + } + return nil } func init() { diff --git a/client/cache/file_test.go b/client/cache/file_test.go new file mode 100644 index 00000000..f9a325d6 --- /dev/null +++ b/client/cache/file_test.go @@ -0,0 +1,108 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFileCacheStartAndGC(t *testing.T) { + fc := NewFileCache().(*FileCache) + err := fc.StartAndGC(`{`) + assert.NotNil(t, err) + err = fc.StartAndGC(`{}`) + assert.Nil(t, err) + + assert.Equal(t, fc.CachePath, FileCachePath) + assert.Equal(t, fc.DirectoryLevel, FileCacheDirectoryLevel) + assert.Equal(t, fc.EmbedExpiry, int(FileCacheEmbedExpiry)) + assert.Equal(t, fc.FileSuffix, FileCacheFileSuffix) + + err = fc.StartAndGC(`{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + // could not create dir + assert.NotNil(t, err) + + str := getTestCacheFilePath() + err = fc.StartAndGC(fmt.Sprintf(`{"CachePath":"%s","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`, str)) + assert.Nil(t, err) + assert.Equal(t, fc.CachePath, str) + assert.Equal(t, fc.DirectoryLevel, 2) + assert.Equal(t, fc.EmbedExpiry, 0) + assert.Equal(t, fc.FileSuffix, ".bin") + + err = fc.StartAndGC(fmt.Sprintf(`{"CachePath":"%s","FileSuffix":".bin","DirectoryLevel":"aaa","EmbedExpiry":"0"}`, str)) + assert.NotNil(t, err) + + err = fc.StartAndGC(fmt.Sprintf(`{"CachePath":"%s","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"aaa"}`, str)) + assert.NotNil(t, err) +} + +func TestFileCacheInit(t *testing.T) { + fc := NewFileCache().(*FileCache) + fc.CachePath = "////aaa" + err := fc.Init() + assert.NotNil(t, err) + fc.CachePath = getTestCacheFilePath() + err = fc.Init() + assert.Nil(t, err) +} + +func TestFileGetContents(t *testing.T) { + _, err := FileGetContents("/bin/aaa") + assert.NotNil(t, err) + fn := filepath.Join(os.TempDir(), "fileCache.txt") + f, err := os.Create(fn) + assert.Nil(t, err) + _, err = f.WriteString("text") + assert.Nil(t, err) + data, err := FileGetContents(fn) + assert.Nil(t, err) + assert.Equal(t, "text", string(data)) +} + +func TestGobEncodeDecode(t *testing.T) { + _, err := GobEncode(func() { + fmt.Print("test func") + }) + assert.NotNil(t, err) + data, err := GobEncode(&FileCacheItem{ + Data: "hello", + }) + assert.Nil(t, err) + err = GobDecode([]byte("wrong data"), &FileCacheItem{}) + assert.NotNil(t, err) + dci := &FileCacheItem{} + err = GobDecode(data, dci) + assert.Nil(t, err) + assert.Equal(t, "hello", dci.Data) +} + +func TestFileCacheDelete(t *testing.T) { + fc := NewFileCache() + err := fc.StartAndGC(`{}`) + assert.Nil(t, err) + err = fc.Delete(context.Background(), "my-key") + assert.Nil(t, err) +} + +func getTestCacheFilePath() string { + return filepath.Join(os.TempDir(), "test", "file.txt") +} diff --git a/client/cache/memcache/memcache.go b/client/cache/memcache/memcache.go index 527d08ca..90242bab 100644 --- a/client/cache/memcache/memcache.go +++ b/client/cache/memcache/memcache.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/cache/memcache" -// "github.com/beego/beego/v2/cache" +// _ "github.com/beego/beego/v2/client/cache/memcache" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) @@ -32,7 +32,6 @@ package memcache import ( "context" "encoding/json" - "errors" "fmt" "strings" "time" @@ -40,6 +39,7 @@ import ( "github.com/bradfitz/gomemcache/memcache" "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" ) // Cache Memcache adapter. @@ -55,36 +55,31 @@ func NewMemCache() cache.Cache { // Get get value from memcache. func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil, err - } - } if item, err := rc.conn.Get(key); err == nil { return item.Value, nil } else { - return nil, err + return nil, berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not read data from memcache, please check your key, network and connection. Root cause: %s", + err.Error()) } } // GetMulti gets a value from a key in memcache. func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { rv := make([]interface{}, len(keys)) - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return rv, err - } - } mv, err := rc.conn.GetMulti(keys) if err != nil { - return rv, err + return rv, berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not read multiple key-values from memcache, "+ + "please check your keys, network and connection. Root cause: %s", + err.Error()) } keysErr := make([]string, 0) for i, ki := range keys { if _, ok := mv[ki]; !ok { - keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "the key isn't exist")) + keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "key not exist")) continue } rv[i] = mv[ki].Value @@ -93,78 +88,54 @@ func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, er if len(keysErr) == 0 { return rv, nil } - return rv, fmt.Errorf(strings.Join(keysErr, "; ")) + return rv, berror.Error(cache.MultiGetFailed, strings.Join(keysErr, "; ")) } // Put puts a value into memcache. func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} if v, ok := val.([]byte); ok { item.Value = v } else if str, ok := val.(string); ok { item.Value = []byte(str) } else { - return errors.New("val only support string and []byte") + return berror.Errorf(cache.InvalidMemCacheValue, + "the value must be string or byte[]. key: %s, value:%v", key, val) } - return rc.conn.Set(&item) + return berror.Wrapf(rc.conn.Set(&item), cache.MemCacheCurdFailed, + "could not put key-value to memcache, key: %s", key) } // Delete deletes a value in memcache. func (rc *Cache) Delete(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.Delete(key) + return berror.Wrapf(rc.conn.Delete(key), cache.MemCacheCurdFailed, + "could not delete key-value from memcache, key: %s", key) } // Incr increases counter. func (rc *Cache) Incr(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Increment(key, 1) - return err + return berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not increase value for key: %s", key) } // Decr decreases counter. func (rc *Cache) Decr(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Decrement(key, 1) - return err + return berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not decrease value for key: %s", key) } // IsExist checks if a value exists in memcache. func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false, err - } - } - _, err := rc.conn.Get(key) + _, err := rc.Get(ctx, key) return err == nil, err } // ClearAll clears all cache in memcache. func (rc *Cache) ClearAll(context.Context) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.FlushAll() + return berror.Wrap(rc.conn.FlushAll(), cache.MemCacheCurdFailed, + "try to clear all key-value pairs failed") } // StartAndGC starts the memcache adapter. @@ -172,21 +143,15 @@ func (rc *Cache) ClearAll(context.Context) error { // If an error occurs during connecting, an error is returned func (rc *Cache) StartAndGC(config string) error { var cf map[string]string - json.Unmarshal([]byte(config), &cf) + if err := json.Unmarshal([]byte(config), &cf); err != nil { + return berror.Wrapf(err, cache.InvalidMemCacheCfg, + "could not unmarshal this config, it must be valid json stringP: %s", config) + } + if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") + return berror.Errorf(cache.InvalidMemCacheCfg, `config must contains "conn" field: %s`, config) } rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil -} - -// connect to memcache and keep the connection. -func (rc *Cache) connectInit() error { rc.conn = memcache.New(rc.conninfo...) return nil } diff --git a/client/cache/memcache/memcache_test.go b/client/cache/memcache/memcache_test.go index 083e661c..1f93cc3c 100644 --- a/client/cache/memcache/memcache_test.go +++ b/client/cache/memcache/memcache_test.go @@ -19,10 +19,12 @@ import ( "fmt" "os" "strconv" + "strings" "testing" "time" _ "github.com/bradfitz/gomemcache/memcache" + "github.com/stretchr/testify/assert" "github.com/beego/beego/v2/client/cache" ) @@ -34,78 +36,62 @@ func TestMemcacheCache(t *testing.T) { } bm, err := cache.NewCache("memcache", fmt.Sprintf(`{"conn": "%s"}`, addr)) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) + timeoutDuration := 10 * time.Second - if err = bm.Put(context.Background(), "astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "1", timeoutDuration)) + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) time.Sleep(11 * time.Second) - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("check err") - } - if err = bm.Put(context.Background(), "astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "1", timeoutDuration)) val, _ := bm.Get(context.Background(), "astaxie") - if v, err := strconv.Atoi(string(val.([]byte))); err != nil || v != 1 { - t.Error("get err") - } + v, err := strconv.Atoi(string(val.([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) - if err = bm.Incr(context.Background(), "astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, bm.Incr(context.Background(), "astaxie")) val, _ = bm.Get(context.Background(), "astaxie") - if v, err := strconv.Atoi(string(val.([]byte))); err != nil || v != 2 { - t.Error("get err") - } + v, err = strconv.Atoi(string(val.([]byte))) + assert.Nil(t, err) + assert.Equal(t, 2, v) - if err = bm.Decr(context.Background(), "astaxie"); err != nil { - t.Error("Decr Error", err) - } + assert.Nil(t, bm.Decr(context.Background(), "astaxie")) val, _ = bm.Get(context.Background(), "astaxie") - if v, err := strconv.Atoi(string(val.([]byte))); err != nil || v != 1 { - t.Error("get err") - } + v, err = strconv.Atoi(string(val.([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) bm.Delete(context.Background(), "astaxie") - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("delete err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) // test string - if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) val, _ = bm.Get(context.Background(), "astaxie") - if v := val.([]byte); string(v) != "author" { - t.Error("get err") - } + vs := val.([]byte) + assert.Equal(t, "author", string(vs)) // test GetMulti - if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { - t.Error("check err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { t.Error("GetMulti ERROR") } @@ -114,21 +100,14 @@ func TestMemcacheCache(t *testing.T) { } vv, err = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0] != nil { - t.Error("GetMulti ERROR") - } - if string(vv[1].([]byte)) != "author1" { - t.Error("GetMulti ERROR") - } - if err != nil && err.Error() == "key [astaxie0] error: key isn't exist" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + assert.Nil(t, vv[0]) + assert.Equal(t, "author1", string(vv[1].([]byte))) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key not exist")) + + assert.Nil(t, bm.ClearAll(context.Background())) // test clear all - if err = bm.ClearAll(context.Background()); err != nil { - t.Error("clear all err") - } } diff --git a/client/cache/memory.go b/client/cache/memory.go index 28c7d980..c1d1a2e5 100644 --- a/client/cache/memory.go +++ b/client/cache/memory.go @@ -17,17 +17,16 @@ package cache import ( "context" "encoding/json" - "errors" "fmt" "strings" "sync" "time" + + "github.com/beego/beego/v2/core/berror" ) -var ( - // Timer for how often to recycle the expired cache items in memory (in seconds) - DefaultEvery = 60 // 1 minute -) +// DefaultEvery sets a timer for how often to recycle the expired cache items in memory (in seconds) +var DefaultEvery = 60 // 1 minute // MemoryItem stores memory cache item. type MemoryItem struct { @@ -41,7 +40,7 @@ func (mi *MemoryItem) isExpire() bool { if mi.lifespan == 0 { return false } - return time.Now().Sub(mi.createdTime) > mi.lifespan + return time.Since(mi.createdTime) > mi.lifespan } // MemoryCache is a memory cache adapter. @@ -64,13 +63,14 @@ func NewMemoryCache() Cache { func (bc *MemoryCache) Get(ctx context.Context, key string) (interface{}, error) { bc.RLock() defer bc.RUnlock() - if itm, ok := bc.items[key]; ok { + if itm, ok := + bc.items[key]; ok { if itm.isExpire() { - return nil, errors.New("the key is expired") + return nil, ErrKeyExpired } return itm.val, nil } - return nil, errors.New("the key isn't exist") + return nil, ErrKeyNotExist } // GetMulti gets caches from memory. @@ -91,7 +91,7 @@ func (bc *MemoryCache) GetMulti(ctx context.Context, keys []string) ([]interface if len(keysErr) == 0 { return rc, nil } - return rc, errors.New(strings.Join(keysErr, "; ")) + return rc, berror.Error(MultiGetFailed, strings.Join(keysErr, "; ")) } // Put puts cache into memory. @@ -108,16 +108,11 @@ func (bc *MemoryCache) Put(ctx context.Context, key string, val interface{}, tim } // Delete cache in memory. +// If the key is not found, it will not return error func (bc *MemoryCache) Delete(ctx context.Context, key string) error { bc.Lock() defer bc.Unlock() - if _, ok := bc.items[key]; !ok { - return errors.New("key not exist") - } delete(bc.items, key) - if _, ok := bc.items[key]; ok { - return errors.New("delete key error") - } return nil } @@ -128,24 +123,14 @@ func (bc *MemoryCache) Incr(ctx context.Context, key string) error { defer bc.Unlock() itm, ok := bc.items[key] if !ok { - return errors.New("key not exist") + return ErrKeyNotExist } - switch val := itm.val.(type) { - case int: - itm.val = val + 1 - case int32: - itm.val = val + 1 - case int64: - itm.val = val + 1 - case uint: - itm.val = val + 1 - case uint32: - itm.val = val + 1 - case uint64: - itm.val = val + 1 - default: - return errors.New("item val is not (u)int (u)int32 (u)int64") + + val, err := incr(itm.val) + if err != nil { + return err } + itm.val = val return nil } @@ -155,36 +140,14 @@ func (bc *MemoryCache) Decr(ctx context.Context, key string) error { defer bc.Unlock() itm, ok := bc.items[key] if !ok { - return errors.New("key not exist") + return ErrKeyNotExist } - switch val := itm.val.(type) { - case int: - itm.val = val - 1 - case int64: - itm.val = val - 1 - case int32: - itm.val = val - 1 - case uint: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint32: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint64: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - default: - return errors.New("item val is not int int64 int32") + + val, err := decr(itm.val) + if err != nil { + return err } + itm.val = val return nil } @@ -209,7 +172,9 @@ func (bc *MemoryCache) ClearAll(context.Context) error { // StartAndGC starts memory cache. Checks expiration in every clock time. func (bc *MemoryCache) StartAndGC(config string) error { var cf map[string]int - json.Unmarshal([]byte(config), &cf) + if err := json.Unmarshal([]byte(config), &cf); err != nil { + return berror.Wrapf(err, InvalidMemoryCacheCfg, "invalid config, please check your input: %s", config) + } if _, ok := cf["interval"]; !ok { cf = make(map[string]int) cf["interval"] = DefaultEvery diff --git a/client/cache/module.go b/client/cache/module.go new file mode 100644 index 00000000..5a4e499e --- /dev/null +++ b/client/cache/module.go @@ -0,0 +1,17 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +const moduleName = "cache" diff --git a/client/cache/redis/redis.go b/client/cache/redis/redis.go index dcf0cd5a..bd244223 100644 --- a/client/cache/redis/redis.go +++ b/client/cache/redis/redis.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/cache/redis" -// "github.com/beego/beego/v2/cache" +// _ "github.com/beego/beego/v2/client/cache/redis" +// "github.com/beego/beego/v2/client/cache" // ) // // bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) @@ -32,7 +32,6 @@ package redis import ( "context" "encoding/json" - "errors" "fmt" "strconv" "strings" @@ -41,12 +40,11 @@ import ( "github.com/gomodule/redigo/redis" "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" ) -var ( - // The collection name of redis for the cache adapter. - DefaultKey = "beecacheRedis" -) +// DefaultKey defines the collection name of redis for the cache adapter. +var DefaultKey = "beecacheRedis" // Cache is Redis cache adapter. type Cache struct { @@ -67,15 +65,20 @@ func NewRedisCache() cache.Cache { } // Execute the redis commands. args[0] must be the key name -func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { - if len(args) < 1 { - return nil, errors.New("missing required arguments") - } +func (rc *Cache) do(commandName string, args ...interface{}) (interface{}, error) { args[0] = rc.associate(args[0]) c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() - return c.Do(commandName, args...) + reply, err := c.Do(commandName, args...) + if err != nil { + return nil, berror.Wrapf(err, cache.RedisCacheCurdFailed, + "could not execute this command: %s", commandName) + } + + return reply, nil } // associate with config key. @@ -95,7 +98,9 @@ func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { // GetMulti gets cache from redis. func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() var args []interface{} for _, key := range keys { args = append(args, rc.associate(key)) @@ -137,13 +142,16 @@ func (rc *Cache) Decr(ctx context.Context, key string) error { } // ClearAll deletes all cache in the redis collection +// Be careful about this method, because it scans all keys and the delete them one by one func (rc *Cache) ClearAll(context.Context) error { cachedKeys, err := rc.Scan(rc.key + ":*") if err != nil { return err } c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() for _, str := range cachedKeys { if _, err = c.Do("DEL", str); err != nil { return err @@ -155,7 +163,9 @@ func (rc *Cache) ClearAll(context.Context) error { // Scan scans all keys matching a given pattern. func (rc *Cache) Scan(pattern string) (keys []string, err error) { c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() var ( cursor uint64 = 0 // start result []interface{} @@ -186,13 +196,16 @@ func (rc *Cache) Scan(pattern string) (keys []string, err error) { // Cached items in redis are stored forever, no garbage collection happens func (rc *Cache) StartAndGC(config string) error { var cf map[string]string - json.Unmarshal([]byte(config), &cf) + err := json.Unmarshal([]byte(config), &cf) + if err != nil { + return berror.Wrapf(err, cache.InvalidRedisCacheCfg, "could not unmarshal the config: %s", config) + } if _, ok := cf["key"]; !ok { cf["key"] = DefaultKey } if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") + return berror.Wrapf(err, cache.InvalidRedisCacheCfg, "config missing conn field: %s", config) } // Format redis://@: @@ -229,9 +242,16 @@ func (rc *Cache) StartAndGC(config string) error { rc.connectInit() c := rc.p.Get() - defer c.Close() + defer func() { + _ = c.Close() + }() - return c.Err() + // test connection + if err = c.Err(); err != nil { + return berror.Wrapf(err, cache.InvalidConnection, + "can not connect to remote redis server, please check the connection info and network state: %s", config) + } + return nil } // connect to redis. @@ -239,19 +259,20 @@ func (rc *Cache) connectInit() { dialFunc := func() (c redis.Conn, err error) { c, err = redis.Dial("tcp", rc.conninfo) if err != nil { - return nil, err + return nil, berror.Wrapf(err, cache.DialFailed, + "could not dial to remote server: %s ", rc.conninfo) } if rc.password != "" { - if _, err := c.Do("AUTH", rc.password); err != nil { - c.Close() + if _, err = c.Do("AUTH", rc.password); err != nil { + _ = c.Close() return nil, err } } _, selecterr := c.Do("SELECT", rc.dbNum) if selecterr != nil { - c.Close() + _ = c.Close() return nil, selecterr } return diff --git a/client/cache/redis/redis_test.go b/client/cache/redis/redis_test.go index 3344bc34..9509816a 100644 --- a/client/cache/redis/redis_test.go +++ b/client/cache/redis/redis_test.go @@ -28,106 +28,81 @@ import ( ) func TestRedisCache(t *testing.T) { - redisAddr := os.Getenv("REDIS_ADDR") if redisAddr == "" { redisAddr = "127.0.0.1:6379" } bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, redisAddr)) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + assert.Nil(t, err) + timeoutDuration := 3 * time.Second - time.Sleep(11 * time.Second) + assert.Nil(t, bm.Put(context.Background(), "astaxie", 1, timeoutDuration)) - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("check err") - } - if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + time.Sleep(5 * time.Second) + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", 1, timeoutDuration)) val, _ := bm.Get(context.Background(), "astaxie") - if v, _ := redis.Int(val, err); v != 1 { - t.Error("get err") - } + v, _ := redis.Int(val, err) + assert.Equal(t, 1, v) - if err = bm.Incr(context.Background(), "astaxie"); err != nil { - t.Error("Incr Error", err) - } + assert.Nil(t, bm.Incr(context.Background(), "astaxie")) val, _ = bm.Get(context.Background(), "astaxie") - if v, _ := redis.Int(val, err); v != 2 { - t.Error("get err") - } + v, _ = redis.Int(val, err) + assert.Equal(t, 2, v) - if err = bm.Decr(context.Background(), "astaxie"); err != nil { - t.Error("Decr Error", err) - } + assert.Nil(t, bm.Decr(context.Background(), "astaxie")) val, _ = bm.Get(context.Background(), "astaxie") - if v, _ := redis.Int(val, err); v != 1 { - t.Error("get err") - } + v, _ = redis.Int(val, err) + assert.Equal(t, 1, v) bm.Delete(context.Background(), "astaxie") - if res, _ := bm.IsExist(context.Background(), "astaxie"); res { - t.Error("delete err") - } + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) // test string - if err = bm.Put(context.Background(), "astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { - t.Error("check err") - } + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) val, _ = bm.Get(context.Background(), "astaxie") - if v, _ := redis.String(val, err); v != "author" { - t.Error("get err") - } + vs, _ := redis.String(val, err) + assert.Equal(t, "author", vs) // test GetMulti - if err = bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := bm.IsExist(context.Background(), "astaxie1"); !res { - t.Error("check err") - } + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[0], nil); v != "author" { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[1], nil); v != "author1" { - t.Error("GetMulti ERROR") - } + assert.Equal(t, 2, len(vv)) + vs, _ = redis.String(vv[0], nil) + assert.Equal(t, "author", vs) + + vs, _ = redis.String(vv[1], nil) + assert.Equal(t, "author1", vs) vv, _ = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) - if vv[0] != nil { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[1], nil); v != "author1" { - t.Error("GetMulti ERROR") - } + assert.Nil(t, vv[0]) + + vs, _ = redis.String(vv[1], nil) + assert.Equal(t, "author1", vs) // test clear all - if err = bm.ClearAll(context.Background()); err != nil { - t.Error("clear all err") - } + assert.Nil(t, bm.ClearAll(context.Background())) } -func TestCache_Scan(t *testing.T) { +func TestCacheScan(t *testing.T) { timeoutDuration := 10 * time.Second addr := os.Getenv("REDIS_ADDR") @@ -137,35 +112,24 @@ func TestCache_Scan(t *testing.T) { // init bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, addr)) - if err != nil { - t.Error("init err") - } + + assert.Nil(t, err) // insert all for i := 0; i < 100; i++ { - if err = bm.Put(context.Background(), fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Nil(t, bm.Put(context.Background(), fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration)) } time.Sleep(time.Second) // scan all for the first time keys, err := bm.(*Cache).Scan(DefaultKey + ":*") - if err != nil { - t.Error("scan Error", err) - } + assert.Nil(t, err) assert.Equal(t, 100, len(keys), "scan all error") // clear all - if err = bm.ClearAll(context.Background()); err != nil { - t.Error("clear all err") - } + assert.Nil(t, bm.ClearAll(context.Background())) // scan all for the second time keys, err = bm.(*Cache).Scan(DefaultKey + ":*") - if err != nil { - t.Error("scan Error", err) - } - if len(keys) != 0 { - t.Error("scan all err") - } + assert.Nil(t, err) + assert.Equal(t, 0, len(keys)) } diff --git a/client/cache/ssdb/ssdb.go b/client/cache/ssdb/ssdb.go index 93fa9feb..54558ea3 100644 --- a/client/cache/ssdb/ssdb.go +++ b/client/cache/ssdb/ssdb.go @@ -3,7 +3,6 @@ package ssdb import ( "context" "encoding/json" - "errors" "fmt" "strconv" "strings" @@ -12,6 +11,7 @@ import ( "github.com/ssdb/gossdb/ssdb" "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" ) // Cache SSDB adapter @@ -27,31 +27,21 @@ func NewSsdbCache() cache.Cache { // Get gets a key's value from memcache. func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil, err - } - } value, err := rc.conn.Get(key) if err == nil { return value, nil } - return nil, err + return nil, berror.Wrapf(err, cache.SsdbCacheCurdFailed, "could not get value, key: %s", key) } // GetMulti gets one or keys values from ssdb. func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { size := len(keys) values := make([]interface{}, size) - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return values, err - } - } res, err := rc.conn.Do("multi_get", keys) if err != nil { - return values, err + return values, berror.Wrapf(err, cache.SsdbCacheCurdFailed, "multi_get failed, key: %v", keys) } resSize := len(res) @@ -63,14 +53,14 @@ func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, er keysErr := make([]string, 0) for i, ki := range keys { if _, ok := keyIdx[ki]; !ok { - keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "the key isn't exist")) + keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "key not exist")) continue } values[i] = res[keyIdx[ki]+1] } if len(keysErr) != 0 { - return values, fmt.Errorf(strings.Join(keysErr, "; ")) + return values, berror.Error(cache.MultiGetFailed, strings.Join(keysErr, "; ")) } return values, nil @@ -78,26 +68,16 @@ func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, er // DelMulti deletes one or more keys from memcache func (rc *Cache) DelMulti(keys []string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Do("multi_del", keys) - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "multi_del failed: %v", keys) } // Put puts value into memcache. // value: must be of type string func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } v, ok := val.(string) if !ok { - return errors.New("value must string") + return berror.Errorf(cache.InvalidSsdbCacheValue, "value must be string: %v", val) } var resp []string var err error @@ -108,72 +88,47 @@ func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout t resp, err = rc.conn.Do("setx", key, v, ttl) } if err != nil { - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "set or setx failed, key: %s", key) } if len(resp) == 2 && resp[0] == "ok" { return nil } - return errors.New("bad response") + return berror.Errorf(cache.SsdbBadResponse, "the response from SSDB server is invalid: %v", resp) } // Delete deletes a value in memcache. func (rc *Cache) Delete(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Del(key) - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "del failed: %s", key) } // Incr increases a key's counter. func (rc *Cache) Incr(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Do("incr", key, 1) - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "increase failed: %s", key) } // Decr decrements a key's counter. func (rc *Cache) Decr(ctx context.Context, key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } _, err := rc.conn.Do("incr", key, -1) - return err + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "decrease failed: %s", key) } // IsExist checks if a key exists in memcache. func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false, err - } - } resp, err := rc.conn.Do("exists", key) if err != nil { - return false, err + return false, berror.Wrapf(err, cache.SsdbCacheCurdFailed, "exists failed: %s", key) } if len(resp) == 2 && resp[1] == "1" { return true, nil } return false, nil - } -// ClearAll clears all cached items in memcache. +// ClearAll clears all cached items in ssdb. +// If there are many keys, this method may spent much time. func (rc *Cache) ClearAll(context.Context) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } keyStart, keyEnd, limit := "", "", 50 resp, err := rc.Scan(keyStart, keyEnd, limit) for err == nil { @@ -187,21 +142,16 @@ func (rc *Cache) ClearAll(context.Context) error { } _, e := rc.conn.Do("multi_del", keys) if e != nil { - return e + return berror.Wrapf(e, cache.SsdbCacheCurdFailed, "multi_del failed: %v", keys) } keyStart = resp[size-2] resp, err = rc.Scan(keyStart, keyEnd, limit) } - return err + return berror.Wrap(err, cache.SsdbCacheCurdFailed, "scan failed") } // Scan key all cached in ssdb. func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil, err - } - } resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) if err != nil { return nil, err @@ -214,30 +164,36 @@ func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, erro // If an error occurs during connection, an error is returned func (rc *Cache) StartAndGC(config string) error { var cf map[string]string - json.Unmarshal([]byte(config), &cf) + err := json.Unmarshal([]byte(config), &cf) + if err != nil { + return berror.Wrapf(err, cache.InvalidSsdbCacheCfg, + "unmarshal this config failed, it must be a valid json string: %s", config) + } if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") + return berror.Wrapf(err, cache.InvalidSsdbCacheCfg, + "Missing conn field: %s", config) } rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil + return rc.connectInit() } // connect to memcache and keep the connection. func (rc *Cache) connectInit() error { conninfoArray := strings.Split(rc.conninfo[0], ":") + if len(conninfoArray) < 2 { + return berror.Errorf(cache.InvalidSsdbCacheCfg, "The value of conn should be host:port: %s", rc.conninfo[0]) + } host := conninfoArray[0] port, e := strconv.Atoi(conninfoArray[1]) if e != nil { - return e + return berror.Errorf(cache.InvalidSsdbCacheCfg, "Port is invalid. It must be integer, %s", rc.conninfo[0]) } var err error - rc.conn, err = ssdb.Connect(host, port) - return err + if rc.conn, err = ssdb.Connect(host, port); err != nil { + return berror.Wrapf(err, cache.InvalidConnection, + "could not connect to SSDB, please check your connection info, network and firewall: %s", rc.conninfo[0]) + } + return nil } func init() { diff --git a/client/cache/ssdb/ssdb_test.go b/client/cache/ssdb/ssdb_test.go index 8ac1efd6..0a38b2de 100644 --- a/client/cache/ssdb/ssdb_test.go +++ b/client/cache/ssdb/ssdb_test.go @@ -5,128 +5,96 @@ import ( "fmt" "os" "strconv" + "strings" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/beego/beego/v2/client/cache" ) func TestSsdbcacheCache(t *testing.T) { - ssdbAddr := os.Getenv("SSDB_ADDR") if ssdbAddr == "" { ssdbAddr = "127.0.0.1:8888" } ssdb, err := cache.NewCache("ssdb", fmt.Sprintf(`{"conn": "%s"}`, ssdbAddr)) - if err != nil { - t.Error("init err") - } + assert.Nil(t, err) // test put and exist - if res, _ := ssdb.IsExist(context.Background(), "ssdb"); res { - t.Error("check err") - } - timeoutDuration := 10 * time.Second + res, _ := ssdb.IsExist(context.Background(), "ssdb") + assert.False(t, res) + timeoutDuration := 3 * time.Second // timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent - if err = ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if res, _ := ssdb.IsExist(context.Background(), "ssdb"); !res { - t.Error("check err") - } + + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration)) + + res, _ = ssdb.IsExist(context.Background(), "ssdb") + assert.True(t, res) // Get test done - if err = ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration)) - if v, _ := ssdb.Get(context.Background(), "ssdb"); v != "ssdb" { - t.Error("get Error") - } + v, _ := ssdb.Get(context.Background(), "ssdb") + assert.Equal(t, "ssdb", v) // inc/dec test done - if err = ssdb.Put(context.Background(), "ssdb", "2", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if err = ssdb.Incr(context.Background(), "ssdb"); err != nil { - t.Error("incr Error", err) - } + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "2", timeoutDuration)) + + assert.Nil(t, ssdb.Incr(context.Background(), "ssdb")) val, _ := ssdb.Get(context.Background(), "ssdb") - if v, err := strconv.Atoi(val.(string)); err != nil || v != 3 { - t.Error("get err") - } + v, err = strconv.Atoi(val.(string)) + assert.Nil(t, err) + assert.Equal(t, 3, v) - if err = ssdb.Decr(context.Background(), "ssdb"); err != nil { - t.Error("decr error") - } + assert.Nil(t, ssdb.Decr(context.Background(), "ssdb")) // test del - if err = ssdb.Put(context.Background(), "ssdb", "3", timeoutDuration); err != nil { - t.Error("set Error", err) - } + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "3", timeoutDuration)) val, _ = ssdb.Get(context.Background(), "ssdb") - if v, err := strconv.Atoi(val.(string)); err != nil || v != 3 { - t.Error("get err") - } - if err := ssdb.Delete(context.Background(), "ssdb"); err == nil { - if e, _ := ssdb.IsExist(context.Background(), "ssdb"); e { - t.Error("delete err") - } - } + v, err = strconv.Atoi(val.(string)) + assert.Equal(t, 3, v) + assert.Nil(t, err) + assert.Nil(t, ssdb.Delete(context.Background(), "ssdb")) + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "ssdb", -10*time.Second)) // test string - if err = ssdb.Put(context.Background(), "ssdb", "ssdb", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if res, _ := ssdb.IsExist(context.Background(), "ssdb"); !res { - t.Error("check err") - } - if v, _ := ssdb.Get(context.Background(), "ssdb"); v.(string) != "ssdb" { - t.Error("get err") - } + + res, _ = ssdb.IsExist(context.Background(), "ssdb") + assert.True(t, res) + + v, _ = ssdb.Get(context.Background(), "ssdb") + assert.Equal(t, "ssdb", v.(string)) // test GetMulti done - if err = ssdb.Put(context.Background(), "ssdb1", "ssdb1", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if res, _ := ssdb.IsExist(context.Background(), "ssdb1"); !res { - t.Error("check err") - } + assert.Nil(t, ssdb.Put(context.Background(), "ssdb1", "ssdb1", -10*time.Second)) + + res, _ = ssdb.IsExist(context.Background(), "ssdb1") + assert.True(t, res) vv, _ := ssdb.GetMulti(context.Background(), []string{"ssdb", "ssdb1"}) - if len(vv) != 2 { - t.Error("getmulti error") - } - if vv[0].(string) != "ssdb" { - t.Error("getmulti error") - } - if vv[1].(string) != "ssdb1" { - t.Error("getmulti error") - } + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "ssdb", vv[0]) + assert.Equal(t, "ssdb1", vv[1]) vv, err = ssdb.GetMulti(context.Background(), []string{"ssdb", "ssdb11"}) - if len(vv) != 2 { - t.Error("getmulti error") - } - if vv[0].(string) != "ssdb" { - t.Error("getmulti error") - } - if vv[1] != nil { - t.Error("getmulti error") - } - if err != nil && err.Error() != "key [ssdb11] error: the key isn't exist" { - t.Error("getmulti error") - } + + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "ssdb", vv[0]) + assert.Nil(t, vv[1]) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key not exist")) // test clear all done - if err = ssdb.ClearAll(context.Background()); err != nil { - t.Error("clear all err") - } + assert.Nil(t, ssdb.ClearAll(context.Background())) e1, _ := ssdb.IsExist(context.Background(), "ssdb") e2, _ := ssdb.IsExist(context.Background(), "ssdb1") - if e1 || e2 { - t.Error("check err") - } + assert.False(t, e1) + assert.False(t, e2) } diff --git a/client/httplib/README.md b/client/httplib/README.md index 90a6c505..a5723c6d 100644 --- a/client/httplib/README.md +++ b/client/httplib/README.md @@ -8,8 +8,8 @@ httplib is an libs help you to curl remote url. you can use Get to crawl data. - import "github.com/beego/beego/v2/httplib" - + import "github.com/beego/beego/v2/client/httplib" + str, err := httplib.Get("http://beego.me/").String() if err != nil { // error @@ -39,7 +39,7 @@ Example: // GET httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) - + // POST httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) @@ -95,4 +95,4 @@ httplib support mutil file upload, use `req.PostFile()` See godoc for further documentation and examples. -* [godoc.org/github.com/beego/beego/v2/httplib](https://godoc.org/github.com/beego/beego/v2/httplib) +* [godoc.org/github.com/beego/beego/v2/client/httplib](https://godoc.org/github.com/beego/beego/v2/client/httplib) diff --git a/client/httplib/client_option.go b/client/httplib/client_option.go new file mode 100644 index 00000000..f970e67d --- /dev/null +++ b/client/httplib/client_option.go @@ -0,0 +1,155 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "crypto/tls" + "net/http" + "net/url" + "time" +) + +type ( + ClientOption func(client *Client) + BeegoHTTPRequestOption func(request *BeegoHTTPRequest) +) + +// WithEnableCookie will enable cookie in all subsequent request +func WithEnableCookie(enable bool) ClientOption { + return func(client *Client) { + client.Setting.EnableCookie = enable + } +} + +// WithEnableCookie will adds UA in all subsequent request +func WithUserAgent(userAgent string) ClientOption { + return func(client *Client) { + client.Setting.UserAgent = userAgent + } +} + +// WithTLSClientConfig will adds tls config in all subsequent request +func WithTLSClientConfig(config *tls.Config) ClientOption { + return func(client *Client) { + client.Setting.TLSClientConfig = config + } +} + +// WithTransport will set transport field in all subsequent request +func WithTransport(transport http.RoundTripper) ClientOption { + return func(client *Client) { + client.Setting.Transport = transport + } +} + +// WithProxy will set http proxy field in all subsequent request +func WithProxy(proxy func(*http.Request) (*url.URL, error)) ClientOption { + return func(client *Client) { + client.Setting.Proxy = proxy + } +} + +// WithCheckRedirect will specifies the policy for handling redirects in all subsequent request +func WithCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) ClientOption { + return func(client *Client) { + client.Setting.CheckRedirect = redirect + } +} + +// WithHTTPSetting can replace beegoHTTPSeting +func WithHTTPSetting(setting BeegoHTTPSettings) ClientOption { + return func(client *Client) { + client.Setting = setting + } +} + +// WithEnableGzip will enable gzip in all subsequent request +func WithEnableGzip(enable bool) ClientOption { + return func(client *Client) { + client.Setting.Gzip = enable + } +} + +// BeegoHttpRequestOption + +// WithTimeout sets connect time out and read-write time out for BeegoRequest. +func WithTimeout(connectTimeout, readWriteTimeout time.Duration) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.SetTimeout(connectTimeout, readWriteTimeout) + } +} + +// WithHeader adds header item string in request. +func WithHeader(key, value string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Header(key, value) + } +} + +// WithCookie adds a cookie to the request. +func WithCookie(cookie *http.Cookie) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Header("Cookie", cookie.String()) + } +} + +// Withtokenfactory adds a custom function to set Authorization +func WithTokenFactory(tokenFactory func() string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + t := tokenFactory() + + request.Header("Authorization", t) + } +} + +// WithBasicAuth adds a custom function to set basic auth +func WithBasicAuth(basicAuth func() (string, string)) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + username, password := basicAuth() + request.SetBasicAuth(username, password) + } +} + +// WithFilters will use the filter as the invocation filters +func WithFilters(fcs ...FilterChain) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.SetFilters(fcs...) + } +} + +// WithContentType adds ContentType in header +func WithContentType(contentType string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Header(contentTypeKey, contentType) + } +} + +// WithParam adds query param in to request. +func WithParam(key, value string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Param(key, value) + } +} + +// WithRetry set retry times and delay for the request +// default is 0 (never retry) +// -1 retry indefinitely (forever) +// Other numbers specify the exact retry amount +func WithRetry(times int, delay time.Duration) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Retries(times) + request.RetryDelay(delay) + } +} diff --git a/client/httplib/client_option_test.go b/client/httplib/client_option_test.go new file mode 100644 index 00000000..416f092b --- /dev/null +++ b/client/httplib/client_option_test.go @@ -0,0 +1,261 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "errors" + "net" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type respCarrier struct { + bytes []byte +} + +func (r *respCarrier) SetBytes(bytes []byte) { + r.bytes = bytes +} + +func (r *respCarrier) String() string { + return string(r.bytes) +} + +func TestOptionWithEnableCookie(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/", + WithEnableCookie(true)) + if err != nil { + t.Fatal(err) + } + + v := "smallfish" + resp := &respCarrier{} + err = client.Get(resp, "/cookies/set?k1="+v) + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + err = client.Get(resp, "/cookies") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in cookie") + } +} + +func TestOptionWithUserAgent(t *testing.T) { + v := "beego" + client, err := NewClient("test", "http://httpbin.org/", + WithUserAgent(v)) + if err != nil { + t.Fatal(err) + } + + resp := &respCarrier{} + err = client.Get(resp, "/headers") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestOptionWithCheckRedirect(t *testing.T) { + client, err := NewClient("test", "https://goolnk.com/33BD2j", + WithCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + })) + if err != nil { + t.Fatal(err) + } + err = client.Get(nil, "") + assert.NotNil(t, err) +} + +func TestOptionWithHTTPSetting(t *testing.T) { + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + + client, err := NewClient("test", "http://httpbin.org/", + WithHTTPSetting(setting)) + if err != nil { + t.Fatal(err) + } + + resp := &respCarrier{} + err = client.Get(resp, "/get") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in user-agent") + } +} + +func TestOptionWithHeader(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + client.CommonOpts = append(client.CommonOpts, WithHeader("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36")) + + resp := &respCarrier{} + err = client.Get(resp, "/headers") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), "Mozilla/5.0") + if n == -1 { + t.Fatal("Mozilla/5.0 not found in user-agent") + } +} + +func TestOptionWithTokenFactory(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + client.CommonOpts = append(client.CommonOpts, + WithTokenFactory(func() string { + return "testauth" + })) + + resp := &respCarrier{} + err = client.Get(resp, "/headers") + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), "testauth") + if n == -1 { + t.Fatal("Auth is not set in request") + } +} + +func TestOptionWithBasicAuth(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + + resp := &respCarrier{} + err = client.Get(resp, "/basic-auth/user/passwd", + WithBasicAuth(func() (string, string) { + return "user", "passwd" + })) + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + n := strings.Index(resp.String(), "authenticated") + if n == -1 { + t.Fatal("authenticated not found in response") + } +} + +func TestOptionWithContentType(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + + v := "application/json" + resp := &respCarrier{} + err = client.Get(resp, "/headers", WithContentType(v)) + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in header") + } +} + +func TestOptionWithParam(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + + v := "smallfish" + resp := &respCarrier{} + err = client.Get(resp, "/get", WithParam("username", v)) + if err != nil { + t.Fatal(err) + } + t.Log(resp.String()) + + n := strings.Index(resp.String(), v) + if n == -1 { + t.Fatal(v + " not found in header") + } +} + +func TestOptionWithRetry(t *testing.T) { + client, err := NewClient("test", "https://goolnk.com/33BD2j", + WithCheckRedirect(func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + })) + if err != nil { + t.Fatal(err) + } + + retryAmount := 1 + retryDelay := 800 * time.Millisecond + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _ = client.Get(nil, "", WithRetry(retryAmount, retryDelay)) + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } +} diff --git a/client/httplib/error_code.go b/client/httplib/error_code.go new file mode 100644 index 00000000..177419ad --- /dev/null +++ b/client/httplib/error_code.go @@ -0,0 +1,134 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "github.com/beego/beego/v2/core/berror" +) + +var InvalidUrl = berror.DefineCode(4001001, moduleName, "InvalidUrl", ` +You pass a invalid url to httplib module. Please check your url, be careful about special character. +`) + +var InvalidUrlProtocolVersion = berror.DefineCode(4001002, moduleName, "InvalidUrlProtocolVersion", ` +You pass a invalid protocol version. In practice, we use HTTP/1.0, HTTP/1.1, HTTP/1.2 +But something like HTTP/3.2 is valid for client, and the major version is 3, minor version is 2. +but you must confirm that server support those abnormal protocol version. +`) + +var UnsupportedBodyType = berror.DefineCode(4001003, moduleName, "UnsupportedBodyType", ` +You use a invalid data as request body. +For now, we only support type string and byte[]. +`) + +var InvalidXMLBody = berror.DefineCode(4001004, moduleName, "InvalidXMLBody", ` +You pass invalid data which could not be converted to XML documents. In general, if you pass structure, it works well. +Sometimes you got XML document and you want to make it as request body. So you call XMLBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +var InvalidYAMLBody = berror.DefineCode(4001005, moduleName, "InvalidYAMLBody", ` +You pass invalid data which could not be converted to YAML documents. In general, if you pass structure, it works well. +Sometimes you got YAML document and you want to make it as request body. So you call YAMLBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +var InvalidJSONBody = berror.DefineCode(4001006, moduleName, "InvalidJSONBody", ` +You pass invalid data which could not be converted to JSON documents. In general, if you pass structure, it works well. +Sometimes you got JSON document and you want to make it as request body. So you call JSONBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +// start with 5 -------------------------------------------------------------------------- + +var CreateFormFileFailed = berror.DefineCode(5001001, moduleName, "CreateFormFileFailed", ` +In normal case than handling files with BeegoRequest, you should not see this error code. +Unexpected EOF, invalid characters, bad file descriptor may cause this error. +`) + +var ReadFileFailed = berror.DefineCode(5001002, moduleName, "ReadFileFailed", ` +There are several cases that cause this error: +1. file not found. Please check the file name; +2. file not found, but file name is correct. If you use relative file path, it's very possible for you to see this code. +make sure that this file is in correct directory which Beego looks for; +3. Beego don't have the privilege to read this file, please change file mode; +`) + +var CopyFileFailed = berror.DefineCode(5001003, moduleName, "CopyFileFailed", ` +When we try to read file content and then copy it to another writer, and failed. +1. Unexpected EOF; +2. Bad file descriptor; +3. Write conflict; + +Please check your file content, and confirm that file is not processed by other process (or by user manually). +`) + +var CloseFileFailed = berror.DefineCode(5001004, moduleName, "CloseFileFailed", ` +After handling files, Beego try to close file but failed. Usually it was caused by bad file descriptor. +`) + +var SendRequestFailed = berror.DefineCode(5001005, moduleName, "SendRequestRetryExhausted", ` +Beego send HTTP request, but it failed. +If you config retry times, it means that Beego had retried and failed. +When you got this error, there are vary kind of reason: +1. Network unstable and timeout. In this case, sometimes server has received the request. +2. Server error. Make sure that server works well. +3. The request is invalid, which means that you pass some invalid parameter. +`) + +var ReadGzipBodyFailed = berror.DefineCode(5001006, moduleName, "BuildGzipReaderFailed", ` +Beego parse gzip-encode body failed. Usually Beego got invalid response. +Please confirm that server returns gzip data. +`) + +var CreateFileIfNotExistFailed = berror.DefineCode(5001007, moduleName, "CreateFileIfNotExist", ` +Beego want to create file if not exist and failed. +In most cases, it means that Beego doesn't have the privilege to create this file. +Please change file mode to ensure that Beego is able to create files on specific directory. +Or you can run Beego with higher authority. +In some cases, you pass invalid filename. Make sure that the file name is valid on your system. +`) + +var UnmarshalJSONResponseToObjectFailed = berror.DefineCode(5001008, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid json document +`) + +var UnmarshalXMLResponseToObjectFailed = berror.DefineCode(5001009, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid XML document +`) + +var UnmarshalYAMLResponseToObjectFailed = berror.DefineCode(5001010, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid YAML document +`) + +var UnmarshalResponseToObjectFailed = berror.DefineCode(5001011, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +There are several cases that cause this error: +1. You pass valid structure pointer to the function; +2. The body is valid json, Yaml or XML document +`) diff --git a/client/httplib/filter/log/filter.go b/client/httplib/filter/log/filter.go new file mode 100644 index 00000000..9d2e09d3 --- /dev/null +++ b/client/httplib/filter/log/filter.go @@ -0,0 +1,130 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +import ( + "context" + "io" + "net/http" + "net/http/httputil" + + "github.com/beego/beego/v2/client/httplib" + "github.com/beego/beego/v2/core/logs" +) + +// FilterChainBuilder can build a log filter +type FilterChainBuilder struct { + printableContentTypes []string // only print the body of included mime types of request and response + log func(f interface{}, v ...interface{}) // custom log function +} + +// BuilderOption option constructor +type BuilderOption func(*FilterChainBuilder) + +type logInfo struct { + req []byte + resp []byte + err error +} + +var defaultprintableContentTypes = []string{ + "text/plain", "text/xml", "text/html", "text/csv", + "text/calendar", "text/javascript", "text/javascript", + "text/css", +} + +// NewFilterChainBuilder initialize a filterChainBuilder, pass options to customize +func NewFilterChainBuilder(opts ...BuilderOption) *FilterChainBuilder { + res := &FilterChainBuilder{ + printableContentTypes: defaultprintableContentTypes, + log: logs.Debug, + } + for _, o := range opts { + o(res) + } + + return res +} + +// WithLog return option constructor modify log function +func WithLog(f func(f interface{}, v ...interface{})) BuilderOption { + return func(h *FilterChainBuilder) { + h.log = f + } +} + +// WithprintableContentTypes return option constructor modify printableContentTypes +func WithprintableContentTypes(types []string) BuilderOption { + return func(h *FilterChainBuilder) { + h.printableContentTypes = types + } +} + +// FilterChain can print the request after FilterChain processing and response before processsing +func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + info := &logInfo{} + defer info.print(builder.log) + resp, err := next(ctx, req) + info.err = err + contentType := req.GetRequest().Header.Get("Content-Type") + shouldPrintBody := builder.shouldPrintBody(contentType, req.GetRequest().Body) + dump, err := httputil.DumpRequest(req.GetRequest(), shouldPrintBody) + info.req = dump + if err != nil { + logs.Error(err) + } + if resp != nil { + contentType = resp.Header.Get("Content-Type") + shouldPrintBody = builder.shouldPrintBody(contentType, resp.Body) + dump, err = httputil.DumpResponse(resp, shouldPrintBody) + info.resp = dump + if err != nil { + logs.Error(err) + } + } + return resp, err + } +} + +func (builder *FilterChainBuilder) shouldPrintBody(contentType string, body io.ReadCloser) bool { + if contains(builder.printableContentTypes, contentType) { + return true + } + if body != nil { + logs.Warn("printableContentTypes do not contain %s, if you want to print request and response body please add it.", contentType) + } + return false +} + +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} + +func (info *logInfo) print(log func(f interface{}, v ...interface{})) { + log("Request: ====================") + log("%q", info.req) + log("Response: ===================") + log("%q", info.resp) + if info.err != nil { + log("Error: ======================") + log("%q", info.err) + } +} diff --git a/client/httplib/filter/log/filter_test.go b/client/httplib/filter/log/filter_test.go new file mode 100644 index 00000000..b17dd395 --- /dev/null +++ b/client/httplib/filter/log/filter_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestFilterChain(t *testing.T) { + next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + time.Sleep(100 * time.Millisecond) + return &http.Response{ + StatusCode: 404, + }, nil + } + builder := NewFilterChainBuilder() + filter := builder.FilterChain(next) + req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") + resp, err := filter(context.Background(), req) + assert.NotNil(t, resp) + assert.Nil(t, err) +} + +func TestContains(t *testing.T) { + jsonType := "application/json" + cases := []struct { + Name string + Types []string + ContentType string + Expected bool + }{ + {"case1", []string{jsonType}, jsonType, true}, + {"case2", []string{"text/plain"}, jsonType, false}, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + if ans := contains(c.Types, c.ContentType); ans != c.Expected { + t.Fatalf("Types: %v, ContentType: %v, expected %v, but %v got", + c.Types, c.ContentType, c.Expected, ans) + } + }) + } +} diff --git a/client/httplib/filter/opentracing/filter.go b/client/httplib/filter/opentracing/filter.go index cde50261..aef20e66 100644 --- a/client/httplib/filter/opentracing/filter.go +++ b/client/httplib/filter/opentracing/filter.go @@ -21,20 +21,21 @@ import ( logKit "github.com/go-kit/kit/log" opentracingKit "github.com/go-kit/kit/tracing/opentracing" "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/log" "github.com/beego/beego/v2/client/httplib" ) type FilterChainBuilder struct { + // TagURL true will tag span with url + TagURL bool // CustomSpanFunc users are able to custom their span CustomSpanFunc func(span opentracing.Span, ctx context.Context, req *httplib.BeegoHTTPRequest, resp *http.Response, err error) } func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { - return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { - method := req.GetRequest().Method operationName := method + "#" + req.GetRequest().URL.String() @@ -50,13 +51,19 @@ func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filt } span.SetTag("http.method", method) span.SetTag("peer.hostname", req.GetRequest().URL.Host) - span.SetTag("http.url", req.GetRequest().URL.String()) + span.SetTag("http.scheme", req.GetRequest().URL.Scheme) span.SetTag("span.kind", "client") span.SetTag("component", "beego") + + if builder.TagURL { + span.SetTag("http.url", req.GetRequest().URL.String()) + } + span.LogFields(log.String("http.url", req.GetRequest().URL.String())) + if err != nil { span.SetTag("error", true) - span.SetTag("message", err.Error()) + span.LogFields(log.String("message", err.Error())) } else if resp != nil && !(resp.StatusCode < 300 && resp.StatusCode >= 200) { span.SetTag("error", true) } diff --git a/client/httplib/filter/opentracing/filter_test.go b/client/httplib/filter/opentracing/filter_test.go index b9c1e1e2..a9b9cbb0 100644 --- a/client/httplib/filter/opentracing/filter_test.go +++ b/client/httplib/filter/opentracing/filter_test.go @@ -26,14 +26,16 @@ import ( "github.com/beego/beego/v2/client/httplib" ) -func TestFilterChainBuilder_FilterChain(t *testing.T) { +func TestFilterChainBuilderFilterChain(t *testing.T) { next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { time.Sleep(100 * time.Millisecond) return &http.Response{ StatusCode: 404, }, errors.New("hello") } - builder := &FilterChainBuilder{} + builder := &FilterChainBuilder{ + TagURL: true, + } filter := builder.FilterChain(next) req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") resp, err := filter(context.Background(), req) diff --git a/client/httplib/filter/prometheus/filter.go b/client/httplib/filter/prometheus/filter.go index 5761eb7e..e93b2298 100644 --- a/client/httplib/filter/prometheus/filter.go +++ b/client/httplib/filter/prometheus/filter.go @@ -32,11 +32,12 @@ type FilterChainBuilder struct { RunMode string } -var summaryVec prometheus.ObserverVec -var initSummaryVec sync.Once +var ( + summaryVec prometheus.ObserverVec + initSummaryVec sync.Once +) func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { - initSummaryVec.Do(func() { summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ Name: "beego", diff --git a/client/httplib/filter/prometheus/filter_test.go b/client/httplib/filter/prometheus/filter_test.go index 1e7935d0..4a7b29f2 100644 --- a/client/httplib/filter/prometheus/filter_test.go +++ b/client/httplib/filter/prometheus/filter_test.go @@ -25,7 +25,7 @@ import ( "github.com/beego/beego/v2/client/httplib" ) -func TestFilterChainBuilder_FilterChain(t *testing.T) { +func TestFilterChainBuilderFilterChain(t *testing.T) { next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { time.Sleep(100 * time.Millisecond) return &http.Response{ diff --git a/client/httplib/http_response.go b/client/httplib/http_response.go new file mode 100644 index 00000000..89930cb1 --- /dev/null +++ b/client/httplib/http_response.go @@ -0,0 +1,39 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" +) + +// NewHttpResponseWithJsonBody will try to convert the data to json format +// usually you only use this when you want to mock http Response +func NewHttpResponseWithJsonBody(data interface{}) *http.Response { + var body []byte + if str, ok := data.(string); ok { + body = []byte(str) + } else if bts, ok := data.([]byte); ok { + body = bts + } else { + body, _ = json.Marshal(data) + } + return &http.Response{ + ContentLength: int64(len(body)), + Body: ioutil.NopCloser(bytes.NewReader(body)), + } +} diff --git a/client/httplib/http_response_test.go b/client/httplib/http_response_test.go new file mode 100644 index 00000000..a62dd42c --- /dev/null +++ b/client/httplib/http_response_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewHttpResponseWithJsonBody(t *testing.T) { + // string + resp := NewHttpResponseWithJsonBody("{}") + assert.Equal(t, int64(2), resp.ContentLength) + + resp = NewHttpResponseWithJsonBody([]byte("{}")) + assert.Equal(t, int64(2), resp.ContentLength) + + resp = NewHttpResponseWithJsonBody(&user{ + Name: "Tom", + }) + assert.True(t, resp.ContentLength > 0) +} diff --git a/client/httplib/httpclient.go b/client/httplib/httpclient.go new file mode 100644 index 00000000..c2a61fcf --- /dev/null +++ b/client/httplib/httpclient.go @@ -0,0 +1,174 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" +) + +// Client provides an HTTP client supporting chain call +type Client struct { + Name string + Endpoint string + CommonOpts []BeegoHTTPRequestOption + + Setting BeegoHTTPSettings +} + +// HTTPResponseCarrier If value implement HTTPResponseCarrier. http.Response will pass to SetHTTPResponse +type HTTPResponseCarrier interface { + SetHTTPResponse(resp *http.Response) +} + +// HTTPBodyCarrier If value implement HTTPBodyCarrier. http.Response.Body will pass to SetReader +type HTTPBodyCarrier interface { + SetReader(r io.ReadCloser) +} + +// HTTPBytesCarrier If value implement HTTPBytesCarrier. +// All the byte in http.Response.Body will pass to SetBytes +type HTTPBytesCarrier interface { + SetBytes(bytes []byte) +} + +// HTTPStatusCarrier If value implement HTTPStatusCarrier. http.Response.StatusCode will pass to SetStatusCode +type HTTPStatusCarrier interface { + SetStatusCode(status int) +} + +// HttpHeaderCarrier If value implement HttpHeaderCarrier. http.Response.Header will pass to SetHeader +type HTTPHeadersCarrier interface { + SetHeader(header map[string][]string) +} + +// NewClient return a new http client +func NewClient(name string, endpoint string, opts ...ClientOption) (*Client, error) { + res := &Client{ + Name: name, + Endpoint: endpoint, + } + setting := GetDefaultSetting() + res.Setting = setting + for _, o := range opts { + o(res) + } + return res, nil +} + +func (c *Client) customReq(req *BeegoHTTPRequest, opts []BeegoHTTPRequestOption) { + req.Setting(c.Setting) + opts = append(c.CommonOpts, opts...) + for _, o := range opts { + o(req) + } +} + +// handleResponse try to parse body to meaningful value +func (c *Client) handleResponse(value interface{}, req *BeegoHTTPRequest) error { + // make sure req.resp is not nil + _, err := req.Bytes() + if err != nil { + return err + } + + err = c.handleCarrier(value, req) + if err != nil { + return err + } + + return req.ToValue(value) +} + +// handleCarrier set http data to value +func (c *Client) handleCarrier(value interface{}, req *BeegoHTTPRequest) error { + if value == nil { + return nil + } + + if carrier, ok := value.(HTTPResponseCarrier); ok { + b, err := req.Bytes() + if err != nil { + return err + } + req.resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + carrier.SetHTTPResponse(req.resp) + } + if carrier, ok := value.(HTTPBodyCarrier); ok { + b, err := req.Bytes() + if err != nil { + return err + } + reader := ioutil.NopCloser(bytes.NewReader(b)) + carrier.SetReader(reader) + } + if carrier, ok := value.(HTTPBytesCarrier); ok { + b, err := req.Bytes() + if err != nil { + return err + } + carrier.SetBytes(b) + } + if carrier, ok := value.(HTTPStatusCarrier); ok { + carrier.SetStatusCode(req.resp.StatusCode) + } + if carrier, ok := value.(HTTPHeadersCarrier); ok { + carrier.SetHeader(req.resp.Header) + } + return nil +} + +// Get Send a GET request and try to give its result value +func (c *Client) Get(value interface{}, path string, opts ...BeegoHTTPRequestOption) error { + req := Get(c.Endpoint + path) + c.customReq(req, opts) + return c.handleResponse(value, req) +} + +// Post Send a POST request and try to give its result value +func (c *Client) Post(value interface{}, path string, body interface{}, opts ...BeegoHTTPRequestOption) error { + req := Post(c.Endpoint + path) + c.customReq(req, opts) + if body != nil { + req = req.Body(body) + } + return c.handleResponse(value, req) +} + +// Put Send a Put request and try to give its result value +func (c *Client) Put(value interface{}, path string, body interface{}, opts ...BeegoHTTPRequestOption) error { + req := Put(c.Endpoint + path) + c.customReq(req, opts) + if body != nil { + req = req.Body(body) + } + return c.handleResponse(value, req) +} + +// Delete Send a Delete request and try to give its result value +func (c *Client) Delete(value interface{}, path string, opts ...BeegoHTTPRequestOption) error { + req := Delete(c.Endpoint + path) + c.customReq(req, opts) + return c.handleResponse(value, req) +} + +// Head Send a Head request and try to give its result value +func (c *Client) Head(value interface{}, path string, opts ...BeegoHTTPRequestOption) error { + req := Head(c.Endpoint + path) + c.customReq(req, opts) + return c.handleResponse(value, req) +} diff --git a/client/httplib/httpclient_test.go b/client/httplib/httpclient_test.go new file mode 100644 index 00000000..d1de3828 --- /dev/null +++ b/client/httplib/httpclient_test.go @@ -0,0 +1,220 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "encoding/xml" + "io" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewClient(t *testing.T) { + client, err := NewClient("test1", "http://beego.me", WithEnableCookie(true)) + assert.NoError(t, err) + assert.NotNil(t, client) + assert.Equal(t, true, client.Setting.EnableCookie) +} + +type slideShowResponse struct { + Resp *http.Response + bytes []byte + StatusCode int + Body io.ReadCloser + Header map[string][]string + + Slideshow slideshow `json:"slideshow" yaml:"slideshow"` +} + +func (r *slideShowResponse) SetHTTPResponse(resp *http.Response) { + r.Resp = resp +} + +func (r *slideShowResponse) SetBytes(bytes []byte) { + r.bytes = bytes +} + +func (r *slideShowResponse) SetReader(reader io.ReadCloser) { + r.Body = reader +} + +func (r *slideShowResponse) SetStatusCode(status int) { + r.StatusCode = status +} + +func (r *slideShowResponse) SetHeader(header map[string][]string) { + r.Header = header +} + +func (r *slideShowResponse) String() string { + return string(r.bytes) +} + +type slideshow struct { + XMLName xml.Name `xml:"slideshow"` + + Title string `json:"title" yaml:"title" xml:"title,attr"` + Author string `json:"author" yaml:"author" xml:"author,attr"` + Date string `json:"date" yaml:"date" xml:"date,attr"` + Slides []slide `json:"slides" yaml:"slides" xml:"slide"` +} + +type slide struct { + XMLName xml.Name `xml:"slide"` + + Title string `json:"title" yaml:"title" xml:"title"` +} + +func TestClientHandleCarrier(t *testing.T) { + v := "beego" + client, err := NewClient("test", "http://httpbin.org/", + WithUserAgent(v)) + if err != nil { + t.Fatal(err) + } + + s := &slideShowResponse{} + err = client.Get(s, "/json") + if err != nil { + t.Fatal(err) + } + defer s.Body.Close() + + assert.NotNil(t, s.Resp) + assert.NotNil(t, s.Body) + assert.Equal(t, "429", s.Header["Content-Length"][0]) + assert.Equal(t, 200, s.StatusCode) + + b, err := ioutil.ReadAll(s.Body) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 429, len(b)) + assert.Equal(t, s.String(), string(b)) +} + +func TestClientGet(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org/") + if err != nil { + t.Fatal(err) + } + + // json + var s *slideShowResponse + err = client.Get(&s, "/json") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "Sample Slide Show", s.Slideshow.Title) + assert.Equal(t, 2, len(s.Slideshow.Slides)) + assert.Equal(t, "Overview", s.Slideshow.Slides[1].Title) + + // xml + var ssp *slideshow + err = client.Get(&ssp, "/base64/PD94bWwgPz48c2xpZGVzaG93CnRpdGxlPSJTYW1wbGUgU2xpZGUgU2hvdyIKZGF0ZT0iRGF0ZSBvZiBwdWJsaWNhdGlvbiIKYXV0aG9yPSJZb3VycyBUcnVseSI+PHNsaWRlIHR5cGU9ImFsbCI+PHRpdGxlPldha2UgdXAgdG8gV29uZGVyV2lkZ2V0cyE8L3RpdGxlPjwvc2xpZGU+PHNsaWRlIHR5cGU9ImFsbCI+PHRpdGxlPk92ZXJ2aWV3PC90aXRsZT48aXRlbT5XaHkgPGVtPldvbmRlcldpZGdldHM8L2VtPiBhcmUgZ3JlYXQ8L2l0ZW0+PGl0ZW0vPjxpdGVtPldobyA8ZW0+YnV5czwvZW0+IFdvbmRlcldpZGdldHM8L2l0ZW0+PC9zbGlkZT48L3NsaWRlc2hvdz4=") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "Sample Slide Show", ssp.Title) + assert.Equal(t, 2, len(ssp.Slides)) + assert.Equal(t, "Overview", ssp.Slides[1].Title) + + // yaml + s = nil + err = client.Get(&s, "/base64/c2xpZGVzaG93OgogIGF1dGhvcjogWW91cnMgVHJ1bHkKICBkYXRlOiBkYXRlIG9mIHB1YmxpY2F0aW9uCiAgc2xpZGVzOgogIC0gdGl0bGU6IFdha2UgdXAgdG8gV29uZGVyV2lkZ2V0cyEKICAgIHR5cGU6IGFsbAogIC0gaXRlbXM6CiAgICAtIFdoeSA8ZW0+V29uZGVyV2lkZ2V0czwvZW0+IGFyZSBncmVhdAogICAgLSBXaG8gPGVtPmJ1eXM8L2VtPiBXb25kZXJXaWRnZXRzCiAgICB0aXRsZTogT3ZlcnZpZXcKICAgIHR5cGU6IGFsbAogIHRpdGxlOiBTYW1wbGUgU2xpZGUgU2hvdw==") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "Sample Slide Show", s.Slideshow.Title) + assert.Equal(t, 2, len(s.Slideshow.Slides)) + assert.Equal(t, "Overview", s.Slideshow.Slides[1].Title) +} + +func TestClientPost(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org") + if err != nil { + t.Fatal(err) + } + + resp := &slideShowResponse{} + err = client.Get(resp, "/json") + if err != nil { + t.Fatal(err) + } + + jsonStr := resp.String() + err = client.Post(resp, "/post", jsonStr) + if err != nil { + t.Fatal(err) + } + assert.NotNil(t, resp) + assert.Equal(t, http.MethodPost, resp.Resp.Request.Method) +} + +func TestClientPut(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org") + if err != nil { + t.Fatal(err) + } + + resp := &slideShowResponse{} + err = client.Get(resp, "/json") + if err != nil { + t.Fatal(err) + } + + jsonStr := resp.String() + err = client.Put(resp, "/put", jsonStr) + if err != nil { + t.Fatal(err) + } + assert.NotNil(t, resp) + assert.Equal(t, http.MethodPut, resp.Resp.Request.Method) +} + +func TestClientDelete(t *testing.T) { + client, err := NewClient("test", "http://httpbin.org") + if err != nil { + t.Fatal(err) + } + + resp := &slideShowResponse{} + err = client.Delete(resp, "/delete") + if err != nil { + t.Fatal(err) + } + defer resp.Resp.Body.Close() + assert.NotNil(t, resp) + assert.Equal(t, http.MethodDelete, resp.Resp.Request.Method) +} + +func TestClientHead(t *testing.T) { + client, err := NewClient("test", "http://beego.me") + if err != nil { + t.Fatal(err) + } + + resp := &slideShowResponse{} + err = client.Head(resp, "") + if err != nil { + t.Fatal(err) + } + defer resp.Resp.Body.Close() + assert.NotNil(t, resp) + assert.Equal(t, http.MethodHead, resp.Resp.Request.Method) +} diff --git a/client/httplib/httplib.go b/client/httplib/httplib.go index 56c50cd2..5be4598b 100644 --- a/client/httplib/httplib.go +++ b/client/httplib/httplib.go @@ -15,7 +15,7 @@ // Package httplib is used as http.Client // Usage: // -// import "github.com/beego/beego/v2/httplib" +// import "github.com/beego/beego/v2/client/httplib" // // b := httplib.Post("http://beego.me/") // b.Param("username","astaxie") @@ -40,58 +40,37 @@ import ( "encoding/xml" "io" "io/ioutil" - "log" "mime/multipart" "net" "net/http" - "net/http/cookiejar" - "net/http/httputil" "net/url" "os" "path" "strings" - "sync" "time" "gopkg.in/yaml.v2" + + "github.com/beego/beego/v2/core/berror" + "github.com/beego/beego/v2/core/logs" ) -var defaultSetting = BeegoHTTPSettings{ - UserAgent: "beegoServer", - ConnectTimeout: 60 * time.Second, - ReadWriteTimeout: 60 * time.Second, - Gzip: true, - DumpBody: true, -} - -var defaultCookieJar http.CookieJar -var settingMutex sync.Mutex +const contentTypeKey = "Content-Type" // it will be the last filter and execute request.Do var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { return req.doRequest(ctx) } -// createDefaultCookie creates a global cookiejar to store cookies. -func createDefaultCookie() { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultCookieJar, _ = cookiejar.New(nil) -} - -// SetDefaultSetting overwrites default settings -func SetDefaultSetting(setting BeegoHTTPSettings) { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultSetting = setting -} - // NewBeegoRequest returns *BeegoHttpRequest with specific method +// TODO add error as return value +// I think if we don't return error +// users are hard to check whether we create Beego request successfully func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { var resp http.Response u, err := url.Parse(rawurl) if err != nil { - log.Println("Httplib:", err) + logs.Error("%+v", berror.Wrapf(err, InvalidUrl, "invalid raw url: %s", rawurl)) } req := http.Request{ URL: u, @@ -136,24 +115,6 @@ func Head(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "HEAD") } -// BeegoHTTPSettings is the http.Client setting -type BeegoHTTPSettings struct { - ShowDebug bool - UserAgent string - ConnectTimeout time.Duration - ReadWriteTimeout time.Duration - TLSClientConfig *tls.Config - Proxy func(*http.Request) (*url.URL, error) - Transport http.RoundTripper - CheckRedirect func(req *http.Request, via []*http.Request) error - EnableCookie bool - Gzip bool - DumpBody bool - Retries int // if set to -1 means will retry forever - RetryDelay time.Duration - FilterChains []FilterChain -} - // BeegoHTTPRequest provides more useful methods than http.Request for requesting a url. type BeegoHTTPRequest struct { url string @@ -163,7 +124,6 @@ type BeegoHTTPRequest struct { setting BeegoHTTPSettings resp *http.Response body []byte - dump []byte } // GetRequest returns the request object @@ -195,12 +155,6 @@ func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { return b } -// Debug sets show debug or not when executing request. -func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { - b.setting.ShowDebug = isdebug - return b -} - // Retries sets Retries times. // default is 0 (never retry) // -1 retry indefinitely (forever) @@ -216,17 +170,6 @@ func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { return b } -// DumpBody sets the DumbBody field -func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { - b.setting.DumpBody = isdump - return b -} - -// DumpRequest returns the DumpRequest -func (b *BeegoHTTPRequest) DumpRequest() []byte { - return b.dump -} - // SetTimeout sets connect time out and read-write time out for BeegoRequest. func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { b.setting.ConnectTimeout = connectTimeout @@ -253,9 +196,9 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { } // SetProtocolVersion sets the protocol version for incoming requests. -// Client requests always use HTTP/1.1. +// Client requests always use HTTP/1.1 func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { - if len(vers) == 0 { + if vers == "" { vers = "HTTP/1.1" } @@ -264,8 +207,9 @@ func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { b.req.Proto = vers b.req.ProtoMajor = major b.req.ProtoMinor = minor + return b } - + logs.Error("%+v", berror.Errorf(InvalidUrlProtocolVersion, "invalid protocol: %s", vers)) return b } @@ -314,6 +258,12 @@ func (b *BeegoHTTPRequest) AddFilters(fcs ...FilterChain) *BeegoHTTPRequest { return b } +// SetEscapeHTML is used to set the flag whether escape HTML special characters during processing +func (b *BeegoHTTPRequest) SetEscapeHTML(isEscape bool) *BeegoHTTPRequest { + b.setting.EscapeHTML = isEscape + return b +} + // Param adds query param in to request. // params build query string as ?key1=value1&key2=value2... func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { @@ -333,16 +283,25 @@ func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest // Body adds request raw body. // Supports string and []byte. +// TODO return error if data is invalid func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { switch t := data.(type) { case string: bf := bytes.NewBufferString(t) b.req.Body = ioutil.NopCloser(bf) + b.req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bf), nil + } b.req.ContentLength = int64(len(t)) case []byte: bf := bytes.NewBuffer(t) b.req.Body = ioutil.NopCloser(bf) + b.req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bf), nil + } b.req.ContentLength = int64(len(t)) + default: + logs.Error("%+v", berror.Errorf(UnsupportedBodyType, "unsupported body data type: %s", t)) } return b } @@ -352,11 +311,14 @@ func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := xml.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidXMLBody, "obj could not be converted to XML data") } b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) + b.req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader(byts)), nil + } b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/xml") + b.req.Header.Set(contentTypeKey, "application/xml") } return b, nil } @@ -366,11 +328,11 @@ func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) if b.req.Body == nil && obj != nil { byts, err := yaml.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidYAMLBody, "obj could not be converted to YAML data") } b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/x+yaml") + b.req.Header.Set(contentTypeKey, "application/x+yaml") } return b, nil } @@ -378,17 +340,28 @@ func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) // JSONBody adds the request raw body encoded in JSON. func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { - byts, err := json.Marshal(obj) + byts, err := b.JSONMarshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidJSONBody, "obj could not be converted to JSON body") } b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/json") + b.req.Header.Set(contentTypeKey, "application/json") } return b, nil } +func (b *BeegoHTTPRequest) JSONMarshal(obj interface{}) ([]byte, error) { + bf := bytes.NewBuffer([]byte{}) + jsonEncoder := json.NewEncoder(bf) + jsonEncoder.SetEscapeHTML(b.setting.EscapeHTML) + err := jsonEncoder.Encode(obj) + if err != nil { + return nil, err + } + return bf.Bytes(), nil +} + func (b *BeegoHTTPRequest) buildURL(paramBody string) { // build GET url with query string if b.req.Method == "GET" && len(paramBody) > 0 { @@ -404,47 +377,60 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) { if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { // with files if len(b.files) > 0 { - pr, pw := io.Pipe() - bodyWriter := multipart.NewWriter(pw) - go func() { - for formname, filename := range b.files { - fileWriter, err := bodyWriter.CreateFormFile(formname, filename) - if err != nil { - log.Println("Httplib:", err) - } - fh, err := os.Open(filename) - if err != nil { - log.Println("Httplib:", err) - } - // iocopy - _, err = io.Copy(fileWriter, fh) - fh.Close() - if err != nil { - log.Println("Httplib:", err) - } - } - for k, v := range b.params { - for _, vv := range v { - bodyWriter.WriteField(k, vv) - } - } - bodyWriter.Close() - pw.Close() - }() - b.Header("Content-Type", bodyWriter.FormDataContentType()) - b.req.Body = ioutil.NopCloser(pr) - b.Header("Transfer-Encoding", "chunked") + b.handleFiles() return } // with params if len(paramBody) > 0 { - b.Header("Content-Type", "application/x-www-form-urlencoded") + b.Header(contentTypeKey, "application/x-www-form-urlencoded") b.Body(paramBody) } } } +func (b *BeegoHTTPRequest) handleFiles() { + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) + go func() { + for formname, filename := range b.files { + b.handleFileToBody(bodyWriter, formname, filename) + } + for k, v := range b.params { + for _, vv := range v { + _ = bodyWriter.WriteField(k, vv) + } + } + _ = bodyWriter.Close() + _ = pw.Close() + }() + b.Header(contentTypeKey, bodyWriter.FormDataContentType()) + b.req.Body = ioutil.NopCloser(pr) + b.Header("Transfer-Encoding", "chunked") +} + +func (b *BeegoHTTPRequest) handleFileToBody(bodyWriter *multipart.Writer, formname string, filename string) { + fileWriter, err := bodyWriter.CreateFormFile(formname, filename) + const errFmt = "Httplib: %+v" + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CreateFormFileFailed, + "could not create form file, formname: %s, filename: %s", formname, filename)) + } + fh, err := os.Open(filename) + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, ReadFileFailed, "could not open this file %s", filename)) + } + // iocopy + _, err = io.Copy(fileWriter, fh) + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CopyFileFailed, "could not copy this file %s", filename)) + } + err = fh.Close() + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CloseFileFailed, "could not close this file %s", filename)) + } +} + func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { if b.resp.StatusCode != 0 { return b.resp, nil @@ -463,7 +449,6 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { } func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Response, err error) { - root := doRequestFilter if len(b.setting.FilterChains) > 0 { for i := len(b.setting.FilterChains) - 1; i >= 0; i-- { @@ -473,62 +458,20 @@ func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Res return root(ctx, b) } -func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, err error) { - var paramBody string - if len(b.params) > 0 { - var buf bytes.Buffer - for k, v := range b.params { - for _, vv := range v { - buf.WriteString(url.QueryEscape(k)) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(vv)) - buf.WriteByte('&') - } - } - paramBody = buf.String() - paramBody = paramBody[0 : len(paramBody)-1] - } +func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (*http.Response, error) { + paramBody := b.buildParamBody() b.buildURL(paramBody) urlParsed, err := url.Parse(b.url) if err != nil { - return nil, err + return nil, berror.Wrapf(err, InvalidUrl, "parse url failed, the url is %s", b.url) } b.req.URL = urlParsed - trans := b.setting.Transport + trans := b.buildTrans() - if trans == nil { - // create default transport - trans = &http.Transport{ - TLSClientConfig: b.setting.TLSClientConfig, - Proxy: b.setting.Proxy, - Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), - MaxIdleConnsPerHost: 100, - } - } else { - // if b.transport is *http.Transport then set the settings. - if t, ok := trans.(*http.Transport); ok { - if t.TLSClientConfig == nil { - t.TLSClientConfig = b.setting.TLSClientConfig - } - if t.Proxy == nil { - t.Proxy = b.setting.Proxy - } - if t.Dial == nil { - t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) - } - } - } - - var jar http.CookieJar - if b.setting.EnableCookie { - if defaultCookieJar == nil { - createDefaultCookie() - } - jar = defaultCookieJar - } + jar := b.buildCookieJar() client := &http.Client{ Transport: trans, @@ -543,13 +486,10 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, client.CheckRedirect = b.setting.CheckRedirect } - if b.setting.ShowDebug { - dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) - if err != nil { - log.Println(err.Error()) - } - b.dump = dump - } + return b.sendRequest(client) +} + +func (b *BeegoHTTPRequest) sendRequest(client *http.Client) (resp *http.Response, err error) { // retries default value is 0, it will run once. // retries equal to -1, it will run forever until success // retries is setted, it will retries fixed times. @@ -557,11 +497,66 @@ func (b *BeegoHTTPRequest) doRequest(ctx context.Context) (resp *http.Response, for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { resp, err = client.Do(b.req) if err == nil { - break + return } time.Sleep(b.setting.RetryDelay) } - return resp, err + return nil, berror.Wrap(err, SendRequestFailed, "sending request fail") +} + +func (b *BeegoHTTPRequest) buildCookieJar() http.CookieJar { + var jar http.CookieJar + if b.setting.EnableCookie { + if defaultCookieJar == nil { + createDefaultCookie() + } + jar = defaultCookieJar + } + return jar +} + +func (b *BeegoHTTPRequest) buildTrans() http.RoundTripper { + trans := b.setting.Transport + + if trans == nil { + // create default transport + trans = &http.Transport{ + TLSClientConfig: b.setting.TLSClientConfig, + Proxy: b.setting.Proxy, + DialContext: TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), + MaxIdleConnsPerHost: 100, + } + } else if t, ok := trans.(*http.Transport); ok { + // if b.transport is *http.Transport then set the settings. + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.setting.TLSClientConfig + } + if t.Proxy == nil { + t.Proxy = b.setting.Proxy + } + if t.DialContext == nil { + t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) + } + } + return trans +} + +func (b *BeegoHTTPRequest) buildParamBody() string { + var paramBody string + if len(b.params) > 0 { + var buf bytes.Buffer + for k, v := range b.params { + for _, vv := range v { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(vv)) + buf.WriteByte('&') + } + } + paramBody = buf.String() + paramBody = paramBody[0 : len(paramBody)-1] + } + return paramBody } // String returns the body string in response. @@ -592,10 +587,10 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { reader, err := gzip.NewReader(resp.Body) if err != nil { - return nil, err + return nil, berror.Wrap(err, ReadGzipBodyFailed, "building gzip reader failed") } b.body, err = ioutil.ReadAll(reader) - return b.body, err + return b.body, berror.Wrap(err, ReadGzipBodyFailed, "reading gzip data failed") } b.body, err = ioutil.ReadAll(resp.Body) return b.body, err @@ -638,7 +633,7 @@ func pathExistAndMkdir(filename string) (err error) { return nil } } - return err + return berror.Wrapf(err, CreateFileIfNotExistFailed, "try to create(if not exist) failed: %s", filename) } // ToJSON returns the map that marshals from the body bytes as json in response. @@ -648,7 +643,8 @@ func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { if err != nil { return err } - return json.Unmarshal(data, v) + return berror.Wrap(json.Unmarshal(data, v), + UnmarshalJSONResponseToObjectFailed, "unmarshal json body to object failed.") } // ToXML returns the map that marshals from the body bytes as xml in response . @@ -658,7 +654,8 @@ func (b *BeegoHTTPRequest) ToXML(v interface{}) error { if err != nil { return err } - return xml.Unmarshal(data, v) + return berror.Wrap(xml.Unmarshal(data, v), + UnmarshalXMLResponseToObjectFailed, "unmarshal xml body to object failed.") } // ToYAML returns the map that marshals from the body bytes as yaml in response . @@ -668,7 +665,42 @@ func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { if err != nil { return err } - return yaml.Unmarshal(data, v) + return berror.Wrap(yaml.Unmarshal(data, v), + UnmarshalYAMLResponseToObjectFailed, "unmarshal yaml body to object failed.") +} + +// ToValue attempts to resolve the response body to value using an existing method. +// Calls Response inner. +// If response header contain Content-Type, func will call ToJSON\ToXML\ToYAML. +// Else it will try to parse body as json\yaml\xml, If all attempts fail, an error will be returned +func (b *BeegoHTTPRequest) ToValue(value interface{}) error { + if value == nil { + return nil + } + + contentType := strings.Split(b.resp.Header.Get(contentTypeKey), ";")[0] + // try to parse it as content type + switch contentType { + case "application/json": + return b.ToJSON(value) + case "text/xml", "application/xml": + return b.ToXML(value) + case "text/yaml", "application/x-yaml", "application/x+yaml": + return b.ToYAML(value) + } + + // try to parse it anyway + if err := b.ToJSON(value); err == nil { + return nil + } + if err := b.ToYAML(value); err == nil { + return nil + } + if err := b.ToXML(value); err == nil { + return nil + } + + return berror.Error(UnmarshalResponseToObjectFailed, "unmarshal body to object failed.") } // Response executes request client gets response manually. @@ -677,8 +709,18 @@ func (b *BeegoHTTPRequest) Response() (*http.Response, error) { } // TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +// Deprecated +// we will move this at the end of 2021 +// please use TimeoutDialerCtx func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { return func(netw, addr string) (net.Conn, error) { + return TimeoutDialerCtx(cTimeout, rwTimeout)(context.Background(), netw, addr) + } +} + +func TimeoutDialerCtx(cTimeout time.Duration, + rwTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) { + return func(ctx context.Context, netw, addr string) (net.Conn, error) { conn, err := net.DialTimeout(netw, addr, cTimeout) if err != nil { return nil, err diff --git a/client/httplib/httplib_test.go b/client/httplib/httplib_test.go index d0f826cb..471be02c 100644 --- a/client/httplib/httplib_test.go +++ b/client/httplib/httplib_test.go @@ -15,8 +15,10 @@ package httplib import ( + "bytes" "context" "errors" + "fmt" "io/ioutil" "net" "net/http" @@ -62,7 +64,6 @@ func TestDoRequest(t *testing.T) { if elapsedTime < delayedTime { t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) } - } func TestGet(t *testing.T) { @@ -247,7 +248,6 @@ func TestToJson(t *testing.T) { t.Fatal("response is not valid ip") } } - } func TestToFile(t *testing.T) { @@ -259,7 +259,7 @@ func TestToFile(t *testing.T) { } defer os.Remove(f) b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { + if n := bytes.Index(b, []byte("origin")); n == -1 { t.Fatal(err) } } @@ -273,7 +273,7 @@ func TestToFileDir(t *testing.T) { } defer os.RemoveAll("./files") b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { + if n := bytes.Index(b, []byte("origin")); n == -1 { t.Fatal(err) } } @@ -300,3 +300,149 @@ func TestAddFilter(t *testing.T) { r := Get("http://beego.me") assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains)) } + +func TestFilterChainOrder(t *testing.T) { + req := Get("http://beego.me") + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("first"), nil + } + }) + + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("second"), nil + } + }) + + resp, err := req.DoRequestWithCtx(context.Background()) + assert.Nil(t, err) + data := make([]byte, 5) + _, _ = resp.Body.Read(data) + assert.Equal(t, "first", string(data)) +} + +func TestHead(t *testing.T) { + req := Head("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "HEAD", req.req.Method) +} + +func TestDelete(t *testing.T) { + req := Delete("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "DELETE", req.req.Method) +} + +func TestPost(t *testing.T) { + req := Post("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "POST", req.req.Method) +} + +func TestNewBeegoRequest(t *testing.T) { + req := NewBeegoRequest("http://beego.me", "GET") + assert.NotNil(t, req) + assert.Equal(t, "GET", req.req.Method) + + // invalid case but still go request + req = NewBeegoRequest("httpa\ta://beego.me", "GET") + assert.NotNil(t, req) +} + +func TestBeegoHTTPRequestSetProtocolVersion(t *testing.T) { + req := NewBeegoRequest("http://beego.me", "GET") + req.SetProtocolVersion("HTTP/3.10") + assert.Equal(t, "HTTP/3.10", req.req.Proto) + assert.Equal(t, 3, req.req.ProtoMajor) + assert.Equal(t, 10, req.req.ProtoMinor) + + req.SetProtocolVersion("") + assert.Equal(t, "HTTP/1.1", req.req.Proto) + assert.Equal(t, 1, req.req.ProtoMajor) + assert.Equal(t, 1, req.req.ProtoMinor) + + // invalid case + req.SetProtocolVersion("HTTP/aaa1.1") + assert.Equal(t, "HTTP/1.1", req.req.Proto) + assert.Equal(t, 1, req.req.ProtoMajor) + assert.Equal(t, 1, req.req.ProtoMinor) +} + +func TestPut(t *testing.T) { + req := Put("http://beego.me") + assert.NotNil(t, req) + assert.Equal(t, "PUT", req.req.Method) +} + +func TestBeegoHTTPRequestHeader(t *testing.T) { + req := Post("http://beego.me") + key, value := "test-header", "test-header-value" + req.Header(key, value) + assert.Equal(t, value, req.req.Header.Get(key)) +} + +func TestBeegoHTTPRequestSetHost(t *testing.T) { + req := Post("http://beego.me") + host := "test-hose" + req.SetHost(host) + assert.Equal(t, host, req.req.Host) +} + +func TestBeegoHTTPRequestParam(t *testing.T) { + req := Post("http://beego.me") + key, value := "test-param", "test-param-value" + req.Param(key, value) + assert.Equal(t, value, req.params[key][0]) + + value1 := "test-param-value-1" + req.Param(key, value1) + assert.Equal(t, value1, req.params[key][1]) +} + +func TestBeegoHTTPRequestBody(t *testing.T) { + req := Post("http://beego.me") + body := `hello, world` + req.Body([]byte(body)) + assert.Equal(t, int64(len(body)), req.req.ContentLength) + assert.NotNil(t, req.req.GetBody) + assert.NotNil(t, req.req.Body) + + body = "hhhh, I am test" + req.Body(body) + assert.Equal(t, int64(len(body)), req.req.ContentLength) + assert.NotNil(t, req.req.GetBody) + assert.NotNil(t, req.req.Body) + + // invalid case + req.Body(13) +} + +type user struct { + Name string `xml:"name"` +} + +func TestBeegoHTTPRequestXMLBody(t *testing.T) { + req := Post("http://beego.me") + body := &user{ + Name: "Tom", + } + _, err := req.XMLBody(body) + assert.True(t, req.req.ContentLength > 0) + assert.Nil(t, err) + assert.NotNil(t, req.req.GetBody) +} + +// TODO +func TestBeegoHTTPRequestResponseForValue(t *testing.T) { +} + +func TestBeegoHTTPRequestJSONMarshal(t *testing.T) { + req := Post("http://beego.me") + req.SetEscapeHTML(false) + body := map[string]interface{}{ + "escape": "left&right", + } + b, _ := req.JSONMarshal(body) + assert.Equal(t, fmt.Sprintf(`{"escape":"left&right"}%s`, "\n"), string(b)) +} diff --git a/client/httplib/mock/mock.go b/client/httplib/mock/mock.go new file mode 100644 index 00000000..421a7a45 --- /dev/null +++ b/client/httplib/mock/mock.go @@ -0,0 +1,78 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "net/http" + + "github.com/beego/beego/v2/client/httplib" + "github.com/beego/beego/v2/core/logs" +) + +const mockCtxKey = "beego-httplib-mock" + +func init() { + InitMockSetting() +} + +type Stub interface { + Mock(cond RequestCondition, resp *http.Response, err error) + Clear() + MockByPath(path string, resp *http.Response, err error) +} + +var mockFilter = &MockResponseFilter{} + +func InitMockSetting() { + httplib.AddDefaultFilter(mockFilter.FilterChain) +} + +func StartMock() Stub { + return mockFilter +} + +func CtxWithMock(ctx context.Context, mock ...*Mock) context.Context { + return context.WithValue(ctx, mockCtxKey, mock) +} + +func mockFromCtx(ctx context.Context) []*Mock { + ms := ctx.Value(mockCtxKey) + if ms != nil { + if res, ok := ms.([]*Mock); ok { + return res + } + logs.Error("mockCtxKey found in context, but value is not type []*Mock") + } + return nil +} + +type Mock struct { + cond RequestCondition + resp *http.Response + err error +} + +func NewMockByPath(path string, resp *http.Response, err error) *Mock { + return NewMock(NewSimpleCondition(path), resp, err) +} + +func NewMock(con RequestCondition, resp *http.Response, err error) *Mock { + return &Mock{ + cond: con, + resp: resp, + err: err, + } +} diff --git a/client/httplib/mock/mock_condition.go b/client/httplib/mock/mock_condition.go new file mode 100644 index 00000000..912699ff --- /dev/null +++ b/client/httplib/mock/mock_condition.go @@ -0,0 +1,176 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "encoding/json" + "net/textproto" + "regexp" + + "github.com/beego/beego/v2/client/httplib" +) + +type RequestCondition interface { + Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool +} + +// reqCondition create condition +// - path: same path +// - pathReg: request path match pathReg +// - method: same method +// - Query parameters (key, value) +// - header (key, value) +// - Body json format, contains specific (key, value). +type SimpleCondition struct { + pathReg string + path string + method string + query map[string]string + header map[string]string + body map[string]interface{} +} + +func NewSimpleCondition(path string, opts ...simpleConditionOption) *SimpleCondition { + sc := &SimpleCondition{ + path: path, + query: make(map[string]string), + header: make(map[string]string), + body: map[string]interface{}{}, + } + for _, o := range opts { + o(sc) + } + return sc +} + +func (sc *SimpleCondition) Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + var res bool + if len(sc.path) > 0 { + res = sc.matchPath(ctx, req) + } else if len(sc.pathReg) > 0 { + res = sc.matchPathReg(ctx, req) + } else { + return false + } + return res && + sc.matchMethod(ctx, req) && + sc.matchQuery(ctx, req) && + sc.matchHeader(ctx, req) && + sc.matchBodyFields(ctx, req) +} + +func (sc *SimpleCondition) matchPath(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + path := req.GetRequest().URL.Path + return path == sc.path +} + +func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + path := req.GetRequest().URL.Path + if b, err := regexp.Match(sc.pathReg, []byte(path)); err == nil { + return b + } + return false +} + +func (sc *SimpleCondition) matchQuery(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + qs := req.GetRequest().URL.Query() + for k, v := range sc.query { + if uv, ok := qs[k]; !ok || uv[0] != v { + return false + } + } + return true +} + +func (sc *SimpleCondition) matchHeader(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + headers := req.GetRequest().Header + for k, v := range sc.header { + if uv, ok := headers[k]; !ok || uv[0] != v { + return false + } + } + return true +} + +func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + if len(sc.body) == 0 { + return true + } + getBody := req.GetRequest().GetBody + body, err := getBody() + if err != nil { + return false + } + bytes := make([]byte, req.GetRequest().ContentLength) + _, err = body.Read(bytes) + if err != nil { + return false + } + + m := make(map[string]interface{}) + + err = json.Unmarshal(bytes, &m) + + if err != nil { + return false + } + + for k, v := range sc.body { + if uv, ok := m[k]; !ok || uv != v { + return false + } + } + return true +} + +func (sc *SimpleCondition) matchMethod(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + if len(sc.method) > 0 { + return sc.method == req.GetRequest().Method + } + return true +} + +type simpleConditionOption func(sc *SimpleCondition) + +func WithPathReg(pathReg string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.pathReg = pathReg + } +} + +func WithQuery(key, value string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.query[key] = value + } +} + +func WithHeader(key, value string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.header[textproto.CanonicalMIMEHeaderKey(key)] = value + } +} + +func WithJsonBodyFields(field string, value interface{}) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.body[field] = value + } +} + +func WithMethod(method string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.method = method + } +} diff --git a/client/httplib/mock/mock_condition_test.go b/client/httplib/mock/mock_condition_test.go new file mode 100644 index 00000000..79bc7ad8 --- /dev/null +++ b/client/httplib/mock/mock_condition_test.go @@ -0,0 +1,119 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestSimpleConditionMatchPath(t *testing.T) { + sc := NewSimpleCondition("/abc/s") + res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s")) + assert.True(t, res) +} + +func TestSimpleConditionMatchQuery(t *testing.T) { + k, v := "my-key", "my-value" + sc := NewSimpleCondition("/abc/s") + res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value")) + assert.True(t, res) + + sc = NewSimpleCondition("/abc/s", WithQuery(k, v)) + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value")) + assert.True(t, res) + + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-valuesss")) + assert.False(t, res) + + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key-a=my-value")) + assert.False(t, res) + + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value&abc=hello")) + assert.True(t, res) +} + +func TestSimpleConditionMatchHeader(t *testing.T) { + k, v := "my-header", "my-header-value" + sc := NewSimpleCondition("/abc/s") + req := httplib.Get("http://localhost:8080/abc/s") + assert.True(t, sc.Match(context.Background(), req)) + + req = httplib.Get("http://localhost:8080/abc/s") + req.Header(k, v) + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithHeader(k, v)) + req.Header(k, v) + assert.True(t, sc.Match(context.Background(), req)) + + req.Header(k, "invalid") + assert.False(t, sc.Match(context.Background(), req)) +} + +func TestSimpleConditionMatchBodyField(t *testing.T) { + sc := NewSimpleCondition("/abc/s") + req := httplib.Post("http://localhost:8080/abc/s") + + assert.True(t, sc.Match(context.Background(), req)) + + req.Body(`{ + "body-field": 123 +}`) + assert.True(t, sc.Match(context.Background(), req)) + + k := "body-field" + v := float64(123) + sc = NewSimpleCondition("/abc/s", WithJsonBodyFields(k, v)) + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithJsonBodyFields(k, v)) + req.Body(`{ + "body-field": abc +}`) + assert.False(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithJsonBodyFields("body-field", "abc")) + req.Body(`{ + "body-field": "abc" +}`) + assert.True(t, sc.Match(context.Background(), req)) +} + +func TestSimpleConditionMatch(t *testing.T) { + sc := NewSimpleCondition("/abc/s") + req := httplib.Post("http://localhost:8080/abc/s") + + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithMethod("POST")) + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithMethod("GET")) + assert.False(t, sc.Match(context.Background(), req)) +} + +func TestSimpleConditionMatchPathReg(t *testing.T) { + sc := NewSimpleCondition("", WithPathReg(`\/abc\/.*`)) + req := httplib.Post("http://localhost:8080/abc/s") + assert.True(t, sc.Match(context.Background(), req)) + + req = httplib.Post("http://localhost:8080/abcd/s") + assert.False(t, sc.Match(context.Background(), req)) +} diff --git a/client/httplib/mock/mock_filter.go b/client/httplib/mock/mock_filter.go new file mode 100644 index 00000000..225d65f3 --- /dev/null +++ b/client/httplib/mock/mock_filter.go @@ -0,0 +1,61 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "net/http" + + "github.com/beego/beego/v2/client/httplib" +) + +// MockResponse will return mock response if find any suitable mock data +// if you want to test your code using httplib, you need this. +type MockResponseFilter struct { + ms []*Mock +} + +func NewMockResponseFilter() *MockResponseFilter { + return &MockResponseFilter{ + ms: make([]*Mock, 0, 1), + } +} + +func (m *MockResponseFilter) FilterChain(next httplib.Filter) httplib.Filter { + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + ms := mockFromCtx(ctx) + ms = append(ms, m.ms...) + for _, mock := range ms { + if mock.cond.Match(ctx, req) { + return mock.resp, mock.err + } + } + return next(ctx, req) + } +} + +func (m *MockResponseFilter) MockByPath(path string, resp *http.Response, err error) { + m.Mock(NewSimpleCondition(path), resp, err) +} + +func (m *MockResponseFilter) Clear() { + m.ms = make([]*Mock, 0, 1) +} + +// Mock add mock data +// If the cond.Match(...) = true, the resp and err will be returned +func (m *MockResponseFilter) Mock(cond RequestCondition, resp *http.Response, err error) { + m.ms = append(m.ms, NewMock(cond, resp, err)) +} diff --git a/client/httplib/mock/mock_filter_test.go b/client/httplib/mock/mock_filter_test.go new file mode 100644 index 00000000..99c2e67f --- /dev/null +++ b/client/httplib/mock/mock_filter_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestMockResponseFilterFilterChain(t *testing.T) { + req := httplib.Get("http://localhost:8080/abc/s") + ft := NewMockResponseFilter() + + expectedResp := httplib.NewHttpResponseWithJsonBody(`{}`) + expectedErr := errors.New("expected error") + ft.Mock(NewSimpleCondition("/abc/s"), expectedResp, expectedErr) + + req.AddFilters(ft.FilterChain) + + resp, err := req.DoRequest() + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) + + req = httplib.Get("http://localhost:8080/abcd/s") + req.AddFilters(ft.FilterChain) + + resp, err = req.DoRequest() + assert.NotEqual(t, expectedErr, err) + assert.NotEqual(t, expectedResp, resp) + + req = httplib.Get("http://localhost:8080/abc/s") + req.AddFilters(ft.FilterChain) + expectedResp1 := httplib.NewHttpResponseWithJsonBody(map[string]string{}) + expectedErr1 := errors.New("expected error") + ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1) + + resp, err = req.DoRequest() + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) + + req = httplib.Get("http://localhost:8080/abc/abs/bbc") + req.AddFilters(ft.FilterChain) + ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1) + resp, err = req.DoRequest() + assert.Equal(t, expectedErr1, err) + assert.Equal(t, expectedResp1, resp) +} diff --git a/client/httplib/mock/mock_test.go b/client/httplib/mock/mock_test.go new file mode 100644 index 00000000..ce7322df --- /dev/null +++ b/client/httplib/mock/mock_test.go @@ -0,0 +1,75 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestStartMock(t *testing.T) { + // httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain} + + stub := StartMock() + // defer stub.Clear() + + expectedResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`)) + expectedErr := errors.New("expected err") + + stub.Mock(NewSimpleCondition("/abc"), expectedResp, expectedErr) + + resp, err := OriginalCodeUsingHttplib() + + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) +} + +// TestStartMock_Isolation Test StartMock that +// mock only work for this request +func TestStartMockIsolation(t *testing.T) { + // httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain} + // setup global stub + stub := StartMock() + globalMockResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`)) + globalMockErr := errors.New("expected err") + stub.Mock(NewSimpleCondition("/abc"), globalMockResp, globalMockErr) + + expectedResp := httplib.NewHttpResponseWithJsonBody(struct { + A string `json:"a"` + }{ + A: "aaa", + }) + expectedErr := errors.New("expected err aa") + m := NewMockByPath("/abc", expectedResp, expectedErr) + ctx := CtxWithMock(context.Background(), m) + + resp, err := OriginnalCodeUsingHttplibPassCtx(ctx) + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) +} + +func OriginnalCodeUsingHttplibPassCtx(ctx context.Context) (*http.Response, error) { + return httplib.Get("http://localhost:7777/abc").DoRequestWithCtx(ctx) +} + +func OriginalCodeUsingHttplib() (*http.Response, error) { + return httplib.Get("http://localhost:7777/abc").DoRequest() +} diff --git a/client/httplib/module.go b/client/httplib/module.go new file mode 100644 index 00000000..8503133c --- /dev/null +++ b/client/httplib/module.go @@ -0,0 +1,17 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +const moduleName = "httplib" diff --git a/client/httplib/setting.go b/client/httplib/setting.go new file mode 100644 index 00000000..3d8d195c --- /dev/null +++ b/client/httplib/setting.go @@ -0,0 +1,87 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "crypto/tls" + "net/http" + "net/http/cookiejar" + "net/url" + "sync" + "time" +) + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings struct { + UserAgent string + ConnectTimeout time.Duration + ReadWriteTimeout time.Duration + TLSClientConfig *tls.Config + Proxy func(*http.Request) (*url.URL, error) + Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error + EnableCookie bool + Gzip bool + Retries int // if set to -1 means will retry forever + RetryDelay time.Duration + FilterChains []FilterChain + EscapeHTML bool // if set to false means will not escape escape HTML special characters during processing, default true +} + +// createDefaultCookie creates a global cookiejar to store cookies. +func createDefaultCookie() { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultCookieJar, _ = cookiejar.New(nil) +} + +// SetDefaultSetting overwrites default settings +// Keep in mind that when you invoke the SetDefaultSetting +// some methods invoked before SetDefaultSetting +func SetDefaultSetting(setting BeegoHTTPSettings) { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultSetting = setting +} + +// GetDefaultSetting return current default setting +func GetDefaultSetting() BeegoHTTPSettings { + return defaultSetting +} + +var defaultSetting = BeegoHTTPSettings{ + UserAgent: "beegoServer", + ConnectTimeout: 60 * time.Second, + ReadWriteTimeout: 60 * time.Second, + Gzip: true, + FilterChains: make([]FilterChain, 0, 4), + EscapeHTML: true, +} + +var ( + defaultCookieJar http.CookieJar + settingMutex sync.Mutex +) + +// AddDefaultFilter add a new filter into defaultSetting +// Be careful about using this method if you invoke SetDefaultSetting somewhere +func AddDefaultFilter(fc FilterChain) { + settingMutex.Lock() + defer settingMutex.Unlock() + if defaultSetting.FilterChains == nil { + defaultSetting.FilterChains = make([]FilterChain, 0, 4) + } + defaultSetting.FilterChains = append(defaultSetting.FilterChains, fc) +} diff --git a/client/httplib/testing/client.go b/client/httplib/testing/client.go index 517e0722..43e2e968 100644 --- a/client/httplib/testing/client.go +++ b/client/httplib/testing/client.go @@ -18,8 +18,10 @@ import ( "github.com/beego/beego/v2/client/httplib" ) -var port = "" -var baseURL = "http://localhost:" +var ( + port = "" + baseURL = "http://localhost:" +) // TestHTTPRequest beego test request client type TestHTTPRequest struct { diff --git a/client/orm/README.md b/client/orm/README.md index 58669e1f..bb11d1c6 100644 --- a/client/orm/README.md +++ b/client/orm/README.md @@ -27,7 +27,7 @@ more features please read the docs **Install:** - go get github.com/beego/beego/v2/orm + go get github.com/beego/beego/v2/client/orm ## Changelog @@ -45,7 +45,7 @@ package main import ( "fmt" - "github.com/beego/beego/v2/orm" + "github.com/beego/beego/v2/client/orm" _ "github.com/go-sql-driver/mysql" // import your used driver ) diff --git a/client/orm/clauses/const.go b/client/orm/clauses/const.go new file mode 100644 index 00000000..38a6d556 --- /dev/null +++ b/client/orm/clauses/const.go @@ -0,0 +1,6 @@ +package clauses + +const ( + ExprSep = "__" + ExprDot = "." +) diff --git a/client/orm/clauses/order_clause/order.go b/client/orm/clauses/order_clause/order.go new file mode 100644 index 00000000..bdb2d1ca --- /dev/null +++ b/client/orm/clauses/order_clause/order.go @@ -0,0 +1,104 @@ +package order_clause + +import ( + "strings" + + "github.com/beego/beego/v2/client/orm/clauses" +) + +type Sort int8 + +const ( + None Sort = 0 + Ascending Sort = 1 + Descending Sort = 2 +) + +type Option func(order *Order) + +type Order struct { + column string + sort Sort + isRaw bool +} + +func Clause(options ...Option) *Order { + o := &Order{} + for _, option := range options { + option(o) + } + + return o +} + +func (o *Order) GetColumn() string { + return o.column +} + +func (o *Order) GetSort() Sort { + return o.sort +} + +func (o *Order) SortString() string { + switch o.GetSort() { + case Ascending: + return "ASC" + case Descending: + return "DESC" + } + + return `` +} + +func (o *Order) IsRaw() bool { + return o.isRaw +} + +func ParseOrder(expressions ...string) []*Order { + var orders []*Order + for _, expression := range expressions { + sort := Ascending + column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot) + if column[0] == '-' { + sort = Descending + column = column[1:] + } + + orders = append(orders, &Order{ + column: column, + sort: sort, + }) + } + + return orders +} + +func Column(column string) Option { + return func(order *Order) { + order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot) + } +} + +func sort(sort Sort) Option { + return func(order *Order) { + order.sort = sort + } +} + +func SortAscending() Option { + return sort(Ascending) +} + +func SortDescending() Option { + return sort(Descending) +} + +func SortNone() Option { + return sort(None) +} + +func Raw() Option { + return func(order *Order) { + order.isRaw = true + } +} diff --git a/client/orm/clauses/order_clause/order_test.go b/client/orm/clauses/order_clause/order_test.go new file mode 100644 index 00000000..9f658702 --- /dev/null +++ b/client/orm/clauses/order_clause/order_test.go @@ -0,0 +1,137 @@ +package order_clause + +import ( + "testing" +) + +func TestClause(t *testing.T) { + column := `a` + + o := Clause( + Column(column), + ) + + if o.GetColumn() != column { + t.Error() + } +} + +func TestSortAscending(t *testing.T) { + o := Clause( + SortAscending(), + ) + + if o.GetSort() != Ascending { + t.Error() + } +} + +func TestSortDescending(t *testing.T) { + o := Clause( + SortDescending(), + ) + + if o.GetSort() != Descending { + t.Error() + } +} + +func TestSortNone(t *testing.T) { + o1 := Clause( + SortNone(), + ) + + if o1.GetSort() != None { + t.Error() + } + + o2 := Clause() + + if o2.GetSort() != None { + t.Error() + } +} + +func TestRaw(t *testing.T) { + o1 := Clause() + + if o1.IsRaw() { + t.Error() + } + + o2 := Clause( + Raw(), + ) + + if !o2.IsRaw() { + t.Error() + } +} + +func TestColumn(t *testing.T) { + o1 := Clause( + Column(`aaa`), + ) + + if o1.GetColumn() != `aaa` { + t.Error() + } +} + +func TestParseOrder(t *testing.T) { + orders := ParseOrder( + `-user__status`, + `status`, + `user__status`, + ) + + t.Log(orders) + + if orders[0].GetSort() != Descending { + t.Error() + } + + if orders[0].GetColumn() != `user.status` { + t.Error() + } + + if orders[1].GetColumn() != `status` { + t.Error() + } + + if orders[1].GetSort() != Ascending { + t.Error() + } + + if orders[2].GetColumn() != `user.status` { + t.Error() + } +} + +func TestOrderGetColumn(t *testing.T) { + o := Clause( + Column(`user__id`), + ) + if o.GetColumn() != `user.id` { + t.Error() + } +} + +func TestSortString(t *testing.T) { + template := "got: %s, want: %s" + + o1 := Clause(sort(Sort(1))) + if o1.SortString() != "ASC" { + t.Errorf(template, o1.SortString(), "ASC") + } + + o2 := Clause(sort(Sort(2))) + if o2.SortString() != "DESC" { + t.Errorf(template, o2.SortString(), "DESC") + } + + o3 := Clause(sort(Sort(3))) + if o3.SortString() != `` { + t.Errorf(template, o3.SortString(), ``) + } +} diff --git a/client/orm/cmd.go b/client/orm/cmd.go index b0661971..432785d6 100644 --- a/client/orm/cmd.go +++ b/client/orm/cmd.go @@ -15,6 +15,7 @@ package orm import ( + "context" "flag" "fmt" "os" @@ -26,9 +27,7 @@ type commander interface { Run() error } -var ( - commands = make(map[string]commander) -) +var commands = make(map[string]commander) // print help. func printHelp(errs ...string) { @@ -141,6 +140,7 @@ func (d *commandSyncDb) Run() error { fmt.Printf(" %s\n", err.Error()) } + ctx := context.Background() for i, mi := range modelCache.allOrdered() { if !isApplicableTableForDB(mi.addrField, d.al.Name) { @@ -154,7 +154,7 @@ func (d *commandSyncDb) Run() error { } var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(db, mi.table) + columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table) if err != nil { if d.rtOnError { return err @@ -188,7 +188,7 @@ func (d *commandSyncDb) Run() error { } for _, idx := range indexes[mi.table] { - if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { + if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } diff --git a/client/orm/cmd_utils.go b/client/orm/cmd_utils.go index 8d6c0c33..7b795b22 100644 --- a/client/orm/cmd_utils.go +++ b/client/orm/cmd_utils.go @@ -126,9 +126,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string { // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands func getColumnDefault(fi *fieldInfo) string { - var ( - v, t, d string - ) + var v, t, d string // Skip default attribute if field is in relations if fi.rel || fi.reverse { diff --git a/client/orm/db.go b/client/orm/db.go index 4080f292..5a9f1b58 100644 --- a/client/orm/db.go +++ b/client/orm/db.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "errors" "fmt" @@ -31,41 +32,37 @@ const ( formatDateTime = "2006-01-02 15:04:05" ) -var ( - // ErrMissPK missing pk error - ErrMissPK = errors.New("missed pk value") -) +// ErrMissPK missing pk error +var ErrMissPK = errors.New("missed pk value") -var ( - operators = map[string]bool{ - "exact": true, - "iexact": true, - "strictexact": true, - "contains": true, - "icontains": true, - // "regex": true, - // "iregex": true, - "gt": true, - "gte": true, - "lt": true, - "lte": true, - "eq": true, - "nq": true, - "ne": true, - "startswith": true, - "endswith": true, - "istartswith": true, - "iendswith": true, - "in": true, - "between": true, - // "year": true, - // "month": true, - // "day": true, - // "week_day": true, - "isnull": true, - // "search": true, - } -) +var operators = map[string]bool{ + "exact": true, + "iexact": true, + "strictexact": true, + "contains": true, + "icontains": true, + // "regex": true, + // "iregex": true, + "gt": true, + "gte": true, + "lt": true, + "lte": true, + "eq": true, + "nq": true, + "ne": true, + "startswith": true, + "endswith": true, + "istartswith": true, + "iendswith": true, + "in": true, + "between": true, + // "year": true, + // "month": true, + // "day": true, + // "week_day": true, + "isnull": true, + // "search": true, +} // an instance of dbBaser interface/ type dbBase struct { @@ -268,7 +265,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } // create insert sql preparation statement object. -func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { +func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() dbcols := make([]string, 0, len(mi.fields.dbcols)) @@ -289,12 +286,12 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, d.ins.HasReturningID(mi, &query) - stmt, err := q.Prepare(query) + stmt, err := q.PrepareContext(ctx, query) return stmt, query, err } // insert struct with prepared statement and given struct reflect value. -func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) if err != nil { return 0, err @@ -306,7 +303,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, err := row.Scan(&id) return id, err } - res, err := stmt.Exec(values...) + res, err := stmt.ExecContext(ctx, values...) if err == nil { return res.LastInsertId() } @@ -314,7 +311,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } // query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { var whereCols []string var args []interface{} @@ -360,7 +357,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo d.ins.ReplaceMarks(&query) - row := q.QueryRow(query, args...) + row := q.QueryRowContext(ctx, query, args...) if err := row.Scan(refs...); err != nil { if err == sql.ErrNoRows { return ErrNoRows @@ -375,26 +372,26 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo } // execute insert sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { +func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { names := make([]string, 0, len(mi.fields.dbcols)) values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) if err != nil { return 0, err } - id, err := d.InsertValue(q, mi, false, names, values) + id, err := d.InsertValue(ctx, q, mi, false, names, values) if err != nil { return 0, err } if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return id, err } // multi-insert sql with given slice struct reflect.Value. -func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { +func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 nums int @@ -440,7 +437,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul } if i > 1 && i%bulk == 0 || length == i { - num, err := d.InsertValue(q, mi, true, names, values[:nums]) + num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums]) if err != nil { return cnt, err } @@ -451,7 +448,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul var err error if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return cnt, err @@ -459,7 +456,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -482,7 +479,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -498,7 +495,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err @@ -507,7 +504,7 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []s // InsertOrUpdate a row // If your primary key or unique column conflict will update // If no will insert -func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { args0 := "" iouStr := "" argsMap := map[string]string{} @@ -536,7 +533,6 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a names := make([]string, 0, len(mi.fields.dbcols)-1) Q := d.ins.TableQuote() values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) - if err != nil { return 0, err } @@ -590,7 +586,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -607,7 +603,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) if err != nil && err.Error() == `pq: syntax error at or near "ON"` { @@ -617,7 +613,7 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a } // execute update sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if !ok { return 0, ErrMissPK @@ -674,7 +670,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, setValues...) + res, err := q.ExecContext(ctx, query, setValues...) if err == nil { return res.RowsAffected() } @@ -683,7 +679,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // execute delete sql dbQuerier with given struct reflect.Value. // delete index is pk. -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { var whereCols []string var args []interface{} // if specify cols length > 0, then use it for where condition. @@ -712,21 +708,14 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, args...) + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if num > 0 { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) - } else { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) - } - } - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -738,7 +727,7 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. // update table-related record by querySet. // need querySet not struct reflect.Value to update related records. -func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { +func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { @@ -819,13 +808,7 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } d.ins.ReplaceMarks(&query) - var err error - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, values...) - } else { - res, err = q.Exec(query, values...) - } + res, err := q.ExecContext(ctx, query, values...) if err == nil { return res.RowsAffected() } @@ -834,13 +817,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con // delete related records. // do UpdateBanch or DeleteBanch by condition of tables' relationship. -func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { +func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { for _, fi := range mi.fields.fieldsReverse { fi = fi.reverseFieldInfo switch fi.onDelete { case odCascade: cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) + _, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz) if err != nil { return err } @@ -850,7 +833,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * if fi.onDelete == odSetDefault { params[fi.column] = fi.initial.String() } - _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) + _, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz) if err != nil { return err } @@ -861,7 +844,7 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz * } // delete table-related records. -func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { +func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true @@ -886,7 +869,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con d.ins.ReplaceMarks(&query) var rs *sql.Rows - r, err := q.Query(query, args...) + r, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -920,19 +903,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) d.ins.ReplaceMarks(&query) - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, args...) - } else { - res, err = q.Exec(query, args...) - } + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if num > 0 { - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -943,14 +921,14 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } // read related records. -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { - +func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) - errTyp := true + unregister := true one := true isPtr := true + name := "" if val.Kind() == reflect.Ptr { fn := "" @@ -963,19 +941,17 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi case reflect.Struct: isPtr = false fn = getFullName(typ) + name = getTableName(reflect.New(typ)) } } else { fn = getFullName(ind.Type()) + name = getTableName(ind) } - errTyp = fn != mi.fullName + unregister = fn != mi.fullName } - if errTyp { - if one { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) - } else { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) - } + if unregister { + RegisterModel(container) } rlimit := qs.limit @@ -1040,6 +1016,9 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi if qs.distinct { sqlSelect += " DISTINCT" } + if qs.aggregate != "" { + sels = qs.aggregate + } query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, specifyIndexes, join, where, groupBy, orderBy, limit) @@ -1050,18 +1029,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi d.ins.ReplaceMarks(&query) - var rs *sql.Rows - var err error - if qs != nil && qs.forContext { - rs, err = q.QueryContext(qs.ctx, query, args...) - if err != nil { - return 0, err - } - } else { - rs, err = q.Query(query, args...) - if err != nil { - return 0, err - } + rs, err := q.QueryContext(ctx, query, args...) + if err != nil { + return 0, err + } + + defer rs.Close() + + slice := ind + if unregister { + mi, _ = modelCache.get(name) + tCols = mi.fields.dbcols + colsNum = len(tCols) } refs := make([]interface{}, colsNum) @@ -1069,11 +1048,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi var ref interface{} refs[i] = &ref } - - defer rs.Close() - - slice := ind - var cnt int64 for rs.Next() { if one && cnt == 0 || !one { @@ -1172,7 +1146,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi } // excute count sql and return count result int64. -func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { +func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) @@ -1194,12 +1168,7 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition d.ins.ReplaceMarks(&query) - var row *sql.Row - if qs != nil && qs.forContext { - row = q.QueryRowContext(qs.ctx, query, args...) - } else { - row = q.QueryRow(query, args...) - } + row := q.QueryRowContext(ctx, query, args...) err = row.Scan(&cnt) return } @@ -1453,12 +1422,10 @@ end: } return value, nil - } // set one value to struct column field. func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { - fieldType := fi.fieldType isNative := !fi.isFielder @@ -1649,8 +1616,7 @@ setValue: } // query sql, read values , save to *[]ParamList. -func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { - +func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( maps []Params lists []ParamsList @@ -1732,7 +1698,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond d.ins.ReplaceMarks(&query) - rs, err := q.Query(query, args...) + rs, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -1847,7 +1813,7 @@ func (d *dbBase) HasReturningID(*modelInfo, *string) bool { } // sync auto key -func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { return nil } @@ -1892,10 +1858,10 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { } // get all cloumns in table. -func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { columns := make(map[string][3]string) query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return columns, err } @@ -1934,7 +1900,7 @@ func (d *dbBase) ShowColumnsQuery(table string) string { } // not implement. -func (d *dbBase) IndexExists(dbQuerier, string, string) bool { +func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool { panic(ErrNotImplement) } diff --git a/client/orm/db_alias.go b/client/orm/db_alias.go index 29e0904c..28c8ab8e 100644 --- a/client/orm/db_alias.go +++ b/client/orm/db_alias.go @@ -112,8 +112,10 @@ type DB struct { stmtDecoratorsLimit int } -var _ dbQuerier = new(DB) -var _ txer = new(DB) +var ( + _ dbQuerier = new(DB) + _ txer = new(DB) +) func (d *DB) Begin() (*sql.Tx, error) { return d.DB.Begin() @@ -221,8 +223,10 @@ type TxDB struct { tx *sql.Tx } -var _ dbQuerier = new(TxDB) -var _ txEnder = new(TxDB) +var ( + _ dbQuerier = new(TxDB) + _ txEnder = new(TxDB) +) func (t *TxDB) Commit() error { return t.tx.Commit() @@ -232,8 +236,18 @@ func (t *TxDB) Rollback() error { return t.tx.Rollback() } -var _ dbQuerier = new(TxDB) -var _ txEnder = new(TxDB) +func (t *TxDB) RollbackUnlessCommit() error { + err := t.tx.Rollback() + if err != sql.ErrTxDone { + return err + } + return nil +} + +var ( + _ dbQuerier = new(TxDB) + _ txEnder = new(TxDB) +) func (t *TxDB) Prepare(query string) (*sql.Stmt, error) { return t.PrepareContext(context.Background(), query) @@ -357,7 +371,6 @@ func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) } func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) { - al := &alias{} al.DB = &DB{ RWMutex: new(sync.RWMutex), @@ -414,7 +427,7 @@ func SetMaxIdleConns(aliasName string, maxIdleConns int) { // Deprecated you should not use this, we will remove it in the future func SetMaxOpenConns(aliasName string, maxOpenConns int) { al := getDbAlias(aliasName) - al.SetMaxIdleConns(maxOpenConns) + al.SetMaxOpenConns(maxOpenConns) } // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name diff --git a/client/orm/db_alias_test.go b/client/orm/db_alias_test.go index 6275cb2a..b5d53464 100644 --- a/client/orm/db_alias_test.go +++ b/client/orm/db_alias_test.go @@ -35,7 +35,7 @@ func TestRegisterDataBase(t *testing.T) { assert.Equal(t, al.ConnMaxLifetime, time.Minute) } -func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { +func TestRegisterDataBaseMaxStmtCacheSizeNegative1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) assert.Nil(t, err) @@ -45,7 +45,7 @@ func TestRegisterDataBase_MaxStmtCacheSizeNegative1(t *testing.T) { assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) } -func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { +func TestRegisterDataBaseMaxStmtCacheSize0(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize0" err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) assert.Nil(t, err) @@ -55,7 +55,7 @@ func TestRegisterDataBase_MaxStmtCacheSize0(t *testing.T) { assert.Equal(t, al.DB.stmtDecoratorsLimit, 0) } -func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { +func TestRegisterDataBaseMaxStmtCacheSize1(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize1" err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) assert.Nil(t, err) @@ -65,7 +65,7 @@ func TestRegisterDataBase_MaxStmtCacheSize1(t *testing.T) { assert.Equal(t, al.DB.stmtDecoratorsLimit, 1) } -func TestRegisterDataBase_MaxStmtCacheSize841(t *testing.T) { +func TestRegisterDataBaseMaxStmtCacheSize841(t *testing.T) { aliasName := "TestRegisterDataBase_MaxStmtCacheSize841" err := RegisterDataBase(aliasName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) assert.Nil(t, err) diff --git a/client/orm/db_mysql.go b/client/orm/db_mysql.go index ee68baf7..5b3333e0 100644 --- a/client/orm/db_mysql.go +++ b/client/orm/db_mysql.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "reflect" "strings" @@ -93,8 +94,8 @@ func (d *dbBaseMysql) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) @@ -105,7 +106,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool // If your primary key or unique column conflict will update // If no will insert // Add "`" for mysql sql building -func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { var iouStr string argsMap := map[string]string{} @@ -123,7 +124,6 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val names := make([]string, 0, len(mi.fields.dbcols)-1) Q := d.ins.TableQuote() values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) - if err != nil { return 0, err } @@ -161,7 +161,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -178,7 +178,7 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) return id, err diff --git a/client/orm/db_oracle.go b/client/orm/db_oracle.go index 1de440b6..a3b93ff3 100644 --- a/client/orm/db_oracle.go +++ b/client/orm/db_oracle.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strings" @@ -89,8 +90,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string { } // check index is exist -func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ +func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) @@ -124,7 +125,7 @@ func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, inde // execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -147,7 +148,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() @@ -163,7 +164,7 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err diff --git a/client/orm/db_postgres.go b/client/orm/db_postgres.go index 12431d6e..b2f321db 100644 --- a/client/orm/db_postgres.go +++ b/client/orm/db_postgres.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "strconv" ) @@ -140,7 +141,7 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { } // sync auto key -func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { if len(autoFields) == 0 { return nil } @@ -151,7 +152,7 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string mi.table, name, Q, name, Q, Q, mi.table, Q) - if _, err := db.Exec(query); err != nil { + if _, err := db.ExecContext(ctx, query); err != nil { return err } } @@ -174,9 +175,9 @@ func (d *dbBasePostgres) DbTypes() map[string]string { } // check index exist in postgresql. -func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) - row := db.QueryRow(query) + row := db.QueryRowContext(ctx, query) var cnt int row.Scan(&cnt) return cnt > 0 diff --git a/client/orm/db_sqlite.go b/client/orm/db_sqlite.go index aff713a5..6a4b3131 100644 --- a/client/orm/db_sqlite.go +++ b/client/orm/db_sqlite.go @@ -15,6 +15,7 @@ package orm import ( + "context" "database/sql" "fmt" "reflect" @@ -73,11 +74,11 @@ type dbBaseSqlite struct { var _ dbBaser = new(dbBaseSqlite) // override base db read for update behavior as SQlite does not support syntax -func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { if isForUpdate { DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") } - return d.dbBase.Read(q, mi, ind, tz, cols, false) + return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false) } // get sqlite operator. @@ -114,9 +115,9 @@ func (d *dbBaseSqlite) ShowTablesQuery() string { } // get columns in sqlite. -func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return nil, err } @@ -140,9 +141,9 @@ func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { } // check index exist in sqlite. -func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("PRAGMA index_list('%s')", table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { panic(err) } diff --git a/client/orm/db_tables.go b/client/orm/db_tables.go index 5fd472d1..a0b355ca 100644 --- a/client/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -18,6 +18,9 @@ import ( "fmt" "strings" "time" + + "github.com/beego/beego/v2/client/orm/clauses" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" ) // table info struct. @@ -103,7 +106,6 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related [] // parse related fields. func (t *dbTables) parseRelated(rels []string, depth int) { - relsNum := len(rels) related := make([]string, relsNum) copy(related, rels) @@ -421,7 +423,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { } // generate order sql. -func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { +func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) { if len(orders) == 0 { return } @@ -430,19 +432,25 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { orderSqls := make([]string, 0, len(orders)) for _, order := range orders { - asc := "ASC" - if order[0] == '-' { - asc = "DESC" - order = order[1:] - } - exprs := strings.Split(order, ExprSep) + column := order.GetColumn() + clause := strings.Split(column, clauses.ExprDot) - index, _, fi, suc := t.parseExprs(t.mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) - } + if order.IsRaw() { + if len(clause) == 2 { + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString())) + } else if len(clause) == 1 { + orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString())) + } else { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } + } else { + index, _, fi, suc := t.parseExprs(t.mi, clause) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString())) + } } orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) diff --git a/client/orm/db_tidb.go b/client/orm/db_tidb.go index 6020a488..48c5b4e7 100644 --- a/client/orm/db_tidb.go +++ b/client/orm/db_tidb.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" ) @@ -47,8 +48,8 @@ func (d *dbBaseTidb) ShowColumnsQuery(table string) string { } // execute sql to check index exist. -func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) diff --git a/client/orm/db_utils.go b/client/orm/db_utils.go index 8d86b54a..f2e1353c 100644 --- a/client/orm/db_utils.go +++ b/client/orm/db_utils.go @@ -55,7 +55,6 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac // get fields description as flatted string. func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { - outFor: for _, arg := range args { if arg == nil { diff --git a/client/orm/do_nothing_orm.go b/client/orm/do_nothing_orm.go index c6da420d..d9e574a5 100644 --- a/client/orm/do_nothing_orm.go +++ b/client/orm/do_nothing_orm.go @@ -27,8 +27,7 @@ import ( var _ Ormer = new(DoNothingOrm) -type DoNothingOrm struct { -} +type DoNothingOrm struct{} func (d *DoNothingOrm) Read(md interface{}, cols ...string) error { return nil @@ -66,6 +65,7 @@ func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { return nil } @@ -74,6 +74,7 @@ func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { return nil } +// NOTE: this method is deprecated, context parameter will not take effect. func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { return nil } diff --git a/client/orm/do_nothing_orm_test.go b/client/orm/do_nothing_orm_test.go index 4d477353..e10f70af 100644 --- a/client/orm/do_nothing_orm_test.go +++ b/client/orm/do_nothing_orm_test.go @@ -36,7 +36,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, o.Driver()) - assert.Nil(t, o.QueryM2MWithCtx(nil, nil, "")) assert.Nil(t, o.QueryM2M(nil, "")) assert.Nil(t, o.ReadWithCtx(nil, nil)) assert.Nil(t, o.Read(nil)) @@ -92,7 +91,6 @@ func TestDoNothingOrm(t *testing.T) { assert.Nil(t, err) assert.Equal(t, int64(0), i) - assert.Nil(t, o.QueryTableWithCtx(nil, nil)) assert.Nil(t, o.QueryTable(nil)) assert.Nil(t, o.Read(nil)) diff --git a/client/orm/filter/bean/default_value_filter.go b/client/orm/filter/bean/default_value_filter.go index b71903b3..5d1fc2a1 100644 --- a/client/orm/filter/bean/default_value_filter.go +++ b/client/orm/filter/bean/default_value_filter.go @@ -19,10 +19,9 @@ import ( "reflect" "strings" - "github.com/beego/beego/v2/core/logs" - "github.com/beego/beego/v2/client/orm" "github.com/beego/beego/v2/core/bean" + "github.com/beego/beego/v2/core/logs" ) // DefaultValueFilterChainBuilder only works for InsertXXX method, @@ -81,13 +80,10 @@ func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter switch inv.Method { case "Insert", "InsertWithCtx": d.handleInsert(ctx, inv) - break case "InsertOrUpdate", "InsertOrUpdateWithCtx": d.handleInsertOrUpdate(ctx, inv) - break case "InsertMulti", "InsertMultiWithCtx": d.handleInsertMulti(ctx, inv) - break } return next(ctx, inv) } diff --git a/client/orm/filter/bean/default_value_filter_test.go b/client/orm/filter/bean/default_value_filter_test.go index 871d5539..ec47688d 100644 --- a/client/orm/filter/bean/default_value_filter_test.go +++ b/client/orm/filter/bean/default_value_filter_test.go @@ -22,7 +22,7 @@ import ( "github.com/beego/beego/v2/client/orm" ) -func TestDefaultValueFilterChainBuilder_FilterChain(t *testing.T) { +func TestDefaultValueFilterChainBuilderFilterChain(t *testing.T) { builder := NewDefaultValueFilterChainBuilder(nil, true, true) o := orm.NewFilterOrmDecorator(&defaultValueTestOrm{}, builder.FilterChain) @@ -57,7 +57,6 @@ func TestDefaultValueFilterChainBuilder_FilterChain(t *testing.T) { _, _ = o.InsertMulti(3, []*DefaultValueTestEntity{entity}) assert.Equal(t, 12, entity.Age) assert.Equal(t, 13, entity.AgeInOldStyle) - } type defaultValueTestOrm struct { diff --git a/client/orm/filter/opentracing/filter.go b/client/orm/filter/opentracing/filter.go index 75852c63..7afa07f1 100644 --- a/client/orm/filter/opentracing/filter.go +++ b/client/orm/filter/opentracing/filter.go @@ -27,7 +27,7 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to trace QuerySetter -// actually we trace invoking "QueryTable" and "QueryTableWithCtx" +// actually we trace invoking "QueryTable" // the method Begin*, Commit and Rollback are ignored. // When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. type FilterChainBuilder struct { diff --git a/client/orm/filter/opentracing/filter_test.go b/client/orm/filter/opentracing/filter_test.go index d6ee1fd8..67de5c9e 100644 --- a/client/orm/filter/opentracing/filter_test.go +++ b/client/orm/filter/opentracing/filter_test.go @@ -24,7 +24,7 @@ import ( "github.com/beego/beego/v2/client/orm" ) -func TestFilterChainBuilder_FilterChain(t *testing.T) { +func TestFilterChainBuilderFilterChain(t *testing.T) { next := func(ctx context.Context, inv *orm.Invocation) []interface{} { inv.TxName = "Hello" return []interface{}{} diff --git a/client/orm/filter/prometheus/filter.go b/client/orm/filter/prometheus/filter.go index 7563f51e..f525cebc 100644 --- a/client/orm/filter/prometheus/filter.go +++ b/client/orm/filter/prometheus/filter.go @@ -32,18 +32,19 @@ import ( // this Filter's behavior looks a little bit strange // for example: // if we want to records the metrics of QuerySetter -// actually we only records metrics of invoking "QueryTable" and "QueryTableWithCtx" +// actually we only records metrics of invoking "QueryTable" type FilterChainBuilder struct { AppName string ServerName string RunMode string } -var summaryVec prometheus.ObserverVec -var initSummaryVec sync.Once +var ( + summaryVec prometheus.ObserverVec + initSummaryVec sync.Once +) func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { - initSummaryVec.Do(func() { summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ Name: "beego", @@ -85,7 +86,7 @@ func (builder *FilterChainBuilder) report(ctx context.Context, inv *orm.Invocati } func (builder *FilterChainBuilder) reportTxn(ctx context.Context, inv *orm.Invocation) { - dur := time.Now().Sub(inv.TxStartTime) / time.Millisecond + dur := time.Since(inv.TxStartTime) / time.Millisecond summaryVec.WithLabelValues(inv.Method, inv.TxName, strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur)) } diff --git a/client/orm/filter/prometheus/filter_test.go b/client/orm/filter/prometheus/filter_test.go index a25515a7..912ae073 100644 --- a/client/orm/filter/prometheus/filter_test.go +++ b/client/orm/filter/prometheus/filter_test.go @@ -24,7 +24,7 @@ import ( "github.com/beego/beego/v2/client/orm" ) -func TestFilterChainBuilder_FilterChain1(t *testing.T) { +func TestFilterChainBuilderFilterChain1(t *testing.T) { next := func(ctx context.Context, inv *orm.Invocation) []interface{} { inv.Method = "coming" return []interface{}{} @@ -58,5 +58,4 @@ func TestFilterChainBuilder_FilterChain1(t *testing.T) { inv.Method = "Update" builder.report(ctx, inv, time.Second) - } diff --git a/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go index a60390a1..edeaaade 100644 --- a/client/orm/filter_orm_decorator.go +++ b/client/orm/filter_orm_decorator.go @@ -20,6 +20,7 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/core/utils" ) @@ -27,8 +28,10 @@ const ( TxNameKey = "TxName" ) -var _ Ormer = new(filterOrmDecorator) -var _ TxOrmer = new(filterOrmDecorator) +var ( + _ Ormer = new(filterOrmDecorator) + _ TxOrmer = new(filterOrmDecorator) +) type filterOrmDecorator struct { ormer @@ -119,7 +122,6 @@ func (f *filterOrmDecorator) ReadOrCreate(md interface{}, col1 string, cols ...s } func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { - mi, _ := modelCache.getByMd(md) inv := &Invocation{ Method: "ReadOrCreateWithCtx", @@ -142,7 +144,6 @@ func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...ut } func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { - mi, _ := modelCache.getByMd(md) inv := &Invocation{ Method: "LoadRelatedWithCtx", @@ -161,36 +162,33 @@ func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interfac } func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { - return f.QueryM2MWithCtx(context.Background(), md, name) -} - -func (f *filterOrmDecorator) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { - mi, _ := modelCache.getByMd(md) inv := &Invocation{ - Method: "QueryM2MWithCtx", + Method: "QueryM2M", Args: []interface{}{md, name}, Md: md, mi: mi, InsideTx: f.insideTx, TxStartTime: f.txStartTime, f: func(c context.Context) []interface{} { - res := f.ormer.QueryM2MWithCtx(c, md, name) + res := f.ormer.QueryM2M(md, name) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil } return res[0].(QueryM2Mer) } -func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { - return f.QueryTableWithCtx(context.Background(), ptrStructOrTableName) +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.") + return f.QueryM2M(md, name) } -func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { +func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { var ( name string md interface{} @@ -209,18 +207,18 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT } inv := &Invocation{ - Method: "QueryTableWithCtx", + Method: "QueryTable", Args: []interface{}{ptrStructOrTableName}, InsideTx: f.insideTx, TxStartTime: f.txStartTime, Md: md, mi: mi, f: func(c context.Context) []interface{} { - res := f.ormer.QueryTableWithCtx(c, ptrStructOrTableName) + res := f.ormer.QueryTable(ptrStructOrTableName) return []interface{}{res} }, } - res := f.root(ctx, inv) + res := f.root(context.Background(), inv) if res[0] == nil { return nil @@ -228,6 +226,12 @@ func (f *filterOrmDecorator) QueryTableWithCtx(ctx context.Context, ptrStructOrT return res[0].(QuerySeter) } +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.") + return f.QueryTable(ptrStructOrTableName) +} + func (f *filterOrmDecorator) DBStats() *sql.DBStats { inv := &Invocation{ Method: "DBStats", @@ -458,7 +462,7 @@ func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.T TxStartTime: f.txStartTime, TxName: getTxNameFromCtx(ctx), f: func(c context.Context) []interface{} { - err := doTxTemplate(f, c, opts, task) + err := doTxTemplate(c, f, opts, task) return []interface{}{err} }, } @@ -498,7 +502,23 @@ func (f *filterOrmDecorator) Rollback() error { return f.convertError(res[0]) } -func (f *filterOrmDecorator) convertError(v interface{}) error { +func (f *filterOrmDecorator) RollbackUnlessCommit() error { + inv := &Invocation{ + Method: "RollbackUnlessCommit", + Args: []interface{}{}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: f.txName, + f: func(c context.Context) []interface{} { + err := f.TxCommitter.RollbackUnlessCommit() + return []interface{}{err} + }, + } + res := f.root(context.Background(), inv) + return f.convertError(res[0]) +} + +func (*filterOrmDecorator) convertError(v interface{}) error { if v == nil { return nil } diff --git a/client/orm/filter_orm_decorator_test.go b/client/orm/filter_orm_decorator_test.go index 9e223358..e1908041 100644 --- a/client/orm/filter_orm_decorator_test.go +++ b/client/orm/filter_orm_decorator_test.go @@ -21,13 +21,12 @@ import ( "sync" "testing" - "github.com/beego/beego/v2/core/utils" - "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/utils" ) -func TestFilterOrmDecorator_Read(t *testing.T) { - +func TestFilterOrmDecoratorRead(t *testing.T) { register() o := &filterMockOrm{} @@ -46,7 +45,7 @@ func TestFilterOrmDecorator_Read(t *testing.T) { assert.Equal(t, "read error", err.Error()) } -func TestFilterOrmDecorator_BeginTx(t *testing.T) { +func TestFilterOrmDecoratorBeginTx(t *testing.T) { register() o := &filterMockOrm{} @@ -96,7 +95,7 @@ func TestFilterOrmDecorator_BeginTx(t *testing.T) { assert.Equal(t, "rollback", err.Error()) } -func TestFilterOrmDecorator_DBStats(t *testing.T) { +func TestFilterOrmDecoratorDBStats(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { @@ -111,7 +110,7 @@ func TestFilterOrmDecorator_DBStats(t *testing.T) { assert.Equal(t, -1, res.MaxOpenConnections) } -func TestFilterOrmDecorator_Delete(t *testing.T) { +func TestFilterOrmDecoratorDelete(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { @@ -128,7 +127,7 @@ func TestFilterOrmDecorator_Delete(t *testing.T) { assert.Equal(t, int64(-2), res) } -func TestFilterOrmDecorator_DoTx(t *testing.T) { +func TestFilterOrmDecoratorDoTx(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { @@ -175,7 +174,7 @@ func TestFilterOrmDecorator_DoTx(t *testing.T) { assert.NotNil(t, err) } -func TestFilterOrmDecorator_Driver(t *testing.T) { +func TestFilterOrmDecoratorDriver(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { @@ -190,7 +189,7 @@ func TestFilterOrmDecorator_Driver(t *testing.T) { assert.Nil(t, res) } -func TestFilterOrmDecorator_Insert(t *testing.T) { +func TestFilterOrmDecoratorInsert(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { @@ -209,7 +208,7 @@ func TestFilterOrmDecorator_Insert(t *testing.T) { assert.Equal(t, int64(100), i) } -func TestFilterOrmDecorator_InsertMulti(t *testing.T) { +func TestFilterOrmDecoratorInsertMulti(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { @@ -229,7 +228,7 @@ func TestFilterOrmDecorator_InsertMulti(t *testing.T) { assert.Equal(t, int64(2), i) } -func TestFilterOrmDecorator_InsertOrUpdate(t *testing.T) { +func TestFilterOrmDecoratorInsertOrUpdate(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { @@ -247,7 +246,7 @@ func TestFilterOrmDecorator_InsertOrUpdate(t *testing.T) { assert.Equal(t, int64(1), i) } -func TestFilterOrmDecorator_LoadRelated(t *testing.T) { +func TestFilterOrmDecoratorLoadRelated(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { @@ -264,11 +263,11 @@ func TestFilterOrmDecorator_LoadRelated(t *testing.T) { assert.Equal(t, int64(99), i) } -func TestFilterOrmDecorator_QueryM2M(t *testing.T) { +func TestFilterOrmDecoratorQueryM2M(t *testing.T) { o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryM2MWithCtx", inv.Method) + assert.Equal(t, "QueryM2M", inv.Method) assert.Equal(t, 2, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) @@ -279,12 +278,12 @@ func TestFilterOrmDecorator_QueryM2M(t *testing.T) { assert.Nil(t, res) } -func TestFilterOrmDecorator_QueryTable(t *testing.T) { +func TestFilterOrmDecoratorQueryTable(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { return func(ctx context.Context, inv *Invocation) []interface{} { - assert.Equal(t, "QueryTableWithCtx", inv.Method) + assert.Equal(t, "QueryTable", inv.Method) assert.Equal(t, 1, len(inv.Args)) assert.Equal(t, "FILTER_TEST", inv.GetTableName()) assert.False(t, inv.InsideTx) @@ -295,7 +294,7 @@ func TestFilterOrmDecorator_QueryTable(t *testing.T) { assert.Nil(t, res) } -func TestFilterOrmDecorator_Raw(t *testing.T) { +func TestFilterOrmDecoratorRaw(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { @@ -311,7 +310,7 @@ func TestFilterOrmDecorator_Raw(t *testing.T) { assert.Nil(t, res) } -func TestFilterOrmDecorator_ReadForUpdate(t *testing.T) { +func TestFilterOrmDecoratorReadForUpdate(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { @@ -328,7 +327,7 @@ func TestFilterOrmDecorator_ReadForUpdate(t *testing.T) { assert.Equal(t, "read for update error", err.Error()) } -func TestFilterOrmDecorator_ReadOrCreate(t *testing.T) { +func TestFilterOrmDecoratorReadOrCreate(t *testing.T) { register() o := &filterMockOrm{} od := NewFilterOrmDecorator(o, func(next Filter) Filter { @@ -402,6 +401,10 @@ func (f *filterMockOrm) Rollback() error { return errors.New("rollback") } +func (f *filterMockOrm) RollbackUnlessCommit() error { + return errors.New("rollback unless commit") +} + func (f *filterMockOrm) DBStats() *sql.DBStats { return &sql.DBStats{ MaxOpenConnections: -1, diff --git a/client/orm/hints/db_hints_test.go b/client/orm/hints/db_hints_test.go index 510f9f16..ab18ee81 100644 --- a/client/orm/hints/db_hints_test.go +++ b/client/orm/hints/db_hints_test.go @@ -21,7 +21,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNewHint_time(t *testing.T) { +func TestNewHintTime(t *testing.T) { key := "qweqwe" value := time.Second hint := NewHint(key, value) @@ -30,7 +30,7 @@ func TestNewHint_time(t *testing.T) { assert.Equal(t, hint.GetValue(), value) } -func TestNewHint_int(t *testing.T) { +func TestNewHintInt(t *testing.T) { key := "qweqwe" value := 281230 hint := NewHint(key, value) @@ -39,7 +39,7 @@ func TestNewHint_int(t *testing.T) { assert.Equal(t, hint.GetValue(), value) } -func TestNewHint_float(t *testing.T) { +func TestNewHintFloat(t *testing.T) { key := "qweqwe" value := 21.2459753 hint := NewHint(key, value) @@ -55,7 +55,7 @@ func TestForceIndex(t *testing.T) { assert.Equal(t, hint.GetKey(), KeyForceIndex) } -func TestForceIndex_0(t *testing.T) { +func TestForceIndex0(t *testing.T) { var s []string hint := ForceIndex(s...) assert.Equal(t, hint.GetValue(), s) @@ -69,7 +69,7 @@ func TestIgnoreIndex(t *testing.T) { assert.Equal(t, hint.GetKey(), KeyIgnoreIndex) } -func TestIgnoreIndex_0(t *testing.T) { +func TestIgnoreIndex0(t *testing.T) { var s []string hint := IgnoreIndex(s...) assert.Equal(t, hint.GetValue(), s) @@ -83,7 +83,7 @@ func TestUseIndex(t *testing.T) { assert.Equal(t, hint.GetKey(), KeyUseIndex) } -func TestUseIndex_0(t *testing.T) { +func TestUseIndex0(t *testing.T) { var s []string hint := UseIndex(s...) assert.Equal(t, hint.GetValue(), s) diff --git a/client/orm/migration/ddl.go b/client/orm/migration/ddl.go index ec6dc2e7..ab452b49 100644 --- a/client/orm/migration/ddl.go +++ b/client/orm/migration/ddl.go @@ -116,7 +116,6 @@ func (m *Migration) UniCol(uni, name string) *Column { // ForeignCol creates a new foreign column and returns the instance of column func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { - foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} foreign.Name = colname m.AddForeign(foreign) @@ -153,7 +152,6 @@ func (c *Column) SetAuto(inc bool) *Column { func (c *Column) SetNullable(null bool) *Column { if null { c.Null = "" - } else { c.Null = "NOT NULL" } @@ -184,7 +182,6 @@ func (c *Column) SetDataType(dataType string) *Column { func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { if null { c.OldNull = "" - } else { c.OldNull = "NOT NULL" } @@ -219,7 +216,6 @@ func (c *Column) SetPrimary(m *Migration) *Column { // AddColumnsToUnique adds the columns to Unique Struct func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { - unique.Columns = append(unique.Columns, columns...) return unique @@ -227,7 +223,6 @@ func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { // AddColumns adds columns to m struct func (m *Migration) AddColumns(columns ...*Column) *Migration { - m.Columns = append(m.Columns, columns...) return m diff --git a/client/orm/migration/migration.go b/client/orm/migration/migration.go index 86d6f590..dda7737d 100644 --- a/client/orm/migration/migration.go +++ b/client/orm/migration/migration.go @@ -72,9 +72,7 @@ type Migration struct { RemoveForeigns []*Foreign } -var ( - migrationMap map[string]Migrationer -) +var migrationMap map[string]Migrationer func init() { migrationMap = make(map[string]Migrationer) @@ -82,7 +80,6 @@ func init() { // Up implement in the Inheritance struct for upgrade func (m *Migration) Up() { - switch m.ModifyType { case "reverse": m.ModifyType = "alter" @@ -94,7 +91,6 @@ func (m *Migration) Up() { // Down implement in the Inheritance struct for down func (m *Migration) Down() { - switch m.ModifyType { case "alter": m.ModifyType = "reverse" @@ -311,6 +307,7 @@ func isRollBack(name string) bool { } return false } + func getAllMigrations() (map[string]string, error) { o := orm.NewOrm() var maps []orm.Params diff --git a/client/orm/mock/condition.go b/client/orm/mock/condition.go new file mode 100644 index 00000000..eda88824 --- /dev/null +++ b/client/orm/mock/condition.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +type Mock struct { + cond Condition + resp []interface{} + cb func(inv *orm.Invocation) +} + +func NewMock(cond Condition, resp []interface{}, cb func(inv *orm.Invocation)) *Mock { + return &Mock{ + cond: cond, + resp: resp, + cb: cb, + } +} + +type Condition interface { + Match(ctx context.Context, inv *orm.Invocation) bool +} + +type SimpleCondition struct { + tableName string + method string +} + +func NewSimpleCondition(tableName string, methodName string) Condition { + return &SimpleCondition{ + tableName: tableName, + method: methodName, + } +} + +func (s *SimpleCondition) Match(ctx context.Context, inv *orm.Invocation) bool { + res := true + if len(s.tableName) != 0 { + res = res && (s.tableName == inv.GetTableName()) + } + + if len(s.method) != 0 { + res = res && (s.method == inv.Method) + } + return res +} diff --git a/client/orm/mock/condition_test.go b/client/orm/mock/condition_test.go new file mode 100644 index 00000000..7f646e70 --- /dev/null +++ b/client/orm/mock/condition_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestSimpleCondition_Match(t *testing.T) { + cond := NewSimpleCondition("", "") + res := cond.Match(context.Background(), &orm.Invocation{}) + assert.True(t, res) + cond = NewSimpleCondition("hello", "") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{})) + + cond = NewSimpleCondition("", "A") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{ + Method: "B", + })) + + assert.True(t, cond.Match(context.Background(), &orm.Invocation{ + Method: "A", + })) +} diff --git a/client/orm/mock/context.go b/client/orm/mock/context.go new file mode 100644 index 00000000..ca251c5d --- /dev/null +++ b/client/orm/mock/context.go @@ -0,0 +1,40 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/core/logs" +) + +type mockCtxKeyType string + +const mockCtxKey = mockCtxKeyType("beego-orm-mock") + +func CtxWithMock(ctx context.Context, mock ...*Mock) context.Context { + return context.WithValue(ctx, mockCtxKey, mock) +} + +func mockFromCtx(ctx context.Context) []*Mock { + ms := ctx.Value(mockCtxKey) + if ms != nil { + if res, ok := ms.([]*Mock); ok { + return res + } + logs.Error("mockCtxKey found in context, but value is not type []*Mock") + } + return nil +} diff --git a/client/orm/mock/context_test.go b/client/orm/mock/context_test.go new file mode 100644 index 00000000..a3ed1e90 --- /dev/null +++ b/client/orm/mock/context_test.go @@ -0,0 +1,29 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCtx(t *testing.T) { + ms := make([]*Mock, 0, 4) + ctx := CtxWithMock(context.Background(), ms...) + res := mockFromCtx(ctx) + assert.Equal(t, ms, res) +} diff --git a/client/orm/mock/mock.go b/client/orm/mock/mock.go new file mode 100644 index 00000000..7aca914c --- /dev/null +++ b/client/orm/mock/mock.go @@ -0,0 +1,71 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +var stub = newOrmStub() + +func init() { + orm.AddGlobalFilterChain(stub.FilterChain) +} + +type Stub interface { + Mock(m *Mock) + Clear() +} + +type OrmStub struct { + ms []*Mock +} + +func StartMock() Stub { + return stub +} + +func newOrmStub() *OrmStub { + return &OrmStub{ + ms: make([]*Mock, 0, 4), + } +} + +func (o *OrmStub) Mock(m *Mock) { + o.ms = append(o.ms, m) +} + +func (o *OrmStub) Clear() { + o.ms = make([]*Mock, 0, 4) +} + +func (o *OrmStub) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { + ms := mockFromCtx(ctx) + ms = append(ms, o.ms...) + + for _, mock := range ms { + if mock.cond.Match(ctx, inv) { + if mock.cb != nil { + mock.cb(inv) + } + return mock.resp + } + } + return next(ctx, inv) + } +} diff --git a/client/orm/mock/mock_orm.go b/client/orm/mock/mock_orm.go new file mode 100644 index 00000000..70bee4f7 --- /dev/null +++ b/client/orm/mock/mock_orm.go @@ -0,0 +1,167 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "database/sql" + "os" + "path/filepath" + + _ "github.com/mattn/go-sqlite3" + + "github.com/beego/beego/v2/client/orm" +) + +func init() { + RegisterMockDB("default") +} + +// RegisterMockDB create an "virtual DB" by using sqllite +// you should not +func RegisterMockDB(name string) { + source := filepath.Join(os.TempDir(), name+".db") + _ = orm.RegisterDataBase(name, "sqlite3", source) +} + +// MockTable only check table name +func MockTable(tableName string, resp ...interface{}) *Mock { + return NewMock(NewSimpleCondition(tableName, ""), resp, nil) +} + +// MockMethod only check method name +func MockMethod(method string, resp ...interface{}) *Mock { + return NewMock(NewSimpleCondition("", method), resp, nil) +} + +// MockOrmRead support orm.Read and orm.ReadWithCtx +// cb is used to mock read data from DB +func MockRead(tableName string, cb func(data interface{}), err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadWithCtx"), []interface{}{err}, func(inv *orm.Invocation) { + if cb != nil { + cb(inv.Args[0]) + } + }) +} + +// MockReadForUpdateWithCtx support ReadForUpdate and ReadForUpdateWithCtx +// cb is used to mock read data from DB +func MockReadForUpdateWithCtx(tableName string, cb func(data interface{}), err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadForUpdateWithCtx"), + []interface{}{err}, + func(inv *orm.Invocation) { + cb(inv.Args[0]) + }) +} + +// MockReadOrCreateWithCtx support ReadOrCreate and ReadOrCreateWithCtx +// cb is used to mock read data from DB +func MockReadOrCreateWithCtx(tableName string, + cb func(data interface{}), + insert bool, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadOrCreateWithCtx"), + []interface{}{insert, id, err}, + func(inv *orm.Invocation) { + cb(inv.Args[0]) + }) +} + +// MockInsertWithCtx support Insert and InsertWithCtx +func MockInsertWithCtx(tableName string, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertWithCtx"), []interface{}{id, err}, nil) +} + +// MockInsertMultiWithCtx support InsertMulti and InsertMultiWithCtx +func MockInsertMultiWithCtx(tableName string, cnt int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertMultiWithCtx"), []interface{}{cnt, err}, nil) +} + +// MockInsertOrUpdateWithCtx support InsertOrUpdate and InsertOrUpdateWithCtx +func MockInsertOrUpdateWithCtx(tableName string, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertOrUpdateWithCtx"), []interface{}{id, err}, nil) +} + +// MockUpdateWithCtx support UpdateWithCtx and Update +func MockUpdateWithCtx(tableName string, affectedRow int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "UpdateWithCtx"), []interface{}{affectedRow, err}, nil) +} + +// MockDeleteWithCtx support Delete and DeleteWithCtx +func MockDeleteWithCtx(tableName string, affectedRow int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "DeleteWithCtx"), []interface{}{affectedRow, err}, nil) +} + +// MockQueryM2MWithCtx support QueryM2MWithCtx and QueryM2M +// Now you may be need to use golang/mock to generate QueryM2M mock instance +// Or use DoNothingQueryM2Mer +// for example: +// post := Post{Id: 4} +// m2m := Ormer.QueryM2M(&post, "Tags") +// when you write test code: +// MockQueryM2MWithCtx("post", "Tags", mockM2Mer) +// "post" is the table name of model Post structure +// TODO provide orm.QueryM2Mer +func MockQueryM2MWithCtx(tableName string, name string, res orm.QueryM2Mer) *Mock { + return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{res}, nil) +} + +// MockLoadRelatedWithCtx support LoadRelatedWithCtx and LoadRelated +func MockLoadRelatedWithCtx(tableName string, name string, rows int64, err error) *Mock { + return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{rows, err}, nil) +} + +// MockQueryTableWithCtx support QueryTableWithCtx and QueryTable +func MockQueryTableWithCtx(tableName string, qs orm.QuerySeter) *Mock { + return NewMock(NewSimpleCondition(tableName, "QueryTable"), []interface{}{qs}, nil) +} + +// MockRawWithCtx support RawWithCtx and Raw +func MockRawWithCtx(rs orm.RawSeter) *Mock { + return NewMock(NewSimpleCondition("", "RawWithCtx"), []interface{}{rs}, nil) +} + +// MockDriver support Driver +// func MockDriver(driver orm.Driver) *Mock { +// return NewMock(NewSimpleCondition("", "Driver"), []interface{}{driver}) +// } + +// MockDBStats support DBStats +func MockDBStats(stats *sql.DBStats) *Mock { + return NewMock(NewSimpleCondition("", "DBStats"), []interface{}{stats}, nil) +} + +// MockBeginWithCtxAndOpts support Begin, BeginWithCtx, BeginWithOpts, BeginWithCtxAndOpts +// func MockBeginWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock { +// return NewMock(NewSimpleCondition("", "BeginWithCtxAndOpts"), []interface{}{txOrm, err}) +// } + +// MockDoTxWithCtxAndOpts support DoTx, DoTxWithCtx, DoTxWithOpts, DoTxWithCtxAndOpts +// func MockDoTxWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock { +// return MockBeginWithCtxAndOpts(txOrm, err) +// } + +// MockCommit support Commit +func MockCommit(err error) *Mock { + return NewMock(NewSimpleCondition("", "Commit"), []interface{}{err}, nil) +} + +// MockRollback support Rollback +func MockRollback(err error) *Mock { + return NewMock(NewSimpleCondition("", "Rollback"), []interface{}{err}, nil) +} + +// MockRollbackUnlessCommit support RollbackUnlessCommit +func MockRollbackUnlessCommit(err error) *Mock { + return NewMock(NewSimpleCondition("", "RollbackUnlessCommit"), []interface{}{err}, nil) +} diff --git a/client/orm/mock/mock_orm_test.go b/client/orm/mock/mock_orm_test.go new file mode 100644 index 00000000..29d9ea01 --- /dev/null +++ b/client/orm/mock/mock_orm_test.go @@ -0,0 +1,312 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "database/sql" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +const mockErrorMsg = "mock error" + +func init() { + orm.RegisterModel(&User{}) +} + +func TestMockDBStats(t *testing.T) { + s := StartMock() + defer s.Clear() + stats := &sql.DBStats{} + s.Mock(MockDBStats(stats)) + + o := orm.NewOrm() + + res := o.DBStats() + + assert.Equal(t, stats, res) +} + +func TestMockDeleteWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + s.Mock(MockDeleteWithCtx((&User{}).TableName(), 12, nil)) + o := orm.NewOrm() + rows, err := o.Delete(&User{}) + assert.Equal(t, int64(12), rows) + assert.Nil(t, err) +} + +func TestMockInsertOrUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + s.Mock(MockInsertOrUpdateWithCtx((&User{}).TableName(), 12, nil)) + o := orm.NewOrm() + id, err := o.InsertOrUpdate(&User{}) + assert.Equal(t, int64(12), id) + assert.Nil(t, err) +} + +func TestMockRead(t *testing.T) { + s := StartMock() + defer s.Clear() + err := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, err)) + o := orm.NewOrm() + u := &User{} + e := o.Read(u) + assert.Equal(t, err, e) + assert.Equal(t, "Tom", u.Name) +} + +func TestMockQueryM2MWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingQueryM2Mer{} + s.Mock(MockQueryM2MWithCtx((&User{}).TableName(), "Tags", mock)) + o := orm.NewOrm() + res := o.QueryM2M(&User{}, "Tags") + assert.Equal(t, mock, res) +} + +func TestMockQueryTableWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingQuerySetter{} + s.Mock(MockQueryTableWithCtx((&User{}).TableName(), mock)) + o := orm.NewOrm() + res := o.QueryTable(&User{}) + assert.Equal(t, mock, res) +} + +func TestMockTable(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockTable((&User{}).TableName(), mock)) + o := orm.NewOrm() + res := o.Read(&User{}) + assert.Equal(t, mock, res) +} + +func TestMockInsertMultiWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockInsertMultiWithCtx((&User{}).TableName(), 12, mock)) + o := orm.NewOrm() + res, err := o.InsertMulti(11, []interface{}{&User{}}) + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockInsertWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockInsertWithCtx((&User{}).TableName(), 13, mock)) + o := orm.NewOrm() + res, err := o.Insert(&User{}) + assert.Equal(t, int64(13), res) + assert.Equal(t, mock, err) +} + +func TestMockUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockUpdateWithCtx((&User{}).TableName(), 12, mock)) + o := orm.NewOrm() + res, err := o.Update(&User{}) + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockLoadRelatedWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockLoadRelatedWithCtx((&User{}).TableName(), "T", 12, mock)) + o := orm.NewOrm() + res, err := o.LoadRelated(&User{}, "T") + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockMethod(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockMethod("ReadWithCtx", mock)) + o := orm.NewOrm() + err := o.Read(&User{}) + assert.Equal(t, mock, err) +} + +func TestMockReadForUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockReadForUpdateWithCtx((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + o := orm.NewOrm() + u := &User{} + err := o.ReadForUpdate(u) + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestMockRawWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingRawSetter{} + s.Mock(MockRawWithCtx(mock)) + o := orm.NewOrm() + res := o.Raw("") + assert.Equal(t, mock, res) +} + +func TestMockReadOrCreateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockReadOrCreateWithCtx((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, false, 12, mock)) + o := orm.NewOrm() + u := &User{} + inserted, id, err := o.ReadOrCreate(u, "") + assert.Equal(t, mock, err) + assert.Equal(t, int64(12), id) + assert.False(t, inserted) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionClosure(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + u, err := originalTxUsingClosure() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionManually(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + u, err := originalTxManually() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionRollback(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), nil, errors.New("read error"))) + s.Mock(MockRollback(mock)) + _, err := originalTx() + assert.Equal(t, mock, err) +} + +func TestTransactionRollbackUnlessCommit(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRollbackUnlessCommit(mock)) + + // u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.RollbackUnlessCommit() + assert.Equal(t, mock, err) +} + +func TestTransactionCommit(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, nil)) + s.Mock(MockCommit(mock)) + u, err := originalTx() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func originalTx() (*User, error) { + u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.Read(u) + if err == nil { + err = txOrm.Commit() + return u, err + } else { + err = txOrm.Rollback() + return nil, err + } +} + +func originalTxManually() (*User, error) { + u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.Read(u) + _ = txOrm.Commit() + return u, err +} + +func originalTxUsingClosure() (*User, error) { + u := &User{} + var err error + o := orm.NewOrm() + _ = o.DoTx(func(ctx context.Context, txOrm orm.TxOrmer) error { + err = txOrm.Read(u) + return nil + }) + return u, err +} + +type User struct { + Id int + Name string +} + +func (u *User) TableName() string { + return "user" +} diff --git a/client/orm/mock/mock_queryM2Mer.go b/client/orm/mock/mock_queryM2Mer.go new file mode 100644 index 00000000..732e27b6 --- /dev/null +++ b/client/orm/mock/mock_queryM2Mer.go @@ -0,0 +1,88 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +// DoNothingQueryM2Mer do nothing +// use it to build mock orm.QueryM2Mer +type DoNothingQueryM2Mer struct{} + +func (d *DoNothingQueryM2Mer) AddWithCtx(ctx context.Context, i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) RemoveWithCtx(ctx context.Context, i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) ExistWithCtx(ctx context.Context, i interface{}) bool { + return true +} + +func (d *DoNothingQueryM2Mer) ClearWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) CountWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Add(i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Remove(i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Exist(i interface{}) bool { + return true +} + +func (d *DoNothingQueryM2Mer) Clear() (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Count() (int64, error) { + return 0, nil +} + +type QueryM2MerCondition struct { + tableName string + name string +} + +func NewQueryM2MerCondition(tableName string, name string) *QueryM2MerCondition { + return &QueryM2MerCondition{ + tableName: tableName, + name: name, + } +} + +func (q *QueryM2MerCondition) Match(ctx context.Context, inv *orm.Invocation) bool { + res := true + if len(q.tableName) > 0 { + res = res && (q.tableName == inv.GetTableName()) + } + if len(q.name) > 0 { + res = res && (len(inv.Args) > 1) && (q.name == inv.Args[1].(string)) + } + return res +} diff --git a/client/orm/mock/mock_queryM2Mer_test.go b/client/orm/mock/mock_queryM2Mer_test.go new file mode 100644 index 00000000..ef754092 --- /dev/null +++ b/client/orm/mock/mock_queryM2Mer_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestDoNothingQueryM2Mer(t *testing.T) { + m2m := &DoNothingQueryM2Mer{} + + i, err := m2m.Clear() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Count() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Add() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Remove() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + assert.True(t, m2m.Exist(nil)) +} + +func TestNewQueryM2MerCondition(t *testing.T) { + cond := NewQueryM2MerCondition("", "") + res := cond.Match(context.Background(), &orm.Invocation{}) + assert.True(t, res) + cond = NewQueryM2MerCondition("hello", "") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{})) + + cond = NewQueryM2MerCondition("", "A") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{ + Args: []interface{}{0, "B"}, + })) + + assert.True(t, cond.Match(context.Background(), &orm.Invocation{ + Args: []interface{}{0, "A"}, + })) +} diff --git a/client/orm/mock/mock_querySetter.go b/client/orm/mock/mock_querySetter.go new file mode 100644 index 00000000..5bbf9888 --- /dev/null +++ b/client/orm/mock/mock_querySetter.go @@ -0,0 +1,182 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" +) + +// DoNothingQuerySetter do nothing +// usually you use this to build your mock QuerySetter +type DoNothingQuerySetter struct{} + +func (d *DoNothingQuerySetter) OrderClauses(orders ...*order_clause.Order) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) CountWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ExistWithCtx(ctx context.Context) bool { + return true +} + +func (d *DoNothingQuerySetter) UpdateWithCtx(ctx context.Context, values orm.Params) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) DeleteWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) PrepareInsertWithCtx(ctx context.Context) (orm.Inserter, error) { + return nil, nil +} + +func (d *DoNothingQuerySetter) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingQuerySetter) ValuesWithCtx(ctx context.Context, results *[]orm.Params, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesListWithCtx(ctx context.Context, results *[]orm.ParamsList, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesFlatWithCtx(ctx context.Context, result *orm.ParamsList, expr string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Aggregate(s string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Filter(s string, i ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) FilterRaw(s string, s2 string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Exclude(s string, i ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) SetCond(condition *orm.Condition) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) GetCond() *orm.Condition { + return orm.NewCondition() +} + +func (d *DoNothingQuerySetter) Limit(limit interface{}, args ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Offset(offset interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) GroupBy(exprs ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) OrderBy(exprs ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) RelatedSel(params ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Distinct() orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) ForUpdate() orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Count() (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Exist() bool { + return true +} + +func (d *DoNothingQuerySetter) Update(values orm.Params) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Delete() (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) PrepareInsert() (orm.Inserter, error) { + return nil, nil +} + +func (d *DoNothingQuerySetter) All(container interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) One(container interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingQuerySetter) Values(results *[]orm.Params, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesList(results *[]orm.ParamsList, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesFlat(result *orm.ParamsList, expr string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return 0, nil +} diff --git a/client/orm/mock/mock_querySetter_test.go b/client/orm/mock/mock_querySetter_test.go new file mode 100644 index 00000000..09e5ad8c --- /dev/null +++ b/client/orm/mock/mock_querySetter_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingQuerySetter(t *testing.T) { + setter := &DoNothingQuerySetter{} + setter.GroupBy().Filter("").Limit(10). + Distinct().Exclude("a").FilterRaw("", ""). + ForceIndex().ForUpdate().IgnoreIndex(). + Offset(11).OrderBy().RelatedSel().SetCond(nil).UseIndex() + + assert.True(t, setter.Exist()) + err := setter.One(nil) + assert.Nil(t, err) + i, err := setter.Count() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Delete() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.All(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Update(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.RowsToMap(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.RowsToStruct(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Values(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.ValuesFlat(nil, "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.ValuesList(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + ins, err := setter.PrepareInsert() + assert.Nil(t, err) + assert.Nil(t, ins) + + assert.NotNil(t, setter.GetCond()) +} diff --git a/client/orm/mock/mock_rawSetter.go b/client/orm/mock/mock_rawSetter.go new file mode 100644 index 00000000..f4768cc8 --- /dev/null +++ b/client/orm/mock/mock_rawSetter.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "database/sql" + + "github.com/beego/beego/v2/client/orm" +) + +type DoNothingRawSetter struct{} + +func (d *DoNothingRawSetter) Exec() (sql.Result, error) { + return nil, nil +} + +func (d *DoNothingRawSetter) QueryRow(containers ...interface{}) error { + return nil +} + +func (d *DoNothingRawSetter) QueryRows(containers ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) SetArgs(i ...interface{}) orm.RawSeter { + return d +} + +func (d *DoNothingRawSetter) Values(container *[]orm.Params, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) ValuesList(container *[]orm.ParamsList, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) ValuesFlat(container *orm.ParamsList, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) Prepare() (orm.RawPreparer, error) { + return nil, nil +} diff --git a/client/orm/mock/mock_rawSetter_test.go b/client/orm/mock/mock_rawSetter_test.go new file mode 100644 index 00000000..dd98edbd --- /dev/null +++ b/client/orm/mock/mock_rawSetter_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingRawSetter(t *testing.T) { + rs := &DoNothingRawSetter{} + i, err := rs.ValuesList(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.Values(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.ValuesFlat(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.RowsToStruct(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.RowsToMap(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.QueryRows() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + err = rs.QueryRow() + // assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + s, err := rs.Exec() + assert.Nil(t, err) + assert.Nil(t, s) + + p, err := rs.Prepare() + assert.Nil(t, err) + assert.Nil(t, p) + + rrs := rs.SetArgs() + assert.Equal(t, rrs, rs) +} diff --git a/client/orm/mock/mock_test.go b/client/orm/mock/mock_test.go new file mode 100644 index 00000000..73bce4e5 --- /dev/null +++ b/client/orm/mock/mock_test.go @@ -0,0 +1,58 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestOrmStub_FilterChain(t *testing.T) { + os := newOrmStub() + inv := &orm.Invocation{ + Args: []interface{}{10}, + } + i := 1 + os.FilterChain(func(ctx context.Context, inv *orm.Invocation) []interface{} { + i++ + return nil + })(context.Background(), inv) + + assert.Equal(t, 2, i) + + m := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) { + arg := inv.Args[0] + j := arg.(int) + inv.Args[0] = j + 1 + }) + os.Mock(m) + + os.FilterChain(nil)(context.Background(), inv) + assert.Equal(t, 11, inv.Args[0]) + + inv.Args[0] = 10 + ctxMock := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) { + arg := inv.Args[0] + j := arg.(int) + inv.Args[0] = j + 3 + }) + + os.FilterChain(nil)(CtxWithMock(context.Background(), ctxMock), inv) + assert.Equal(t, 13, inv.Args[0]) +} diff --git a/client/orm/models.go b/client/orm/models.go index 64dfab09..c78bcf69 100644 --- a/client/orm/models.go +++ b/client/orm/models.go @@ -32,9 +32,7 @@ const ( defaultStructTagDelim = ";" ) -var ( - modelCache = NewModelCacheHandler() -) +var modelCache = NewModelCacheHandler() // model info collection type _modelCache struct { @@ -332,11 +330,6 @@ end: // register register models to model cache func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { - if mc.done { - err = fmt.Errorf("register must be run before BootStrap") - return - } - for _, model := range models { val := reflect.ValueOf(model) typ := reflect.Indirect(val).Type() @@ -352,7 +345,9 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m err = fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) return } - + if val.Elem().Kind() == reflect.Slice { + val = reflect.New(val.Elem().Type().Elem()) + } table := getTableName(val) if prefixOrSuffixStr != "" { @@ -371,8 +366,7 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m } if _, ok := mc.get(table); ok { - err = fmt.Errorf(" table name `%s` repeat register, must be unique\n", table) - return + return nil } mi := newModelInfo(val) @@ -389,12 +383,6 @@ func (mc *_modelCache) register(prefixOrSuffixStr string, prefixOrSuffix bool, m } } } - - if mi.fields.pk == nil { - err = fmt.Errorf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - return - } - } mi.table = table diff --git a/client/orm/models_info_f.go b/client/orm/models_info_f.go index c7ad4801..6a9e7a99 100644 --- a/client/orm/models_info_f.go +++ b/client/orm/models_info_f.go @@ -101,29 +101,30 @@ func newFields() *fields { // single field info type fieldInfo struct { - mi *modelInfo - fieldIndex []int - fieldType int dbcol bool // table column fk and onetoone inModel bool - name string - fullName string - column string - addrValue reflect.Value - sf reflect.StructField auto bool pk bool null bool index bool unique bool - colDefault bool // whether has default tag - initial StrTo // store the default value - size int + colDefault bool // whether has default tag toText bool autoNow bool autoNowAdd bool rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true reverse bool + isFielder bool // implement Fielder interface + mi *modelInfo + fieldIndex []int + fieldType int + name string + fullName string + column string + addrValue reflect.Value + sf reflect.StructField + initial StrTo // store the default value + size int reverseField string reverseFieldInfo *fieldInfo reverseFieldInfoTwo *fieldInfo @@ -134,7 +135,6 @@ type fieldInfo struct { relModelInfo *modelInfo digits int decimals int - isFielder bool // implement Fielder interface onDelete string description string timePrecision *int @@ -387,7 +387,6 @@ checkType: fi.timePrecision = &v } } - } if attrs["auto_now"] { diff --git a/client/orm/models_info_m.go b/client/orm/models_info_m.go index c9a979af..b94480ca 100644 --- a/client/orm/models_info_m.go +++ b/client/orm/models_info_m.go @@ -22,16 +22,16 @@ import ( // single model info type modelInfo struct { + manual bool + isThrough bool pkg string name string fullName string table string model interface{} fields *fields - manual bool addrField reflect.Value // store the original struct value uniques []string - isThrough bool } // new model info diff --git a/client/orm/models_test.go b/client/orm/models_test.go index 9b71d968..ea8a89fc 100644 --- a/client/orm/models_test.go +++ b/client/orm/models_test.go @@ -25,8 +25,6 @@ import ( _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - // As tidb can't use go get, so disable the tidb testing now - // _ "github.com/pingcap/tidb" ) // A slice string field. @@ -118,6 +116,10 @@ var _ Fielder = new(JSONFieldTest) type Data struct { ID int `orm:"column(id)"` Boolean bool + Byte byte + Int8 int8 + Uint8 uint8 + Rune rune Char string `orm:"size(50)"` Text string `orm:"type(text)"` JSON string `orm:"type(json);default({\"name\":\"json\"})"` @@ -125,26 +127,21 @@ type Data struct { Time time.Time `orm:"type(time)"` Date time.Time `orm:"type(date)"` DateTime time.Time `orm:"column(datetime)"` - Byte byte - Rune rune Int int - Int8 int8 + Uint uint Int16 int16 + Uint16 uint16 Int32 int32 Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 Uint32 uint32 - Uint64 uint64 Float32 float32 + Uint64 uint64 Float64 float64 Decimal float64 `orm:"digits(8);decimals(4)"` } type DataNull struct { ID int `orm:"column(id)"` - Boolean bool `orm:"null"` Char string `orm:"null;size(50)"` Text string `orm:"null;type(text)"` JSON string `orm:"type(json);null"` @@ -153,19 +150,20 @@ type DataNull struct { Date time.Time `orm:"null;type(date)"` DateTime time.Time `orm:"null;column(datetime)"` DateTimePrecision time.Time `orm:"null;type(datetime);precision(4)"` + Boolean bool `orm:"null"` Byte byte `orm:"null"` + Int8 int8 `orm:"null"` + Uint8 uint8 `orm:"null"` Rune rune `orm:"null"` Int int `orm:"null"` - Int8 int8 `orm:"null"` + Uint uint `orm:"null"` Int16 int16 `orm:"null"` + Uint16 uint16 `orm:"null"` Int32 int32 `orm:"null"` Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` Float32 float32 `orm:"null"` + Uint64 uint64 `orm:"null"` Float64 float64 `orm:"null"` Decimal float64 `orm:"digits(8);decimals(4);null"` NullString sql.NullString `orm:"null"` @@ -195,41 +193,43 @@ type DataNull struct { DateTimePtr *time.Time `orm:"null"` } -type String string -type Boolean bool -type Byte byte -type Rune rune -type Int int -type Int8 int8 -type Int16 int16 -type Int32 int32 -type Int64 int64 -type Uint uint -type Uint8 uint8 -type Uint16 uint16 -type Uint32 uint32 -type Uint64 uint64 -type Float32 float64 -type Float64 float64 +type ( + String string + Boolean bool + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float64 + Float64 float64 +) type DataCustom struct { ID int `orm:"column(id)"` Boolean Boolean + Byte Byte + Int8 Int8 + Uint8 Uint8 + Rune Rune Char string `orm:"size(50)"` Text string `orm:"type(text)"` - Byte Byte - Rune Rune Int Int - Int8 Int8 + Uint Uint Int16 Int16 + Uint16 Uint16 Int32 Int32 Int64 Int64 - Uint Uint - Uint8 Uint8 - Uint16 Uint16 Uint32 Uint32 - Uint64 Uint64 Float32 Float32 + Uint64 Uint64 Float64 Float64 Decimal Float64 `orm:"digits(8);decimals(4)"` } @@ -255,24 +255,40 @@ func NewTM() *TM { return obj } +type DeptInfo struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + +type UnregisterModel struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + type User struct { - ID int `orm:"column(id)"` - UserName string `orm:"size(30);unique"` - Email string `orm:"size(100)"` - Password string `orm:"size(100)"` - Status int16 `orm:"column(Status)"` - IsStaff bool - IsActive bool `orm:"default(true)"` - Created time.Time `orm:"auto_now_add;type(date)"` - Updated time.Time `orm:"auto_now"` - Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - ShouldSkip string `orm:"-"` - Nums int - Langs SliceStringField `orm:"size(100)"` - Extra JSONFieldTest `orm:"type(text)"` - unexport bool `orm:"-"` - unexportBool bool + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 `orm:"column(Status)"` + IsStaff bool + IsActive bool `orm:"default(true)"` + Unexported bool `orm:"-"` + UnexportedBool bool + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` + Nums int + Langs SliceStringField `orm:"size(100)"` + Extra JSONFieldTest `orm:"type(text)"` } func (u *User) TableIndex() [][]string { @@ -477,51 +493,49 @@ var ( dDbBaser dbBaser ) -var ( - helpinfo = `need driver and source! +var helpinfo = `need driver and source! Default DB Drivers. - + driver: url mysql: https://github.com/go-sql-driver/mysql sqlite3: https://github.com/mattn/go-sqlite3 postgres: https://github.com/lib/pq tidb: https://github.com/pingcap/tidb - + usage: - + go get -u github.com/beego/beego/v2/client/orm go get -u github.com/go-sql-driver/mysql go get -u github.com/mattn/go-sqlite3 go get -u github.com/lib/pq go get -u github.com/pingcap/tidb - + #### MySQL mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" go test -v github.com/beego/beego/v2/client/orm - - + + #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' go test -v github.com/beego/beego/v2/client/orm - - + + #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" go test -v github.com/beego/beego/v2/client/orm - + #### TiDB export ORM_DRIVER=tidb export ORM_SOURCE='memory://test/test' go test -v github.com/beego/beego/v2/pgk/orm - + ` -) func init() { // Debug, _ = StrTo(DBARGS.Debug).Bool() @@ -533,7 +547,6 @@ func init() { } err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) - if err != nil { panic(fmt.Sprintf("can not register database: %v", err)) } @@ -542,5 +555,4 @@ func init() { if alias.Driver == DRMySQL { alias.Engine = "INNODB" } - } diff --git a/client/orm/models_utils_test.go b/client/orm/models_utils_test.go index 0a6995b3..4dceda1c 100644 --- a/client/orm/models_utils_test.go +++ b/client/orm/models_utils_test.go @@ -29,7 +29,7 @@ func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool { return db == "default" } -func Test_IsApplicableTableForDB(t *testing.T) { +func TestIsApplicableTableForDB(t *testing.T) { assert.False(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa")) assert.True(t, isApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default")) } diff --git a/client/orm/orm.go b/client/orm/orm.go index 1adf84e2..30753acf 100644 --- a/client/orm/orm.go +++ b/client/orm/orm.go @@ -62,10 +62,10 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/hints" - "github.com/beego/beego/v2/core/utils" - "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/core/utils" ) // DebugQueries define the debug @@ -101,27 +101,43 @@ type ormBase struct { db dbQuerier } -var _ DQL = new(ormBase) -var _ DML = new(ormBase) -var _ DriverGetter = new(ormBase) +var ( + _ DQL = new(ormBase) + _ DML = new(ormBase) + _ DriverGetter = new(ormBase) +) // get model info and model reflect value -func (o *ormBase) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { +func (*ormBase) getMi(md interface{}) (mi *modelInfo) { + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + typ := ind.Type() + mi = getTypeMi(typ) + return +} + +// get need ptr model info and model reflect value +func (*ormBase) getPtrMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { val := reflect.ValueOf(md) ind = reflect.Indirect(val) typ := ind.Type() - if needPtr && val.Kind() != reflect.Ptr { + if val.Kind() != reflect.Ptr { panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) } - name := getFullName(typ) + mi = getTypeMi(typ) + return +} + +func getTypeMi(mdTyp reflect.Type) *modelInfo { + name := getFullName(mdTyp) if mi, ok := modelCache.getByFullName(name); ok { - return mi, ind + return mi } panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) } // get field info from model info by given field name -func (o *ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { +func (*ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { fi, ok := mi.fields.GetByAny(name) if !ok { panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) @@ -133,28 +149,31 @@ func (o *ormBase) getFieldInfo(mi *modelInfo, name string) *fieldInfo { func (o *ormBase) Read(md interface{}, cols ...string) error { return o.ReadWithCtx(context.Background(), md, cols...) } + func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + mi, ind := o.getPtrMiInd(md) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) } // read data to model, like Read(), but use "SELECT FOR UPDATE" form func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { return o.ReadForUpdateWithCtx(context.Background(), md, cols...) } + func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) + mi, ind := o.getPtrMiInd(md) + return o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, true) } // Try to read a row from the database, or insert one if it doesn't exist func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { return o.ReadOrCreateWithCtx(context.Background(), md, col1, cols...) } + func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { cols = append([]string{col1}, cols...) - mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) + mi, ind := o.getPtrMiInd(md) + err := o.alias.DbBaser.Read(ctx, o.db, mi, ind, o.alias.TZ, cols, false) if err == ErrNoRows { // Create id, err := o.InsertWithCtx(ctx, md) @@ -177,9 +196,10 @@ func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 func (o *ormBase) Insert(md interface{}) (int64, error) { return o.InsertWithCtx(context.Background(), md) } + func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { - mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + mi, ind := o.getPtrMiInd(md) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return id, err } @@ -190,7 +210,7 @@ func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, err } // set auto pk field -func (o *ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { +func (*ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { if mi.fields.pk.auto { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) @@ -204,6 +224,7 @@ func (o *ormBase) setPk(mi *modelInfo, ind reflect.Value, id int64) { func (o *ormBase) InsertMulti(bulk int, mds interface{}) (int64, error) { return o.InsertMultiWithCtx(context.Background(), bulk, mds) } + func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { var cnt int64 @@ -221,8 +242,8 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac if bulk <= 1 { for i := 0; i < sind.Len(); i++ { ind := reflect.Indirect(sind.Index(i)) - mi, _ := o.getMiInd(ind.Interface(), false) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) + mi := o.getMi(ind.Interface()) + id, err := o.alias.DbBaser.Insert(ctx, o.db, mi, ind, o.alias.TZ) if err != nil { return cnt, err } @@ -232,8 +253,8 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac cnt++ } } else { - mi, _ := o.getMiInd(sind.Index(0).Interface(), false) - return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) + mi := o.getMi(sind.Index(0).Interface()) + return o.alias.DbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.TZ) } return cnt, nil } @@ -242,9 +263,10 @@ func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interfac func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (int64, error) { return o.InsertOrUpdateWithCtx(context.Background(), md, colConflictAndArgs...) } + func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) + mi, ind := o.getPtrMiInd(md) + id, err := o.alias.DbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) if err != nil { return id, err } @@ -259,9 +281,10 @@ func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, col func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { return o.UpdateWithCtx(context.Background(), md, cols...) } + func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) + mi, ind := o.getPtrMiInd(md) + return o.alias.DbBaser.Update(ctx, o.db, mi, ind, o.alias.TZ, cols) } // delete model in database @@ -269,24 +292,16 @@ func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...str func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { return o.DeleteWithCtx(context.Background(), md, cols...) } + func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) - if err != nil { - return num, err - } - if num > 0 { - o.setPk(mi, ind, 0) - } - return num, nil + mi, ind := o.getPtrMiInd(md) + num, err := o.alias.DbBaser.Delete(ctx, o.db, mi, ind, o.alias.TZ, cols) + return num, err } // create a models to models queryer func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { - return o.QueryM2MWithCtx(context.Background(), md, name) -} -func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) fi := o.getFieldInfo(mi, name) switch { @@ -299,6 +314,12 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri return newQueryM2M(md, o, mi, fi, ind) } +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.") + return o.QueryM2M(md, name) +} + // load related models to md model. // args are limit, offset int and order string. // @@ -310,7 +331,8 @@ func (o *ormBase) QueryM2MWithCtx(ctx context.Context, md interface{}, name stri func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) { return o.LoadRelatedWithCtx(context.Background(), md, name, args...) } -func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { + +func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { _, fi, ind, qs := o.queryRelated(md, name) var relDepth int @@ -351,7 +373,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s qs.relDepth = relDepth if len(order) > 0 { - qs.orders = []string{order} + qs.orders = order_clause.ParseOrder(order) } find := ind.FieldByIndex(fi.fieldIndex) @@ -376,7 +398,7 @@ func (o *ormBase) LoadRelatedWithCtx(ctx context.Context, md interface{}, name s // get QuerySeter for related models to md model func (o *ormBase) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, *querySet) { - mi, ind := o.getMiInd(md, true) + mi, ind := o.getPtrMiInd(md) fi := o.getFieldInfo(mi, name) _, _, exist := getExistPk(mi, ind) @@ -451,9 +473,6 @@ func (o *ormBase) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { - return o.QueryTableWithCtx(context.Background(), ptrStructOrTableName) -} -func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { var name string if table, ok := ptrStructOrTableName.(string); ok { name = nameStrategyMap[defaultNameStrategy](table) @@ -469,14 +488,21 @@ func (o *ormBase) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName in if qs == nil { panic(fmt.Errorf(" table name: `%s` not exists", name)) } - return + return qs +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.") + return o.QueryTable(ptrStructOrTableName) } // return a raw query seter for raw sql string. func (o *ormBase) Raw(query string, args ...interface{}) RawSeter { return o.RawWithCtx(context.Background(), query, args...) } -func (o *ormBase) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { + +func (o *ormBase) RawWithCtx(_ context.Context, query string, args ...interface{}) RawSeter { return newRawSet(o, query, args) } @@ -525,6 +551,10 @@ func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxO }, } + if Debug { + _txOrm.db = newDbQueryLog(o.alias, _txOrm.db) + } + var taskTxOrm TxOrmer = _txOrm return taskTxOrm, nil } @@ -542,10 +572,10 @@ func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, t } func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { - return doTxTemplate(o, ctx, opts, task) + return doTxTemplate(ctx, o, opts, task) } -func doTxTemplate(o TxBeginner, ctx context.Context, opts *sql.TxOptions, +func doTxTemplate(ctx context.Context, o TxBeginner, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) if err != nil { @@ -565,7 +595,7 @@ func doTxTemplate(o TxBeginner, ctx context.Context, opts *sql.TxOptions, } } }() - var taskTxOrm = _txOrm + taskTxOrm := _txOrm err = task(ctx, taskTxOrm) panicked = false return err @@ -585,6 +615,10 @@ func (t *txOrm) Rollback() error { return t.db.(txEnder).Rollback() } +func (t *txOrm) RollbackUnlessCommit() error { + return t.db.(txEnder).RollbackUnlessCommit() +} + // NewOrm create new orm func NewOrm() Ormer { BootStrap() // execute only once @@ -595,9 +629,8 @@ func NewOrm() Ormer { func NewOrmUsingDB(aliasName string) Ormer { if al, ok := dataBaseCache.get(aliasName); ok { return newDBWithAlias(al) - } else { - panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } + panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) } // NewOrmWithDB create a new ormer object with specify *sql.DB for query diff --git a/client/orm/orm_conds.go b/client/orm/orm_conds.go index 5409406e..7d6eabe2 100644 --- a/client/orm/orm_conds.go +++ b/client/orm/orm_conds.go @@ -17,11 +17,13 @@ package orm import ( "fmt" "strings" + + "github.com/beego/beego/v2/client/orm/clauses" ) // ExprSep define the expression separation const ( - ExprSep = "__" + ExprSep = clauses.ExprSep ) type condValue struct { @@ -76,7 +78,6 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition { // AndCond combine a condition to current condition func (c *Condition) AndCond(cond *Condition) *Condition { - if c == cond { panic(fmt.Errorf(" cannot use self as sub cond")) } diff --git a/client/orm/orm_log.go b/client/orm/orm_log.go index 61addeb5..e6f8bc83 100644 --- a/client/orm/orm_log.go +++ b/client/orm/orm_log.go @@ -40,8 +40,8 @@ func NewLog(out io.Writer) *Log { } func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { - var logMap = make(map[string]interface{}) - sub := time.Now().Sub(t) / 1e5 + logMap := make(map[string]interface{}) + sub := time.Since(t) / 1e5 elsp := float64(int(sub)) / 10.0 logMap["cost_time"] = elsp flag := " OK" @@ -85,20 +85,32 @@ func (d *stmtQueryLog) Close() error { } func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { + return d.ExecContext(context.Background(), args...) +} + +func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { a := time.Now() - res, err := d.stmt.Exec(args...) + res, err := d.stmt.ExecContext(ctx, args...) debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) return res, err } func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { + return d.QueryContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { a := time.Now() - res, err := d.stmt.Query(args...) + res, err := d.stmt.QueryContext(ctx, args...) debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) return res, err } func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { + return d.QueryRowContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { a := time.Now() res := d.stmt.QueryRow(args...) debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) @@ -122,9 +134,11 @@ type dbQueryLog struct { txe txEnder } -var _ dbQuerier = new(dbQueryLog) -var _ txer = new(dbQueryLog) -var _ txEnder = new(dbQueryLog) +var ( + _ dbQuerier = new(dbQueryLog) + _ txer = new(dbQueryLog) + _ txEnder = new(dbQueryLog) +) func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { return d.PrepareContext(context.Background(), query) @@ -195,6 +209,13 @@ func (d *dbQueryLog) Rollback() error { return err } +func (d *dbQueryLog) RollbackUnlessCommit() error { + a := time.Now() + err := d.db.(txEnder).RollbackUnlessCommit() + debugLogQueies(d.alias, "tx.RollbackUnlessCommit", "ROLLBACK UNLESS COMMIT", a, err) + return err +} + func (d *dbQueryLog) SetDB(db dbQuerier) { d.db = db } diff --git a/client/orm/orm_object.go b/client/orm/orm_object.go index 6f9798d3..50c1ca41 100644 --- a/client/orm/orm_object.go +++ b/client/orm/orm_object.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" "reflect" ) @@ -31,6 +32,10 @@ var _ Inserter = new(insertSet) // insert model ignore it's registered or not. func (o *insertSet) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} + +func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { if o.closed { return 0, ErrStmtClosed } @@ -44,7 +49,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) { if name != o.mi.fullName { panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) } - id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) + id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) if err != nil { return id, err } @@ -70,11 +75,11 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(orm *ormBase, mi *modelInfo) (Inserter, error) { +func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm bi.mi = mi - st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi) if err != nil { return nil, err } diff --git a/client/orm/orm_querym2m.go b/client/orm/orm_querym2m.go index 17e1b5d1..9da49bba 100644 --- a/client/orm/orm_querym2m.go +++ b/client/orm/orm_querym2m.go @@ -14,7 +14,10 @@ package orm -import "reflect" +import ( + "context" + "reflect" +) // model to model struct type queryM2M struct { @@ -33,6 +36,10 @@ type queryM2M struct { // // make sure the relation is defined in post model struct tag. func (o *queryM2M) Add(mds ...interface{}) (int64, error) { + return o.AddWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi mi := fi.relThroughModelInfo mfi := fi.reverseFieldInfo @@ -96,11 +103,15 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } names = append(names, otherNames...) values = append(values, otherValues...) - return dbase.InsertValue(orm.db, mi, true, names, values) + return dbase.InsertValue(ctx, orm.db, mi, true, names, values) } // remove models following the origin model relationship func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { + return o.RemoveWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) @@ -109,21 +120,33 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { // check model is existed in relationship of origin model func (o *queryM2M) Exist(md interface{}) bool { + return o.ExistWithCtx(context.Background(), md) +} + +func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool { fi := o.fi return o.qs.Filter(fi.reverseFieldInfo.name, o.md). - Filter(fi.reverseFieldInfoTwo.name, md).Exist() + Filter(fi.reverseFieldInfoTwo.name, md).ExistWithCtx(ctx) } // clean all models in related of origin model func (o *queryM2M) Clear() (int64, error) { + return o.ClearWithCtx(context.Background()) +} + +func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).DeleteWithCtx(ctx) } // count all related models of origin model func (o *queryM2M) Count() (int64, error) { + return o.CountWithCtx(context.Background()) +} + +func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() + return o.qs.Filter(fi.reverseFieldInfo.name, o.md).CountWithCtx(ctx) } var _ QueryM2Mer = new(queryM2M) diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go index 177cfc3a..9f7b8441 100644 --- a/client/orm/orm_queryset.go +++ b/client/orm/orm_queryset.go @@ -18,6 +18,7 @@ import ( "context" "fmt" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/client/orm/hints" ) @@ -64,21 +65,20 @@ func ColValue(opt operator, value interface{}) interface{} { // real query struct type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []string - distinct bool - forUpdate bool - useIndex int - indexes []string - orm *ormBase - ctx context.Context - forContext bool + mi *modelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []*order_clause.Order + distinct bool + forUpdate bool + useIndex int + indexes []string + orm *ormBase + aggregate string } var _ QuerySeter = new(querySet) @@ -139,8 +139,20 @@ func (o querySet) GroupBy(exprs ...string) QuerySeter { // add ORDER expression. // "column" means ASC, "-column" means DESC. -func (o querySet) OrderBy(exprs ...string) QuerySeter { - o.orders = exprs +func (o querySet) OrderBy(expressions ...string) QuerySeter { + if len(expressions) <= 0 { + return &o + } + o.orders = order_clause.ParseOrder(expressions...) + return &o +} + +// add ORDER expression. +func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter { + if len(orders) <= 0 { + return &o + } + o.orders = orders return &o } @@ -210,23 +222,39 @@ func (o querySet) GetCond() *Condition { // return QuerySeter execution result number func (o *querySet) Count() (int64, error) { - return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.CountWithCtx(context.Background()) +} + +func (o *querySet) CountWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // check result empty or not after QuerySeter executed func (o *querySet) Exist() bool { - cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.ExistWithCtx(context.Background()) +} + +func (o *querySet) ExistWithCtx(ctx context.Context) bool { + cnt, _ := o.orm.alias.DbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return cnt > 0 } // execute update with parameters func (o *querySet) Update(values Params) (int64, error) { - return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) + return o.UpdateWithCtx(context.Background(), values) +} + +func (o *querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) { + return o.orm.alias.DbBaser.UpdateBatch(ctx, o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) } // execute delete func (o *querySet) Delete() (int64, error) { - return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) + return o.DeleteWithCtx(context.Background()) +} + +func (o *querySet) DeleteWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.DbBaser.DeleteBatch(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) } // return a insert queryer. @@ -235,20 +263,32 @@ func (o *querySet) Delete() (int64, error) { // i,err := sq.PrepareInsert() // i.Add(&user1{},&user2{}) func (o *querySet) PrepareInsert() (Inserter, error) { - return newInsertSet(o.orm, o.mi) + return o.PrepareInsertWithCtx(context.Background()) +} + +func (o *querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) { + return newInsertSet(ctx, o.orm, o.mi) } // query all data and map to containers. // cols means the columns when querying. func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + return o.AllWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) } // query one row data and map to containers. // cols means the columns when querying. func (o *querySet) One(container interface{}, cols ...string) error { + return o.OneWithCtx(context.Background(), container, cols...) +} + +func (o *querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { o.limit = 1 - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) + num, err := o.orm.alias.DbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) if err != nil { return err } @@ -266,19 +306,31 @@ func (o *querySet) One(container interface{}, cols ...string) error { // expres means condition expression. // it converts data to []map[column]value. func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + return o.ValuesWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to [][]interface // it converts data to [][column_index]value func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) + return o.ValuesListWithCtx(context.Background(), results, exprs...) +} + +func (o *querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) } // query all data and map to []interface. // it's designed for one row record set, auto change to []value, not [][column]value. func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) + return o.ValuesFlatWithCtx(context.Background(), result, expr) +} + +func (o *querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) { + return o.orm.alias.DbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) } // query all rows into map[string]interface with specify key and value column name. @@ -309,13 +361,6 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) panic(ErrNotImplement) } -// set context to QuerySeter. -func (o querySet) WithContext(ctx context.Context) QuerySeter { - o.ctx = ctx - o.forContext = true - return &o -} - // create new QuerySeter. func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o := new(querySet) @@ -323,3 +368,9 @@ func newQuerySet(orm *ormBase, mi *modelInfo) QuerySeter { o.orm = orm return o } + +// aggregate func +func (o querySet) Aggregate(s string) QuerySeter { + o.aggregate = s + return &o +} diff --git a/client/orm/orm_raw.go b/client/orm/orm_raw.go index e11e97fa..25452660 100644 --- a/client/orm/orm_raw.go +++ b/client/orm/orm_raw.go @@ -181,6 +181,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { if err == nil { ind.Set(reflect.ValueOf(t)) } + } else if len(str) >= 8 { + str = str[:8] + t, err := time.ParseInLocation(formatTime, str, DefaultTimeLoc) + if err == nil { + ind.Set(reflect.ValueOf(t)) + } } } case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: @@ -247,7 +253,6 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr } cur++ } - } else { value := reflect.ValueOf(refs[cur]).Elem().Interface() if isPtr && value == nil { @@ -431,7 +436,6 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { sInd.Set(nInd) } } - } else { return ErrNoRows } @@ -600,7 +604,6 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { } if cnt > 0 { - if structMode { sInds[0].Set(sInd) } else { diff --git a/client/orm/orm_test.go b/client/orm/orm_test.go index eb0d644f..764e5b6d 100644 --- a/client/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -31,9 +31,10 @@ import ( "testing" "time" - "github.com/beego/beego/v2/client/orm/hints" - "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/hints" ) var _ = os.PathSeparator @@ -84,11 +85,7 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er } ok = is && ok || !is && !ok if !ok { - if is { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } else { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } + err = fmt.Errorf("expected: `%v`, get `%v`", b, a) } wrongArg: @@ -132,7 +129,8 @@ func getCaller(skip int) string { if cur == line { flag = ">>" } - code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) + ls := formatLines(string(lines[o+i])) + code := fmt.Sprintf(" %s %5d: %s", flag, cur, ls) if code != "" { codes = append(codes, code) } @@ -145,6 +143,10 @@ func getCaller(skip int) string { return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) } +func formatLines(s string) string { + return strings.ReplaceAll(s, "\t", " ") +} + // Deprecated: Using stretchr/testify/assert func throwFail(t *testing.T, err error, args ...interface{}) { if err != nil { @@ -161,6 +163,7 @@ func throwFail(t *testing.T, err error, args ...interface{}) { } } +// deprecated using assert.XXX func throwFailNow(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) @@ -190,8 +193,8 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(User)) RegisterModel(new(Profile)) RegisterModel(new(Post)) - RegisterModel(new(Tag)) RegisterModel(new(NullValue)) + RegisterModel(new(Tag)) RegisterModel(new(Comment)) RegisterModel(new(UserBig)) RegisterModel(new(PostTags)) @@ -206,6 +209,7 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(Index)) RegisterModel(new(StrPk)) RegisterModel(new(TM)) + RegisterModel(new(DeptInfo)) err := RunSyncdb("default", true, Debug) throwFail(t, err) @@ -213,7 +217,7 @@ func TestSyncDb(t *testing.T) { modelCache.clean() } -func TestRegisterModels(t *testing.T) { +func TestRegisterModels(_ *testing.T) { RegisterModel(new(Data), new(DataNull), new(DataCustom)) RegisterModel(new(User)) RegisterModel(new(Profile)) @@ -234,6 +238,7 @@ func TestRegisterModels(t *testing.T) { RegisterModel(new(Index)) RegisterModel(new(StrPk)) RegisterModel(new(TM)) + RegisterModel(new(DeptInfo)) BootStrap() @@ -245,10 +250,10 @@ func TestModelSyntax(t *testing.T) { user := &User{} ind := reflect.ValueOf(user).Elem() fn := getFullName(ind.Type()) - mi, ok := modelCache.getByFullName(fn) + _, ok := modelCache.getByFullName(fn) throwFail(t, AssertIs(ok, true)) - mi, ok = modelCache.get("user") + mi, ok := modelCache.get("user") throwFail(t, AssertIs(ok, true)) if ok { throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) @@ -310,7 +315,6 @@ func TestDataTypes(t *testing.T) { case "DateTime": case "Time": assert.True(t, vu.(time.Time).In(DefaultTimeLoc).Sub(value.(time.Time).In(DefaultTimeLoc)) <= time.Second) - break default: assert.Equal(t, value, vu) } @@ -335,6 +339,73 @@ func TestTM(t *testing.T) { throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC")) } +func TestUnregisterModel(t *testing.T) { + data := []*DeptInfo{ + { + DeptName: "A", + EmployeeName: "A1", + Salary: 1000, + }, + { + DeptName: "A", + EmployeeName: "A2", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B1", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B2", + Salary: 4000, + }, + { + DeptName: "B", + EmployeeName: "B3", + Salary: 3000, + }, + } + qs := dORM.QueryTable("dept_info") + i, _ := qs.PrepareInsert() + for _, d := range data { + _, err := i.Insert(d) + if err != nil { + throwFail(t, err) + } + } + + f := func() { + var res []UnregisterModel + n, err := dORM.QueryTable("dept_info").All(&res) + throwFail(t, err) + throwFail(t, AssertIs(n, 5)) + throwFail(t, AssertIs(res[0].EmployeeName, "A1")) + + type Sum struct { + DeptName string + Total int + } + var sun []Sum + qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun) + throwFail(t, AssertIs(sun[0].DeptName, "A")) + throwFail(t, AssertIs(sun[0].Total, 3000)) + + type Max struct { + DeptName string + Max float64 + } + var max []Max + qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max) + throwFail(t, AssertIs(max[1].DeptName, "B")) + throwFail(t, AssertIs(max[1].Max, 4000)) + } + for i := 0; i < 5; i++ { + f() + } +} + func TestNullDataTypes(t *testing.T) { d := DataNull{} @@ -777,7 +848,6 @@ The program—and web server—godoc processes Go source files to extract docume throwFailNow(t, AssertIs(nums, num)) } } - } func TestCustomField(t *testing.T) { @@ -818,10 +888,11 @@ func TestCustomField(t *testing.T) { func TestExpr(t *testing.T) { user := &User{} - qs := dORM.QueryTable(user) - qs = dORM.QueryTable((*User)(nil)) - qs = dORM.QueryTable("User") - qs = dORM.QueryTable("user") + var qs QuerySeter + assert.NotPanics(t, func() { qs = dORM.QueryTable(user) }) + assert.NotPanics(t, func() { qs = dORM.QueryTable((*User)(nil)) }) + assert.NotPanics(t, func() { qs = dORM.QueryTable("User") }) + assert.NotPanics(t, func() { qs = dORM.QueryTable("user") }) num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) @@ -1079,6 +1150,26 @@ func TestOrderBy(t *testing.T) { num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`profile__age`), + order_clause.SortDescending(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsMysql { + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`rand()`), + order_clause.Raw(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + } } func TestAll(t *testing.T) { @@ -1150,7 +1241,6 @@ func TestOne(t *testing.T) { err = qs.Filter("user_name", "nothing").One(&user) throwFail(t, AssertIs(err, ErrNoRows)) - } func TestValues(t *testing.T) { @@ -1165,6 +1255,19 @@ func TestValues(t *testing.T) { throwFail(t, AssertIs(maps[2]["Profile"], nil)) } + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column("Id"), + order_clause.SortAscending(), + ), + ).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") throwFail(t, err) throwFail(t, AssertIs(num, 3)) @@ -1187,8 +1290,8 @@ func TestValuesList(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 3)) if num == 3 { - throwFail(t, AssertIs(list[0][1], "slene")) - throwFail(t, AssertIs(list[2][9], nil)) + throwFail(t, AssertIs(list[0][1], "slene")) // username + throwFail(t, AssertIs(list[2][10], nil)) // profile } num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") @@ -1622,7 +1725,7 @@ func TestQueryM2M(t *testing.T) { throwFailNow(t, AssertIs(num, 1)) } -func TestQueryRelate(t *testing.T) { +func TestQueryRelate(_ *testing.T) { // post := &Post{Id: 2} // qs := dORM.QueryRelate(post, "Tags") @@ -1746,14 +1849,11 @@ func TestRawQueryRow(t *testing.T) { switch col { case "id": throwFail(t, AssertIs(id, 1)) - break - case "time": - case "date": - case "datetime": + case "time", "datetime": v = v.(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc) assert.True(t, v.(time.Time).Sub(value) <= time.Second) - break + case "date": default: throwFail(t, AssertIs(v, dataValues[col])) } @@ -1847,7 +1947,6 @@ func TestQueryRows(t *testing.T) { case "Date": case "DateTime": assert.True(t, vu.(time.Time).In(DefaultTimeLoc).Sub(value.(time.Time).In(DefaultTimeLoc)) <= time.Second) - break default: assert.Equal(t, value, vu) } @@ -1871,7 +1970,6 @@ func TestQueryRows(t *testing.T) { case "Date": case "DateTime": assert.True(t, vu.(time.Time).In(DefaultTimeLoc).Sub(value.(time.Time).In(DefaultTimeLoc)) <= time.Second) - break default: assert.Equal(t, value, vu) } @@ -1977,9 +2075,7 @@ func TestRawPrepare(t *testing.T) { err error pre RawPreparer ) - switch { - case IsMysql || IsSqlite: - + if IsMysql || IsSqlite { pre, err = dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() assert.Nil(t, err) if pre != nil { @@ -2014,9 +2110,7 @@ func TestRawPrepare(t *testing.T) { assert.Nil(t, err) assert.Equal(t, num, int64(3)) } - - case IsPostgres: - + } else if IsPostgres { pre, err = dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() assert.Nil(t, err) if pre != nil { @@ -2134,7 +2228,7 @@ func TestTransaction(t *testing.T) { to, err := o.Begin() throwFail(t, err) - var names = []string{"1", "2", "3"} + names := []string{"1", "2", "3"} var tag Tag tag.Name = names[0] @@ -2146,8 +2240,7 @@ func TestTransaction(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) - switch { - case IsMysql || IsSqlite: + if IsMysql || IsSqlite { res, err := to.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() throwFail(t, err) if err == nil { @@ -2158,27 +2251,62 @@ func TestTransaction(t *testing.T) { } err = to.Rollback() - throwFail(t, err) - + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name__in", names).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) + assert.Nil(t, err) + assert.Equal(t, int64(0), num) to, err = o.Begin() - throwFail(t, err) + assert.Nil(t, err) tag.Name = "commit" id, err = to.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) + assert.Nil(t, err) + assert.True(t, id > 0) - to.Commit() - throwFail(t, err) + err = to.Commit() + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name", "commit").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) + assert.Nil(t, err) + assert.Equal(t, int64(1), num) +} +func TestTxOrmRollbackUnlessCommit(t *testing.T) { + o := NewOrm() + var tag Tag + + // test not commited and call RollbackUnlessCommit + to, err := o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err := to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + num, err := o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(0), num) + + // test commit and call RollbackUnlessCommit + + to, err = o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err = to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + + err = to.Commit() + assert.Nil(t, err) + + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + + num, err = o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(1), num) } func TestTransactionIsolationLevel(t *testing.T) { @@ -2276,8 +2404,10 @@ func TestReadOrCreate(t *testing.T) { throwFail(t, AssertIs(nu.Status, u.Status)) throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) throwFail(t, AssertIs(nu.IsActive, u.IsActive)) + assert.True(t, u.ID > 0) dORM.Delete(u) + assert.True(t, u.ID > 0) } func TestInLine(t *testing.T) { @@ -2514,93 +2644,101 @@ func TestIgnoreCaseTag(t *testing.T) { func TestInsertOrUpdate(t *testing.T) { RegisterModel(new(User)) - user := User{UserName: "unique_username133", Status: 1, Password: "o"} - user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} - user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} + userName := "unique_username133" + column := "user_name" + user := User{UserName: userName, Status: 1, Password: "o"} + user1 := User{UserName: userName, Status: 2, Password: "o"} + user2 := User{UserName: userName, Status: 3, Password: "oo"} dORM.Insert(&user) - test := User{UserName: "unique_username133"} fmt.Println(dORM.Driver().Name()) if dORM.Driver().Name() == "sqlite3" { fmt.Println("sqlite3 is nonsupport") return } - // test1 - _, err := dORM.InsertOrUpdate(&user1, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user1.Status, test.Status)) - } - // test2 - _, err = dORM.InsertOrUpdate(&user2, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status, test.Status)) - throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) + + specs := []struct { + description string + user User + colConflitAndArgs []string + assertion func(expected User, actual User) + isPostgresCompatible bool + }{ + { + description: "test1", + user: user1, + colConflitAndArgs: []string{column}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs(expected.Status, actual.Status)) + }, + isPostgresCompatible: true, + }, + { + description: "test2", + user: user2, + colConflitAndArgs: []string{column}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs(expected.Status, actual.Status)) + throwFailNow(t, AssertIs(expected.Password, strings.TrimSpace(actual.Password))) + }, + isPostgresCompatible: true, + }, + { + description: "test3 +", + user: user2, + colConflitAndArgs: []string{column, "status=status+1"}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs(expected.Status+1, actual.Status)) + }, + isPostgresCompatible: false, + }, + { + description: "test4 -", + user: user2, + colConflitAndArgs: []string{column, "status=status-1"}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs((expected.Status+1)-1, actual.Status)) + }, + isPostgresCompatible: false, + }, + { + description: "test5 *", + user: user2, + colConflitAndArgs: []string{column, "status=status*3"}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs(((expected.Status+1)-1)*3, actual.Status)) + }, + isPostgresCompatible: false, + }, + { + description: "test6 /", + user: user2, + colConflitAndArgs: []string{column, "Status=Status/3"}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs((((expected.Status+1)-1)*3)/3, actual.Status)) + }, + isPostgresCompatible: false, + }, } - // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - if IsPostgres { - return - } - // test3 + - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) + for _, spec := range specs { + // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + if IsPostgres && !spec.isPostgresCompatible { + continue } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status+1, test.Status)) - } - // test4 - - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) + + _, err := dORM.InsertOrUpdate(&spec.user, spec.colConflitAndArgs...) + if err != nil { + fmt.Println(err) + if !(err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") { + throwFailNow(t, err) + } + continue } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) - } - // test5 * - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) - } - // test6 / - _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) + + test := User{UserName: userName} + err = dORM.Read(&test, column) + throwFailNow(t, AssertIs(err, nil)) + spec.assertion(spec.user, test) } } @@ -2634,7 +2772,9 @@ func TestStrPkInsert(t *testing.T) { if err != nil { fmt.Println(err) if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + return } else if err == ErrLastInsertIdUnavailable { + return } else { throwFailNow(t, err) } @@ -2720,10 +2860,83 @@ func TestCondition(t *testing.T) { hasCycle(p.cond) } } - return } hasCycle(cond) // cycleFlag was true,meaning use self as sub cond throwFail(t, AssertIs(!cycleFlag, true)) - return +} + +func TestContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + user := User{UserName: "slene"} + + err := dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, err) + + cancel() + err = dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, AssertIs(err, context.Canceled)) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + + qs := dORM.QueryTable(user) + _, err = qs.Filter("UserName", "slene").CountWithCtx(ctx) + throwFail(t, AssertIs(err, context.Canceled)) +} + +func TestDebugLog(t *testing.T) { + txCommitFn := func() { + o := NewOrm() + o.DoTx(func(ctx context.Context, txOrm TxOrmer) (txerr error) { + _, txerr = txOrm.QueryTable(&User{}).Count() + return + }) + } + + txRollbackFn := func() { + o := NewOrm() + o.DoTx(func(ctx context.Context, txOrm TxOrmer) (txerr error) { + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 3 + user.IsStaff = true + user.IsActive = true + + txOrm.Insert(user) + txerr = fmt.Errorf("mock error") + return + }) + } + + Debug = true + output1 := captureDebugLogOutput(txCommitFn) + + assert.Contains(t, output1, "START TRANSACTION") + assert.Contains(t, output1, "COMMIT") + + output2 := captureDebugLogOutput(txRollbackFn) + + assert.Contains(t, output2, "START TRANSACTION") + assert.Contains(t, output2, "ROLLBACK") + + Debug = false + output1 = captureDebugLogOutput(txCommitFn) + assert.EqualValues(t, output1, "") + + output2 = captureDebugLogOutput(txRollbackFn) + assert.EqualValues(t, output2, "") +} + +func captureDebugLogOutput(f func()) string { + var buf bytes.Buffer + DebugLog.SetOutput(&buf) + defer func() { + DebugLog.SetOutput(os.Stderr) + }() + f() + return buf.String() } diff --git a/client/orm/qb_postgres.go b/client/orm/qb_postgres.go index eec784df..d7f21692 100644 --- a/client/orm/qb_postgres.go +++ b/client/orm/qb_postgres.go @@ -21,7 +21,6 @@ func processingStr(str []string) string { // Select will join the fields func (qb *PostgresQueryBuilder) Select(fields ...string) QueryBuilder { - var str string n := len(fields) @@ -80,7 +79,6 @@ func (qb *PostgresQueryBuilder) RightJoin(table string) QueryBuilder { // On join with on cond func (qb *PostgresQueryBuilder) On(cond string) QueryBuilder { - var str string cond = strings.Replace(cond, " ", "", -1) slice := strings.Split(cond, "=") diff --git a/client/orm/types.go b/client/orm/types.go index cb735ac8..145f2897 100644 --- a/client/orm/types.go +++ b/client/orm/types.go @@ -20,6 +20,7 @@ import ( "reflect" "time" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" "github.com/beego/beego/v2/core/utils" ) @@ -109,8 +110,32 @@ type TxBeginner interface { } type TxCommitter interface { + txEnder +} + +// transaction beginner +type txer interface { + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// transaction ending +type txEnder interface { Commit() error Rollback() error + + // RollbackUnlessCommit if the transaction has been committed, do nothing, or transaction will be rollback + // For example: + // ```go + // txOrm := orm.Begin() + // defer txOrm.RollbackUnlessCommit() + // err := txOrm.Insert() // do something + // if err != nil { + // return err + // } + // txOrm.Commit() + // ``` + RollbackUnlessCommit() error } // Data Manipulation Language @@ -196,12 +221,16 @@ type DQL interface { // post := Post{Id: 4} // m2m := Ormer.QueryM2M(&post, "Tags") QueryM2M(md interface{}, name string) QueryM2Mer + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer // return a QuerySeter for table operations. // table name can be string or struct. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), QueryTable(ptrStructOrTableName interface{}) QuerySeter + // NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter DBStats() *sql.DBStats @@ -217,19 +246,25 @@ type ormer interface { DriverGetter } -type Ormer interface { +// QueryExecutor wrapping for ormer +type QueryExecutor interface { ormer +} + +type Ormer interface { + QueryExecutor TxBeginner } type TxOrmer interface { - ormer + QueryExecutor TxCommitter } // Inserter insert prepared statement type Inserter interface { Insert(interface{}) (int64, error) + InsertWithCtx(context.Context, interface{}) (int64, error) Close() error } @@ -289,6 +324,28 @@ type QuerySeter interface { // for example: // qs.OrderBy("-status") OrderBy(exprs ...string) QuerySeter + // add ORDER expression by order clauses + // for example: + // OrderClauses( + // order_clause.Clause( + // order.Column("Id"), + // order.SortAscending(), + // ), + // order_clause.Clause( + // order.Column("status"), + // order.SortDescending(), + // ), + // ) + // OrderClauses(order_clause.Clause( + // order_clause.Column(`user__status`), + // order_clause.SortDescending(),//default None + // )) + // OrderClauses(order_clause.Clause( + // order_clause.Column(`random()`), + // order_clause.SortNone(),//default None + // order_clause.Raw(),//default false.if true, do not check field is valid or not + // )) + OrderClauses(orders ...*order_clause.Order) QuerySeter // add FORCE INDEX expression. // for example: // qs.ForceIndex(`idx_name1`,`idx_name2`) @@ -327,9 +384,11 @@ type QuerySeter interface { // for example: // num, err = qs.Filter("profile__age__gt", 28).Count() Count() (int64, error) + CountWithCtx(context.Context) (int64, error) // check result empty or not after QuerySeter executed // the same as QuerySeter.Count > 0 Exist() bool + ExistWithCtx(context.Context) bool // execute update with parameters // for example: // num, err = qs.Filter("user_name", "slene").Update(Params{ @@ -339,11 +398,13 @@ type QuerySeter interface { // "user_name": "slene2" // }) // user slene's name will change to slene2 Update(values Params) (int64, error) + UpdateWithCtx(ctx context.Context, values Params) (int64, error) // delete from table // for example: // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() // //delete two user who's name is testing1 or testing2 Delete() (int64, error) + DeleteWithCtx(context.Context) (int64, error) // return a insert queryer. // it can be used in times. // example: @@ -352,18 +413,21 @@ type QuerySeter interface { // num, err = i.Insert(&user2) // user table will add one record user2 at once // err = i.Close() //don't forget call Close PrepareInsert() (Inserter, error) + PrepareInsertWithCtx(context.Context) (Inserter, error) // query all data and map to containers. // cols means the columns when querying. // for example: // var users []*User // qs.All(&users) // users[0],users[1],users[2] ... All(container interface{}, cols ...string) (int64, error) + AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) // query one row data and map to containers. // cols means the columns when querying. // for example: // var user User // qs.One(&user) //user.UserName == "slene" One(container interface{}, cols ...string) error + OneWithCtx(ctx context.Context, container interface{}, cols ...string) error // query all data and map to []map[string]interface. // expres means condition expression. // it converts data to []map[column]value. @@ -371,18 +435,21 @@ type QuerySeter interface { // var maps []Params // qs.Values(&maps) //maps[0]["UserName"]=="slene" Values(results *[]Params, exprs ...string) (int64, error) + ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) // query all data and map to [][]interface // it converts data to [][column_index]value // for example: // var list []ParamsList // qs.ValuesList(&list) // list[0][1] == "slene" ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) // query all data and map to []interface. // it's designed for one column record set, auto change to []value, not [][column]value. // for example: // var list ParamsList // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" ValuesFlat(result *ParamsList, expr string) (int64, error) + ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) // query all rows into map[string]interface with specify key and value column name. // keyCol = "name", valueCol = "value" // table data @@ -405,6 +472,15 @@ type QuerySeter interface { // Found int // } RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + // aggregate func. + // for example: + // type result struct { + // DeptName string + // Total int + // } + // var res []result + // o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res) + Aggregate(s string) QuerySeter } // QueryM2Mer model to model query struct @@ -422,18 +498,23 @@ type QueryM2Mer interface { // insert one or more rows to m2m table // make sure the relation is defined in post model struct tag. Add(...interface{}) (int64, error) + AddWithCtx(context.Context, ...interface{}) (int64, error) // remove models following the origin model relationship // only delete rows from m2m table // for example: // tag3 := &Tag{Id:5,Name: "TestTag3"} // num, err = m2m.Remove(tag3) Remove(...interface{}) (int64, error) + RemoveWithCtx(context.Context, ...interface{}) (int64, error) // check model is existed in relationship of origin model Exist(interface{}) bool + ExistWithCtx(context.Context, interface{}) bool // clean all models in related of origin model Clear() (int64, error) + ClearWithCtx(context.Context) (int64, error) // count all related models of origin model Count() (int64, error) + CountWithCtx(context.Context) (int64, error) } // RawPreparer raw query statement @@ -507,11 +588,11 @@ type RawSeter interface { type stmtQuerier interface { Close() error Exec(args ...interface{}) (sql.Result, error) - // ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) Query(args ...interface{}) (*sql.Rows, error) - // QueryContext(args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) QueryRow(args ...interface{}) *sql.Row - // QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row } // db querier @@ -534,42 +615,30 @@ type dbQuerier interface { // QueryRow(query string, args ...interface{}) *sql.Row // } -// transaction beginner -type txer interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -// transaction ending -type txEnder interface { - Commit() error - Rollback() error -} - // base database struct type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + Read(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error + ReadBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + Count(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + ReadValues(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) - InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) - InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + Insert(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(context.Context, dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) + InsertMulti(context.Context, dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(context.Context, dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(context.Context, stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) + Update(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + UpdateBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) + Delete(context.Context, dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) + DeleteBatch(context.Context, dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) SupportUpdateJoin() bool OperatorSQL(string) string GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) GenerateOperatorLeftCol(*fieldInfo, string, *string) - PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) + PrepareInsert(context.Context, dbQuerier, *modelInfo) (stmtQuerier, string, error) MaxLimit() uint64 TableQuote() string ReplaceMarks(*string) @@ -578,12 +647,12 @@ type dbBaser interface { TimeToDB(*time.Time, *time.Location) DbTypes() map[string]string GetTables(dbQuerier) (map[string]bool, error) - GetColumns(dbQuerier, string) (map[string][3]string, error) + GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error) ShowTablesQuery() string ShowColumnsQuery(string) string - IndexExists(dbQuerier, string, string) bool + IndexExists(context.Context, dbQuerier, string, string) bool collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) - setval(dbQuerier, *modelInfo, []string) error + setval(context.Context, dbQuerier, *modelInfo, []string) error GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string } diff --git a/client/orm/utils.go b/client/orm/utils.go index d6c0a8e8..8d05c080 100644 --- a/client/orm/utils.go +++ b/client/orm/utils.go @@ -228,7 +228,7 @@ func snakeStringWithAcronym(s string) string { } data = append(data, d) } - return strings.ToLower(string(data[:])) + return strings.ToLower(string(data)) } // snake string, XxYy to xx_yy , XxYY to xx_y_y @@ -246,7 +246,7 @@ func snakeString(s string) string { } data = append(data, d) } - return strings.ToLower(string(data[:])) + return strings.ToLower(string(data)) } // SetNameStrategy set different name strategy @@ -274,7 +274,7 @@ func camelString(s string) string { } data = append(data, d) } - return string(data[:]) + return string(data) } type argString []string diff --git a/core/admin/profile.go b/core/admin/profile.go index 5b3fdb21..bd9eff4d 100644 --- a/core/admin/profile.go +++ b/core/admin/profile.go @@ -29,8 +29,10 @@ import ( "github.com/beego/beego/v2/core/utils" ) -var startTime = time.Now() -var pid int +var ( + startTime = time.Now() + pid int +) func init() { pid = os.Getpid() @@ -105,10 +107,9 @@ func PrintGCSummary(w io.Writer) { } func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { - if gcstats.NumGC > 0 { lastPause := gcstats.Pause[0] - elapsed := time.Now().Sub(startTime) + elapsed := time.Since(startTime) overhead := float64(gcstats.PauseTotal) / float64(elapsed) * 100 allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() @@ -125,7 +126,7 @@ func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { utils.ToShortTimeFormat(gcstats.PauseQuantiles[99])) } else { // while GC has disabled - elapsed := time.Now().Sub(startTime) + elapsed := time.Since(startTime) allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() fmt.Fprintf(w, "Alloc:%s Sys:%s Alloc(Rate):%s/s\n", diff --git a/core/bean/context.go b/core/bean/context.go index 7cee2c7e..78666437 100644 --- a/core/bean/context.go +++ b/core/bean/context.go @@ -16,5 +16,4 @@ package bean // ApplicationContext define for future // when we decide to support DI, IoC, this will be core API -type ApplicationContext interface { -} +type ApplicationContext interface{} diff --git a/core/bean/mock.go b/core/bean/mock.go new file mode 100644 index 00000000..0a8d3e29 --- /dev/null +++ b/core/bean/mock.go @@ -0,0 +1,142 @@ +package bean + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +// the mock object must be pointer of struct +// the element in mock object can be slices, structures, basic data types, pointers and interface +func Mock(v interface{}) (err error) { + pv := reflect.ValueOf(v) + // the input must be pointer of struct + if pv.Kind() != reflect.Ptr || pv.IsNil() { + err = fmt.Errorf("not a pointer of struct") + return + } + err = mock(pv) + return +} + +func mock(pv reflect.Value) (err error) { + pt := pv.Type() + for i := 0; i < pt.Elem().NumField(); i++ { + ptt := pt.Elem().Field(i) + pvv := pv.Elem().FieldByName(ptt.Name) + if !pvv.CanSet() { + continue + } + kt := ptt.Type.Kind() + tagValue := ptt.Tag.Get("mock") + switch kt { + case reflect.Map: + continue + case reflect.Interface: + if pvv.IsNil() { // when interface is nil,can not sure the type + continue + } + pvv.Set(reflect.New(pvv.Elem().Type().Elem())) + err = mock(pvv.Elem()) + case reflect.Ptr: + err = mockPtr(pvv, ptt.Type.Elem()) + case reflect.Struct: + err = mock(pvv.Addr()) + case reflect.Array, reflect.Slice: + err = mockSlice(tagValue, pvv) + case reflect.String: + pvv.SetString(tagValue) + case reflect.Bool: + err = mockBool(tagValue, pvv) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + value, e := strconv.ParseInt(tagValue, 10, 64) + if e != nil || pvv.OverflowInt(value) { + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + pvv.SetInt(value) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + value, e := strconv.ParseUint(tagValue, 10, 64) + if e != nil || pvv.OverflowUint(value) { + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + pvv.SetUint(value) + case reflect.Float32, reflect.Float64: + value, e := strconv.ParseFloat(tagValue, pvv.Type().Bits()) + if e != nil || pvv.OverflowFloat(value) { + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + pvv.SetFloat(value) + default: + } + if err != nil { + return + } + } + return +} + +// mock slice value +func mockSlice(tagValue string, pvv reflect.Value) (err error) { + if tagValue == "" { + return + } + sliceMetas := strings.Split(tagValue, ":") + if len(sliceMetas) != 2 || sliceMetas[0] != "length" { + err = fmt.Errorf("the value:%s is invalid", tagValue) + return + } + length, e := strconv.Atoi(sliceMetas[1]) + if e != nil { + return e + } + + sliceType := reflect.SliceOf(pvv.Type().Elem()) // get slice type + itemType := sliceType.Elem() // get the type of item in slice + value := reflect.MakeSlice(sliceType, 0, length) + newSliceValue := make([]reflect.Value, 0, length) + for k := 0; k < length; k++ { + itemValue := reflect.New(itemType).Elem() + // if item in slice is struct or pointer,must set zero value + switch itemType.Kind() { + case reflect.Struct: + err = mock(itemValue.Addr()) + case reflect.Ptr: + if itemValue.IsNil() { + itemValue.Set(reflect.New(itemType.Elem())) + if e := mock(itemValue); e != nil { + return e + } + } + } + newSliceValue = append(newSliceValue, itemValue) + if err != nil { + return + } + } + value = reflect.Append(value, newSliceValue...) + pvv.Set(value) + return +} + +// mock bool value +func mockBool(tagValue string, pvv reflect.Value) (err error) { + switch tagValue { + case "true": + pvv.SetBool(true) + case "false": + pvv.SetBool(false) + default: + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + return +} + +// mock pointer +func mockPtr(pvv reflect.Value, ptt reflect.Type) (err error) { + if pvv.IsNil() { + pvv.Set(reflect.New(ptt)) // must set nil value to zero value + } + err = mock(pvv) + return +} diff --git a/core/bean/mock_test.go b/core/bean/mock_test.go new file mode 100644 index 00000000..0fe81f43 --- /dev/null +++ b/core/bean/mock_test.go @@ -0,0 +1,74 @@ +package bean + +import ( + "fmt" + "testing" +) + +func TestMock(t *testing.T) { + type MockSubSubObject struct { + A int `mock:"20"` + } + type MockSubObjectAnoy struct { + Anoy int `mock:"20"` + } + type MockSubObject struct { + A bool `mock:"true"` + B MockSubSubObject + } + type MockObject struct { + A string `mock:"aaaaa"` + B int8 `mock:"10"` + C []*MockSubObject `mock:"length:2"` + D bool `mock:"true"` + E *MockSubObject + F []int `mock:"length:3"` + G InterfaceA + H InterfaceA + MockSubObjectAnoy + } + m := &MockObject{G: &ImplA{}} + err := Mock(m) + if err != nil { + t.Fatalf("mock failed: %v", err) + } + if m.A != "aaaaa" || m.B != 10 || m.C[1].B.A != 20 || + !m.E.A || m.E.B.A != 20 || !m.D || len(m.F) != 3 { + t.Fail() + } + _, ok := m.G.(*ImplA) + if !ok { + t.Fail() + } + _, ok = m.G.(*ImplB) + if ok { + t.Fail() + } + _, ok = m.H.(*ImplA) + if ok { + t.Fail() + } + if m.Anoy != 20 { + t.Fail() + } +} + +type InterfaceA interface { + Item() +} + +type ImplA struct { + A string `mock:"aaa"` +} + +func (i *ImplA) Item() { + fmt.Println("implA") +} + +type ImplB struct { + B string `mock:"bbb"` +} + +func (i *ImplB) Item() { + fmt.Println("implB") +} diff --git a/core/bean/tag_auto_wire_bean_factory_test.go b/core/bean/tag_auto_wire_bean_factory_test.go index bcdada67..b5744af7 100644 --- a/core/bean/tag_auto_wire_bean_factory_test.go +++ b/core/bean/tag_auto_wire_bean_factory_test.go @@ -51,24 +51,25 @@ func TestTagAutoWireBeanFactory_AutoWire(t *testing.T) { } type ComplicateStruct struct { - IntValue int `default:"12"` - StrValue string `default:"hello, strValue"` - Int8Value int8 `default:"8"` - Int16Value int16 `default:"16"` - Int32Value int32 `default:"32"` - Int64Value int64 `default:"64"` + BoolValue bool `default:"true"` + Int8Value int8 `default:"8"` + Uint8Value uint8 `default:"88"` - UintValue uint `default:"13"` - Uint8Value uint8 `default:"88"` + Int16Value int16 `default:"16"` Uint16Value uint16 `default:"1616"` + Int32Value int32 `default:"32"` Uint32Value uint32 `default:"3232"` + + IntValue int `default:"12"` + UintValue uint `default:"13"` + Int64Value int64 `default:"64"` Uint64Value uint64 `default:"6464"` + StrValue string `default:"hello, strValue"` + Float32Value float32 `default:"32.32"` Float64Value float64 `default:"64.64"` - BoolValue bool `default:"true"` - ignoreInt int `default:"11"` TimeValue time.Time `default:"2018-02-03 12:13:14.000"` diff --git a/core/berror/codes.go b/core/berror/codes.go new file mode 100644 index 00000000..b6712a84 --- /dev/null +++ b/core/berror/codes.go @@ -0,0 +1,86 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" + "sync" +) + +// A Code is an unsigned 32-bit error code as defined in the beego spec. +type Code interface { + Code() uint32 + Module() string + Desc() string + Name() string +} + +var defaultCodeRegistry = &codeRegistry{ + codes: make(map[uint32]*codeDefinition, 127), +} + +// DefineCode defining a new Code +// Before defining a new code, please read Beego specification. +// desc could be markdown doc +func DefineCode(code uint32, module string, name string, desc string) Code { + res := &codeDefinition{ + code: code, + module: module, + desc: desc, + } + defaultCodeRegistry.lock.Lock() + defer defaultCodeRegistry.lock.Unlock() + + if _, ok := defaultCodeRegistry.codes[code]; ok { + panic(fmt.Sprintf("duplicate code, code %d has been registered", code)) + } + defaultCodeRegistry.codes[code] = res + return res +} + +type codeRegistry struct { + lock sync.RWMutex + codes map[uint32]*codeDefinition +} + +func (cr *codeRegistry) Get(code uint32) (Code, bool) { + cr.lock.RLock() + defer cr.lock.RUnlock() + c, ok := cr.codes[code] + return c, ok +} + +type codeDefinition struct { + code uint32 + module string + desc string + name string +} + +func (c *codeDefinition) Name() string { + return c.name +} + +func (c *codeDefinition) Code() uint32 { + return c.code +} + +func (c *codeDefinition) Module() string { + return c.module +} + +func (c *codeDefinition) Desc() string { + return c.desc +} diff --git a/core/berror/error.go b/core/berror/error.go new file mode 100644 index 00000000..ca09798a --- /dev/null +++ b/core/berror/error.go @@ -0,0 +1,69 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +// code, msg +const errFmt = "ERROR-%d, %s" + +// Err returns an error representing c and msg. If c is OK, returns nil. +func Error(c Code, msg string) error { + return fmt.Errorf(errFmt, c.Code(), msg) +} + +// Errorf returns error +func Errorf(c Code, format string, a ...interface{}) error { + return Error(c, fmt.Sprintf(format, a...)) +} + +func Wrap(err error, c Code, msg string) error { + if err == nil { + return nil + } + return errors.Wrap(err, fmt.Sprintf(errFmt, c.Code(), msg)) +} + +func Wrapf(err error, c Code, format string, a ...interface{}) error { + return Wrap(err, c, fmt.Sprintf(format, a...)) +} + +// FromError is very simple. It just parse error msg and check whether code has been register +// if code not being register, return unknown +// if err.Error() is not valid beego error code, return unknown +func FromError(err error) (Code, bool) { + msg := err.Error() + codeSeg := strings.SplitN(msg, ",", 2) + if strings.HasPrefix(codeSeg[0], "ERROR-") { + codeStr := strings.SplitN(codeSeg[0], "-", 2) + if len(codeStr) < 2 { + return Unknown, false + } + codeInt, e := strconv.ParseUint(codeStr[1], 10, 32) + if e != nil { + return Unknown, false + } + if code, ok := defaultCodeRegistry.Get(uint32(codeInt)); ok { + return code, true + } + } + return Unknown, false +} diff --git a/core/berror/error_test.go b/core/berror/error_test.go new file mode 100644 index 00000000..7a14d933 --- /dev/null +++ b/core/berror/error_test.go @@ -0,0 +1,77 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +var testCode1 = DefineCode(1, "unit_test", "TestError", "Hello, test code1") + +var testErr = errors.New("hello, this is error") + +func TestErrorf(t *testing.T) { + msg := Errorf(testCode1, "errorf %s", "aaaa") + assert.NotNil(t, msg) + assert.Equal(t, "ERROR-1, errorf aaaa", msg.Error()) +} + +func TestWrapf(t *testing.T) { + err := Wrapf(testErr, testCode1, "Wrapf %s", "aaaa") + assert.NotNil(t, err) + assert.True(t, errors.Is(err, testErr)) +} + +func TestFromError(t *testing.T) { + err := errors.New("ERROR-1, errorf aaaa") + code, ok := FromError(err) + assert.True(t, ok) + assert.Equal(t, testCode1, code) + assert.Equal(t, "unit_test", code.Module()) + assert.Equal(t, "Hello, test code1", code.Desc()) + + err = errors.New("not beego error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-2, not register") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-aaa, invalid code") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("aaaaaaaaaaaaaa") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-2-3, invalid error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR, invalid error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) +} diff --git a/core/berror/pre_define_code.go b/core/berror/pre_define_code.go new file mode 100644 index 00000000..ff8eb46b --- /dev/null +++ b/core/berror/pre_define_code.go @@ -0,0 +1,52 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" +) + +// pre define code + +// Unknown indicates got some error which is not defined +var Unknown = DefineCode(5000001, "error", "Unknown", fmt.Sprintf(` +Unknown error code. Usually you will see this code in three cases: +1. You forget to define Code or function DefineCode not being executed; +2. This is not Beego's error but you call FromError(); +3. Beego got unexpected error and don't know how to handle it, and then return Unknown error + +A common practice to DefineCode looks like: +%s + +In this way, you may forget to import this package, and got Unknown error. + +Sometimes, you believe you got Beego error, but actually you don't, and then you call FromError(err) + +`, goCodeBlock(` +import your_package + +func init() { + DefineCode(5100100, "your_module", "detail") + // ... +} +`))) + +func goCodeBlock(code string) string { + return codeBlock("go", code) +} + +func codeBlock(lan string, code string) string { + return fmt.Sprintf("```%s\n%s\n```", lan, code) +} diff --git a/core/config/base_config_test.go b/core/config/base_config_test.go index 74a669a7..340d05db 100644 --- a/core/config/base_config_test.go +++ b/core/config/base_config_test.go @@ -66,7 +66,6 @@ func newBaseConfier(str1 string) *BaseConfiger { } else { return "", errors.New("mock error") } - }, } } diff --git a/core/config/config.go b/core/config/config.go index d0add317..bccbc738 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -14,7 +14,7 @@ // Package config is used to parse config. // Usage: -// import "github.com/beego/beego/v2/config" +// import "github.com/beego/beego/v2/core/config" // Examples. // // cnf, err := config.NewConfig("ini", "config.conf") @@ -165,6 +165,7 @@ func (c *BaseConfiger) DefaultBool(key string, defaultVal bool) bool { } return defaultVal } + func (c *BaseConfiger) DefaultFloat(key string, defaultVal float64) float64 { if res, err := c.Float(key); err == nil { return res @@ -186,11 +187,11 @@ func (c *BaseConfiger) Strings(key string) ([]string, error) { return strings.Split(res, ";"), nil } -func (c *BaseConfiger) Sub(key string) (Configer, error) { +func (*BaseConfiger) Sub(string) (Configer, error) { return nil, errors.New("unsupported operation") } -func (c *BaseConfiger) OnChange(key string, fn func(value string)) { +func (*BaseConfiger) OnChange(_ string, _ func(value string)) { // do nothing } @@ -248,6 +249,12 @@ func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { value[k2] = ExpandValueEnv(v2) } m[k] = value + case map[interface{}]interface{}: + tmp := make(map[string]interface{}, len(value)) + for k2, v2 := range value { + tmp[k2.(string)] = v2 + } + m[k] = ExpandValueEnvForMap(tmp) } } return m @@ -370,5 +377,4 @@ func ToString(x interface{}) string { type DecodeOption func(options decodeOptions) -type decodeOptions struct { -} +type decodeOptions struct{} diff --git a/core/config/config_test.go b/core/config/config_test.go index 15d6ffa6..86d3a2c5 100644 --- a/core/config/config_test.go +++ b/core/config/config_test.go @@ -20,7 +20,6 @@ import ( ) func TestExpandValueEnv(t *testing.T) { - testCases := []struct { item string want string @@ -51,5 +50,4 @@ func TestExpandValueEnv(t *testing.T) { t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) } } - } diff --git a/core/config/env/env.go b/core/config/env/env.go index fbf06c5d..d745362a 100644 --- a/core/config/env/env.go +++ b/core/config/env/env.go @@ -18,26 +18,29 @@ package env import ( "fmt" + "go/build" + "io/ioutil" "os" + "path/filepath" "strings" - - "github.com/beego/beego/v2/core/utils" + "sync" ) -var env *utils.BeeMap +var env sync.Map func init() { - env = utils.NewBeeMap() for _, e := range os.Environ() { splits := strings.Split(e, "=") - env.Set(splits[0], os.Getenv(splits[0])) + env.Store(splits[0], os.Getenv(splits[0])) } + env.Store("GOBIN", GetGOBIN()) // set non-empty GOBIN when initialization + env.Store("GOPATH", GetGOPATH()) // set non-empty GOPATH when initialization } // Get returns a value for a given key. // If the key does not exist, the default value will be returned. func Get(key string, defVal string) string { - if val := env.Get(key); val != nil { + if val, ok := env.Load(key); ok { return val.(string) } return defVal @@ -46,7 +49,7 @@ func Get(key string, defVal string) string { // MustGet returns a value by key. // If the key does not exist, it will return an error. func MustGet(key string) (string, error) { - if val := env.Get(key); val != nil { + if val, ok := env.Load(key); ok { return val.(string), nil } return "", fmt.Errorf("no env variable with %s", key) @@ -55,7 +58,7 @@ func MustGet(key string) (string, error) { // Set sets a value in the ENV copy. // This does not affect the child process environment. func Set(key string, value string) { - env.Set(key, value) + env.Store(key, value) } // MustSet sets a value in the ENV copy and the child process environment. @@ -65,23 +68,110 @@ func MustSet(key string, value string) error { if err != nil { return err } - env.Set(key, value) + env.Store(key, value) return nil } // GetAll returns all keys/values in the current child process environment. func GetAll() map[string]string { - items := env.Items() - envs := make(map[string]string, env.Count()) - - for key, val := range items { + envs := make(map[string]string, 32) + env.Range(func(key, value interface{}) bool { switch key := key.(type) { case string: - switch val := val.(type) { + switch val := value.(type) { case string: envs[key] = val } } - } + return true + }) return envs } + +// envFile returns the name of the Go environment configuration file. +// Copy from https://github.com/golang/go/blob/c4f2a9788a7be04daf931ac54382fbe2cb754938/src/cmd/go/internal/cfg/cfg.go#L150-L166 +func envFile() (string, error) { + if file := os.Getenv("GOENV"); file != "" { + if file == "off" { + return "", fmt.Errorf("GOENV=off") + } + return file, nil + } + dir, err := os.UserConfigDir() + if err != nil { + return "", err + } + if dir == "" { + return "", fmt.Errorf("missing user-config dir") + } + return filepath.Join(dir, "go", "env"), nil +} + +// GetRuntimeEnv returns the value of runtime environment variable, +// that is set by running following command: `go env -w key=value`. +func GetRuntimeEnv(key string) (string, error) { + file, err := envFile() + if err != nil { + return "", err + } + if file == "" { + return "", fmt.Errorf("missing runtime env file") + } + + var runtimeEnv string + data, err := ioutil.ReadFile(file) + if err != nil { + return "", err + } + envStrings := strings.Split(string(data), "\n") + for _, envItem := range envStrings { + envItem = strings.TrimSuffix(envItem, "\r") + envKeyValue := strings.Split(envItem, "=") + if len(envKeyValue) == 2 && strings.TrimSpace(envKeyValue[0]) == key { + runtimeEnv = strings.TrimSpace(envKeyValue[1]) + } + } + return runtimeEnv, nil +} + +// GetGOBIN returns GOBIN environment variable as a string. +// It will NOT be an empty string. +func GetGOBIN() string { + // The one set by user explicitly by `export GOBIN=/path` or `env GOBIN=/path command` + gobin := strings.TrimSpace(Get("GOBIN", "")) + if gobin == "" { + var err error + // The one set by user by running `go env -w GOBIN=/path` + gobin, err = GetRuntimeEnv("GOBIN") + if err != nil { + // The default one that Golang uses + return filepath.Join(build.Default.GOPATH, "bin") + } + if gobin == "" { + return filepath.Join(build.Default.GOPATH, "bin") + } + return gobin + } + return gobin +} + +// GetGOPATH returns GOPATH environment variable as a string. +// It will NOT be an empty string. +func GetGOPATH() string { + // The one set by user explicitly by `export GOPATH=/path` or `env GOPATH=/path command` + gopath := strings.TrimSpace(Get("GOPATH", "")) + if gopath == "" { + var err error + // The one set by user by running `go env -w GOPATH=/path` + gopath, err = GetRuntimeEnv("GOPATH") + if err != nil { + // The default one that Golang uses + return build.Default.GOPATH + } + if gopath == "" { + return build.Default.GOPATH + } + return gopath + } + return gopath +} diff --git a/core/config/env/env_test.go b/core/config/env/env_test.go index 3f1d4dba..1cd88147 100644 --- a/core/config/env/env_test.go +++ b/core/config/env/env_test.go @@ -15,7 +15,9 @@ package env import ( + "go/build" "os" + "path/filepath" "testing" ) @@ -73,3 +75,41 @@ func TestEnvGetAll(t *testing.T) { t.Error("expected environment not empty.") } } + +func TestEnvFile(t *testing.T) { + envFile, err := envFile() + if err != nil { + t.Errorf("expected to get env file without error, but got %s.", err) + } + if envFile == "" { + t.Error("expected to get valid env file, but got empty string.") + } +} + +func TestGetGOBIN(t *testing.T) { + customGOBIN := filepath.Join("path", "to", "gobin") + Set("GOBIN", customGOBIN) + if gobin := GetGOBIN(); gobin != Get("GOBIN", "") { + t.Errorf("expected GOBIN environment variable equals to %s, but got %s.", customGOBIN, gobin) + } + + Set("GOBIN", "") + defaultGOBIN := filepath.Join(build.Default.GOPATH, "bin") + if gobin := GetGOBIN(); gobin != defaultGOBIN { + t.Errorf("expected GOBIN environment variable equals to %s, but got %s.", defaultGOBIN, gobin) + } +} + +func TestGetGOPATH(t *testing.T) { + customGOPATH := filepath.Join("path", "to", "gopath") + Set("GOPATH", customGOPATH) + if goPath := GetGOPATH(); goPath != Get("GOPATH", "") { + t.Errorf("expected GOPATH environment variable equals to %s, but got %s.", customGOPATH, goPath) + } + + Set("GOPATH", "") + defaultGOPATH := build.Default.GOPATH + if goPath := GetGOPATH(); goPath != defaultGOPATH { + t.Errorf("expected GOPATH environment variable equals to %s, but got %s.", defaultGOPATH, goPath) + } +} diff --git a/core/config/etcd/config.go b/core/config/etcd/config.go index acc43f35..fb12b065 100644 --- a/core/config/etcd/config.go +++ b/core/config/etcd/config.go @@ -20,10 +20,10 @@ import ( "fmt" "time" - "github.com/coreos/etcd/clientv3" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" + clientv3 "go.etcd.io/etcd/client/v3" "google.golang.org/grpc" "github.com/beego/beego/v2/core/config" @@ -119,7 +119,6 @@ func (e *EtcdConfiger) Sub(key string) (config.Configer, error) { // TODO remove this before release v2.0.0 func (e *EtcdConfiger) OnChange(key string, fn func(value string)) { - buildOptsFunc := func() []clientv3.OpOption { return []clientv3.OpOption{} } @@ -144,11 +143,9 @@ func (e *EtcdConfiger) OnChange(key string, fn func(value string)) { rch = e.client.Watch(context.Background(), e.prefix+key, buildOptsFunc()...) } }() - } -type EtcdConfigerProvider struct { -} +type EtcdConfigerProvider struct{} // Parse = ParseData([]byte(key)) // key must be json diff --git a/core/config/etcd/config_test.go b/core/config/etcd/config_test.go index 6d0bb793..ffaf2727 100644 --- a/core/config/etcd/config_test.go +++ b/core/config/etcd/config_test.go @@ -20,8 +20,8 @@ import ( "testing" "time" - "github.com/coreos/etcd/clientv3" "github.com/stretchr/testify/assert" + clientv3 "go.etcd.io/etcd/client/v3" ) func TestEtcdConfigerProvider_Parse(t *testing.T) { @@ -32,7 +32,6 @@ func TestEtcdConfigerProvider_Parse(t *testing.T) { } func TestEtcdConfiger(t *testing.T) { - provider := &EtcdConfigerProvider{} cfger, _ := provider.Parse(readEtcdConfig()) diff --git a/core/config/global.go b/core/config/global.go index d0b74253..3a334cb6 100644 --- a/core/config/global.go +++ b/core/config/global.go @@ -42,15 +42,19 @@ func String(key string) (string, error) { func Strings(key string) ([]string, error) { return globalInstance.Strings(key) } + func Int(key string) (int, error) { return globalInstance.Int(key) } + func Int64(key string) (int64, error) { return globalInstance.Int64(key) } + func Bool(key string) (bool, error) { return globalInstance.Bool(key) } + func Float(key string) (float64, error) { return globalInstance.Float(key) } @@ -64,15 +68,19 @@ func DefaultString(key string, defaultVal string) string { func DefaultStrings(key string, defaultVal []string) []string { return globalInstance.DefaultStrings(key, defaultVal) } + func DefaultInt(key string, defaultVal int) int { return globalInstance.DefaultInt(key, defaultVal) } + func DefaultInt64(key string, defaultVal int64) int64 { return globalInstance.DefaultInt64(key, defaultVal) } + func DefaultBool(key string, defaultVal bool) bool { return globalInstance.DefaultBool(key, defaultVal) } + func DefaultFloat(key string, defaultVal float64) float64 { return globalInstance.DefaultFloat(key, defaultVal) } @@ -89,6 +97,7 @@ func GetSection(section string) (map[string]string, error) { func Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error { return globalInstance.Unmarshaler(prefix, obj, opt...) } + func Sub(key string) (Configer, error) { return globalInstance.Sub(key) } diff --git a/core/config/ini.go b/core/config/ini.go index f2877ce8..4d80fbbf 100644 --- a/core/config/ini.go +++ b/core/config/ini.go @@ -46,8 +46,7 @@ var ( ) // IniConfig implements Config to parse ini file. -type IniConfig struct { -} +type IniConfig struct{} // Parse creates a new Config and parses the file configuration from the named file. func (ini *IniConfig) Parse(name string) (Configer, error) { @@ -521,7 +520,7 @@ func init() { err := InitGlobalInstance("ini", "conf/app.conf") if err != nil { - logs.Warn("init global config instance failed. If you donot use this, just ignore it. ", err) + logs.Debug("init global config instance failed. If you donot use this, just ignore it. ", err) } } diff --git a/core/config/ini_test.go b/core/config/ini_test.go index b7a03aa2..37ec7203 100644 --- a/core/config/ini_test.go +++ b/core/config/ini_test.go @@ -23,7 +23,6 @@ import ( ) func TestIni(t *testing.T) { - var ( inicontext = ` ;comment one @@ -129,11 +128,9 @@ password = ${GOPATH} if res != "astaxie" { t.Fatal("get name error") } - } func TestIniSave(t *testing.T) { - const ( inicontext = ` app = app diff --git a/core/config/json/json.go b/core/config/json/json.go index c1a29cad..098f0308 100644 --- a/core/config/json/json.go +++ b/core/config/json/json.go @@ -31,8 +31,7 @@ import ( ) // JSONConfig is a json config parser and implements Config interface. -type JSONConfig struct { -} +type JSONConfig struct{} // Parse returns a ConfigContainer with parsed json config map. func (js *JSONConfig) Parse(filename string) (config.Configer, error) { @@ -211,7 +210,6 @@ func (c *JSONConfigContainer) String(key string) (string, error) { // DefaultString returns the string value for a given key. // if err != nil return defaultval func (c *JSONConfigContainer) DefaultString(key string, defaultVal string) string { - // TODO FIXME should not use "" to replace non existence if v, err := c.String(key); v != "" && err == nil { return v } diff --git a/core/config/json/json_test.go b/core/config/json/json_test.go index 8f5b2c83..2d107183 100644 --- a/core/config/json/json_test.go +++ b/core/config/json/json_test.go @@ -25,7 +25,6 @@ import ( ) func TestJsonStartsWithArray(t *testing.T) { - const jsoncontextwitharray = `[ { "url": "user", @@ -72,7 +71,6 @@ func TestJsonStartsWithArray(t *testing.T) { } func TestJson(t *testing.T) { - var ( jsoncontext = `{ "appname": "beeapi", diff --git a/core/config/toml/toml.go b/core/config/toml/toml.go index 9261cd27..0c678164 100644 --- a/core/config/toml/toml.go +++ b/core/config/toml/toml.go @@ -47,7 +47,6 @@ func (c *Config) ParseData(data []byte) (config.Configer, error) { return &configContainer{ t: t, }, nil - } // configContainer support key looks like "a.b.c" @@ -70,7 +69,6 @@ func (c *configContainer) Set(key, val string) error { // return error if key not found or value is invalid type func (c *configContainer) String(key string) (string, error) { res, err := c.get(key) - if err != nil { return "", err } @@ -90,7 +88,6 @@ func (c *configContainer) String(key string) (string, error) { // return error if key not found or value is invalid type func (c *configContainer) Strings(key string) ([]string, error) { val, err := c.get(key) - if err != nil { return []string{}, err } @@ -141,9 +138,7 @@ func (c *configContainer) Int64(key string) (int64, error) { // bool return bool value // return error if key not found or value is invalid type func (c *configContainer) Bool(key string) (bool, error) { - res, err := c.get(key) - if err != nil { return false, err } @@ -330,7 +325,6 @@ func (c *configContainer) get(key string) (interface{}, error) { segs := strings.Split(key, keySeparator) t, err := subTree(c.t, segs[0:len(segs)-1]) - if err != nil { return nil, err } diff --git a/core/config/xml/xml.go b/core/config/xml/xml.go index 059ada5c..c260d3b5 100644 --- a/core/config/xml/xml.go +++ b/core/config/xml/xml.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/config/xml" -// "github.com/beego/beego/v2/config" +// _ "github.com/beego/beego/v2/core/config/xml" +// "github.com/beego/beego/v2/core/config" // ) // // cnf, err := config.NewConfig("xml", "config.xml") @@ -39,12 +39,11 @@ import ( "strings" "sync" + "github.com/beego/x2j" "github.com/mitchellh/mapstructure" "github.com/beego/beego/v2/core/config" "github.com/beego/beego/v2/core/logs" - - "github.com/beego/x2j" ) // Config is a xml config parser and implements Config interface. @@ -71,7 +70,17 @@ func (xc *Config) ParseData(data []byte) (config.Configer, error) { return nil, err } - x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) + v := d["config"] + if v == nil { + return nil, fmt.Errorf("xml parse should include in tags") + } + + confVal, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xml parse tags should include sub tags") + } + + x.data = config.ExpandValueEnvForMap(confVal) return x, nil } @@ -103,7 +112,6 @@ func (c *ConfigContainer) Sub(key string) (config.Configer, error) { return &ConfigContainer{ data: sub, }, nil - } func (c *ConfigContainer) sub(key string) (map[string]interface{}, error) { @@ -171,7 +179,6 @@ func (c *ConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { return defaultVal } return v - } // Float returns the float value for a given key. diff --git a/core/config/xml/xml_test.go b/core/config/xml/xml_test.go index 37c5fe7f..2807be69 100644 --- a/core/config/xml/xml_test.go +++ b/core/config/xml/xml_test.go @@ -25,9 +25,8 @@ import ( ) func TestXML(t *testing.T) { - var ( - // xml parse should incluce in tags + // xml parse should include in tags xmlcontext = ` beeapi @@ -149,7 +148,25 @@ func TestXML(t *testing.T) { err = xmlconf.Unmarshaler("mysection", sec) assert.Nil(t, err) assert.Equal(t, "MySection", sec.Name) +} +func TestXMLMissConfig(t *testing.T) { + xmlcontext1 := ` + + beeapi + ` + + c := &Config{} + _, err := c.ParseData([]byte(xmlcontext1)) + assert.Equal(t, "xml parse should include in tags", err.Error()) + + xmlcontext2 := ` + + + ` + + _, err = c.ParseData([]byte(xmlcontext2)) + assert.Equal(t, "xml parse tags should include sub tags", err.Error()) } type Section struct { diff --git a/core/config/yaml/yaml.go b/core/config/yaml/yaml.go index 778a4eb1..ae02ad16 100644 --- a/core/config/yaml/yaml.go +++ b/core/config/yaml/yaml.go @@ -13,15 +13,10 @@ // limitations under the License. // Package yaml for config provider -// -// depend on github.com/beego/goyaml2 -// -// go install github.com/beego/goyaml2 -// // Usage: // import( -// _ "github.com/beego/beego/v2/config/yaml" -// "github.com/beego/beego/v2/config" +// _ "github.com/beego/beego/v2/core/config/yaml" +// "github.com/beego/beego/v2/core/config" // ) // // cnf, err := config.NewConfig("yaml", "config.yaml") @@ -30,17 +25,13 @@ package yaml import ( - "bytes" - "encoding/json" "errors" "fmt" "io/ioutil" - "log" "os" "strings" "sync" - "github.com/beego/goyaml2" "gopkg.in/yaml.v2" "github.com/beego/beego/v2/core/config" @@ -51,7 +42,7 @@ import ( type Config struct{} // Parse returns a ConfigContainer with parsed yaml config map. -func (yaml *Config) Parse(filename string) (y config.Configer, err error) { +func (*Config) Parse(filename string) (y config.Configer, err error) { cnf, err := ReadYmlReader(filename) if err != nil { return @@ -63,7 +54,7 @@ func (yaml *Config) Parse(filename string) (y config.Configer, err error) { } // ParseData parse yaml data -func (yaml *Config) ParseData(data []byte) (config.Configer, error) { +func (*Config) ParseData(data []byte) (config.Configer, error) { cnf, err := parseYML(data) if err != nil { return nil, err @@ -75,7 +66,6 @@ func (yaml *Config) ParseData(data []byte) (config.Configer, error) { } // ReadYmlReader Read yaml file to map. -// if json like, use json package, unless goyaml2 package. func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { buf, err := ioutil.ReadFile(path) if err != nil { @@ -86,37 +76,14 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { } // parseYML parse yaml formatted []byte to map. -func parseYML(buf []byte) (cnf map[string]interface{}, err error) { - if len(buf) < 3 { - return - } - - if string(buf[0:1]) == "{" { - log.Println("Look like a Json, try json umarshal") - err = json.Unmarshal(buf, &cnf) - if err == nil { - log.Println("It is Json Map") - return - } - } - - data, err := goyaml2.Read(bytes.NewReader(buf)) +func parseYML(buf []byte) (map[string]interface{}, error) { + cnf := make(map[string]interface{}) + err := yaml.Unmarshal(buf, cnf) if err != nil { - log.Println("Goyaml2 ERR>", string(buf), err) - return - } - - if data == nil { - log.Println("Goyaml2 output nil? Pls report bug\n" + string(buf)) - return - } - cnf, ok := data.(map[string]interface{}) - if !ok { - log.Println("Not a Map? >> ", string(buf), data) - cnf = nil + return nil, err } cnf = config.ExpandValueEnvForMap(cnf) - return + return cnf, err } // ConfigContainer is a config which represents the yaml configuration. @@ -126,8 +93,8 @@ type ConfigContainer struct { } // Unmarshaler is similar to Sub -func (c *ConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { - sub, err := c.sub(prefix) +func (c *ConfigContainer) Unmarshaler(prefix string, obj interface{}, _ ...config.DecodeOption) error { + sub, err := c.subMap(prefix) if err != nil { return err } @@ -140,7 +107,7 @@ func (c *ConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...con } func (c *ConfigContainer) Sub(key string) (config.Configer, error) { - sub, err := c.sub(key) + sub, err := c.subMap(key) if err != nil { return nil, err } @@ -149,21 +116,19 @@ func (c *ConfigContainer) Sub(key string) (config.Configer, error) { }, nil } -func (c *ConfigContainer) sub(key string) (map[string]interface{}, error) { +func (c *ConfigContainer) subMap(key string) (map[string]interface{}, error) { tmpData := c.data keys := strings.Split(key, ".") for idx, k := range keys { if v, ok := tmpData[k]; ok { - switch v.(type) { + switch val := v.(type) { case map[string]interface{}: - { - tmpData = v.(map[string]interface{}) - if idx == len(keys)-1 { - return tmpData, nil - } + tmpData = val + if idx == len(keys)-1 { + return tmpData, nil } default: - return nil, errors.New(fmt.Sprintf("the key is invalid: %s", key)) + return nil, fmt.Errorf("the key is invalid: %s", key) } } } @@ -171,7 +136,7 @@ func (c *ConfigContainer) sub(key string) (map[string]interface{}, error) { return tmpData, nil } -func (c *ConfigContainer) OnChange(key string, fn func(value string)) { +func (*ConfigContainer) OnChange(_ string, _ func(value string)) { // do nothing logs.Warn("Unsupported operation: OnChange") } @@ -219,12 +184,18 @@ func (c *ConfigContainer) DefaultInt(key string, defaultVal int) int { // Int64 returns the int64 value for a given key. func (c *ConfigContainer) Int64(key string) (int64, error) { - if v, err := c.getData(key); err != nil { + v, err := c.getData(key) + if err != nil { return 0, err - } else if vv, ok := v.(int64); ok { - return vv, nil } - return 0, errors.New("not bool value") + switch val := v.(type) { + case int: + return int64(val), nil + case int64: + return val, nil + default: + return 0, errors.New("not int or int64 value") + } } // DefaultInt64 returns the int64 value for a given key. @@ -302,9 +273,19 @@ func (c *ConfigContainer) DefaultStrings(key string, defaultVal []string) []stri // GetSection returns map for the given section func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section]; ok { - return v.(map[string]string), nil + switch val := v.(type) { + case map[string]interface{}: + res := make(map[string]string, len(val)) + for k2, v2 := range val { + res[k2] = fmt.Sprintf("%v", v2) + } + return res, nil + case map[string]string: + return val, nil + default: + return nil, fmt.Errorf("unexpected type: %v", v) + } } return nil, errors.New("not exist section") } @@ -317,7 +298,11 @@ func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { return err } defer f.Close() - err = goyaml2.Write(f, c.data) + buf, err := yaml.Marshal(c.data) + if err != nil { + return err + } + _, err = f.Write(buf) return err } @@ -335,40 +320,30 @@ func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { } func (c *ConfigContainer) getData(key string) (interface{}, error) { - - if len(key) == 0 { + if key == "" { return nil, errors.New("key is empty") } c.RLock() defer c.RUnlock() - keys := strings.Split(c.key(key), ".") + keys := strings.Split(key, ".") tmpData := c.data for idx, k := range keys { if v, ok := tmpData[k]; ok { - switch v.(type) { + switch val := v.(type) { case map[string]interface{}: - { - tmpData = v.(map[string]interface{}) - if idx == len(keys)-1 { - return tmpData, nil - } + tmpData = val + if idx == len(keys)-1 { + return tmpData, nil } default: - { - return v, nil - } - + return v, nil } } } return nil, fmt.Errorf("not exist key %q", key) } -func (c *ConfigContainer) key(key string) string { - return key -} - func init() { config.Register("yaml", &Config{}) } diff --git a/core/config/yaml/yaml_test.go b/core/config/yaml/yaml_test.go index cf889613..0430962e 100644 --- a/core/config/yaml/yaml_test.go +++ b/core/config/yaml/yaml_test.go @@ -25,7 +25,6 @@ import ( ) func TestYaml(t *testing.T) { - var ( yamlcontext = ` "appname": beeapi @@ -59,6 +58,7 @@ func TestYaml(t *testing.T) { "emptystrings": []string{}, } ) + f, err := os.Create("testyaml.conf") if err != nil { t.Fatal(err) @@ -70,11 +70,28 @@ func TestYaml(t *testing.T) { } f.Close() defer os.Remove("testyaml.conf") + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") if err != nil { t.Fatal(err) } + m, err := ReadYmlReader("testyaml.conf") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, m, yamlconf.(*ConfigContainer).data) + + shadow, err := (&Config{}).ParseData([]byte(yamlcontext)) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, shadow, yamlconf) + + yamlconf.OnChange("abc", func(value string) { + fmt.Printf("on change, value is %s \n", value) + }) + res, _ := yamlconf.String("appname") if res != "beeapi" { t.Fatal("appname not equal to beeapi") @@ -104,7 +121,7 @@ func TestYaml(t *testing.T) { value, err = yamlconf.DIY(k) } if err != nil { - t.Errorf("get key %q value fatal,%v err %s", k, v, err) + t.Errorf("get key %q value fatal, %v err %s", k, v, err) } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { t.Errorf("get key %q value, want %v got %v .", k, v, value) } @@ -120,7 +137,9 @@ func TestYaml(t *testing.T) { } sub, err := yamlconf.Sub("user") - assert.Nil(t, err) + if err != nil { + t.Fatal(err) + } assert.NotNil(t, sub) name, err := sub.String("name") assert.Nil(t, err) @@ -143,6 +162,38 @@ func TestYaml(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "tom", user.Name) assert.Equal(t, 13, user.Age) + + // default value + assert.Equal(t, "beeapi", yamlconf.DefaultString("appname", "invalid")) + assert.Equal(t, "invalid", yamlconf.DefaultString("i-appname", "invalid")) + assert.Equal(t, 8080, yamlconf.DefaultInt("httpport", 8090)) + assert.Equal(t, 8090, yamlconf.DefaultInt("i-httpport", 8090)) + assert.Equal(t, 3.1415976, yamlconf.DefaultFloat("PI", 3.14)) + assert.Equal(t, 3.14, yamlconf.DefaultFloat("1-PI", 3.14)) + assert.True(t, yamlconf.DefaultBool("copyrequestbody", false)) + assert.True(t, yamlconf.DefaultBool("i-copyrequestbody", true)) + assert.Equal(t, int64(8080), yamlconf.DefaultInt64("httpport", 8090)) + assert.Equal(t, int64(8090), yamlconf.DefaultInt64("i-httpport", 8090)) + assert.Equal(t, "tom", yamlconf.DefaultString("user.name", "invalid")) + assert.Equal(t, "invalid", yamlconf.DefaultString("user.1-name", "invalid")) + assert.Equal(t, []string{"tom"}, yamlconf.DefaultStrings("strings", []string{"tom"})) + + appName, err := yamlconf.DIY("appname") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "beeapi", appName) + + err = yamlconf.SaveConfigFile(f.Name()) + if err != nil { + t.Fatal(err) + } + + section, err := yamlconf.GetSection("user") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "tom", section["name"]) } type User struct { diff --git a/core/logs/README.md b/core/logs/README.md index c7c82110..b2c405ff 100644 --- a/core/logs/README.md +++ b/core/logs/README.md @@ -4,7 +4,7 @@ logs is a Go logs manager. It can use many logs adapters. The repo is inspired b ## How to install? - go get github.com/beego/beego/v2/logs + go get github.com/beego/beego/v2/core/logs ## What adapters are supported? @@ -16,7 +16,7 @@ First you must import it ```golang import ( - "github.com/beego/beego/v2/logs" + "github.com/beego/beego/v2/core/logs" ) ``` diff --git a/core/logs/alils/alils.go b/core/logs/alils/alils.go index 7a3e4ddf..6fd8702a 100644 --- a/core/logs/alils/alils.go +++ b/core/logs/alils/alils.go @@ -180,7 +180,6 @@ func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { // Flush implementing method. empty. func (c *aliLSWriter) Flush() { - // flush all group for _, lg := range c.group { c.flush(lg) @@ -192,7 +191,6 @@ func (c *aliLSWriter) Destroy() { } func (c *aliLSWriter) flush(lg *LogGroup) { - c.lock.Lock() defer c.lock.Unlock() err := c.store.PutLogs(lg) diff --git a/core/logs/alils/log_project.go b/core/logs/alils/log_project.go index 7ede3fef..f993003e 100755 --- a/core/logs/alils/log_project.go +++ b/core/logs/alils/log_project.go @@ -128,7 +128,6 @@ func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { // and ttl is time-to-live(in day) of logs, // and shardCnt is the number of shards. func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) { - type Body struct { Name string `json:"logstoreName"` TTL int `json:"ttl"` @@ -212,7 +211,6 @@ func (p *LogProject) DeleteLogStore(name string) (err error) { // UpdateLogStore updates a logstore according by logstore name, // obviously we can't modify the logstore name itself. func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) { - type Body struct { Name string `json:"logstoreName"` TTL int `json:"ttl"` @@ -355,7 +353,6 @@ func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { // CreateMachineGroup creates a new machine group in SLS. func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { - body, err := json.Marshal(m) if err != nil { return @@ -395,7 +392,6 @@ func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { // UpdateMachineGroup updates a machine group. func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { - body, err := json.Marshal(m) if err != nil { return @@ -555,7 +551,6 @@ func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { // UpdateConfig updates a config. func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { - body, err := json.Marshal(c) if err != nil { return @@ -595,7 +590,6 @@ func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { // CreateConfig creates a new config in SLS. func (p *LogProject) CreateConfig(c *LogConfig) (err error) { - body, err := json.Marshal(c) if err != nil { return diff --git a/core/logs/alils/log_store.go b/core/logs/alils/log_store.go index d5ff25e2..fd9d9eaf 100755 --- a/core/logs/alils/log_store.go +++ b/core/logs/alils/log_store.go @@ -241,7 +241,6 @@ func (s *LogStore) GetLogsBytes(shardID int, cursor string, // LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { - gl = &LogGroupList{} err = proto.Unmarshal(data, gl) if err != nil { diff --git a/core/logs/conn.go b/core/logs/conn.go index 1fd71be7..cfeb3f91 100644 --- a/core/logs/conn.go +++ b/core/logs/conn.go @@ -95,7 +95,6 @@ func (c *connWriter) WriteMsg(lm *LogMsg) error { // Flush implementing method. empty. func (c *connWriter) Flush() { - } // Destroy destroy connection writer and close tcp listener. diff --git a/core/logs/conn_test.go b/core/logs/conn_test.go index ca9ea1c7..85b95c10 100644 --- a/core/logs/conn_test.go +++ b/core/logs/conn_test.go @@ -26,7 +26,6 @@ import ( // ConnTCPListener takes a TCP listener and accepts n TCP connections // Returns connections using connChan func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { - // Listen and accept n incoming connections for i := 0; i < n; i++ { conn, err := ln.Accept() diff --git a/core/logs/console.go b/core/logs/console.go index 05b13cd7..ff4fcf46 100644 --- a/core/logs/console.go +++ b/core/logs/console.go @@ -62,8 +62,7 @@ func (c *consoleWriter) Format(lm *LogMsg) string { msg = strings.Replace(msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) } h, _, _ := formatTimeHeader(lm.When) - bytes := append(h, msg...) - return string(bytes) + return string(append(h, msg...)) } func (c *consoleWriter) SetFormatter(f LogFormatter) { @@ -88,7 +87,6 @@ func newConsole() *consoleWriter { // Init initianlizes the console logger. // jsonConfig must be in the format '{"level":LevelTrace}' func (c *consoleWriter) Init(config string) error { - if len(config) == 0 { return nil } @@ -116,12 +114,10 @@ func (c *consoleWriter) WriteMsg(lm *LogMsg) error { // Destroy implementing method. empty. func (c *consoleWriter) Destroy() { - } // Flush implementing method. empty. func (c *consoleWriter) Flush() { - } func init() { diff --git a/core/logs/es/es.go b/core/logs/es/es.go index 2e592ffd..2073ab3d 100644 --- a/core/logs/es/es.go +++ b/core/logs/es/es.go @@ -29,7 +29,7 @@ func NewES() logs.Logger { // please import this package // usually means that you can import this package in your main package // for example, anonymous: -// import _ "github.com/beego/beego/v2/logs/es" +// import _ "github.com/beego/beego/v2/core/logs/es" type esLogger struct { *elasticsearch.Client DSN string `json:"dsn"` @@ -41,7 +41,6 @@ type esLogger struct { } func (el *esLogger) Format(lm *logs.LogMsg) string { - msg := lm.OldStyleFormat() idx := LogDocument{ Timestamp: lm.When.Format(time.RFC3339), @@ -60,7 +59,6 @@ func (el *esLogger) SetFormatter(f logs.LogFormatter) { // {"dsn":"http://localhost:9200/","level":1} func (el *esLogger) Init(config string) error { - err := json.Unmarshal([]byte(config), el) if err != nil { return err @@ -113,7 +111,6 @@ func (el *esLogger) Destroy() { // Flush is a empty method func (el *esLogger) Flush() { - } type LogDocument struct { diff --git a/core/logs/file.go b/core/logs/file.go index b01be357..3234a674 100644 --- a/core/logs/file.go +++ b/core/logs/file.go @@ -33,6 +33,11 @@ import ( // Writes messages by lines limit, file size limit, or time frequency. type fileLogWriter struct { sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize + + Rotate bool `json:"rotate"` + Daily bool `json:"daily"` + Hourly bool `json:"hourly"` + // The opened file Filename string `json:"filename"` fileWriter *os.File @@ -49,29 +54,27 @@ type fileLogWriter struct { maxSizeCurSize int // Rotate daily - Daily bool `json:"daily"` MaxDays int64 `json:"maxdays"` dailyOpenDate int dailyOpenTime time.Time // Rotate hourly - Hourly bool `json:"hourly"` MaxHours int64 `json:"maxhours"` hourlyOpenDate int hourlyOpenTime time.Time - Rotate bool `json:"rotate"` - Level int `json:"level"` - + // Permissions for log file Perm string `json:"perm"` + // Permissions for directory if it is specified in FileName + DirPerm string `json:"dirperm"` RotatePerm string `json:"rotateperm"` fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix - formatter LogFormatter - Formatter string `json:"formatter"` + logFormatter LogFormatter + Formatter string `json:"formatter"` } // newFileWriter creates a FileLogWriter returning as LoggerInterface. @@ -85,15 +88,16 @@ func newFileWriter() Logger { RotatePerm: "0440", Level: LevelTrace, Perm: "0660", + DirPerm: "0770", MaxLines: 10000000, MaxFiles: 999, MaxSize: 1 << 28, } - w.formatter = w + w.logFormatter = w return w } -func (w *fileLogWriter) Format(lm *LogMsg) string { +func (*fileLogWriter) Format(lm *LogMsg) string { msg := lm.OldStyleFormat() hd, _, _ := formatTimeHeader(lm.When) msg = fmt.Sprintf("%s %s\n", string(hd), msg) @@ -101,7 +105,7 @@ func (w *fileLogWriter) Format(lm *LogMsg) string { } func (w *fileLogWriter) SetFormatter(f LogFormatter) { - w.formatter = f + w.logFormatter = f } // Init file logger with json config. @@ -116,12 +120,11 @@ func (w *fileLogWriter) SetFormatter(f LogFormatter) { // "perm":"0600" // } func (w *fileLogWriter) Init(config string) error { - err := json.Unmarshal([]byte(config), w) if err != nil { return err } - if len(w.Filename) == 0 { + if w.Filename == "" { return errors.New("jsonconfig must have filename") } w.suffix = filepath.Ext(w.Filename) @@ -133,9 +136,9 @@ func (w *fileLogWriter) Init(config string) error { if len(w.Formatter) > 0 { fmtr, ok := GetFormatter(w.Formatter) if !ok { - return errors.New(fmt.Sprintf("the formatter with name: %s not found", w.Formatter)) + return fmt.Errorf("the formatter with name: %s not found", w.Formatter) } - w.formatter = fmtr + w.logFormatter = fmtr } err = w.startLogger() return err @@ -164,7 +167,6 @@ func (w *fileLogWriter) needRotateHourly(hour int) bool { return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.Hourly && hour != w.hourlyOpenDate) - } // WriteMsg writes logger message into file. @@ -175,7 +177,7 @@ func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { _, d, h := formatTimeHeader(lm.When) - msg := w.formatter.Format(lm) + msg := w.logFormatter.Format(lm) if w.Rotate { w.RLock() if w.needRotateHourly(h) { @@ -218,8 +220,13 @@ func (w *fileLogWriter) createLogFile() (*os.File, error) { return nil, err } + dirperm, err := strconv.ParseInt(w.DirPerm, 8, 64) + if err != nil { + return nil, err + } + filepath := path.Dir(w.Filename) - os.MkdirAll(filepath, os.FileMode(perm)) + os.MkdirAll(filepath, os.FileMode(dirperm)) fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) if err == nil { diff --git a/core/logs/file_test.go b/core/logs/file_test.go index e9a8922b..6aac919a 100644 --- a/core/logs/file_test.go +++ b/core/logs/file_test.go @@ -42,12 +42,63 @@ func TestFilePerm(t *testing.T) { if err != nil { t.Fatal(err) } - if file.Mode() != 0666 { + if file.Mode() != 0o666 { t.Fatal("unexpected log file permission") } os.Remove("test.log") } +func TestFileWithPrefixPath(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"log/test.log"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + _, err := os.Stat("log/test.log") + if err != nil { + t.Fatal(err) + } + os.Remove("log/test.log") + os.Remove("log") +} + +func TestFilePermWithPrefixPath(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"mylogpath/test.log", "perm": "0220", "dirperm": "0770"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + + dir, err := os.Stat("mylogpath") + if err != nil { + t.Fatal(err) + } + if !dir.IsDir() { + t.Fatal("mylogpath expected to be a directory") + } + + file, err := os.Stat("mylogpath/test.log") + if err != nil { + t.Fatal(err) + } + if file.Mode() != 0o0220 { + t.Fatal("unexpected file permission") + } + + os.Remove("mylogpath/test.log") + os.Remove("mylogpath") +} + func TestFile1(t *testing.T) { log := NewLogger(10000) log.SetLogger("file", `{"filename":"test.log"}`) @@ -74,7 +125,7 @@ func TestFile1(t *testing.T) { lineNum++ } } - var expected = LevelDebug + 1 + expected := LevelDebug + 1 if lineNum != expected { t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") } @@ -107,7 +158,7 @@ func TestFile2(t *testing.T) { lineNum++ } } - var expected = LevelError + 1 + expected := LevelError + 1 if lineNum != expected { t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") } @@ -164,6 +215,7 @@ func TestFileDailyRotate_05(t *testing.T) { testFileDailyRotate(t, fn1, fn2) os.Remove(fn) } + func TestFileDailyRotate_06(t *testing.T) { // test file mode log := NewLogger(10000) log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) @@ -177,7 +229,7 @@ func TestFileDailyRotate_06(t *testing.T) { // test file mode log.Emergency("emergency") rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" s, _ := os.Lstat(rotateName) - if s.Mode() != 0440 { + if s.Mode() != 0o440 { os.Remove(rotateName) os.Remove("test3.log") t.Fatal("rotate file mode error") @@ -250,7 +302,7 @@ func TestFileHourlyRotate_06(t *testing.T) { // test file mode log.Emergency("emergency") rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" s, _ := os.Lstat(rotateName) - if s.Mode() != 0440 { + if s.Mode() != 0o440 { os.Remove(rotateName) os.Remove("test3.log") t.Fatal("rotate file mode error") @@ -268,17 +320,18 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { Rotate: true, Level: LevelTrace, Perm: "0660", + DirPerm: "0770", RotatePerm: "0440", } - fw.formatter = fw + fw.logFormatter = fw - if daily { + if fw.Daily { fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) fw.dailyOpenDate = fw.dailyOpenTime.Day() } - if hourly { + if fw.Hourly { fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) fw.hourlyOpenDate = fw.hourlyOpenTime.Day() @@ -309,9 +362,10 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { Rotate: true, Level: LevelTrace, Perm: "0660", + DirPerm: "0770", RotatePerm: "0440", } - fw.formatter = fw + fw.logFormatter = fw fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) @@ -343,10 +397,11 @@ func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { Rotate: true, Level: LevelTrace, Perm: "0660", + DirPerm: "0770", RotatePerm: "0440", } - fw.formatter = fw + fw.logFormatter = fw fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() @@ -369,6 +424,7 @@ func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { } fw.Destroy() } + func exists(path string) (bool, error) { _, err := os.Stat(path) if err == nil { diff --git a/core/logs/jianliao.go b/core/logs/jianliao.go index c82a0957..95835c06 100644 --- a/core/logs/jianliao.go +++ b/core/logs/jianliao.go @@ -31,7 +31,6 @@ func newJLWriter() Logger { // Init JLWriter with json config string func (s *JLWriter) Init(config string) error { - res := json.Unmarshal([]byte(config), s) if res == nil && len(s.Formatter) > 0 { fmtr, ok := GetFormatter(s.Formatter) diff --git a/core/logs/log.go b/core/logs/log.go index 52d3007f..ad2ef953 100644 --- a/core/logs/log.go +++ b/core/logs/log.go @@ -15,7 +15,7 @@ // Package logs provide a general log interface // Usage: // -// import "github.com/beego/beego/v2/logs" +// import "github.com/beego/beego/v2/core/logs" // // log := NewLogger(10000) // log.SetLogger("console", "") @@ -92,8 +92,10 @@ type Logger interface { SetFormatter(f LogFormatter) } -var adapters = make(map[string]newLoggerFunc) -var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} +var ( + adapters = make(map[string]newLoggerFunc) + levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} +) // Register makes a log provide available by the provided name. // If Register is called twice with the same name or if driver is nil, @@ -112,17 +114,17 @@ func Register(name string, log newLoggerFunc) { // Can contain several providers and log message into all providers. type BeeLogger struct { lock sync.Mutex - level int init bool enableFuncCallDepth bool - loggerFuncCallDepth int enableFullFilePath bool asynchronous bool + wg sync.WaitGroup + level int + loggerFuncCallDepth int prefix string msgChanLen int64 msgChan chan *LogMsg signalChan chan string - wg sync.WaitGroup outputs []*nameLogger globalFormatter string } @@ -201,7 +203,6 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { } err := lg.Init(config) - if err != nil { fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err @@ -261,13 +262,13 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) { lm := &LogMsg{ Msg: string(p), Level: levelLoggerImpl, - When: time.Now(), + When: time.Now(), } // set levelLoggerImpl to ensure all log message will be write out err = bl.writeMsg(lm) if err == nil { - return len(p), err + return len(p), nil } return 0, err } diff --git a/core/logs/log_test.go b/core/logs/log_test.go index b65d8dbb..c8a3a478 100644 --- a/core/logs/log_test.go +++ b/core/logs/log_test.go @@ -18,7 +18,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - ) func TestBeeLoggerDelLogger(t *testing.T) { diff --git a/core/logs/logger.go b/core/logs/logger.go index 77e68ea2..29e33a82 100644 --- a/core/logs/logger.go +++ b/core/logs/logger.go @@ -112,8 +112,10 @@ var ( reset = string([]byte{27, 91, 48, 109}) ) -var once sync.Once -var colorMap map[string]string +var ( + once sync.Once + colorMap map[string]string +) func initColor() { if runtime.GOOS == "windows" { diff --git a/core/logs/multifile.go b/core/logs/multifile.go index 79178211..0ad8244f 100644 --- a/core/logs/multifile.go +++ b/core/logs/multifile.go @@ -45,7 +45,6 @@ var levelNames = [...]string{"emergency", "alert", "critical", "error", "warning // } func (f *multiFileLogWriter) Init(config string) error { - writer := newFileWriter().(*fileLogWriter) err := writer.Init(config) if err != nil { diff --git a/core/logs/multifile_test.go b/core/logs/multifile_test.go index 57b96094..8050bc6b 100644 --- a/core/logs/multifile_test.go +++ b/core/logs/multifile_test.go @@ -60,7 +60,7 @@ func TestFiles_1(t *testing.T) { lineNum++ } } - var expected = 1 + expected := 1 if fn == "" { expected = LevelDebug + 1 } @@ -74,5 +74,4 @@ func TestFiles_1(t *testing.T) { } os.Remove(file) } - } diff --git a/core/utils/debug.go b/core/utils/debug.go index 93c27b70..93279181 100644 --- a/core/utils/debug.go +++ b/core/utils/debug.go @@ -47,20 +47,20 @@ func GetDisplayString(data ...interface{}) string { } func display(displayed bool, data ...interface{}) string { - var pc, file, line, ok = runtime.Caller(2) + pc, file, line, ok := runtime.Caller(2) if !ok { return "" } - var buf = new(bytes.Buffer) + buf := new(bytes.Buffer) fmt.Fprintf(buf, "[Debug] at %s() [%s:%d]\n", function(pc), file, line) fmt.Fprintf(buf, "\n[Variables]\n") for i := 0; i < len(data); i += 2 { - var output = fomateinfo(len(data[i].(string))+3, data[i+1]) + output := fomateinfo(len(data[i].(string))+3, data[i+1]) fmt.Fprintf(buf, "%s = %s", data[i], output) } @@ -72,7 +72,7 @@ func display(displayed bool, data ...interface{}) string { // return data dump and format bytes func fomateinfo(headlen int, data ...interface{}) []byte { - var buf = new(bytes.Buffer) + buf := new(bytes.Buffer) if len(data) > 1 { fmt.Fprint(buf, " ") @@ -83,9 +83,9 @@ func fomateinfo(headlen int, data ...interface{}) []byte { } for k, v := range data { - var buf2 = new(bytes.Buffer) + buf2 := new(bytes.Buffer) var pointers *pointerInfo - var interfaces = make([]reflect.Value, 0, 10) + interfaces := make([]reflect.Value, 0, 10) printKeyValue(buf2, reflect.ValueOf(v), &pointers, &interfaces, nil, true, " ", 1) @@ -140,13 +140,13 @@ func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, return true } - var elem = val.Elem() + elem := val.Elem() if isSimpleType(elem, elem.Kind(), pointers, interfaces) { return true } - var addr = val.Elem().UnsafeAddr() + addr := val.Elem().UnsafeAddr() for p := *pointers; p != nil; p = p.prev { if addr == p.addr { @@ -162,7 +162,7 @@ func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, // dump value func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) { - var t = val.Kind() + t := val.Kind() switch t { case reflect.Bool: @@ -183,7 +183,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, return } - var addr = val.Elem().UnsafeAddr() + addr := val.Elem().UnsafeAddr() for p := *pointers; p != nil; p = p.prev { if addr == p.addr { @@ -206,7 +206,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, case reflect.String: fmt.Fprint(buf, "\"", val.String(), "\"") case reflect.Interface: - var value = val.Elem() + value := val.Elem() if !value.IsValid() { fmt.Fprint(buf, "nil") @@ -223,7 +223,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, printKeyValue(buf, value, pointers, interfaces, structFilter, formatOutput, indent, level+1) } case reflect.Struct: - var t = val.Type() + t := val.Type() fmt.Fprint(buf, t) fmt.Fprint(buf, "{") @@ -235,7 +235,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, fmt.Fprint(buf, " ") } - var name = t.Field(i).Name + name := t.Field(i).Name if formatOutput { for ind := 0; ind < level; ind++ { @@ -270,12 +270,12 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, fmt.Fprint(buf, val.Type()) fmt.Fprint(buf, "{") - var allSimple = true + allSimple := true for i := 0; i < val.Len(); i++ { - var elem = val.Index(i) + elem := val.Index(i) - var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + isSimple := isSimpleType(elem, elem.Kind(), pointers, interfaces) if !isSimple { allSimple = false @@ -312,18 +312,18 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, fmt.Fprint(buf, "}") case reflect.Map: - var t = val.Type() - var keys = val.MapKeys() + t := val.Type() + keys := val.MapKeys() fmt.Fprint(buf, t) fmt.Fprint(buf, "{") - var allSimple = true + allSimple := true for i := 0; i < len(keys); i++ { - var elem = val.MapIndex(keys[i]) + elem := val.MapIndex(keys[i]) - var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + isSimple := isSimpleType(elem, elem.Kind(), pointers, interfaces) if !isSimple { allSimple = false @@ -372,8 +372,8 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, // PrintPointerInfo dump pointer value func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { - var anyused = false - var pointerNum = 0 + anyused := false + pointerNum := 0 for p := pointers; p != nil; p = p.prev { if len(p.used) > 0 { @@ -384,10 +384,10 @@ func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { } if anyused { - var pointerBufs = make([][]rune, pointerNum+1) + pointerBufs := make([][]rune, pointerNum+1) for i := 0; i < len(pointerBufs); i++ { - var pointerBuf = make([]rune, buf.Len()+headlen) + pointerBuf := make([]rune, buf.Len()+headlen) for j := 0; j < len(pointerBuf); j++ { pointerBuf[j] = ' ' @@ -402,7 +402,7 @@ func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { if pn == p.n { pointerBufs[pn][p.pos+headlen] = '└' - var maxpos = 0 + maxpos := 0 for i, pos := range p.used { if i < len(p.used)-1 { @@ -440,10 +440,10 @@ func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { // Stack get stack bytes func Stack(skip int, indent string) []byte { - var buf = new(bytes.Buffer) + buf := new(bytes.Buffer) for i := skip; ; i++ { - var pc, file, line, ok = runtime.Caller(i) + pc, file, line, ok := runtime.Caller(i) if !ok { break diff --git a/core/utils/debug_test.go b/core/utils/debug_test.go index efb8924e..a748d20a 100644 --- a/core/utils/debug_test.go +++ b/core/utils/debug_test.go @@ -28,8 +28,8 @@ func TestPrint(t *testing.T) { } func TestPrintPoint(t *testing.T) { - var v1 = new(mytype) - var v2 = new(mytype) + v1 := new(mytype) + v2 := new(mytype) v1.prev = nil v1.next = v2 diff --git a/core/utils/kv_test.go b/core/utils/kv_test.go index 4c9643dc..cd0dc5c3 100644 --- a/core/utils/kv_test.go +++ b/core/utils/kv_test.go @@ -34,5 +34,4 @@ func TestKVs(t *testing.T) { v = kvs.GetValueOr(`key-not-exists`, 8546) assert.Equal(t, 8546, v) - } diff --git a/core/utils/rand.go b/core/utils/rand.go index 344d1cd5..f80dad7d 100644 --- a/core/utils/rand.go +++ b/core/utils/rand.go @@ -27,7 +27,7 @@ func RandomCreateBytes(n int, alphabets ...byte) []byte { if len(alphabets) == 0 { alphabets = alphaNum } - var bytes = make([]byte, n) + bytes := make([]byte, n) var randBy bool if num, err := rand.Read(bytes); num != n || err != nil { r.Seed(time.Now().UnixNano()) diff --git a/core/utils/safemap.go b/core/utils/safemap.go index 1793030a..8c48cb5b 100644 --- a/core/utils/safemap.go +++ b/core/utils/safemap.go @@ -18,7 +18,7 @@ import ( "sync" ) -// BeeMap is a map with lock +// Deprecated: using sync.Map type BeeMap struct { lock *sync.RWMutex bm map[interface{}]interface{} diff --git a/core/utils/slice.go b/core/utils/slice.go index 8f2cef98..0eda4670 100644 --- a/core/utils/slice.go +++ b/core/utils/slice.go @@ -20,6 +20,7 @@ import ( ) type reducetype func(interface{}) interface{} + type filtertype func(interface{}) bool // InSlice checks given string in string slice or not. diff --git a/core/utils/time.go b/core/utils/time.go index 579b292a..00d13861 100644 --- a/core/utils/time.go +++ b/core/utils/time.go @@ -21,7 +21,6 @@ import ( // short string format func ToShortTimeFormat(d time.Duration) string { - u := uint64(d) if u < uint64(time.Second) { switch { @@ -44,5 +43,4 @@ func ToShortTimeFormat(d time.Duration) string { return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) } } - } diff --git a/core/validation/README.md b/core/validation/README.md index 46d7c935..dee5a7b1 100644 --- a/core/validation/README.md +++ b/core/validation/README.md @@ -7,18 +7,18 @@ validation is a form validation for a data validation and error collecting using Install: - go get github.com/beego/beego/v2/validation + go get github.com/beego/beego/v2/core/validation Test: - go test github.com/beego/beego/v2/validation + go test github.com/beego/beego/v2/core/validation ## Example Direct Use: import ( - "github.com/beego/beego/v2/validation" + "github.com/beego/beego/v2/core/validation" "log" ) @@ -49,7 +49,7 @@ Direct Use: Struct Tag Use: import ( - "github.com/beego/beego/v2/validation" + "github.com/beego/beego/v2/core/validation" ) // validation function follow with "valid" tag @@ -81,7 +81,7 @@ Struct Tag Use: Use custom function: import ( - "github.com/beego/beego/v2/validation" + "github.com/beego/beego/v2/core/validation" ) type user struct { diff --git a/core/validation/validation.go b/core/validation/validation.go index eb3a1042..78a31f13 100644 --- a/core/validation/validation.go +++ b/core/validation/validation.go @@ -15,7 +15,7 @@ // Package validation for validations // // import ( -// "github.com/beego/beego/v2/validation" +// "github.com/beego/beego/v2/core/validation" // "log" // ) // @@ -121,7 +121,7 @@ func (v *Validation) Clear() { v.ErrorsMap = nil } -// HasErrors Has ValidationError nor not. +// HasErrors Has ValidationError or not. func (v *Validation) HasErrors() bool { return len(v.Errors) > 0 } @@ -158,7 +158,7 @@ func (v *Validation) Max(obj interface{}, max int, key string) *Result { return v.apply(Max{max, key}, obj) } -// Range Test that the obj is between mni and max if obj's type is int +// Range Test that the obj is between min and max if obj's type is int func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj) } @@ -235,8 +235,11 @@ func (v *Validation) Tel(obj interface{}, key string) *Result { // Phone Test that the obj is chinese mobile or telephone number if type is string func (v *Validation) Phone(obj interface{}, key string) *Result { - return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}}, - Tel{Match: Match{Regexp: telPattern}}, key}, obj) + return v.apply(Phone{ + Mobile{Match: Match{Regexp: mobilePattern}}, + Tel{Match: Match{Regexp: telPattern}}, + key, + }, obj) } // ZipCode Test that the obj is chinese zip code if type is string diff --git a/core/validation/validation_test.go b/core/validation/validation_test.go index bca4f560..00a45a98 100644 --- a/core/validation/validation_test.go +++ b/core/validation/validation_test.go @@ -605,7 +605,6 @@ func TestCanSkipAlso(t *testing.T) { if !b { t.Fatal("validation should be passed") } - } func TestFieldNoEmpty(t *testing.T) { diff --git a/go.mod b/go.mod index a6d30e37..46e6f139 100644 --- a/go.mod +++ b/go.mod @@ -1,60 +1,40 @@ -//module github.com/beego/beego/v2 module github.com/beego/beego/v2 +go 1.14 + require ( - github.com/Knetic/govaluate v3.0.0+incompatible // indirect - github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 - github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 - github.com/casbin/casbin v1.7.0 + github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b + github.com/casbin/casbin v1.9.1 github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 - github.com/coreos/etcd v3.3.25+incompatible - github.com/coreos/go-semver v0.3.0 // indirect - github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect - github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f // indirect - github.com/couchbase/go-couchbase v0.0.0-20210126152612-8e416c37c8ef - github.com/couchbase/gomemcached v0.1.2-0.20210126151728-840240974836 // indirect - github.com/couchbase/goutils v0.0.0-20201030094643-5e82bb967e67 // indirect - github.com/elastic/go-elasticsearch/v6 v6.8.5 - github.com/elazarl/go-bindata-assetfs v1.0.0 - github.com/go-kit/kit v0.9.0 - github.com/go-redis/redis v6.14.2+incompatible + github.com/couchbase/go-couchbase v0.1.0 + github.com/couchbase/gomemcached v0.1.3 // indirect + github.com/couchbase/goutils v0.1.0 // indirect + github.com/elastic/go-elasticsearch/v6 v6.8.10 + github.com/elazarl/go-bindata-assetfs v1.0.1 + github.com/go-kit/kit v0.10.0 github.com/go-redis/redis/v7 v7.4.0 - github.com/go-sql-driver/mysql v1.5.0 - github.com/gogo/protobuf v1.3.1 - github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect + github.com/go-sql-driver/mysql v1.6.0 + github.com/gogo/protobuf v1.3.2 github.com/gomodule/redigo v2.0.0+incompatible - github.com/google/go-cmp v0.5.0 // indirect - github.com/google/uuid v1.1.1 // indirect + github.com/google/uuid v1.2.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/hashicorp/golang-lru v0.5.4 github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 - github.com/lib/pq v1.0.0 - github.com/mattn/go-sqlite3 v2.0.3+incompatible - github.com/mitchellh/mapstructure v1.3.3 + github.com/lib/pq v1.10.2 + github.com/mattn/go-sqlite3 v1.14.7 + github.com/mitchellh/mapstructure v1.4.1 github.com/opentracing/opentracing-go v1.2.0 - github.com/pelletier/go-toml v1.8.1 + github.com/pelletier/go-toml v1.9.2 github.com/pkg/errors v0.9.1 - github.com/prometheus/client_golang v1.7.0 - github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 + github.com/prometheus/client_golang v1.11.0 + github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18 github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec - github.com/stretchr/testify v1.4.0 - github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect - github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect - go.etcd.io/bbolt v1.3.5 // indirect - go.etcd.io/etcd v3.3.25+incompatible // indirect - go.uber.org/zap v1.15.0 // indirect - golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 - golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 // indirect - golang.org/x/text v0.3.3 // indirect - golang.org/x/tools v0.0.0-20201211185031-d93e913c1a58 - google.golang.org/grpc v1.26.0 - gopkg.in/yaml.v2 v2.2.8 - honnef.co/go/tools v0.0.1-2020.1.5 // indirect + github.com/stretchr/testify v1.7.0 + go.etcd.io/etcd/client/v3 v3.5.0 + golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a + google.golang.org/grpc v1.38.0 + google.golang.org/protobuf v1.26.0 + gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b ) - -replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 - -replace gopkg.in/yaml.v2 v2.2.1 => github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d -replace github.com/coreos/bbolt => go.etcd.io/bbolt v1.3.5 -go 1.14 diff --git a/go.sum b/go.sum index 888160b2..d9f9a12f 100644 --- a/go.sum +++ b/go.sum @@ -1,98 +1,126 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Knetic/govaluate v3.0.0+incompatible h1:7o6+MAPhYTCF0+fdvoz1xDedhRb4f6s9Tn1Tt7/WTEg= -github.com/Knetic/govaluate v3.0.0+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= +github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= +github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= github.com/alicebob/miniredis v2.5.0+incompatible/go.mod h1:8HZjEj4yU0dwhYHky+DxYx+6BMjkBbe5ONFIF1MXffk= -github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd h1:jZtX5jh5IOMu0fpOTC3ayh6QGSPJ/KWOv1lgPvbRw1M= -github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd/go.mod h1:1b+Y/CofkYwXMUU0OhQqGvsY2Bvgr4j6jfT699wyZKQ= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= +github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= +github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= +github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 h1:nYXb+3jF6Oq/j8R/y90XrKpreCxIalBWfeyeKymgOPk= github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542/go.mod h1:kSeGC/p1AbBiEp5kat81+DSQrZenVBZXklMLaELspWU= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 h1:rRISKWyXfVxvoa702s91Zl5oREZTrR3yv+tXrrX7G/g= -github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= -github.com/casbin/casbin v1.7.0 h1:PuzlE8w0JBg/DhIqnkF1Dewf3z+qmUZMVN07PonvVUQ= -github.com/casbin/casbin v1.7.0/go.mod h1:c67qKN6Oum3UF5Q1+BByfFxkwKvhwW57ITjqwtzR1KE= +github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b h1:L/QXpzIa3pOvUGt1D1lA5KjYhPBAN/3iWdP7xeFS9F0= +github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA= +github.com/casbin/casbin v1.9.1 h1:ucjbS5zTrmSLtH4XogqOG920Poe6QatdXtz1FEbApeM= +github.com/casbin/casbin v1.9.1/go.mod h1:z8uPsfBJGUsnkagrt3G8QvjgTKFMBJ32UP8HpZllfog= +github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= +github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/coreos/etcd v0.5.0-alpha.5 h1:0Qi6Jzjk2CDuuGlIeecpu+em2nrjhOgz2wsIwCmQHmc= -github.com/coreos/etcd v3.3.25+incompatible h1:0GQEw6h3YnuOVdtwygkIfJ+Omx0tZ8/QkVyXI4LkbeY= -github.com/coreos/etcd v3.3.25+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= +github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= -github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbpBpLoyyu8B6e44T7hJy6potg= -github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d h1:OMrhQqj1QCyDT2sxHCDjE+k8aMdn2ngTCGG7g4wrdLo= -github.com/couchbase/go-couchbase v0.0.0-20200519150804-63f3cdb75e0d/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U= -github.com/couchbase/go-couchbase v0.0.0-20201216133707-c04035124b17 h1:1ZELwRDUvpBpmgKSIUP6VMW1jIehzD0sCdWxRyejegw= -github.com/couchbase/go-couchbase v0.0.0-20201216133707-c04035124b17/go.mod h1:+/bddYDxXsf9qt0xpDUtRR47A2GjaXmGGAqQ/k3GJ8A= -github.com/couchbase/go-couchbase v0.0.0-20210126152612-8e416c37c8ef h1:pXh08kdOlR7eZbHWd7Zvz2KZbI2sxbHRxM+UTWj6oQ0= -github.com/couchbase/go-couchbase v0.0.0-20210126152612-8e416c37c8ef/go.mod h1:+/bddYDxXsf9qt0xpDUtRR47A2GjaXmGGAqQ/k3GJ8A= -github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808 h1:8s2l8TVUwMXl6tZMe3+hPCRJ25nQXiA3d1x622JtOqc= -github.com/couchbase/gomemcached v0.0.0-20200526233749-ec430f949808/go.mod h1:srVSlQLB8iXBVXHgnqemxUXqN6FCvClgCMPCsjBDR7c= -github.com/couchbase/gomemcached v0.1.2-0.20201215185628-3bc3f73e68cb h1:ZCFku0K/3Xvl7rXkGGM+ioT76Rxko8V9wDEWa0GFp14= -github.com/couchbase/gomemcached v0.1.2-0.20201215185628-3bc3f73e68cb/go.mod h1:mxliKQxOv84gQ0bJWbI+w9Wxdpt9HjDvgW9MjCym5Vo= -github.com/couchbase/gomemcached v0.1.2-0.20210126151728-840240974836 h1:ZxgtUfduO/Fk2NY1e1YhlgN6tRl0TMdXK9ElddO7uZY= -github.com/couchbase/gomemcached v0.1.2-0.20210126151728-840240974836/go.mod h1:mxliKQxOv84gQ0bJWbI+w9Wxdpt9HjDvgW9MjCym5Vo= -github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a h1:Y5XsLCEhtEI8qbD9RP3Qlv5FXdTDHxZM9UPUnMRgBp8= -github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a/go.mod h1:BQwMFlJzDjFDG3DJUdU0KORxn88UlsOULuxLExMh3Hs= -github.com/couchbase/goutils v0.0.0-20201030094643-5e82bb967e67 h1:NCqJ6fwen6YP0WlV/IyibaT0kPt3JEI1rA62V/UPKT4= -github.com/couchbase/goutils v0.0.0-20201030094643-5e82bb967e67/go.mod h1:BQwMFlJzDjFDG3DJUdU0KORxn88UlsOULuxLExMh3Hs= +github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7 h1:u9SHYsPQNyt5tgDm3YN7+9dYrpK96E5wFilTFWIDZOM= +github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd/v22 v22.3.2 h1:D9/bQk5vlXQFZ6Kwuu6zaiXJ9oTPe68++AzAJc1DzSI= +github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= +github.com/couchbase/go-couchbase v0.1.0 h1:g4bCvDwRL+ZL6HLhYeRlXxEYP31Wpy0VFxnFw6efEp8= +github.com/couchbase/go-couchbase v0.1.0/go.mod h1:+/bddYDxXsf9qt0xpDUtRR47A2GjaXmGGAqQ/k3GJ8A= +github.com/couchbase/gomemcached v0.1.3 h1:HIc5qMYNbuhB7zNaiEtj61DCYkquAwrQlf64q7JzdEY= +github.com/couchbase/gomemcached v0.1.3/go.mod h1:mxliKQxOv84gQ0bJWbI+w9Wxdpt9HjDvgW9MjCym5Vo= +github.com/couchbase/goutils v0.1.0 h1:0WLlKJilu7IBm98T8nS9+J36lBFVLRUSIUtyD/uWpAE= +github.com/couchbase/goutils v0.1.0/go.mod h1:BQwMFlJzDjFDG3DJUdU0KORxn88UlsOULuxLExMh3Hs= +github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 h1:Lgdd/Qp96Qj8jqLpq2cI1I1X7BJnu06efS+XkhRoLUQ= github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76/go.mod h1:vYwsqCOLxGiisLwp9rITslkFNpZD5rz43tf41QFkTWY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 h1:aaQcKT9WumO6JEJcRyTqFVq4XUZiUcKR2/GI31TOcz8= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= +github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= -github.com/elastic/go-elasticsearch/v6 v6.8.5 h1:U2HtkBseC1FNBmDr0TR2tKltL6FxoY+niDAlj5M8TK8= -github.com/elastic/go-elasticsearch/v6 v6.8.5/go.mod h1:UwaDJsD3rWLM5rKNFzv9hgox93HoX8utj1kxD9aFUcI= -github.com/elazarl/go-bindata-assetfs v1.0.0 h1:G/bYguwHIzWq9ZoyUQqrjTmJbbYn3j3CKKpKinvZLFk= -github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= +github.com/edsrzf/mmap-go v1.0.0 h1:CEBF7HpRnUCSJgGUb5h1Gm7e3VkmVDrR8lvWVLtrOFw= +github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= +github.com/elastic/go-elasticsearch/v6 v6.8.10 h1:2lN0gJ93gMBXvkhwih5xquldszpm8FlUwqG5sPzr6a8= +github.com/elastic/go-elasticsearch/v6 v6.8.10/go.mod h1:UwaDJsD3rWLM5rKNFzv9hgox93HoX8utj1kxD9aFUcI= +github.com/elazarl/go-bindata-assetfs v1.0.1 h1:m0kkaHRKEu7tUIUFVwhGGGYClXvyl4RE03qmvRTNfbw= +github.com/elazarl/go-bindata-assetfs v1.0.1/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= +github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= +github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/glendc/gopher-json v0.0.0-20170414221815-dc4743023d0c/go.mod h1:Gja1A+xZ9BoviGJNA2E9vFkPjjsl+CoJxSXiQM1UXtw= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/kit v0.9.0 h1:wDJmvq38kDhkVxi50ni9ykkdUr1PKgqKOoi01fa0Mdk= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.10.0 h1:dXFJfIHVvUcpSgDOV+Ne6t7jXri8Tfv2uOLHUZ2XNuo= +github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgOZ7o= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0 h1:MP4Eh7ZCb31lleYCFuwm0oe4/YGak+5l1vA2NOE80nA= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-redis/redis v6.14.2+incompatible h1:UE9pLhzmWf+xHNmZsoccjXosPicuiNaInPgym8nzfg0= -github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-redis/redis/v7 v7.4.0 h1:7obg6wUoj05T0EpY0o8B59S9w5yeMWql7sw2kwNW1x4= github.com/go-redis/redis/v7 v7.4.0/go.mod h1:JDNMw23GTyLNC4GZu9njt15ctBQVn7xjRfnwdHj/Dcg= -github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d h1:xy93KVe+KrIIwWDEAfQBdIfsiHJkepbYsDr+VY3g9/o= -github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= -github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -104,36 +132,85 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= -github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= +github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= +github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= +github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= +github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= +github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= +github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -141,63 +218,123 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 h1:wxyqOzKxsRJ6vVRL9sXQ64Z45wmBuQ+OTH9sLsC5rKc= github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6/go.mod h1:n931TsDuKuq+uX4v1fulaMbA/7ZLLhjc85h7chZGBCQ= -github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= -github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= -github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= +github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= +github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= +github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= +github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mitchellh/mapstructure v1.3.3 h1:SzB1nHZ2Xi+17FP0zVQBHIZqvwRN9408fJO8h+eeNA8= -github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= +github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= +github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= +github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= +github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= +github.com/nats-io/nats-server/v2 v2.1.2/go.mod h1:Afk+wRZqkMQs/p45uXdrVLuab3gwv3Z8C4HTBu8GD/k= +github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= +github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= +github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= +github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= +github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU= github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg= +github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1 h1:K0jcRCwNQM3vFGh1ppMtDh/+7ApJrjldlX8fA0jDTLQ= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= +github.com/opentracing-contrib/go-observer v0.0.0-20170622124052-a52f23424492/go.mod h1:Ngi6UdF0k5OKD5t5wlmGhe/EDKPoUM3BXZSSfIuJbis= +github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= +github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxSfWAKL3wpBW7V8scJMt8N8gnaMCS9E/cA= +github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= +github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= +github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= +github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml v1.0.1/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml v1.8.1 h1:1Nf83orprkJyknT6h7zbuEGUEjcyVlCxSUGTENmNCRM= -github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= +github.com/pelletier/go-toml v1.9.2 h1:7NiByeVF4jKSG1lDF3X8LTIkq2/bu+1uYbIm1eS5tzk= +github.com/pelletier/go-toml v1.9.2/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= github.com/peterh/liner v1.0.1-0.20171122030339-3681c2a91233/go.mod h1:xIteQHvHuaLYG9IFj6mSxM0fCKrs34IrEQUhOYuGPHc= -github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= +github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= +github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= -github.com/prometheus/client_golang v1.7.0 h1:wCi7urQOGBsYcQROHqpUUX4ct84xp40t9R9JX0FuA/U= -github.com/prometheus/client_golang v1.7.0/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= +github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ= +github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.10.0 h1:RyRA7RzGXQZiW+tGMr7sxa85G1z0yOpM1qq5c8lNawc= +github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= +github.com/prometheus/common v0.26.0 h1:iMAkS2TDoNWnKM+Kopnx/8tnEStIfpYA0ur0xQzzhMQ= +github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= +github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= +github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= +github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644 h1:X+yvsM2yrEktyI+b2qND5gpH8YhURn0k8OCaeRnkINo= -github.com/shiena/ansicolor v0.0.0-20151119151921-a422bbe96644/go.mod h1:nkxAfR/5quYxwPZhyDxgasBMnRtBZd0FCEpawpjMUFg= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= +github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18 h1:DAYUYH5869yV94zvCES9F51oYtN5oGlwjxJJz7ZCnik= +github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18/go.mod h1:nkxAfR/5quYxwPZhyDxgasBMnRtBZd0FCEpawpjMUFg= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/siddontang/go v0.0.0-20170517070808-cb568a3e5cc0 h1:QIF48X1cihydXibm+4wfAc0r/qyPyuFiPFRNphdMpEE= github.com/siddontang/go v0.0.0-20170517070808-cb568a3e5cc0/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw= github.com/siddontang/goredis v0.0.0-20150324035039-760763f78400/go.mod h1:DDcKzU3qCuvj/tPnimWSsZZzvk9qvkvrIL5naVBPh5s= @@ -205,179 +342,251 @@ github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d h1:NVwnfyR3rENtlz62 github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d/go.mod h1:AMEsy7v5z92TR1JKMkLLoaOQk++LVnOKL3ScbJ8GNGA= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= +github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec h1:q6XVwXmKvCRHRqesF3cSv6lNqqHi0QWOvgDlSohg8UA= github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec/go.mod h1:QBvMkMya+gXctz3kmljlUCu/yB3GZ6oee+dUozsezQE= +github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= +github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= +github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112 h1:NBrpnvz0pDPf3+HXZ1C9GcJd1DTpWDLcLWZhNq6uP7o= github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0= -github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c h1:3eGShk3EQf5gJCYW+WzA0TEJQd37HLOmlYF7N0YJwv0= -github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0= +github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/ugorji/go v0.0.0-20171122102828-84cb69a8af83/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= -github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b h1:0Ve0/CCjiAiyKddUMUn3RwIGlq2iTW4GuVzyoKBYO/8= -github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= +github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/gopher-lua v0.0.0-20171031051903-609c9cd26973/go.mod h1:aEV29XrmTYFr3CiRxZeGHpkvbwq+prZduBqMaascyCU= -go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= -go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= -go.etcd.io/etcd v0.5.0-alpha.5 h1:VOolFSo3XgsmnYDLozjvZ6JL6AAwIDu1Yx1y+4EYLDo= -go.etcd.io/etcd v3.3.25+incompatible h1:V1RzkZJj9LqsJRy+TUBgpWSbZXITLB819lstuTFoZOY= -go.etcd.io/etcd v3.3.25+incompatible/go.mod h1:yaeTdrJi5lOmYerz05bd8+V7KubZs8YSFZfzsF9A6aI= -go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= -go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= -go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= +go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738 h1:VcrIfasaLFkyjk6KNlXQSzO+B0fZcnECiDrKJsfxka0= +go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= +go.etcd.io/etcd/api/v3 v3.5.0 h1:GsV3S+OfZEOCNXdtNkBSR7kgLobAa/SO6tCxRa0GAYw= +go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs= +go.etcd.io/etcd/client/pkg/v3 v3.5.0 h1:2aQv6F436YnN7I4VbI8PPYrBhu+SmrTaADcf8Mi/6PU= +go.etcd.io/etcd/client/pkg/v3 v3.5.0/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g= +go.etcd.io/etcd/client/v3 v3.5.0 h1:62Eh0XOro+rDwkrypAGDfgmNh5Joq+z+W9HZdlXMzek= +go.etcd.io/etcd/client/v3 v3.5.0/go.mod h1:AIKXXVX/DQXtfTEqBryiLTUXwON+GuvO6Z7lLS/oTh0= +go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= +go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= +go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= -go.uber.org/zap v1.15.0 h1:ZZCA22JRF2gQE5FoNmhmrf7jeJJ2uhqDUNRYKm8dvmM= -go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +go.uber.org/zap v1.17.0 h1:MTjgFu6ZLKvY6Pvaqk97GlxNBuMpV4Hy/3P6tRGlI2U= +go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= +golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 h1:2M3HP5CCK1Si9FQhwnzYhXdG6DXeebvUHFpre8QvbyI= -golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a h1:gOpx8G595UYyvj8UK4+OFyY4rx037g3fmfhe5SasG3U= +golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20201021035429-f5854403a974 h1:IX6qOQeG5uLjB/hjjwjedwfjND0hgjPMMyO1RoIXQNI= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 h1:AvbQYmiaaaza3cW3QXRyPo5kYgpFIzOAfeAAN7m3qQ4= -golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= +golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40 h1:JWgyZ1qgdTaF3N3oxC+MdTV7qvEEgHo3otj+HB5CM7Q= +golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200117065230-39095c1d176c h1:FodBYPZKH5tAN2O60HlglMwXGAeV/4k+NKbli79M/2c= -golang.org/x/tools v0.0.0-20200117065230-39095c1d176c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200815165600-90abf76919f3 h1:0aScV/0rLmANzEYIhjCOi2pTvDyhZNduBUMD2q3iqs4= -golang.org/x/tools v0.0.0-20200815165600-90abf76919f3/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20201211185031-d93e913c1a58 h1:1Bs6RVeBFtLZ8Yi1Hk07DiOqzvwLD/4hln4iahvFlag= -golang.org/x/tools v0.0.0-20201211185031-d93e913c1a58/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 h1:gSJIx1SDwno+2ElGhA4+qG2zF97qiUzTM+rQ0klBOcE= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190530194941-fb225487d101/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 h1:PDIOdWxZ8eRizhKa1AAvY53xsvLB1cWorMjslvY3VA8= -google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c h1:wtujag7C+4D6KMoulW9YauvK2lgdvCMS260jsqqBXr0= +google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.22.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.26.0 h1:2dTRdpdFEEhJYQD8EMLB61nnrzSCTbG38PhqdhvOltg= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.0 h1:rRYRFMVgRv6E0D70Skyfsr28tDXIuuPZyWGMPdMcnXg= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.31.0 h1:T7P4R73V3SSDPhH7WW7ATbfViLtmamH0DKrP3f9AuDI= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.31.1 h1:SfXqXS5hkufcdZ/mHtYCh53P2b+92WQq/DZcKLgsFRs= -google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= +google.golang.org/grpc v1.38.0 h1:/9BgsAsa5nWe26HqOlvlgJnqBuktYOLCgjCPqsa56W0= +google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= -google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= +gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= +gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc h1:/hemPrYIhOhy8zYrNj+069zDB68us2sMGsfkFJO0iZs= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k= -honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= +sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= +sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= diff --git a/scripts/prepare_etcd.sh b/scripts/prepare_etcd.sh new file mode 100644 index 00000000..d34c05a3 --- /dev/null +++ b/scripts/prepare_etcd.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +etcdctl put current.float 1.23 +etcdctl put current.bool true +etcdctl put current.int 11 +etcdctl put current.string hello +etcdctl put current.serialize.name test +etcdctl put sub.sub.key1 sub.sub.key \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 00000000..977055a3 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" up -d + +export ORM_DRIVER=mysql +export TZ=UTC +export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" + +# wait for services in images ready +sleep 5 + +go test "$(pwd)/..." + +# clear all container +docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" down + + diff --git a/scripts/test_docker_compose.yaml b/scripts/test_docker_compose.yaml new file mode 100644 index 00000000..f22b6deb --- /dev/null +++ b/scripts/test_docker_compose.yaml @@ -0,0 +1,55 @@ +version: "3.8" +services: + redis: + container_name: "beego-redis" + image: redis + environment: + - ALLOW_EMPTY_PASSWORD=yes + ports: + - "6379:6379" + + mysql: + container_name: "beego-mysql" + image: mysql:5.7.30 + ports: + - "13306:3306" + environment: + - MYSQL_ROOT_PASSWORD=1q2w3e + - MYSQL_DATABASE=orm_test + - MYSQL_USER=beego + - MYSQL_PASSWORD=test + + postgresql: + container_name: "beego-postgresql" + image: bitnami/postgresql:latest + ports: + - "5432:5432" + environment: + - ALLOW_EMPTY_PASSWORD=yes + ssdb: + container_name: "beego-ssdb" + image: wendal/ssdb + ports: + - "8888:8888" + memcache: + container_name: "beego-memcache" + image: memcached + ports: + - "11211:11211" + etcd: + command: > + sh -c " + etcdctl put current.float 1.23 + && etcdctl put current.bool true + && etcdctl put current.int 11 + && etcdctl put current.string hello + && etcdctl put current.serialize.name test + " + container_name: "beego-etcd" + environment: + - ALLOW_NONE_AUTHENTICATION=yes +# - ETCD_ADVERTISE_CLIENT_URLS=http://etcd:2379 + image: bitnami/etcd + ports: + - "2379:2379" + - "2380:2380" \ No newline at end of file diff --git a/server/web/admin.go b/server/web/admin.go index 1d4e5d49..66ba3396 100644 --- a/server/web/admin.go +++ b/server/web/admin.go @@ -45,9 +45,7 @@ var beeAdminApp *adminApp var FilterMonitorFunc func(string, string, time.Duration, string, int) bool func init() { - FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } - } func list(root string, p interface{}, m M) { @@ -84,18 +82,13 @@ type adminApp struct { // Run start Beego admin func (admin *adminApp) Run() { - logs.Debug("now we don't start tasks here, if you use task module," + " please invoke task.StartTask, or task will not be executed") - addr := BConfig.Listen.AdminAddr - if BConfig.Listen.AdminPort != 0 { addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) } - logs.Info("Admin server Running on %s", addr) - admin.HttpServer.Run(addr) } diff --git a/server/web/admin_controller.go b/server/web/admin_controller.go index 8ac0ccd4..ce77ecdb 100644 --- a/server/web/admin_controller.go +++ b/server/web/admin_controller.go @@ -79,7 +79,6 @@ func (a *adminController) PrometheusMetrics() { // TaskStatus is a http.Handler with running task status (task name, status and the last execution). // it's in "/task" pattern in admin module. func (a *adminController) TaskStatus() { - rw, req := a.Ctx.ResponseWriter, a.Ctx.Request data := make(map[interface{}]interface{}) @@ -91,11 +90,11 @@ func (a *adminController) TaskStatus() { cmd := admin.GetCommand("task", "run") res := cmd.Execute(taskname) if res.IsSuccess() { - - data["Message"] = []string{"success", + data["Message"] = []string{ + "success", template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", - taskname, res.Content.(string)))} - + taskname, res.Content.(string))), + } } else { data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", res.Error))} } @@ -104,7 +103,7 @@ func (a *adminController) TaskStatus() { // List Tasks content := make(M) resultList := admin.GetCommand("task", "list").Execute().Content.([][]string) - var fields = []string{ + fields := []string{ "Task Name", "Task Spec", "Task Status", @@ -238,14 +237,12 @@ func (a *adminController) ListConf() { data["Title"] = "Routers" writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) case "filter": - var ( - content = M{ - "Fields": []string{ - "Router Pattern", - "Filter Function", - }, - } - ) + content := M{ + "Fields": []string{ + "Router Pattern", + "Filter Function", + }, + } filterTypeData := BeeApp.reportFilter() @@ -287,7 +284,6 @@ func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]i } return response - } // PrintTree print all routers diff --git a/server/web/admin_test.go b/server/web/admin_test.go index 04da4fab..57aa0fe6 100644 --- a/server/web/admin_test.go +++ b/server/web/admin_test.go @@ -14,11 +14,9 @@ import ( "github.com/beego/beego/v2/core/admin" ) -type SampleDatabaseCheck struct { -} +type SampleDatabaseCheck struct{} -type SampleCacheCheck struct { -} +type SampleCacheCheck struct{} func (dc *SampleDatabaseCheck) Check() error { return nil @@ -31,7 +29,6 @@ func (cc *SampleCacheCheck) Check() error { func TestList_01(t *testing.T) { m := make(M) list("BConfig", BConfig, m) - t.Log(m) om := oldMap() for k, v := range om { if fmt.Sprint(m[k]) != fmt.Sprint(v) { @@ -100,8 +97,6 @@ func oldMap() M { } func TestWriteJSON(t *testing.T) { - t.Log("Testing the adding of JSON to the response") - w := httptest.NewRecorder() originalBody := []int{1, 2, 3} @@ -111,7 +106,6 @@ func TestWriteJSON(t *testing.T) { decodedBody := []int{} err := json.NewDecoder(w.Body).Decode(&decodedBody) - if err != nil { t.Fatal("Could not decode response body into slice.") } @@ -147,7 +141,6 @@ func TestHealthCheckHandlerDefault(t *testing.T) { if !strings.Contains(w.Body.String(), "database") { t.Errorf("Expected 'database' in generated template.") } - } func TestBuildHealthCheckResponseList(t *testing.T) { @@ -180,13 +173,10 @@ func TestBuildHealthCheckResponseList(t *testing.T) { t.Errorf("expected %s to be in the response %v", field, response) } } - } - } func TestHealthCheckHandlerReturnsJSON(t *testing.T) { - admin.AddHealthCheck("database", &SampleDatabaseCheck{}) admin.AddHealthCheck("cache", &SampleCacheCheck{}) @@ -245,5 +235,4 @@ func TestHealthCheckHandlerReturnsJSON(t *testing.T) { assert.Equal(t, expectedResponseBody[0], database) assert.Equal(t, expectedResponseBody[1], cache) - } diff --git a/server/web/beego.go b/server/web/beego.go index 14e51a94..cfe9a5dd 100644 --- a/server/web/beego.go +++ b/server/web/beego.go @@ -33,9 +33,7 @@ type M map[string]interface{} // Hook function to run type hookfunc func() error -var ( - hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc -) +var hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc // AddAPPStartHook is used to register the hookfunc // The hookfuncs will run in beego.Run() @@ -50,7 +48,6 @@ func AddAPPStartHook(hf ...hookfunc) { // beego.Run(":8089") // beego.Run("127.0.0.1:8089") func Run(params ...string) { - if len(params) > 0 && params[0] != "" { BeeApp.Run(params[0]) } @@ -75,7 +72,7 @@ func initBeforeHTTPRun() { registerTemplate, registerAdmin, registerGzip, - registerCommentRouter, + // registerCommentRouter, ) for _, hk := range hooks { diff --git a/server/web/captcha/README.md b/server/web/captcha/README.md index 74e1cf82..07a4dc4d 100644 --- a/server/web/captcha/README.md +++ b/server/web/captcha/README.md @@ -7,8 +7,8 @@ package controllers import ( "github.com/beego/beego/v2" - "github.com/beego/beego/v2/cache" - "github.com/beego/beego/v2/utils/captcha" + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/server/web/captcha" ) var cpt *captcha.Captcha diff --git a/server/web/captcha/captcha.go b/server/web/captcha/captcha.go index d052af13..6d8c33b4 100644 --- a/server/web/captcha/captcha.go +++ b/server/web/captcha/captcha.go @@ -20,8 +20,8 @@ // // import ( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/cache" -// "github.com/beego/beego/v2/utils/captcha" +// "github.com/beego/beego/v2/client/cache" +// "github.com/beego/beego/v2/server/web/captcha" // ) // // var cpt *captcha.Captcha @@ -68,15 +68,12 @@ import ( "time" "github.com/beego/beego/v2/core/logs" - "github.com/beego/beego/v2/core/utils" "github.com/beego/beego/v2/server/web" "github.com/beego/beego/v2/server/web/context" ) -var ( - defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} -) +var defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} const ( // default captcha attributes @@ -200,7 +197,7 @@ func (c *Captcha) VerifyReq(req *http.Request) bool { // Verify direct verify id and challenge string func (c *Captcha) Verify(id string, challenge string) (success bool) { - if len(challenge) == 0 || len(id) == 0 { + if challenge == "" || id == "" { return } @@ -244,7 +241,7 @@ func NewCaptcha(urlPrefix string, store Storage) *Captcha { cpt.StdWidth = stdWidth cpt.StdHeight = stdHeight - if len(urlPrefix) == 0 { + if urlPrefix == "" { urlPrefix = defaultURLPrefix } diff --git a/server/web/config.go b/server/web/config.go index c8b5bc51..6bddfe4a 100644 --- a/server/web/config.go +++ b/server/web/config.go @@ -17,6 +17,7 @@ package web import ( "crypto/tls" "fmt" + "net/http" "os" "path/filepath" "reflect" @@ -26,106 +27,414 @@ import ( "github.com/beego/beego/v2" "github.com/beego/beego/v2/core/config" "github.com/beego/beego/v2/core/logs" - "github.com/beego/beego/v2/server/web/session" - "github.com/beego/beego/v2/core/utils" "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" ) // Config is the main struct for BConfig // TODO after supporting multiple servers, remove common config to somewhere else type Config struct { - AppName string // Application name - RunMode string // Running Mode: dev | prod + // AppName + // @Description Application's name. You'd better set it because we use it to do some logging and tracing + // @Default beego + AppName string // Application name + // RunMode + // @Description it's the same as environment. In general, we have different run modes. + // For example, the most common case is using dev, test, prod three environments + // when you are developing the application, you should set it as dev + // when you completed coding and want QA to test your code, you should deploy your application to test environment + // and the RunMode should be set as test + // when you completed all tests, you want to deploy it to prod, you should set it to prod + // You should never set RunMode="dev" when you deploy the application to prod + // because Beego will do more things which need Go SDK and other tools when it found out the RunMode="dev" + // @Default dev + RunMode string // Running Mode: dev | prod + + // RouterCaseSensitive + // @Description If it was true, it means that the router is case sensitive. + // For example, when you register a router with pattern "/hello", + // 1. If this is true, and the request URL is "/Hello", it won't match this pattern + // 2. If this is false and the request URL is "/Hello", it will match this pattern + // @Default true RouterCaseSensitive bool - ServerName string - RecoverPanic bool - RecoverFunc func(*context.Context, *Config) - CopyRequestBody bool - EnableGzip bool - // MaxMemory and MaxUploadSize are used to limit the request body + // RecoverPanic + // @Description if it was true, Beego will try to recover from panic when it serves your http request + // So you should notice that it doesn't mean that Beego will recover all panic cases. + // @Default true + RecoverPanic bool + // CopyRequestBody + // @Description if it's true, Beego will copy the request body. But if the request body's size > MaxMemory, + // Beego will return 413 as http status + // If you are building RESTful API, please set it to true. + // And if you want to read data from request Body multiple times, please set it to true + // In general, if you don't meet any performance issue, you could set it to true + // @Default false + CopyRequestBody bool + // EnableGzip + // @Description If it was true, Beego will try to compress data by using zip algorithm. + // But there are two points: + // 1. Only static resources will be compressed + // 2. Only those static resource which has the extension specified by StaticExtensionsToGzip will be compressed + // @Default false + EnableGzip bool + // EnableErrorsShow + // @Description If it's true, Beego will show error message to page + // it will work with ErrorMaps which allows you register some error handler + // You may want to set it to false when application was deploy to prod environment + // because you may not want to expose your internal error msg to your users + // it's a little bit unsafe + // @Default true + EnableErrorsShow bool + // EnableErrorsRender + // @Description If it's true, it will output the error msg as a page. It's similar to EnableErrorsShow + // And this configure item only work in dev run mode (see RunMode) + // @Default true + EnableErrorsRender bool + // ServerName + // @Description server name. For example, in large scale system, + // you may want to deploy your application to several machines, so that each of them has a server name + // we suggest you'd better set value because Beego use this to output some DEBUG msg, + // or integrated with other tools such as tracing, metrics + // @Default + ServerName string + + // RecoverFunc + // @Description when Beego want to recover from panic, it will use this func as callback + // see RecoverPanic + // @Default defaultRecoverPanic + RecoverFunc func(*context.Context, *Config) + // @Description MaxMemory and MaxUploadSize are used to limit the request body // if the request is not uploading file, MaxMemory is the max size of request body // if the request is uploading file, MaxUploadSize is the max size of request body - MaxMemory int64 - MaxUploadSize int64 - EnableErrorsShow bool - EnableErrorsRender bool - Listen Listen - WebConfig WebConfig - Log LogConfig + // if CopyRequestBody is true, this value will be used as the threshold of request body + // see CopyRequestBody + // the default value is 1 << 26 (64MB) + // @Default 67108864 + MaxMemory int64 + // MaxUploadSize + // @Description MaxMemory and MaxUploadSize are used to limit the request body + // if the request is not uploading file, MaxMemory is the max size of request body + // if the request is uploading file, MaxUploadSize is the max size of request body + // the default value is 1 << 30 (1GB) + // @Default 1073741824 + MaxUploadSize int64 + // Listen + // @Description the configuration about socket or http protocol + Listen Listen + // WebConfig + // @Description the configuration about Web + WebConfig WebConfig + // LogConfig + // @Description log configuration + Log LogConfig } // Listen holds for http and https related config type Listen struct { - Graceful bool // Graceful means use graceful module to start the server - ServerTimeOut int64 - ListenTCP4 bool - EnableHTTP bool - HTTPAddr string - HTTPPort int - AutoTLS bool - Domains []string - TLSCacheDir string - EnableHTTPS bool + // Graceful + // @Description means use graceful module to start the server + // @Default false + Graceful bool + // ListenTCP4 + // @Description if it's true, means that Beego only work for TCP4 + // please check net.Listen function + // In general, you should not set it to true + // @Default false + ListenTCP4 bool + // EnableHTTP + // @Description if it's true, Beego will accept HTTP request. + // But if you want to use HTTPS only, please set it to false + // see EnableHTTPS + // @Default true + EnableHTTP bool + // AutoTLS + // @Description If it's true, Beego will use default value to initialize the TLS configure + // But those values could be override if you have custom value. + // see Domains, TLSCacheDir + // @Default false + AutoTLS bool + // EnableHTTPS + // @Description If it's true, Beego will accept HTTPS request. + // Now, you'd better use HTTPS protocol on prod environment to get better security + // In prod, the best option is EnableHTTPS=true and EnableHTTP=false + // see EnableHTTP + // @Default false + EnableHTTPS bool + // EnableMutualHTTPS + // @Description if it's true, Beego will handle requests on incoming mutual TLS connections + // see Server.ListenAndServeMutualTLS + // @Default false EnableMutualHTTPS bool - HTTPSAddr string - HTTPSPort int - HTTPSCertFile string - HTTPSKeyFile string - TrustCaFile string - EnableAdmin bool - AdminAddr string - AdminPort int - EnableFcgi bool - EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O - ClientAuth int + // EnableAdmin + // @Description if it's true, Beego will provide admin service. + // You can visit the admin service via browser. + // The default port is 8088 + // see AdminPort + // @Default false + EnableAdmin bool + // EnableFcgi + // @Description + // @Default false + EnableFcgi bool + // EnableStdIo + // @Description EnableStdIo works with EnableFcgi Use FCGI via standard I/O + // @Default false + EnableStdIo bool + // ServerTimeOut + // @Description Beego use this as ReadTimeout and WriteTimeout + // The unit is second. + // see http.Server.ReadTimeout, WriteTimeout + // @Default 0 + ServerTimeOut int64 + // HTTPAddr + // @Description Beego listen to this address when the application start up. + // @Default "" + HTTPAddr string + // HTTPPort + // @Description Beego listen to this port + // you'd better change this value when you deploy to prod environment + // @Default 8080 + HTTPPort int + // Domains + // @Description Beego use this to configure TLS. Those domains are "white list" domain + // @Default [] + Domains []string + // TLSCacheDir + // @Description Beego use this as cache dir to store TLS cert data + // @Default "" + TLSCacheDir string + // HTTPSAddr + // @Description Beego will listen to this address to accept HTTPS request + // see EnableHTTPS + // @Default "" + HTTPSAddr string + // HTTPSPort + // @Description Beego will listen to this port to accept HTTPS request + // @Default 10443 + HTTPSPort int + // HTTPSCertFile + // @Description Beego read this file as cert file + // When you are using HTTPS protocol, please configure it + // see HTTPSKeyFile + // @Default "" + HTTPSCertFile string + // HTTPSKeyFile + // @Description Beego read this file as key file + // When you are using HTTPS protocol, please configure it + // see HTTPSCertFile + // @Default "" + HTTPSKeyFile string + // TrustCaFile + // @Description Beego read this file as CA file + // @Default "" + TrustCaFile string + // AdminAddr + // @Description Beego will listen to this address to provide admin service + // In general, it should be the same with your application address, HTTPAddr or HTTPSAddr + // @Default "" + AdminAddr string + // AdminPort + // @Description Beego will listen to this port to provide admin service + // @Default 8088 + AdminPort int + // @Description Beego use this tls.ClientAuthType to initialize TLS connection + // The default value is tls.RequireAndVerifyClientCert + // @Default 4 + ClientAuth int } // WebConfig holds web related config type WebConfig struct { - AutoRender bool - EnableDocs bool - FlashName string - FlashSeparator string - DirectoryIndex bool - StaticDir map[string]string + // AutoRender + // @Description If it's true, Beego will render the page based on your template and data + // In general, keep it as true. + // But if you are building RESTFul API and you don't have any page, + // you can set it to false + // @Default true + AutoRender bool + // Deprecated: Beego didn't use it anymore + EnableDocs bool + // EnableXSRF + // @Description If it's true, Beego will help to provide XSRF support + // But you should notice that, now Beego only work for HTTPS protocol with XSRF + // because it's not safe if using HTTP protocol + // And, the cookie storing XSRF token has two more flags HttpOnly and Secure + // It means that you must use HTTPS protocol and you can not read the token from JS script + // This is completed different from Beego 1.x because we got many security reports + // And if you are in dev environment, you could set it to false + // @Default false + EnableXSRF bool + // DirectoryIndex + // @Description When Beego serves static resources request, it will look up the file. + // If the file is directory, Beego will try to find the index.html as the response + // But if the index.html is not exist or it's a directory, + // Beego will return 403 response if DirectoryIndex is **false** + // @Default false + DirectoryIndex bool + // FlashName + // @Description the cookie's name when Beego try to store the flash data into cookie + // @Default BEEGO_FLASH + FlashName string + // FlashSeparator + // @Description When Beego read flash data from request, it uses this as the separator + // @Default BEEGOFLASH + FlashSeparator string + // StaticDir + // @Description Beego uses this as static resources' root directory. + // It means that Beego will try to search static resource from this start point + // It's a map, the key is the path and the value is the directory + // For example, the default value is /static => static, + // which means that when Beego got a request with path /static/xxx + // Beego will try to find the resource from static directory + // @Default /static => static + StaticDir map[string]string + // StaticExtensionsToGzip + // @Description The static resources with those extension will be compressed if EnableGzip is true + // @Default [".css", ".js" ] StaticExtensionsToGzip []string - StaticCacheFileSize int - StaticCacheFileNum int - TemplateLeft string - TemplateRight string - ViewsPath string - CommentRouterPath string - EnableXSRF bool - XSRFKey string - XSRFExpire int - Session SessionConfig + // StaticCacheFileSize + // @Description If the size of static resource < StaticCacheFileSize, Beego will try to handle it by itself, + // it means that Beego will compressed the file data (if enable) and cache this file. + // But if the file size > StaticCacheFileSize, Beego just simply delegate the request to http.ServeFile + // the default value is 100KB. + // the max memory size of caching static files is StaticCacheFileSize * StaticCacheFileNum + // see StaticCacheFileNum + // @Default 102400 + StaticCacheFileSize int + // StaticCacheFileNum + // @Description Beego use it to control the memory usage of caching static resource file + // If the caching files > StaticCacheFileNum, Beego use LRU algorithm to remove caching file + // the max memory size of caching static files is StaticCacheFileSize * StaticCacheFileNum + // see StaticCacheFileSize + // @Default 1000 + StaticCacheFileNum int + // TemplateLeft + // @Description Beego use this to render page + // see TemplateRight + // @Default {{ + TemplateLeft string + // TemplateRight + // @Description Beego use this to render page + // see TemplateLeft + // @Default }} + TemplateRight string + // ViewsPath + // @Description The directory of Beego application storing template + // @Default views + ViewsPath string + // CommentRouterPath + // @Description Beego scans this directory and its sub directory to generate router + // Beego only scans this directory when it's in dev environment + // @Default controllers + CommentRouterPath string + // XSRFKey + // @Description the name of cookie storing XSRF token + // see EnableXSRF + // @Default beegoxsrf + XSRFKey string + // XSRFExpire + // @Description the expiration time of XSRF token cookie + // second + // @Default 0 + XSRFExpire int + // @Description session related config + Session SessionConfig } // SessionConfig holds session related config type SessionConfig struct { - SessionOn bool - SessionProvider string - SessionName string - SessionGCMaxLifetime int64 - SessionProviderConfig string - SessionCookieLifeTime int - SessionAutoSetCookie bool - SessionDomain string - SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. - SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers - SessionNameInHTTPHeader string - SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params - SessionIDPrefix string + // SessionOn + // @Description if it's true, Beego will auto manage session + // @Default false + SessionOn bool + // SessionAutoSetCookie + // @Description if it's true, Beego will put the session token into cookie too + // @Default true + SessionAutoSetCookie bool + // SessionDisableHTTPOnly + // @Description used to allow for cross domain cookies/javascript cookies + // In general, you should not set it to true unless you understand the risk + // @Default false + SessionDisableHTTPOnly bool + // SessionEnableSidInHTTPHeader + // @Description enable store/get the sessionId into/from http headers + // @Default false + SessionEnableSidInHTTPHeader bool + // SessionEnableSidInURLQuery + // @Description enable get the sessionId from Url Query params + // @Default false + SessionEnableSidInURLQuery bool + // SessionProvider + // @Description session provider's name. + // You should confirm that this provider has been register via session.Register method + // the default value is memory. This is not suitable for distributed system + // @Default memory + SessionProvider string + // SessionName + // @Description If SessionAutoSetCookie is true, we use this value as the cookie's name + // @Default beegosessionID + SessionName string + // SessionGCMaxLifetime + // @Description Beego will GC session to clean useless session. + // unit: second + // @Default 3600 + SessionGCMaxLifetime int64 + // SessionProviderConfig + // @Description the config of session provider + // see SessionProvider + // you should read the document of session provider to learn how to set this value + // @Default "" + SessionProviderConfig string + // SessionCookieLifeTime + // @Description If SessionAutoSetCookie is true, + // we use this value as the expiration time and max age of the cookie + // unit second + // @Default 0 + SessionCookieLifeTime int + // SessionDomain + // @Description If SessionAutoSetCookie is true, we use this value as the cookie's domain + // @Default "" + SessionDomain string + // SessionNameInHTTPHeader + // @Description if SessionEnableSidInHTTPHeader is true, this value will be used as the http header + // @Default Beegosessionid + SessionNameInHTTPHeader string + // SessionCookieSameSite + // @Description If SessionAutoSetCookie is true, we use this value as the cookie's same site policy + // the default value is http.SameSiteDefaultMode + // @Default 1 + SessionCookieSameSite http.SameSite + + // SessionIDPrefix + // @Description session id's prefix + // @Default "" + SessionIDPrefix string } // LogConfig holds Log related config type LogConfig struct { - AccessLogs bool - EnableStaticLogs bool // log static files requests default: false - AccessLogsFormat string // access log format: JSON_FORMAT, APACHE_FORMAT or empty string - FileLineNum bool - Outputs map[string]string // Store Adaptor : config + // AccessLogs + // @Description If it's true, Beego will log the HTTP request info + // @Default false + AccessLogs bool + // EnableStaticLogs + // @Description log static files requests + // @Default false + EnableStaticLogs bool + // FileLineNum + // @Description if it's true, it will log the line number + // @Default true + FileLineNum bool + // AccessLogsFormat + // @Description access log format: JSON_FORMAT, APACHE_FORMAT or empty string + // @Default APACHE_FORMAT + AccessLogsFormat string + // Outputs + // @Description the destination of access log + // the key is log adapter and the value is adapter's configure + // @Default "console" => "" + Outputs map[string]string // Store Adaptor : config } var ( @@ -156,7 +465,7 @@ func init() { if err != nil { panic(err) } - var filename = "app.conf" + filename := "app.conf" if os.Getenv("BEEGO_RUNMODE") != "" { filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" } @@ -275,6 +584,7 @@ func newBConfig() *Config { SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers SessionNameInHTTPHeader: "Beegosessionid", SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params + SessionCookieSameSite: http.SameSiteDefaultMode, }, }, Log: LogConfig{ @@ -303,7 +613,6 @@ func parseConfig(appConfigPath string) (err error) { // For 1.x, it use assignSingleConfig to parse the file // but for 2.x, we use Unmarshaler method func assignConfig(ac config.Configer) error { - parseConfigForV1(ac) err := ac.Unmarshaler("", BConfig) @@ -421,7 +730,6 @@ func assignSingleConfig(p interface{}, ac config.Configer) { // do nothing here } } - } // LoadAppConfig allow developer to apply a config file diff --git a/server/web/config_test.go b/server/web/config_test.go index 9f8f7a80..5fb78b56 100644 --- a/server/web/config_test.go +++ b/server/web/config_test.go @@ -32,6 +32,10 @@ func TestDefaults(t *testing.T) { } } +func TestLoadAppConfig(t *testing.T) { + println(1 << 30) +} + func TestAssignConfig_01(t *testing.T) { _BConfig := &Config{} _BConfig.AppName = "beego_test" @@ -105,7 +109,6 @@ func TestAssignConfig_02(t *testing.T) { t.Log(_BConfig.Log.FileLineNum) t.FailNow() } - } func TestAssignConfig_03(t *testing.T) { diff --git a/server/web/context/acceptencoder.go b/server/web/context/acceptencoder.go index 312e50de..4181bfee 100644 --- a/server/web/context/acceptencoder.go +++ b/server/web/context/acceptencoder.go @@ -131,14 +131,12 @@ var ( } ) -var ( - encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore - "gzip": gzipCompressEncoder, - "deflate": deflateCompressEncoder, - "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip - "identity": noneCompressEncoder, // identity means none-compress - } -) +var encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore + "gzip": gzipCompressEncoder, + "deflate": deflateCompressEncoder, + "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip + "identity": noneCompressEncoder, // identity means none-compress +} // WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { @@ -159,7 +157,7 @@ func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { var outputWriter resetWriter var err error - var ce = noneCompressEncoder + ce := noneCompressEncoder if cf, ok := encoderMap[encoding]; ok { ce = cf diff --git a/server/web/context/context.go b/server/web/context/context.go index 6070c996..42cc4035 100644 --- a/server/web/context/context.go +++ b/server/web/context/context.go @@ -15,7 +15,7 @@ // Package context provide the context utils // Usage: // -// import "github.com/beego/beego/v2/context" +// import "github.com/beego/beego/v2/server/web/context" // // ctx := context.Context{Request:req,ResponseWriter:rw} // @@ -27,23 +27,38 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/base64" + "encoding/json" + "encoding/xml" "errors" "fmt" "net" "net/http" + "net/url" + "reflect" "strconv" "strings" "time" + "google.golang.org/protobuf/proto" + "gopkg.in/yaml.v3" + "github.com/beego/beego/v2/core/utils" + "github.com/beego/beego/v2/server/web/session" ) // Commonly used mime-types const ( - ApplicationJSON = "application/json" - ApplicationXML = "application/xml" - ApplicationYAML = "application/x-yaml" - TextXML = "text/xml" + ApplicationJSON = "application/json" + ApplicationXML = "application/xml" + ApplicationForm = "application/x-www-form-urlencoded" + ApplicationProto = "application/x-protobuf" + ApplicationYAML = "application/x-yaml" + TextXML = "text/xml" + + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" ) // NewContext return the Context with Input and Output @@ -64,6 +79,108 @@ type Context struct { _xsrfToken string } +func (ctx *Context) Bind(obj interface{}) error { + ct, exist := ctx.Request.Header["Content-Type"] + if !exist || len(ct) == 0 { + return ctx.BindJSON(obj) + } + i, l := 0, len(ct[0]) + for i < l && ct[0][i] != ';' { + i++ + } + switch ct[0][0:i] { + case ApplicationJSON: + return ctx.BindJSON(obj) + case ApplicationXML, TextXML: + return ctx.BindXML(obj) + case ApplicationForm: + return ctx.BindForm(obj) + case ApplicationProto: + return ctx.BindProtobuf(obj.(proto.Message)) + case ApplicationYAML: + return ctx.BindYAML(obj) + default: + return errors.New("Unsupported Content-Type:" + ct[0]) + } +} + +// Resp sends response based on the Accept Header +// By default response will be in JSON +func (ctx *Context) Resp(data interface{}) error { + accept := ctx.Input.Header("Accept") + switch accept { + case ApplicationYAML: + return ctx.YamlResp(data) + case ApplicationXML, TextXML: + return ctx.XMLResp(data) + case ApplicationProto: + return ctx.ProtoResp(data.(proto.Message)) + default: + return ctx.JSONResp(data) + } +} + +func (ctx *Context) JSONResp(data interface{}) error { + return ctx.Output.JSON(data, false, false) +} + +func (ctx *Context) XMLResp(data interface{}) error { + return ctx.Output.XML(data, false) +} + +func (ctx *Context) YamlResp(data interface{}) error { + return ctx.Output.YAML(data) +} + +func (ctx *Context) ProtoResp(data proto.Message) error { + return ctx.Output.Proto(data) +} + +// BindYAML only read data from http request body +func (ctx *Context) BindYAML(obj interface{}) error { + return yaml.Unmarshal(ctx.Input.RequestBody, obj) +} + +// BindForm will parse form values to struct via tag. +func (ctx *Context) BindForm(obj interface{}) error { + err := ctx.Request.ParseForm() + if err != nil { + return err + } + return ParseForm(ctx.Request.Form, obj) +} + +// BindJSON only read data from http request body +func (ctx *Context) BindJSON(obj interface{}) error { + return json.Unmarshal(ctx.Input.RequestBody, obj) +} + +// BindProtobuf only read data from http request body +func (ctx *Context) BindProtobuf(obj proto.Message) error { + return proto.Unmarshal(ctx.Input.RequestBody, obj) +} + +// BindXML only read data from http request body +func (ctx *Context) BindXML(obj interface{}) error { + return xml.Unmarshal(ctx.Input.RequestBody, obj) +} + +// ParseForm will parse form values to struct via tag. +func ParseForm(form url.Values, obj interface{}) error { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + if !isStructPtr(objT) { + return fmt.Errorf("%v must be a struct pointer", obj) + } + objT = objT.Elem() + objV = objV.Elem() + return parseFormToStruct(form, objT, objV) +} + +func isStructPtr(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + // Reset initializes Context, BeegoInput and BeegoOutput func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { ctx.Request = r @@ -90,7 +207,7 @@ func (ctx *Context) Abort(status int, body string) { // WriteString writes a string to response body. func (ctx *Context) WriteString(content string) { - ctx.ResponseWriter.Write([]byte(content)) + _, _ = ctx.ResponseWriter.Write([]byte(content)) } // GetCookie gets a cookie from a request for a given key. @@ -123,7 +240,10 @@ func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { sig := parts[2] h := hmac.New(sha256.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) + _, err := fmt.Fprintf(h, "%s%s", vs, timestamp) + if err != nil { + return "", false + } if fmt.Sprintf("%02x", h.Sum(nil)) != sig { return "", false @@ -137,7 +257,7 @@ func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interf vs := base64.URLEncoding.EncodeToString([]byte(value)) timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) h := hmac.New(sha256.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) + _, _ = fmt.Fprintf(h, "%s%s", vs, timestamp) sig := fmt.Sprintf("%02x", h.Sum(nil)) cookie := strings.Join([]string{vs, timestamp, sig}, "|") ctx.Output.Cookie(name, cookie, others...) @@ -150,7 +270,7 @@ func (ctx *Context) XSRFToken(key string, expire int64) string { if !ok { token = string(utils.RandomCreateBytes(32)) // TODO make it configurable - ctx.SetSecureCookie(key, "_xsrf", token, expire, "", "") + ctx.SetSecureCookie(key, "_xsrf", token, expire, "/", "") } ctx._xsrfToken = token } @@ -195,6 +315,22 @@ func (ctx *Context) RenderMethodResult(result interface{}) { } } +// Session return session store of this context of request +func (ctx *Context) Session() (store session.Store, err error) { + if ctx.Input != nil { + if ctx.Input.CruSession != nil { + store = ctx.Input.CruSession + return + } else { + err = errors.New(`no valid session store(please initialize session)`) + return + } + } else { + err = errors.New(`no valid input`) + return + } +} + // Response is a wrapper for the http.ResponseWriter // Started: if true, response was already written to so the other handler will not be executed type Response struct { diff --git a/server/web/context/context_test.go b/server/web/context/context_test.go index 7c0535e0..53717d31 100644 --- a/server/web/context/context_test.go +++ b/server/web/context/context_test.go @@ -18,6 +18,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/beego/beego/v2/server/web/session" ) func TestXsrfReset_01(t *testing.T) { @@ -45,3 +47,67 @@ func TestXsrfReset_01(t *testing.T) { t.FailNow() } } + +func TestContext_Session(t *testing.T) { + c := NewContext() + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() + } +} + +func TestContext_Session1(t *testing.T) { + c := Context{} + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() + } +} + +func TestContext_Session2(t *testing.T) { + c := NewContext() + c.Input.CruSession = &session.MemSessionStore{} + + if store, err := c.Session(); store == nil || err != nil { + t.FailNow() + } +} + +func TestSetCookie(t *testing.T) { + type cookie struct { + Name string + Value string + MaxAge int64 + Path string + Domain string + Secure bool + HttpOnly bool + SameSite string + } + type testItem struct { + item cookie + want string + } + cases := []struct { + request string + valueGp []testItem + }{ + {"/", []testItem{{cookie{"name", "value", -1, "/", "", false, false, "Strict"}, "name=value; Max-Age=0; Path=/; SameSite=Strict"}}}, + {"/", []testItem{{cookie{"name", "value", -1, "/", "", false, false, "Lax"}, "name=value; Max-Age=0; Path=/; SameSite=Lax"}}}, + {"/", []testItem{{cookie{"name", "value", -1, "/", "", false, false, "None"}, "name=value; Max-Age=0; Path=/; SameSite=None"}}}, + {"/", []testItem{{cookie{"name", "value", -1, "/", "", false, false, ""}, "name=value; Max-Age=0; Path=/"}}}, + } + for _, c := range cases { + r, _ := http.NewRequest("GET", c.request, nil) + output := NewOutput() + output.Context = NewContext() + output.Context.Reset(httptest.NewRecorder(), r) + for _, item := range c.valueGp { + params := item.item + var others = []interface{}{params.MaxAge, params.Path, params.Domain, params.Secure, params.HttpOnly, params.SameSite} + output.Context.SetCookie(params.Name, params.Value, others...) + got := output.Context.ResponseWriter.Header().Get("Set-Cookie") + if got != item.want { + t.Fatalf("SetCookie error,should be:\n%v \ngot:\n%v", item.want, got) + } + } + } +} diff --git a/server/web/context/form.go b/server/web/context/form.go new file mode 100644 index 00000000..07a1b0b2 --- /dev/null +++ b/server/web/context/form.go @@ -0,0 +1,189 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/url" + "reflect" + "strconv" + "strings" + "time" +) + +var ( + sliceOfInts = reflect.TypeOf([]int(nil)) + sliceOfStrings = reflect.TypeOf([]string(nil)) +) + +// ParseForm will parse form values to struct via tag. +// Support for anonymous struct. +func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error { + for i := 0; i < objT.NumField(); i++ { + fieldV := objV.Field(i) + if !fieldV.CanSet() { + continue + } + + fieldT := objT.Field(i) + if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct { + err := parseFormToStruct(form, fieldT.Type, fieldV) + if err != nil { + return err + } + continue + } + + tag, ok := formTagName(fieldT) + if !ok { + continue + } + + value, ok := formValue(tag, form, fieldT) + if !ok { + continue + } + + switch fieldT.Type.Kind() { + case reflect.Bool: + b, err := parseFormBoolValue(value) + if err != nil { + return err + } + fieldV.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + x, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + fieldV.SetInt(x) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + fieldV.SetUint(x) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + fieldV.SetFloat(x) + case reflect.Interface: + fieldV.Set(reflect.ValueOf(value)) + case reflect.String: + fieldV.SetString(value) + case reflect.Struct: + if fieldT.Type.String() == "time.Time" { + t, err := parseFormTime(value) + if err != nil { + return err + } + fieldV.Set(reflect.ValueOf(t)) + } + case reflect.Slice: + if fieldT.Type == sliceOfInts { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + val, err := strconv.Atoi(formVals[i]) + if err != nil { + return err + } + fieldV.Index(i).SetInt(int64(val)) + } + } else if fieldT.Type == sliceOfStrings { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + fieldV.Index(i).SetString(formVals[i]) + } + } + } + } + return nil +} + +// nolint +func parseFormTime(value string) (time.Time, error) { + var pattern string + if len(value) >= 25 { + value = value[:25] + pattern = time.RFC3339 + } else if strings.HasSuffix(strings.ToUpper(value), "Z") { + pattern = time.RFC3339 + } else if len(value) >= 19 { + if strings.Contains(value, "T") { + pattern = formatDateTimeT + } else { + pattern = formatDateTime + } + value = value[:19] + } else if len(value) >= 10 { + if len(value) > 10 { + value = value[:10] + } + pattern = formatDate + } else if len(value) >= 8 { + if len(value) > 8 { + value = value[:8] + } + pattern = formatTime + } + return time.ParseInLocation(pattern, value, time.Local) +} + +func parseFormBoolValue(value string) (bool, error) { + if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { + return true, nil + } + if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { + return false, nil + } + return strconv.ParseBool(value) +} + +// nolint +func formTagName(fieldT reflect.StructField) (string, bool) { + tags := strings.Split(fieldT.Tag.Get("form"), ",") + var tag string + if len(tags) == 0 || tags[0] == "" { + tag = fieldT.Name + } else if tags[0] == "-" { + return "", false + } else { + tag = tags[0] + } + return tag, true +} + +func formValue(tag string, form url.Values, fieldT reflect.StructField) (string, bool) { + formValues := form[tag] + var value string + if len(formValues) == 0 { + defaultValue := fieldT.Tag.Get("default") + if defaultValue != "" { + value = defaultValue + } else { + return "", false + } + } + if len(formValues) == 1 { + value = formValues[0] + if value == "" { + return "", false + } + } + return value, true +} diff --git a/server/web/context/input_test.go b/server/web/context/input_test.go index 3a6c2e7b..005ccd9a 100644 --- a/server/web/context/input_test.go +++ b/server/web/context/input_test.go @@ -203,8 +203,8 @@ func TestParams(t *testing.T) { if val := inp.Param("p2"); val != "val2_ver2" { t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") } - } + func BenchmarkQuery(b *testing.B) { beegoInput := NewInput() beegoInput.Context = NewContext() diff --git a/server/web/context/output.go b/server/web/context/output.go index aba5e560..f52eac9d 100644 --- a/server/web/context/output.go +++ b/server/web/context/output.go @@ -31,7 +31,8 @@ import ( "strings" "time" - yaml "gopkg.in/yaml.v2" + "google.golang.org/protobuf/proto" + "gopkg.in/yaml.v2" ) // BeegoOutput does work for sending response header. @@ -63,7 +64,7 @@ func (output *BeegoOutput) Header(key, val string) { // Sends out response body directly. func (output *BeegoOutput) Body(content []byte) error { var encoding string - var buf = &bytes.Buffer{} + buf := &bytes.Buffer{} if output.EnableGzip { encoding = ParseEncoding(output.Context.Request) } @@ -117,13 +118,13 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface // can use nil skip set // default "/" + tmpPath := "/" if len(others) > 1 { if v, ok := others[1].(string); ok && len(v) > 0 { - fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v)) + tmpPath = sanitizeValue(v) } - } else { - fmt.Fprintf(&b, "; Path=%s", "/") } + fmt.Fprintf(&b, "; Path=%s", tmpPath) // default empty if len(others) > 2 { @@ -155,6 +156,13 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface } } + // default empty + if len(others) > 5 { + if v, ok := others[5].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; SameSite=%s", sanitizeValue(v)) + } + } + output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) } @@ -217,6 +225,19 @@ func (output *BeegoOutput) YAML(data interface{}) error { return output.Body(content) } +// Proto writes protobuf to the response body. +func (output *BeegoOutput) Proto(data proto.Message) error { + output.Header("Content-Type", "application/x-protobuf; charset=utf-8") + var content []byte + var err error + content, err = proto.Marshal(data) + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + // JSONP writes jsonp to the response body. func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { output.Header("Content-Type", "application/javascript; charset=utf-8") diff --git a/server/web/context/param/parsers.go b/server/web/context/param/parsers.go index fb4099f5..99ad9275 100644 --- a/server/web/context/param/parsers.go +++ b/server/web/context/param/parsers.go @@ -55,29 +55,25 @@ func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error return f(value, toType) } -type boolParser struct { -} +type boolParser struct{} func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { return strconv.ParseBool(value) } -type stringParser struct { -} +type stringParser struct{} func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { return value, nil } -type intParser struct { -} +type intParser struct{} func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { return strconv.Atoi(value) } -type floatParser struct { -} +type floatParser struct{} func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { if toType.Kind() == reflect.Float32 { @@ -90,8 +86,7 @@ func (p floatParser) parse(value string, toType reflect.Type) (interface{}, erro return strconv.ParseFloat(value, 64) } -type timeParser struct { -} +type timeParser struct{} func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { result, err = time.Parse(time.RFC3339, value) @@ -101,8 +96,7 @@ func (p timeParser) parse(value string, toType reflect.Type) (result interface{} return } -type jsonParser struct { -} +type jsonParser struct{} func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { pResult := reflect.New(toType) diff --git a/server/web/context/param/parsers_test.go b/server/web/context/param/parsers_test.go index 7d4c9661..931d88f6 100644 --- a/server/web/context/param/parsers_test.go +++ b/server/web/context/param/parsers_test.go @@ -13,7 +13,6 @@ type testDefinition struct { } func Test_Parsers(t *testing.T) { - // ints checkParser(testDefinition{"1", 1, intParser{}}, t) checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) @@ -48,12 +47,11 @@ func Test_Parsers(t *testing.T) { checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) // pointers - var someInt = 1 + someInt := 1 checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) - var someStruct = struct{ X int }{5} + someStruct := struct{ X int }{5} checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) - } func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { diff --git a/server/web/controller.go b/server/web/controller.go index 5983cfbd..6bf061dd 100644 --- a/server/web/controller.go +++ b/server/web/controller.go @@ -28,11 +28,13 @@ import ( "reflect" "strconv" "strings" + "sync" - "github.com/beego/beego/v2/server/web/session" + "google.golang.org/protobuf/proto" "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/context/param" + "github.com/beego/beego/v2/server/web/session" ) var ( @@ -40,8 +42,21 @@ var ( ErrAbort = errors.New("user stop run") // GlobalControllerRouter store comments with controller. pkgpath+controller:comments GlobalControllerRouter = make(map[string][]ControllerComments) + copyBufferPool sync.Pool ) +const ( + bytePerKb = 1024 + copyBufferKb = 32 + filePerm = 0o666 +) + +func init() { + copyBufferPool.New = func() interface{} { + return make([]byte, bytePerKb*copyBufferKb) + } +} + // ControllerFilter store the filter for controller type ControllerFilter struct { Pattern string @@ -108,9 +123,9 @@ type Controller struct { EnableRender bool // xsrf data + EnableXSRF bool _xsrfToken string XSRFExpire int - EnableXSRF bool // session CruSession session.Store @@ -226,6 +241,37 @@ func (c *Controller) HandlerFunc(fnname string) bool { // URLMapping register the internal Controller router. func (c *Controller) URLMapping() {} +// Bind if the content type is form, we read data from form +// otherwise, read data from request body +func (c *Controller) Bind(obj interface{}) error { + return c.Ctx.Bind(obj) +} + +// BindYAML only read data from http request body +func (c *Controller) BindYAML(obj interface{}) error { + return c.Ctx.BindYAML(obj) +} + +// BindForm read data from form +func (c *Controller) BindForm(obj interface{}) error { + return c.Ctx.BindForm(obj) +} + +// BindJSON only read data from http request body +func (c *Controller) BindJSON(obj interface{}) error { + return c.Ctx.BindJSON(obj) +} + +// BindProtobuf only read data from http request body +func (c *Controller) BindProtobuf(obj proto.Message) error { + return c.Ctx.BindProtobuf(obj) +} + +// BindXML only read data from http request body +func (c *Controller) BindXML(obj interface{}) error { + return c.Ctx.BindXML(obj) +} + // Mapping the method to function func (c *Controller) Mapping(method string, fn func()) { c.methodMapping[method] = fn @@ -348,9 +394,9 @@ func (c *Controller) Abort(code string) { // CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. func (c *Controller) CustomAbort(status int, body string) { + c.Ctx.Output.Status = status // first panic from ErrorMaps, it is user defined error functions. if _, ok := ErrorMaps[body]; ok { - c.Ctx.Output.Status = status panic(body) } // last panic user string @@ -376,6 +422,26 @@ func (c *Controller) URLFor(endpoint string, values ...interface{}) string { return URLFor(endpoint, values...) } +func (c *Controller) JSONResp(data interface{}) error { + return c.Ctx.JSONResp(data) +} + +func (c *Controller) XMLResp(data interface{}) error { + return c.Ctx.XMLResp(data) +} + +func (c *Controller) YamlResp(data interface{}) error { + return c.Ctx.YamlResp(data) +} + +// Resp sends response based on the Accept Header +// By default response will be in JSON +// it's different from ServeXXX methods +// because we don't store the data to Data field +func (c *Controller) Resp(data interface{}) error { + return c.Ctx.Resp(data) +} + // ServeJSON sends a json response with encoding charset. func (c *Controller) ServeJSON(encoding ...bool) error { var ( @@ -423,11 +489,7 @@ func (c *Controller) Input() (url.Values, error) { // ParseForm maps input data map to obj struct. func (c *Controller) ParseForm(obj interface{}) error { - form, err := c.Input() - if err != nil { - return err - } - return ParseForm(form, obj) + return c.Ctx.BindForm(obj) } // GetString returns the input value by key string or the default value while it's present and input is blank @@ -605,19 +667,31 @@ func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { // SaveToFile saves uploaded file to new path. // it only operates the first one of mutil-upload form file field. -func (c *Controller) SaveToFile(fromfile, tofile string) error { - file, _, err := c.Ctx.Request.FormFile(fromfile) +func (c *Controller) SaveToFile(fromFile, toFile string) error { + buf := copyBufferPool.Get().([]byte) + defer copyBufferPool.Put(buf) + return c.SaveToFileWithBuffer(fromFile, toFile, buf) +} + +type onlyWriter struct { + io.Writer +} + +func (c *Controller) SaveToFileWithBuffer(fromFile string, toFile string, buf []byte) error { + src, _, err := c.Ctx.Request.FormFile(fromFile) if err != nil { return err } - defer file.Close() - f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + defer src.Close() + + dst, err := os.OpenFile(toFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, filePerm) if err != nil { return err } - defer f.Close() - io.Copy(f, file) - return nil + defer dst.Close() + + _, err = io.CopyBuffer(onlyWriter{dst}, src, buf) + return err } // StartSession starts session and load old session data info this controller. diff --git a/server/web/controller_test.go b/server/web/controller_test.go index 4dd203f6..6b0a5203 100644 --- a/server/web/controller_test.go +++ b/server/web/controller_test.go @@ -15,13 +15,17 @@ package web import ( + "io/ioutil" "math" + "net/http" + "net/http/httptest" "os" "path/filepath" "strconv" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/beego/beego/v2/server/web/context" ) @@ -127,10 +131,9 @@ func TestGetUint64(t *testing.T) { } func TestAdditionalViewPaths(t *testing.T) { - wkdir, err := os.Getwd() - assert.Nil(t, err) - dir1 := filepath.Join(wkdir, "_beeTmp", "TestAdditionalViewPaths") - dir2 := filepath.Join(wkdir, "_beeTmp2", "TestAdditionalViewPaths") + tmpDir := os.TempDir() + dir1 := filepath.Join(tmpDir, "_beeTmp", "TestAdditionalViewPaths") + dir2 := filepath.Join(tmpDir, "_beeTmp2", "TestAdditionalViewPaths") defer os.RemoveAll(dir1) defer os.RemoveAll(dir2) @@ -138,7 +141,7 @@ func TestAdditionalViewPaths(t *testing.T) { dir2file := "file2.tpl" genFile := func(dir string, name string, content string) { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0o777) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { @@ -146,7 +149,6 @@ func TestAdditionalViewPaths(t *testing.T) { f.WriteString(content) f.Close() } - } genFile(dir1, dir1file, `
{{.Content}}
`) genFile(dir2, dir2file, `{{.Content}}`) @@ -183,3 +185,127 @@ func TestAdditionalViewPaths(t *testing.T) { ctrl.ViewPath = dir2 ctrl.RenderString() } + +func TestBindJson(t *testing.T) { + var s struct { + Foo string `json:"foo"` + } + header := map[string][]string{"Content-Type": {"application/json"}} + request := &http.Request{Header: header} + input := &context.BeegoInput{RequestBody: []byte(`{"foo": "FOO"}`)} + ctx := &context.Context{Request: request, Input: input} + ctrlr := Controller{Ctx: ctx} + err := ctrlr.Bind(&s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) +} + +func TestBindNoContentType(t *testing.T) { + var s struct { + Foo string `json:"foo"` + } + header := map[string][]string{} + request := &http.Request{Header: header} + input := &context.BeegoInput{RequestBody: []byte(`{"foo": "FOO"}`)} + ctx := &context.Context{Request: request, Input: input} + ctrlr := Controller{Ctx: ctx} + err := ctrlr.Bind(&s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) +} + +func TestBindXML(t *testing.T) { + var s struct { + Foo string `xml:"foo"` + } + xmlBody := ` + + FOO +` + header := map[string][]string{"Content-Type": {"text/xml"}} + request := &http.Request{Header: header} + input := &context.BeegoInput{RequestBody: []byte(xmlBody)} + ctx := &context.Context{Request: request, Input: input} + ctrlr := Controller{Ctx: ctx} + err := ctrlr.Bind(&s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) +} + +func TestBindYAML(t *testing.T) { + var s struct { + Foo string `yaml:"foo"` + } + header := map[string][]string{"Content-Type": {"application/x-yaml"}} + request := &http.Request{Header: header} + input := &context.BeegoInput{RequestBody: []byte("foo: FOO")} + ctx := &context.Context{Request: request, Input: input} + ctrlr := Controller{Ctx: ctx} + err := ctrlr.Bind(&s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) +} + +type TestRespController struct { + Controller +} + +func (t *TestRespController) TestResponse() { + type S struct { + Foo string `json:"foo" xml:"foo" yaml:"foo"` + } + + bar := S{Foo: "bar"} + + _ = t.Resp(bar) +} + +type respTestCase struct { + Accept string + ExpectedContentLength int64 + ExpectedResponse string +} + +func TestControllerResp(t *testing.T) { + // test cases + tcs := []respTestCase{ + {Accept: context.ApplicationJSON, ExpectedContentLength: 13, ExpectedResponse: `{"foo":"bar"}`}, + {Accept: context.ApplicationXML, ExpectedContentLength: 21, ExpectedResponse: `bar`}, + {Accept: context.ApplicationYAML, ExpectedContentLength: 9, ExpectedResponse: "foo: bar\n"}, + {Accept: "OTHER", ExpectedContentLength: 13, ExpectedResponse: `{"foo":"bar"}`}, + } + + for _, tc := range tcs { + testControllerRespTestCases(t, tc) + } +} + +func testControllerRespTestCases(t *testing.T, tc respTestCase) { + // create fake GET request + r, _ := http.NewRequest("GET", "/", nil) + r.Header.Set("Accept", tc.Accept) + w := httptest.NewRecorder() + + // setup the handler + handler := NewControllerRegister() + handler.Add("/", &TestRespController{}, WithRouterMethods(&TestRespController{}, "get:TestResponse")) + handler.ServeHTTP(w, r) + + response := w.Result() + if response.ContentLength != tc.ExpectedContentLength { + t.Errorf("TestResponse() unable to validate content length %d for %s", response.ContentLength, tc.Accept) + } + + if response.StatusCode != http.StatusOK { + t.Errorf("TestResponse() failed to validate response code for %s", tc.Accept) + } + + bodyBytes, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Errorf("TestResponse() failed to parse response body for %s", tc.Accept) + } + bodyString := string(bodyBytes) + if bodyString != tc.ExpectedResponse { + t.Errorf("TestResponse() failed to validate response body '%s' for %s", bodyString, tc.Accept) + } +} diff --git a/server/web/error.go b/server/web/error.go index 85ae5c8c..118585dc 100644 --- a/server/web/error.go +++ b/server/web/error.go @@ -25,7 +25,6 @@ import ( "github.com/beego/beego/v2" "github.com/beego/beego/v2/core/utils" - "github.com/beego/beego/v2/server/web/context" ) diff --git a/server/web/filter.go b/server/web/filter.go index 0baa269f..2237703d 100644 --- a/server/web/filter.go +++ b/server/web/filter.go @@ -26,7 +26,9 @@ import ( type FilterChain func(next FilterFunc) FilterFunc // FilterFunc defines a filter function which is invoked before the controller handler is executed. -type FilterFunc func(ctx *context.Context) +// It's a alias of HandleFunc +// In fact, the HandleFunc is the last Filter. This is the truth +type FilterFunc = HandleFunc // FilterRouter defines a filter operation which is invoked before the controller handler is executed. // It can match the URL against a pattern, and execute a filter function diff --git a/server/web/filter/apiauth/apiauth.go b/server/web/filter/apiauth/apiauth.go index 9e6c30dc..55a914a2 100644 --- a/server/web/filter/apiauth/apiauth.go +++ b/server/web/filter/apiauth/apiauth.go @@ -17,12 +17,12 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/apiauth" +// "github.com/beego/beego/v2/server/web/filter/apiauth" // ) // // func main(){ // // apiauth every request -// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBasicAuth("appid","appkey")) // beego.Run() // } // diff --git a/server/web/filter/auth/basic.go b/server/web/filter/auth/basic.go index 5a01f260..403d4b2c 100644 --- a/server/web/filter/auth/basic.go +++ b/server/web/filter/auth/basic.go @@ -16,7 +16,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/auth" +// "github.com/beego/beego/v2/server/web/filter/auth" // ) // // func main(){ diff --git a/server/web/filter/authz/authz.go b/server/web/filter/authz/authz.go index 8009c976..4ff2d6bf 100644 --- a/server/web/filter/authz/authz.go +++ b/server/web/filter/authz/authz.go @@ -16,7 +16,7 @@ // Simple Usage: // import( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/authz" +// "github.com/beego/beego/v2/server/web/filter/authz" // "github.com/casbin/casbin" // ) // diff --git a/server/web/filter/cors/cors.go b/server/web/filter/cors/cors.go index 0eb9aa30..f6c68ca0 100644 --- a/server/web/filter/cors/cors.go +++ b/server/web/filter/cors/cors.go @@ -16,7 +16,7 @@ // Usage // import ( // "github.com/beego/beego/v2" -// "github.com/beego/beego/v2/plugins/cors" +// "github.com/beego/beego/v2/server/web/filter/cors" // ) // // func main() { @@ -69,10 +69,10 @@ var ( type Options struct { // If set, all origins are allowed. AllowAllOrigins bool - // A list of allowed origins. Wild cards and FQDNs are supported. - AllowOrigins []string // If set, allows to share auth credentials such as cookies. AllowCredentials bool + // A list of allowed origins. Wild cards and FQDNs are supported. + AllowOrigins []string // A list of allowed HTTP methods. AllowMethods []string // A list of allowed HTTP headers. @@ -143,7 +143,7 @@ func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map rHeader = strings.TrimSpace(rHeader) lookupLoop: for _, allowedHeader := range o.AllowHeaders { - if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { + if strings.EqualFold(rHeader, allowedHeader) { allowed = append(allowed, rHeader) break lookupLoop } diff --git a/server/web/filter/prometheus/filter.go b/server/web/filter/prometheus/filter.go index 520170e0..2a26e2f1 100644 --- a/server/web/filter/prometheus/filter.go +++ b/server/web/filter/prometheus/filter.go @@ -33,15 +33,15 @@ const unknownRouterPattern = "UnknownRouterPattern" // FilterChainBuilder is an extension point, // when we want to support some configuration, // please use this structure -type FilterChainBuilder struct { -} +type FilterChainBuilder struct{} -var summaryVec prometheus.ObserverVec -var initSummaryVec sync.Once +var ( + summaryVec prometheus.ObserverVec + initSummaryVec sync.Once +) // FilterChain returns a FilterFunc. The filter will records some metrics func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc { - initSummaryVec.Do(func() { summaryVec = builder.buildVec() err := prometheus.Register(summaryVec) diff --git a/server/web/filter/prometheus/filter_test.go b/server/web/filter/prometheus/filter_test.go index 8da63df9..46ba3caf 100644 --- a/server/web/filter/prometheus/filter_test.go +++ b/server/web/filter/prometheus/filter_test.go @@ -38,10 +38,10 @@ func TestFilterChain(t *testing.T) { ctx.Input.SetData("RouterPattern", "my-route") filter(ctx) assert.True(t, ctx.Input.GetData("invocation").(bool)) + time.Sleep(1 * time.Second) } func TestFilterChainBuilder_report(t *testing.T) { - ctx := context.NewContext() r, _ := http.NewRequest("GET", "/prometheus/user", nil) w := httptest.NewRecorder() @@ -52,4 +52,4 @@ func TestFilterChainBuilder_report(t *testing.T) { ctx.Input.SetData("RouterPattern", "my-route") report(time.Second, ctx, fb.buildVec()) -} \ No newline at end of file +} diff --git a/server/web/filter/ratelimit/bucket.go b/server/web/filter/ratelimit/bucket.go new file mode 100644 index 00000000..67a3907e --- /dev/null +++ b/server/web/filter/ratelimit/bucket.go @@ -0,0 +1,14 @@ +package ratelimit + +import "time" + +// bucket is an interface store ratelimit info +type bucket interface { + take(amount uint) bool + getCapacity() uint + getRemaining() uint + getRate() time.Duration +} + +// bucketOption is constructor option +type bucketOption func(bucket) diff --git a/server/web/filter/ratelimit/limiter.go b/server/web/filter/ratelimit/limiter.go new file mode 100644 index 00000000..5b64b5dd --- /dev/null +++ b/server/web/filter/ratelimit/limiter.go @@ -0,0 +1,168 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "net/http" + "sync" + "time" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" +) + +// Limiter is an interface used to ratelimit +type Limiter interface { + take(amount uint, r *http.Request) bool +} + +// limiterOption is constructor option +type limiterOption func(l *limiter) + +type limiter struct { + sync.RWMutex + capacity uint + rate time.Duration + buckets map[string]bucket + bucketFactory func(opts ...bucketOption) bucket + sessionKey func(r *http.Request) string + resp RejectionResponse +} + +// RejectionResponse stores response information +// for the request rejected by limiter +type RejectionResponse struct { + code int + body string +} + +const perRequestConsumedAmount = 1 + +var defaultRejectionResponse = RejectionResponse{ + code: 429, + body: "too many requests", +} + +// NewLimiter return FilterFunc, the limiter enables rate limit +// according to the configuration. +func NewLimiter(opts ...limiterOption) web.FilterFunc { + l := &limiter{ + buckets: make(map[string]bucket), + sessionKey: func(r *http.Request) string { + return defaultSessionKey(r) + }, + bucketFactory: NewTokenBucket, + resp: defaultRejectionResponse, + } + for _, o := range opts { + o(l) + } + + return func(ctx *context.Context) { + if !l.take(perRequestConsumedAmount, ctx.Request) { + ctx.ResponseWriter.WriteHeader(l.resp.code) + ctx.WriteString(l.resp.body) + } + } +} + +// WithSessionKey return limiterOption. WithSessionKey config func +// which defines the request characteristic againstthe limit is applied +func WithSessionKey(f func(r *http.Request) string) limiterOption { + return func(l *limiter) { + l.sessionKey = f + } +} + +// WithRate return limiterOption. WithRate config how long it takes to +// generate a token. +func WithRate(r time.Duration) limiterOption { + return func(l *limiter) { + l.rate = r + } +} + +// WithCapacity return limiterOption. WithCapacity config the capacity size. +// The bucket with a capacity of n has n tokens after initialization. The capacity +// defines how many requests a client can make in excess of the rate. +func WithCapacity(c uint) limiterOption { + return func(l *limiter) { + l.capacity = c + } +} + +// WithBucketFactory return limiterOption. WithBucketFactory customize the +// implementation of Bucket. +func WithBucketFactory(f func(opts ...bucketOption) bucket) limiterOption { + return func(l *limiter) { + l.bucketFactory = f + } +} + +// WithRejectionResponse return limiterOption. WithRejectionResponse +// customize the response for the request rejected by the limiter. +func WithRejectionResponse(resp RejectionResponse) limiterOption { + return func(l *limiter) { + l.resp = resp + } +} + +func (l *limiter) take(amount uint, r *http.Request) bool { + bucket := l.getBucket(r) + if bucket == nil { + return true + } + return bucket.take(amount) +} + +func (l *limiter) getBucket(r *http.Request) bucket { + key := l.sessionKey(r) + l.RLock() + b, ok := l.buckets[key] + l.RUnlock() + if !ok { + b = l.createBucket(key) + } + + return b +} + +func (l *limiter) createBucket(key string) bucket { + l.Lock() + defer l.Unlock() + // double check avoid overwriting + b, ok := l.buckets[key] + if ok { + return b + } + b = l.bucketFactory(withCapacity(l.capacity), withRate(l.rate)) + l.buckets[key] = b + return b +} + +func defaultSessionKey(r *http.Request) string { + return "" +} + +func RemoteIPSessionKey(r *http.Request) string { + IPAddress := r.Header.Get("X-Real-Ip") + if IPAddress == "" { + IPAddress = r.Header.Get("X-Forwarded-For") + } + if IPAddress == "" { + IPAddress = r.RemoteAddr + } + return IPAddress +} diff --git a/server/web/filter/ratelimit/limiter_test.go b/server/web/filter/ratelimit/limiter_test.go new file mode 100644 index 00000000..cafede5e --- /dev/null +++ b/server/web/filter/ratelimit/limiter_test.go @@ -0,0 +1,76 @@ +package ratelimit + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, requestIP, method, path string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.Header.Set("X-Real-Ip", requestIP) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", requestIP, method, path, w.Code, code) + } +} + +func TestLimiter(t *testing.T) { + handler := web.NewControllerRegister() + err := handler.InsertFilter("/foo/*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(1), WithSessionKey(RemoteIPSessionKey))) + if err != nil { + t.Error(err) + } + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + route := "/foo/1" + ip := "127.0.0.1" + testRequest(t, handler, ip, "GET", route, 200) + testRequest(t, handler, ip, "GET", route, 429) + testRequest(t, handler, "127.0.0.2", "GET", route, 200) + time.Sleep(1 * time.Millisecond) + testRequest(t, handler, ip, "GET", route, 200) +} + +func BenchmarkWithoutLimiter(b *testing.B) { + recorder := httptest.NewRecorder() + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + handler.ServeHTTP(recorder, r) + } + }) +} + +func BenchmarkWithLimiter(b *testing.B) { + recorder := httptest.NewRecorder() + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + err := handler.InsertFilter("*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(100))) + if err != nil { + b.Error(err) + } + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + handler.ServeHTTP(recorder, r) + } + }) +} diff --git a/server/web/filter/ratelimit/token_bucket.go b/server/web/filter/ratelimit/token_bucket.go new file mode 100644 index 00000000..5906ee9e --- /dev/null +++ b/server/web/filter/ratelimit/token_bucket.go @@ -0,0 +1,76 @@ +package ratelimit + +import ( + "sync" + "time" +) + +type tokenBucket struct { + sync.RWMutex + remaining uint + capacity uint + lastCheckAt time.Time + rate time.Duration +} + +// NewTokenBucket return an bucket that implements token bucket +func NewTokenBucket(opts ...bucketOption) bucket { + b := &tokenBucket{lastCheckAt: time.Now()} + for _, o := range opts { + o(b) + } + return b +} + +func withCapacity(capacity uint) bucketOption { + return func(b bucket) { + bucket := b.(*tokenBucket) + bucket.capacity = capacity + bucket.remaining = capacity + } +} + +func withRate(rate time.Duration) bucketOption { + return func(b bucket) { + bucket := b.(*tokenBucket) + bucket.rate = rate + } +} + +func (b *tokenBucket) getRemaining() uint { + b.RLock() + defer b.RUnlock() + return b.remaining +} + +func (b *tokenBucket) getRate() time.Duration { + b.RLock() + defer b.RUnlock() + return b.rate +} + +func (b *tokenBucket) getCapacity() uint { + b.RLock() + defer b.RUnlock() + return b.capacity +} + +func (b *tokenBucket) take(amount uint) bool { + if b.rate <= 0 { + return true + } + b.Lock() + defer b.Unlock() + now := time.Now() + times := uint(now.Sub(b.lastCheckAt) / b.rate) + b.lastCheckAt = b.lastCheckAt.Add(time.Duration(times) * b.rate) + b.remaining += times + if b.remaining < amount { + return false + } + b.remaining -= amount + if b.remaining > b.capacity { + b.remaining = b.capacity + } + return true +} diff --git a/server/web/filter/ratelimit/token_bucket_test.go b/server/web/filter/ratelimit/token_bucket_test.go new file mode 100644 index 00000000..93a1b3bd --- /dev/null +++ b/server/web/filter/ratelimit/token_bucket_test.go @@ -0,0 +1,32 @@ +package ratelimit + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetRate(t *testing.T) { + b := NewTokenBucket(withRate(1 * time.Second)).(*tokenBucket) + assert.Equal(t, b.getRate(), 1*time.Second) +} + +func TestGetRemainingAndCapacity(t *testing.T) { + b := NewTokenBucket(withCapacity(10)) + assert.Equal(t, b.getRemaining(), uint(10)) + assert.Equal(t, b.getCapacity(), uint(10)) +} + +func TestTake(t *testing.T) { + b := NewTokenBucket(withCapacity(10), withRate(10*time.Millisecond)).(*tokenBucket) + for i := 0; i < 10; i++ { + assert.True(t, b.take(1)) + } + assert.False(t, b.take(1)) + assert.Equal(t, b.getRemaining(), uint(0)) + b = NewTokenBucket(withCapacity(1), withRate(1*time.Millisecond)).(*tokenBucket) + assert.True(t, b.take(1)) + time.Sleep(2 * time.Millisecond) + assert.True(t, b.take(1)) +} diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go new file mode 100644 index 00000000..40f3e198 --- /dev/null +++ b/server/web/filter/session/filter.go @@ -0,0 +1,36 @@ +package session + +import ( + "context" + + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" +) + +// Session maintain session for web service +// Session new a session storage and store it into webContext.Context +func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { + sessionConfig := session.NewManagerConfig(options...) + sessionManager, _ := session.NewManager(string(providerType), sessionConfig) + go sessionManager.GC() + + return func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if ctx.Input.CruSession != nil { + return + } + + if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil { + logs.Error(`init session error:%s`, err.Error()) + } else { + // release session at the end of request + defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) + ctx.Input.CruSession = sess + } + + next(ctx) + } + } +} diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go new file mode 100644 index 00000000..6b413962 --- /dev/null +++ b/server/web/filter/session/filter_test.go @@ -0,0 +1,88 @@ +package session + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code) + } +} + +func TestSession(t *testing.T) { + storeKey := uuid.New().String() + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store := ctx.Input.GetData(storeKey); store == nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} + +func TestSession1(t *testing.T) { + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store, err := ctx.Session(); store == nil || err != nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} diff --git a/server/web/filter_chain_test.go b/server/web/filter_chain_test.go index ee19b1e9..d3faf516 100644 --- a/server/web/filter_chain_test.go +++ b/server/web/filter_chain_test.go @@ -15,16 +15,19 @@ package web import ( - "github.com/stretchr/testify/assert" + "fmt" "net/http" "net/http/httptest" + "strconv" "testing" + "time" + + "github.com/stretchr/testify/assert" "github.com/beego/beego/v2/server/web/context" ) func TestControllerRegisterInsertFilterChain(t *testing.T) { - InsertFilterChain("/*", func(next FilterFunc) FilterFunc { return func(ctx *context.Context) { ctx.Output.Header("filter", "filter-chain") @@ -35,22 +38,51 @@ func TestControllerRegisterInsertFilterChain(t *testing.T) { ns := NewNamespace("/chain") ns.Get("/*", func(ctx *context.Context) { - ctx.Output.Body([]byte("hello")) + _ = ctx.Output.Body([]byte("hello")) }) r, _ := http.NewRequest("GET", "/chain/user", nil) w := httptest.NewRecorder() + BeeApp.Handlers.Init() BeeApp.Handlers.ServeHTTP(w, r) assert.Equal(t, "filter-chain", w.Header().Get("filter")) } +func TestControllerRegister_InsertFilterChain_Order(t *testing.T) { + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + r, _ := http.NewRequest("GET", "/abc", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.Init() + BeeApp.Handlers.ServeHTTP(w, r) + first := w.Header().Get("first") + second := w.Header().Get("second") + + ft, _ := strconv.ParseInt(first, 10, 64) + st, _ := strconv.ParseInt(second, 10, 64) + + assert.True(t, st > ft) +} + func TestFilterChainRouter(t *testing.T) { - - app := NewHttpSever() - const filterNonMatch = "filter-chain-non-match" app.InsertFilterChain("/app/nonMatch/before/*", func(next FilterFunc) FilterFunc { return func(ctx *context.Context) { @@ -81,6 +113,8 @@ func TestFilterChainRouter(t *testing.T) { } }) + app.Handlers.Init() + r, _ := http.NewRequest("GET", "/app/match", nil) w := httptest.NewRecorder() diff --git a/server/web/flash_test.go b/server/web/flash_test.go index 2deef54e..3e20c8fb 100644 --- a/server/web/flash_test.go +++ b/server/web/flash_test.go @@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) { // setup the handler handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) handler.ServeHTTP(w, r) // get the Set-Cookie value diff --git a/server/web/fs.go b/server/web/fs.go index 5457457a..4d65630a 100644 --- a/server/web/fs.go +++ b/server/web/fs.go @@ -6,8 +6,7 @@ import ( "path/filepath" ) -type FileSystem struct { -} +type FileSystem struct{} func (d FileSystem) Open(name string) (http.File, error) { return os.Open(name) @@ -17,7 +16,6 @@ func (d FileSystem) Open(name string) (http.File, error) { // directory in the tree, including root. All errors that arise visiting files // and directories are filtered by walkFn. func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { - f, err := fs.Open(root) if err != nil { return err diff --git a/server/web/grace/grace.go b/server/web/grace/grace.go index 0adc8654..96ae10ef 100644 --- a/server/web/grace/grace.go +++ b/server/web/grace/grace.go @@ -22,7 +22,7 @@ // "net/http" // "os" // -// "github.com/beego/beego/v2/grace" +// "github.com/beego/beego/v2/server/web/grace" // ) // // func handler(w http.ResponseWriter, r *http.Request) { diff --git a/server/web/grace/server.go b/server/web/grace/server.go index 13fa6e34..5546546d 100644 --- a/server/web/grace/server.go +++ b/server/web/grace/server.go @@ -303,8 +303,8 @@ func (srv *Server) fork() (err error) { } runningServersForked = true - var files = make([]*os.File, len(runningServers)) - var orderArgs = make([]string, len(runningServers)) + files := make([]*os.File, len(runningServers)) + orderArgs := make([]string, len(runningServers)) for _, srvPtr := range runningServers { f, _ := srvPtr.ln.(*net.TCPListener).File() files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f diff --git a/server/web/hooks.go b/server/web/hooks.go index d44f477f..68cc61a1 100644 --- a/server/web/hooks.go +++ b/server/web/hooks.go @@ -6,8 +6,6 @@ import ( "net/http" "path/filepath" - "github.com/coreos/etcd/pkg/fileutil" - "github.com/beego/beego/v2/core/logs" "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/session" @@ -63,6 +61,7 @@ func registerSession() error { conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery + conf.CookieSameSite = BConfig.WebConfig.Session.SessionCookieSameSite conf.SessionIDPrefix = BConfig.WebConfig.Session.SessionIDPrefix } else { if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { @@ -98,18 +97,3 @@ func registerGzip() error { } return nil } - -func registerCommentRouter() error { - if BConfig.RunMode == DEV { - ctrlDir := filepath.Join(WorkPath, BConfig.WebConfig.CommentRouterPath) - if !fileutil.Exist(ctrlDir) { - logs.Warn("controller package not found, won't generate router: ", ctrlDir) - return nil - } - if err := parserPkg(ctrlDir); err != nil { - return err - } - } - - return nil -} diff --git a/server/web/mock/context.go b/server/web/mock/context.go new file mode 100644 index 00000000..46336c70 --- /dev/null +++ b/server/web/mock/context.go @@ -0,0 +1,28 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "net/http" + + beegoCtx "github.com/beego/beego/v2/server/web/context" +) + +func NewMockContext(req *http.Request) (*beegoCtx.Context, *HttpResponse) { + ctx := beegoCtx.NewContext() + resp := NewMockHttpResponse() + ctx.Reset(resp, req) + return ctx, resp +} diff --git a/server/web/mock/context_test.go b/server/web/mock/context_test.go new file mode 100644 index 00000000..98af12d3 --- /dev/null +++ b/server/web/mock/context_test.go @@ -0,0 +1,66 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "bytes" + "context" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/server/web" +) + +type TestController struct { + web.Controller +} + +func TestMockContext(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:8080/hello?name=tom", bytes.NewReader([]byte{})) + assert.Nil(t, err) + ctx, resp := NewMockContext(req) + ctrl := &TestController{ + Controller: web.Controller{ + Ctx: ctx, + }, + } + ctrl.HelloWorld() + result := resp.BodyToString() + assert.Equal(t, "name=tom", result) +} + +// GET hello?name=XXX +func (c *TestController) HelloWorld() { + name := c.GetString("name") + c.Ctx.WriteString(fmt.Sprintf("name=%s", name)) +} + +func (c *TestController) HelloSession() { + err := c.SessionRegenerateID() + if err != nil { + c.Ctx.WriteString("error") + return + } + _ = c.SetSession("name", "Tom") + c.Ctx.WriteString("set") +} + +func (c *TestController) HelloSessionName() { + name := c.CruSession.Get(context.Background(), "name") + c.Ctx.WriteString(name.(string)) +} diff --git a/server/web/mock/response.go b/server/web/mock/response.go new file mode 100644 index 00000000..6dcd77a1 --- /dev/null +++ b/server/web/mock/response.go @@ -0,0 +1,69 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "encoding/json" + "net/http" +) + +// HttpResponse mock response, which should be used in tests +type HttpResponse struct { + body []byte + header http.Header + StatusCode int +} + +// NewMockHttpResponse you should only use this in your test code +func NewMockHttpResponse() *HttpResponse { + return &HttpResponse{ + body: make([]byte, 0), + header: make(http.Header), + } +} + +// Header return headers +func (m *HttpResponse) Header() http.Header { + return m.header +} + +// Write append the body +func (m *HttpResponse) Write(bytes []byte) (int, error) { + m.body = append(m.body, bytes...) + return len(bytes), nil +} + +// WriteHeader set the status code +func (m *HttpResponse) WriteHeader(statusCode int) { + m.StatusCode = statusCode +} + +// JsonUnmarshal convert the body to object +func (m *HttpResponse) JsonUnmarshal(value interface{}) error { + return json.Unmarshal(m.body, value) +} + +// BodyToString return the body as the string +func (m *HttpResponse) BodyToString() string { + return string(m.body) +} + +// Reset will reset the status to init status +// Usually, you want to reuse this instance you may need to call Reset +func (m *HttpResponse) Reset() { + m.body = make([]byte, 0) + m.header = make(http.Header) + m.StatusCode = 0 +} diff --git a/server/web/mock/session.go b/server/web/mock/session.go new file mode 100644 index 00000000..66a65747 --- /dev/null +++ b/server/web/mock/session.go @@ -0,0 +1,123 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "net/http" + + "github.com/google/uuid" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/session" +) + +// NewSessionProvider create new SessionProvider +// and you could use it to mock data +// Parameter "name" is the real SessionProvider you used +func NewSessionProvider(name string) *SessionProvider { + sp := newSessionProvider() + session.Register(name, sp) + web.GlobalSessions, _ = session.NewManager(name, session.NewManagerConfig()) + return sp +} + +// SessionProvider will replace session provider with "mock" provider +type SessionProvider struct { + Store *SessionStore +} + +func newSessionProvider() *SessionProvider { + return &SessionProvider{ + Store: newSessionStore(), + } +} + +// SessionInit do nothing +func (s *SessionProvider) SessionInit(ctx context.Context, gclifetime int64, config string) error { + return nil +} + +// SessionRead return Store +func (s *SessionProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) { + return s.Store, nil +} + +// SessionExist always return true +func (s *SessionProvider) SessionExist(ctx context.Context, sid string) (bool, error) { + return true, nil +} + +// SessionRegenerate create new Store +func (s *SessionProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + s.Store = newSessionStore() + return s.Store, nil +} + +// SessionDestroy reset Store to nil +func (s *SessionProvider) SessionDestroy(ctx context.Context, sid string) error { + s.Store = nil + return nil +} + +// SessionAll return 0 +func (s *SessionProvider) SessionAll(ctx context.Context) int { + return 0 +} + +// SessionGC do nothing +func (s *SessionProvider) SessionGC(ctx context.Context) { + // we do anything since we don't need to mock GC +} + +type SessionStore struct { + sid string + values map[interface{}]interface{} +} + +func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error { + s.values[key] = value + return nil +} + +func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} { + return s.values[key] +} + +func (s *SessionStore) Delete(ctx context.Context, key interface{}) error { + delete(s.values, key) + return nil +} + +func (s *SessionStore) SessionID(ctx context.Context) string { + return s.sid +} + +// SessionRelease do nothing +func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { + // Support in the future if necessary, now I think we don't need to implement this +} + +func (s *SessionStore) Flush(ctx context.Context) error { + s.values = make(map[interface{}]interface{}, 4) + return nil +} + +func newSessionStore() *SessionStore { + return &SessionStore{ + sid: uuid.New().String(), + values: make(map[interface{}]interface{}, 4), + } +} diff --git a/server/web/mock/session_test.go b/server/web/mock/session_test.go new file mode 100644 index 00000000..f0abd167 --- /dev/null +++ b/server/web/mock/session_test.go @@ -0,0 +1,48 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "bytes" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/server/web" +) + +func TestSessionProvider(t *testing.T) { + sp := NewSessionProvider("file") + assert.NotNil(t, sp) + + req, err := http.NewRequest("GET", "http://localhost:8080/hello?name=tom", bytes.NewReader([]byte{})) + assert.Nil(t, err) + ctx, resp := NewMockContext(req) + ctrl := &TestController{ + Controller: web.Controller{ + Ctx: ctx, + }, + } + ctrl.HelloSession() + result := resp.BodyToString() + assert.Equal(t, "set", result) + + resp.Reset() + ctrl.HelloSessionName() + result = resp.BodyToString() + + assert.Equal(t, "Tom", result) +} diff --git a/server/web/namespace.go b/server/web/namespace.go index 4e0c3b85..825aa4b8 100644 --- a/server/web/namespace.go +++ b/server/web/namespace.go @@ -99,7 +99,7 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { // Router same as beego.Rourer // refer: https://godoc.org/github.com/beego/beego/v2#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, mappingMethods...) + n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...)) return n } @@ -119,56 +119,56 @@ func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace // Get same as beego.Get // refer: https://godoc.org/github.com/beego/beego/v2#Get -func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { +func (n *Namespace) Get(rootpath string, f HandleFunc) *Namespace { n.handlers.Get(rootpath, f) return n } // Post same as beego.Post // refer: https://godoc.org/github.com/beego/beego/v2#Post -func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { +func (n *Namespace) Post(rootpath string, f HandleFunc) *Namespace { n.handlers.Post(rootpath, f) return n } // Delete same as beego.Delete // refer: https://godoc.org/github.com/beego/beego/v2#Delete -func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { +func (n *Namespace) Delete(rootpath string, f HandleFunc) *Namespace { n.handlers.Delete(rootpath, f) return n } // Put same as beego.Put // refer: https://godoc.org/github.com/beego/beego/v2#Put -func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { +func (n *Namespace) Put(rootpath string, f HandleFunc) *Namespace { n.handlers.Put(rootpath, f) return n } // Head same as beego.Head // refer: https://godoc.org/github.com/beego/beego/v2#Head -func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { +func (n *Namespace) Head(rootpath string, f HandleFunc) *Namespace { n.handlers.Head(rootpath, f) return n } // Options same as beego.Options // refer: https://godoc.org/github.com/beego/beego/v2#Options -func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { +func (n *Namespace) Options(rootpath string, f HandleFunc) *Namespace { n.handlers.Options(rootpath, f) return n } // Patch same as beego.Patch // refer: https://godoc.org/github.com/beego/beego/v2#Patch -func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { +func (n *Namespace) Patch(rootpath string, f HandleFunc) *Namespace { n.handlers.Patch(rootpath, f) return n } // Any same as beego.Any // refer: https://godoc.org/github.com/beego/beego/v2#Any -func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { +func (n *Namespace) Any(rootpath string, f HandleFunc) *Namespace { n.handlers.Any(rootpath, f) return n } @@ -187,6 +187,54 @@ func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { return n } +// CtrlGet same as beego.CtrlGet +func (n *Namespace) CtrlGet(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlGet(rootpath, f) + return n +} + +// CtrlPost same as beego.CtrlPost +func (n *Namespace) CtrlPost(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlPost(rootpath, f) + return n +} + +// CtrlDelete same as beego.CtrlDelete +func (n *Namespace) CtrlDelete(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlDelete(rootpath, f) + return n +} + +// CtrlPut same as beego.CtrlPut +func (n *Namespace) CtrlPut(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlPut(rootpath, f) + return n +} + +// CtrlHead same as beego.CtrlHead +func (n *Namespace) CtrlHead(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlHead(rootpath, f) + return n +} + +// CtrlOptions same as beego.CtrlOptions +func (n *Namespace) CtrlOptions(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlOptions(rootpath, f) + return n +} + +// CtrlPatch same as beego.CtrlPatch +func (n *Namespace) CtrlPatch(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlPatch(rootpath, f) + return n +} + +// Any same as beego.CtrlAny +func (n *Namespace) CtrlAny(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlAny(rootpath, f) + return n +} + // Namespace add nest Namespace // usage: // ns := beego.NewNamespace(“/v1”). @@ -311,61 +359,117 @@ func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) } // NSGet call Namespace Get -func NSGet(rootpath string, f FilterFunc) LinkNamespace { +func NSGet(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Get(rootpath, f) } } // NSPost call Namespace Post -func NSPost(rootpath string, f FilterFunc) LinkNamespace { +func NSPost(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Post(rootpath, f) } } // NSHead call Namespace Head -func NSHead(rootpath string, f FilterFunc) LinkNamespace { +func NSHead(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Head(rootpath, f) } } // NSPut call Namespace Put -func NSPut(rootpath string, f FilterFunc) LinkNamespace { +func NSPut(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Put(rootpath, f) } } // NSDelete call Namespace Delete -func NSDelete(rootpath string, f FilterFunc) LinkNamespace { +func NSDelete(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Delete(rootpath, f) } } // NSAny call Namespace Any -func NSAny(rootpath string, f FilterFunc) LinkNamespace { +func NSAny(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Any(rootpath, f) } } // NSOptions call Namespace Options -func NSOptions(rootpath string, f FilterFunc) LinkNamespace { +func NSOptions(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Options(rootpath, f) } } // NSPatch call Namespace Patch -func NSPatch(rootpath string, f FilterFunc) LinkNamespace { +func NSPatch(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Patch(rootpath, f) } } +// NSCtrlGet call Namespace CtrlGet +func NSCtrlGet(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlGet(rootpath, f) + } +} + +// NSCtrlPost call Namespace CtrlPost +func NSCtrlPost(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlPost(rootpath, f) + } +} + +// NSCtrlHead call Namespace CtrlHead +func NSCtrlHead(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlHead(rootpath, f) + } +} + +// NSCtrlPut call Namespace CtrlPut +func NSCtrlPut(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlPut(rootpath, f) + } +} + +// NSCtrlDelete call Namespace CtrlDelete +func NSCtrlDelete(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlDelete(rootpath, f) + } +} + +// NSCtrlAny call Namespace CtrlAny +func NSCtrlAny(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlAny(rootpath, f) + } +} + +// NSCtrlOptions call Namespace CtrlOptions +func NSCtrlOptions(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlOptions(rootpath, f) + } +} + +// NSCtrlPatch call Namespace CtrlPatch +func NSCtrlPatch(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlPatch(rootpath, f) + } +} + // NSAutoRouter call Namespace AutoRouter func NSAutoRouter(c ControllerInterface) LinkNamespace { return func(ns *Namespace) { diff --git a/server/web/namespace_test.go b/server/web/namespace_test.go index 05042c96..e0e15d6f 100644 --- a/server/web/namespace_test.go +++ b/server/web/namespace_test.go @@ -15,6 +15,7 @@ package web import ( + "fmt" "net/http" "net/http/httptest" "strconv" @@ -23,6 +24,40 @@ import ( "github.com/beego/beego/v2/server/web/context" ) +const ( + exampleBody = "hello world" + examplePointerBody = "hello world pointer" + + nsNamespace = "/router" + nsPath = "/user" + nsNamespacePath = "/router/user" +) + +type ExampleController struct { + Controller +} + +func (m ExampleController) Ping() { + err := m.Ctx.Output.Body([]byte(exampleBody)) + if err != nil { + fmt.Println(err) + } +} + +func (m *ExampleController) PingPointer() { + err := m.Ctx.Output.Body([]byte(examplePointerBody)) + if err != nil { + fmt.Println(err) + } +} + +func (m ExampleController) ping() { + err := m.Ctx.Output.Body([]byte("ping method")) + if err != nil { + fmt.Println(err) + } +} + func TestNamespaceGet(t *testing.T) { r, _ := http.NewRequest("GET", "/v1/user", nil) w := httptest.NewRecorder() @@ -166,3 +201,215 @@ func TestNamespaceInside(t *testing.T) { t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) } } + +func TestNamespaceCtrlGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlGet(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlPost(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlPost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlDelete(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlDelete can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlPut(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlPut can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlHead(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlHead can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlOptions(t *testing.T) { + r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlOptions(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlOptions can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlPatch(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlPatch can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlAny(t *testing.T) { + ns := NewNamespace(nsNamespace) + ns.CtrlAny(nsPath, ExampleController.Ping) + AddNamespace(ns) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, nsNamespacePath, nil) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlAny can't run, get the response is " + w.Body.String()) + } + } +} + +func TestNamespaceNSCtrlGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlGet(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/router") + NSCtrlPost(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlPost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlDelete(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlDelete can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlPut(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlPut can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlHead(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlHead can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlOptions(t *testing.T) { + r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlOptions(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlOptions can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlPatch("/user", ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlPatch can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlAny(t *testing.T) { + ns := NewNamespace(nsNamespace) + NSCtrlAny(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, nsNamespacePath, nil) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlAny can't run, get the response is " + w.Body.String()) + } + } +} diff --git a/server/web/parser.go b/server/web/parser.go deleted file mode 100644 index 803880fe..00000000 --- a/server/web/parser.go +++ /dev/null @@ -1,600 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package web - -import ( - "encoding/json" - "errors" - "fmt" - "go/ast" - "io/ioutil" - "os" - "path/filepath" - "regexp" - "sort" - "strconv" - "strings" - "unicode" - - "golang.org/x/tools/go/packages" - - "github.com/beego/beego/v2/core/logs" - - "github.com/beego/beego/v2/core/utils" - "github.com/beego/beego/v2/server/web/context/param" -) - -var globalRouterTemplate = `package {{.routersDir}} - -import ( - beego "github.com/beego/beego/v2/server/web" - "github.com/beego/beego/v2/server/web/context/param"{{.globalimport}} -) - -func init() { -{{.globalinfo}} -} -` - -var ( - lastupdateFilename = "lastupdate.tmp" - commentFilename string - pkgLastupdate map[string]int64 - genInfoList map[string][]ControllerComments - - routerHooks = map[string]int{ - "beego.BeforeStatic": BeforeStatic, - "beego.BeforeRouter": BeforeRouter, - "beego.BeforeExec": BeforeExec, - "beego.AfterExec": AfterExec, - "beego.FinishRouter": FinishRouter, - } - - routerHooksMapping = map[int]string{ - BeforeStatic: "beego.BeforeStatic", - BeforeRouter: "beego.BeforeRouter", - BeforeExec: "beego.BeforeExec", - AfterExec: "beego.AfterExec", - FinishRouter: "beego.FinishRouter", - } -) - -const commentPrefix = "commentsRouter_" - -func init() { - pkgLastupdate = make(map[string]int64) -} - -func parserPkg(pkgRealpath string) error { - rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") - commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) - commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go" - if !compareFile(pkgRealpath) { - logs.Info(pkgRealpath + " no changed") - return nil - } - genInfoList = make(map[string][]ControllerComments) - pkgs, err := packages.Load(&packages.Config{ - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedSyntax, - Dir: pkgRealpath, - }, "./...") - - if err != nil { - return err - } - for _, pkg := range pkgs { - for _, fl := range pkg.Syntax { - for _, d := range fl.Decls { - switch specDecl := d.(type) { - case *ast.FuncDecl: - if specDecl.Recv != nil { - exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser - if ok { - parserComments(specDecl, fmt.Sprint(exp.X), pkg.PkgPath) - } - } - } - } - } - } - genRouterCode(pkgRealpath) - savetoFile(pkgRealpath) - return nil -} - -type parsedComment struct { - routerPath string - methods []string - params map[string]parsedParam - filters []parsedFilter - imports []parsedImport -} - -type parsedImport struct { - importPath string - importAlias string -} - -type parsedFilter struct { - pattern string - pos int - filter string - params []bool -} - -type parsedParam struct { - name string - datatype string - location string - defValue string - required bool -} - -func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { - if f.Doc != nil { - parsedComments, err := parseComment(f.Doc.List) - if err != nil { - return err - } - for _, parsedComment := range parsedComments { - if parsedComment.routerPath != "" { - key := pkgpath + ":" + controllerName - cc := ControllerComments{} - cc.Method = f.Name.String() - cc.Router = parsedComment.routerPath - cc.AllowHTTPMethods = parsedComment.methods - cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) - cc.FilterComments = buildFilters(parsedComment.filters) - cc.ImportComments = buildImports(parsedComment.imports) - genInfoList[key] = append(genInfoList[key], cc) - } - } - } - return nil -} - -func buildImports(pis []parsedImport) []*ControllerImportComments { - var importComments []*ControllerImportComments - - for _, pi := range pis { - importComments = append(importComments, &ControllerImportComments{ - ImportPath: pi.importPath, - ImportAlias: pi.importAlias, - }) - } - - return importComments -} - -func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { - var filterComments []*ControllerFilterComments - - for _, pf := range pfs { - var ( - returnOnOutput bool - resetParams bool - ) - - if len(pf.params) >= 1 { - returnOnOutput = pf.params[0] - } - - if len(pf.params) >= 2 { - resetParams = pf.params[1] - } - - filterComments = append(filterComments, &ControllerFilterComments{ - Filter: pf.filter, - Pattern: pf.pattern, - Pos: pf.pos, - ReturnOnOutput: returnOnOutput, - ResetParams: resetParams, - }) - } - - return filterComments -} - -func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { - result := make([]*param.MethodParam, 0, len(funcParams)) - for _, fparam := range funcParams { - for _, pName := range fparam.Names { - methodParam := buildMethodParam(fparam, pName.Name, pc) - result = append(result, methodParam) - } - } - return result -} - -func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam { - options := []param.MethodParamOption{} - if cparam, ok := pc.params[name]; ok { - // Build param from comment info - name = cparam.name - if cparam.required { - options = append(options, param.IsRequired) - } - switch cparam.location { - case "body": - options = append(options, param.InBody) - case "header": - options = append(options, param.InHeader) - case "path": - options = append(options, param.InPath) - } - if cparam.defValue != "" { - options = append(options, param.Default(cparam.defValue)) - } - } else { - if paramInPath(name, pc.routerPath) { - options = append(options, param.InPath) - } - } - return param.New(name, options...) -} - -func paramInPath(name, route string) bool { - return strings.HasSuffix(route, ":"+name) || - strings.Contains(route, ":"+name+"/") -} - -var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) - -func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { - pcs = []*parsedComment{} - params := map[string]parsedParam{} - filters := []parsedFilter{} - imports := []parsedImport{} - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Param") { - pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) - if len(pv) < 4 { - logs.Error("Invalid @Param format. Needs at least 4 parameters") - } - p := parsedParam{} - names := strings.SplitN(pv[0], "=>", 2) - p.name = names[0] - funcParamName := p.name - if len(names) > 1 { - funcParamName = names[1] - } - p.location = pv[1] - p.datatype = pv[2] - switch len(pv) { - case 5: - p.required, _ = strconv.ParseBool(pv[3]) - case 6: - p.defValue = pv[3] - p.required, _ = strconv.ParseBool(pv[4]) - } - params[funcParamName] = p - } - } - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Import") { - iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) - if len(iv) == 0 || len(iv) > 2 { - logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") - continue - } - - p := parsedImport{} - p.importPath = iv[0] - - if len(iv) == 2 { - p.importAlias = iv[1] - } - - imports = append(imports, p) - } - } - -filterLoop: - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Filter") { - fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) - if len(fv) < 3 { - logs.Error("Invalid @Filter format. Needs at least 3 parameters") - continue filterLoop - } - - p := parsedFilter{} - p.pattern = fv[0] - posName := fv[1] - if pos, exists := routerHooks[posName]; exists { - p.pos = pos - } else { - logs.Error("Invalid @Filter pos: ", posName) - continue filterLoop - } - - p.filter = fv[2] - fvParams := fv[3:] - for _, fvParam := range fvParams { - switch fvParam { - case "true": - p.params = append(p.params, true) - case "false": - p.params = append(p.params, false) - default: - logs.Error("Invalid @Filter param: ", fvParam) - continue filterLoop - } - } - - filters = append(filters, p) - } - } - - for _, c := range lines { - var pc = &parsedComment{} - pc.params = params - pc.filters = filters - pc.imports = imports - - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@router") { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - matches := routeRegex.FindStringSubmatch(t) - if len(matches) == 3 { - pc.routerPath = matches[1] - methods := matches[2] - if methods == "" { - pc.methods = []string{"get"} - // pc.hasGet = true - } else { - pc.methods = strings.Split(methods, ",") - // pc.hasGet = strings.Contains(methods, "get") - } - pcs = append(pcs, pc) - } else { - return nil, errors.New("Router information is missing") - } - } - } - return -} - -// direct copy from bee\g_docs.go -// analysis params return []string -// @Param query form string true "The email for login" -// [query form string true "The email for login"] -func getparams(str string) []string { - var s []rune - var j int - var start bool - var r []string - var quoted int8 - for _, c := range str { - if unicode.IsSpace(c) && quoted == 0 { - if !start { - continue - } else { - start = false - j++ - r = append(r, string(s)) - s = make([]rune, 0) - continue - } - } - - start = true - if c == '"' { - quoted ^= 1 - continue - } - s = append(s, c) - } - if len(s) > 0 { - r = append(r, string(s)) - } - return r -} - -func genRouterCode(pkgRealpath string) { - os.Mkdir(getRouterDir(pkgRealpath), 0755) - logs.Info("generate router from comments") - var ( - globalinfo string - globalimport string - sortKey []string - ) - for k := range genInfoList { - sortKey = append(sortKey, k) - } - sort.Strings(sortKey) - for _, k := range sortKey { - cList := genInfoList[k] - sort.Sort(ControllerCommentsSlice(cList)) - for _, c := range cList { - allmethod := "nil" - if len(c.AllowHTTPMethods) > 0 { - allmethod = "[]string{" - for _, m := range c.AllowHTTPMethods { - allmethod += `"` + m + `",` - } - allmethod = strings.TrimRight(allmethod, ",") + "}" - } - - params := "nil" - if len(c.Params) > 0 { - params = "[]map[string]string{" - for _, p := range c.Params { - for k, v := range p { - params = params + `map[string]string{` + k + `:"` + v + `"},` - } - } - params = strings.TrimRight(params, ",") + "}" - } - - methodParams := "param.Make(" - if len(c.MethodParams) > 0 { - lines := make([]string, 0, len(c.MethodParams)) - for _, m := range c.MethodParams { - lines = append(lines, fmt.Sprint(m)) - } - methodParams += "\n " + - strings.Join(lines, ",\n ") + - ",\n " - } - methodParams += ")" - - imports := "" - if len(c.ImportComments) > 0 { - for _, i := range c.ImportComments { - var s string - if i.ImportAlias != "" { - s = fmt.Sprintf(` - %s "%s"`, i.ImportAlias, i.ImportPath) - } else { - s = fmt.Sprintf(` - "%s"`, i.ImportPath) - } - if !strings.Contains(globalimport, s) { - imports += s - } - } - } - - filters := "" - if len(c.FilterComments) > 0 { - for _, f := range c.FilterComments { - filters += fmt.Sprintf(` &beego.ControllerFilter{ - Pattern: "%s", - Pos: %s, - Filter: %s, - ReturnOnOutput: %v, - ResetParams: %v, - },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) - } - } - - if filters == "" { - filters = "nil" - } else { - filters = fmt.Sprintf(`[]*beego.ControllerFilter{ -%s - }`, filters) - } - - globalimport += imports - - globalinfo = globalinfo + ` - beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], - beego.ControllerComments{ - Method: "` + strings.TrimSpace(c.Method) + `", - ` + `Router: "` + c.Router + `"` + `, - AllowHTTPMethods: ` + allmethod + `, - MethodParams: ` + methodParams + `, - Filters: ` + filters + `, - Params: ` + params + `}) -` - } - } - - if globalinfo != "" { - f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) - if err != nil { - panic(err) - } - defer f.Close() - - routersDir := AppConfig.DefaultString("routersdir", "routers") - content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) - content = strings.Replace(content, "{{.routersDir}}", routersDir, -1) - content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) - f.WriteString(content) - } -} - -func compareFile(pkgRealpath string) bool { - if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) { - return true - } - if utils.FileExists(lastupdateFilename) { - content, err := ioutil.ReadFile(lastupdateFilename) - if err != nil { - return true - } - json.Unmarshal(content, &pkgLastupdate) - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return true - } - if v, ok := pkgLastupdate[pkgRealpath]; ok { - if lastupdate <= v { - return false - } - } - } - return true -} - -func savetoFile(pkgRealpath string) { - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return - } - pkgLastupdate[pkgRealpath] = lastupdate - d, err := json.Marshal(pkgLastupdate) - if err != nil { - return - } - ioutil.WriteFile(lastupdateFilename, d, os.ModePerm) -} - -func getpathTime(pkgRealpath string) (lastupdate int64, err error) { - fl, err := ioutil.ReadDir(pkgRealpath) - if err != nil { - return lastupdate, err - } - for _, f := range fl { - var t int64 - if f.IsDir() { - t, err = getpathTime(filepath.Join(pkgRealpath, f.Name())) - if err != nil { - return lastupdate, err - } - } else { - t = f.ModTime().UnixNano() - } - if lastupdate < t { - lastupdate = t - } - } - return lastupdate, nil -} - -func getRouterDir(pkgRealpath string) string { - dir := filepath.Dir(pkgRealpath) - for { - routersDir := AppConfig.DefaultString("routersdir", "routers") - d := filepath.Join(dir, routersDir) - if utils.FileExists(d) { - return d - } - - if r, _ := filepath.Rel(dir, AppPath); r == "." { - return d - } - // Parent dir. - dir = filepath.Dir(dir) - } -} diff --git a/server/web/policy.go b/server/web/policy.go index 1b810520..41fb2669 100644 --- a/server/web/policy.go +++ b/server/web/policy.go @@ -25,7 +25,7 @@ type PolicyFunc func(*context.Context) // FindPolicy Find Router info for URL func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { - var urlPath = cont.Input.URL() + urlPath := cont.Input.URL() if !BConfig.RouterCaseSensitive { urlPath = strings.ToLower(urlPath) } diff --git a/server/web/router.go b/server/web/router.go index 93c77e7d..c4400b20 100644 --- a/server/web/router.go +++ b/server/web/router.go @@ -20,13 +20,13 @@ import ( "net/http" "path" "reflect" + "runtime" "strconv" "strings" "sync" "time" "github.com/beego/beego/v2/core/logs" - "github.com/beego/beego/v2/core/utils" beecontext "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/context/param" @@ -68,20 +68,8 @@ var ( "UNLOCK": true, } // these web.Controller's methods shouldn't reflect to AutoRouter - exceptMethod = []string{"Abort", "CheckXSRFCookie", "CustomAbort", "DelSession", - "DestroySession", "Finish", "GetBool", "GetControllerAndAction", - "GetFile", "GetFiles", "GetFloat", "GetInt", "GetInt16", - "GetInt32", "GetInt64", "GetInt8", "GetSecureCookie", "GetSession", - "GetString", "GetStrings", "GetUint16", "GetUint32", "GetUint64", - "GetUint8", "HandlerFunc", "Init", "Input", - "IsAjax", "Mapping", "ParseForm", - "Prepare", "Redirect", "Render", "RenderBytes", - "RenderString", "SaveToFile", "SaveToFileWithBuffer", "SaveToFileWithBuffer", - "ServeFormatted", "ServeJSON", "ServeJSONP", "ServeXML", "ServeYAML", - "SessionRegenerateID", "SetData", "SetSecureCookie", "SetSession", "StartSession", - "StopRun", "URLFor", "URLMapping", "XSRFFormHTML", - "XSRFToken", - } + // see registerControllerExceptMethods + exceptMethod = initExceptMethod() urlPlaceholder = "{{placeholder}}" // DefaultAccessLogFilter will skip the accesslog if return true @@ -94,8 +82,7 @@ type FilterHandler interface { } // default log filter static file will not show -type logFilter struct { -} +type logFilter struct{} func (l *logFilter) Filter(ctx *beecontext.Context) bool { requestPath := path.Clean(ctx.Request.URL.Path) @@ -115,34 +102,69 @@ func ExceptMethodAppend(action string) { exceptMethod = append(exceptMethod, action) } +func initExceptMethod() []string { + res := make([]string, 0, 32) + c := &Controller{} + t := reflect.TypeOf(c) + for i := 0; i < t.NumMethod(); i++ { + m := t.Method(i) + res = append(res, m.Name) + } + return res +} + // ControllerInfo holds information about the controller. type ControllerInfo struct { pattern string controllerType reflect.Type methods map[string]string handler http.Handler - runFunction FilterFunc + runFunction HandleFunc routerType int initialize func() ControllerInterface methodParams []*param.MethodParam + sessionOn bool } +type ControllerOption func(*ControllerInfo) + func (c *ControllerInfo) GetPattern() string { return c.pattern } +func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { + return func(c *ControllerInfo) { + c.methods = parseMappingMethods(ctrlInterface, mappingMethod) + } +} + +func WithRouterSessionOn(sessionOn bool) ControllerOption { + return func(c *ControllerInfo) { + c.sessionOn = sessionOn + } +} + +type filterChainConfig struct { + pattern string + chain FilterChain + opts []FilterOpt +} + // ControllerRegister containers registered router rules, controller handlers and filters. type ControllerRegister struct { routers map[string]*Tree enablePolicy bool - policies map[string]*Tree enableFilter bool + policies map[string]*Tree filters [FinishRouter + 1][]*FilterRouter pool sync.Pool // the filter created by FilterChain chainRoot *FilterRouter + // keep registered chain and build it when serve http + filterChains []filterChainConfig + cfg *Config } @@ -162,12 +184,27 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { return beecontext.NewContext() }, }, - cfg: cfg, + cfg: cfg, + filterChains: make([]filterChainConfig, 0, 4), } res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) return res } +// Init will be executed when HttpServer start running +func (p *ControllerRegister) Init() { + for i := len(p.filterChains) - 1; i >= 0; i-- { + fc := p.filterChains[i] + root := p.chainRoot + filterFunc := fc.chain(func(ctx *beecontext.Context) { + var preFilterParams map[string]string + root.filter(ctx, p.getUrlPath(ctx), preFilterParams) + }) + p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...) + p.chainRoot.next = root + } +} + // Add controller handler and pattern rules to ControllerRegister. // usage: // default methods is the same name as method @@ -178,41 +215,64 @@ func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { // Add("/api/delete",&RestController{},"delete:DeleteFood") // Add("/api",&RestController{},"get,post:ApiFunc" // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - p.addWithMethodParams(pattern, c, nil, mappingMethods...) +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) { + p.addWithMethodParams(pattern, c, nil, opts...) } -func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { +func parseMappingMethods(c ControllerInterface, mappingMethods []string) map[string]string { reflectVal := reflect.ValueOf(c) t := reflect.Indirect(reflectVal).Type() methods := make(map[string]string) - if len(mappingMethods) > 0 { - semi := strings.Split(mappingMethods[0], ";") - for _, v := range semi { - colon := strings.Split(v, ":") - if len(colon) != 2 { - panic("method mapping format is invalid") + + if len(mappingMethods) == 0 { + return methods + } + + semi := strings.Split(mappingMethods[0], ";") + for _, v := range semi { + colon := strings.Split(v, ":") + if len(colon) != 2 { + panic("method mapping format is invalid") + } + comma := strings.Split(colon[0], ",") + for _, m := range comma { + if m != "*" && !HTTPMETHOD[strings.ToUpper(m)] { + panic(v + " is an invalid method mapping. Method doesn't exist " + m) } - comma := strings.Split(colon[0], ",") - for _, m := range comma { - if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { - if val := reflectVal.MethodByName(colon[1]); val.IsValid() { - methods[strings.ToUpper(m)] = colon[1] - } else { - panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) - } - } else { - panic(v + " is an invalid method mapping. Method doesn't exist " + m) - } + if val := reflectVal.MethodByName(colon[1]); val.IsValid() { + methods[strings.ToUpper(m)] = colon[1] + continue } + panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) } } - route := &ControllerInfo{} - route.pattern = pattern - route.methods = methods - route.routerType = routerTypeBeego - route.controllerType = t + return methods +} + +func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) { + if len(route.methods) == 0 { + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + return + } + for k := range route.methods { + if k != "*" { + p.addToRouter(k, route.pattern, route) + continue + } + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + } +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + + route := p.createBeegoRouter(t, pattern) route.initialize = func() ControllerInterface { vc := reflect.New(route.controllerType) execController, ok := vc.Interface().(ControllerInterface) @@ -236,23 +296,18 @@ func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInt return execController } - route.methodParams = methodParams - if len(methods) == 0 { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } + for i := range opts { + opts[i](route) } + + globalSessionOn := p.cfg.WebConfig.Session.SessionOn + if !globalSessionOn && route.sessionOn { + logs.Warn("global sessionOn is false, sessionOn of router [%s] can't be set to true", route.pattern) + route.sessionOn = globalSessionOn + } + + p.addRouterForMethod(route) } func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { @@ -280,7 +335,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { for _, f := range a.Filters { p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) } - p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) + p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) } } } @@ -301,12 +356,274 @@ func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { p.pool.Put(ctx) } +// CtrlGet add get method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlGet("/api/:id", MyController.Ping) +// If the receiver of function Ping is pointer, you should use CtrlGet("/api/:id", (*MyController).Ping) +func (p *ControllerRegister) CtrlGet(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodGet, pattern, f) +} + +// CtrlPost add post method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPost("/api/:id", MyController.Ping) +// If the receiver of function Ping is pointer, you should use CtrlPost("/api/:id", (*MyController).Ping) +func (p *ControllerRegister) CtrlPost(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPost, pattern, f) +} + +// CtrlHead add head method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlHead("/api/:id", MyController.Ping) +// If the receiver of function Ping is pointer, you should use CtrlHead("/api/:id", (*MyController).Ping) +func (p *ControllerRegister) CtrlHead(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodHead, pattern, f) +} + +// CtrlPut add put method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPut("/api/:id", MyController.Ping) + +func (p *ControllerRegister) CtrlPut(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPut, pattern, f) +} + +// CtrlPatch add patch method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPatch("/api/:id", MyController.Ping) +func (p *ControllerRegister) CtrlPatch(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPatch, pattern, f) +} + +// CtrlDelete add delete method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlDelete("/api/:id", MyController.Ping) +func (p *ControllerRegister) CtrlDelete(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodDelete, pattern, f) +} + +// CtrlOptions add options method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlOptions("/api/:id", MyController.Ping) +func (p *ControllerRegister) CtrlOptions(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodOptions, pattern, f) +} + +// CtrlAny add all method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlAny("/api/:id", MyController.Ping) +func (p *ControllerRegister) CtrlAny(pattern string, f interface{}) { + p.AddRouterMethod("*", pattern, f) +} + +// AddRouterMethod add http method router +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// AddRouterMethod("get","/api/:id", MyController.Ping) +func (p *ControllerRegister) AddRouterMethod(httpMethod, pattern string, f interface{}) { + httpMethod = p.getUpperMethodString(httpMethod) + ct, methodName := getReflectTypeAndMethod(f) + + p.addBeegoTypeRouter(ct, methodName, httpMethod, pattern) +} + +// addBeegoTypeRouter add beego type router +func (p *ControllerRegister) addBeegoTypeRouter(ct reflect.Type, ctMethod, httpMethod, pattern string) { + route := p.createBeegoRouter(ct, pattern) + methods := p.getHttpMethodMapMethod(httpMethod, ctMethod) + route.methods = methods + + p.addRouterForMethod(route) +} + +// createBeegoRouter create beego router base on reflect type and pattern +func (p *ControllerRegister) createBeegoRouter(ct reflect.Type, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeBeego + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.controllerType = ct + return route +} + +// createRestfulRouter create restful router with filter function and pattern +func (p *ControllerRegister) createRestfulRouter(f HandleFunc, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeRESTFul + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.runFunction = f + return route +} + +// createHandlerRouter create handler router with handler and pattern +func (p *ControllerRegister) createHandlerRouter(h http.Handler, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeHandler + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.handler = h + return route +} + +// getHttpMethodMapMethod based on http method and controller method, if ctMethod is empty, then it will +// use http method as the controller method +func (p *ControllerRegister) getHttpMethodMapMethod(httpMethod, ctMethod string) map[string]string { + methods := make(map[string]string) + // not match-all sign, only add for the http method + if httpMethod != "*" { + + if ctMethod == "" { + ctMethod = httpMethod + } + methods[httpMethod] = ctMethod + return methods + } + + // add all http method + for val := range HTTPMETHOD { + if ctMethod == "" { + methods[val] = val + } else { + methods[val] = ctMethod + } + } + return methods +} + +// getUpperMethodString get upper string of method, and panic if the method +// is not valid +func (p *ControllerRegister) getUpperMethodString(method string) string { + method = strings.ToUpper(method) + if method != "*" && !HTTPMETHOD[method] { + panic("not support http method: " + method) + } + return method +} + +// get reflect controller type and method by controller method expression +func getReflectTypeAndMethod(f interface{}) (controllerType reflect.Type, method string) { + // check f is a function + funcType := reflect.TypeOf(f) + if funcType.Kind() != reflect.Func { + panic("not a method") + } + + // get function name + funcObj := runtime.FuncForPC(reflect.ValueOf(f).Pointer()) + if funcObj == nil { + panic("cannot find the method") + } + funcNameSli := strings.Split(funcObj.Name(), ".") + lFuncSli := len(funcNameSli) + if lFuncSli == 0 { + panic("invalid method full name: " + funcObj.Name()) + } + + method = funcNameSli[lFuncSli-1] + if len(method) == 0 { + panic("method name is empty") + } else if method[0] > 96 || method[0] < 65 { + panic(fmt.Sprintf("%s is not a public method", method)) + } + + // check only one param which is the method receiver + if numIn := funcType.NumIn(); numIn != 1 { + panic("invalid number of param in") + } + + controllerType = funcType.In(0) + + // check controller has the method + _, exists := controllerType.MethodByName(method) + if !exists { + panic(controllerType.String() + " has no method " + method) + } + + // check the receiver implement ControllerInterface + if controllerType.Kind() == reflect.Ptr { + controllerType = controllerType.Elem() + } + controller := reflect.New(controllerType) + _, ok := controller.Interface().(ControllerInterface) + if !ok { + panic(controllerType.String() + " is not implemented ControllerInterface") + } + + return +} + +// HandleFunc define how to process the request +type HandleFunc func(ctx *beecontext.Context) + // Get add get method // usage: // Get("/", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegister) Get(pattern string, f FilterFunc) { +func (p *ControllerRegister) Get(pattern string, f HandleFunc) { p.AddMethod("get", pattern, f) } @@ -315,7 +632,7 @@ func (p *ControllerRegister) Get(pattern string, f FilterFunc) { // Post("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegister) Post(pattern string, f FilterFunc) { +func (p *ControllerRegister) Post(pattern string, f HandleFunc) { p.AddMethod("post", pattern, f) } @@ -324,7 +641,7 @@ func (p *ControllerRegister) Post(pattern string, f FilterFunc) { // Put("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegister) Put(pattern string, f FilterFunc) { +func (p *ControllerRegister) Put(pattern string, f HandleFunc) { p.AddMethod("put", pattern, f) } @@ -333,7 +650,7 @@ func (p *ControllerRegister) Put(pattern string, f FilterFunc) { // Delete("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { +func (p *ControllerRegister) Delete(pattern string, f HandleFunc) { p.AddMethod("delete", pattern, f) } @@ -342,7 +659,7 @@ func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { // Head("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegister) Head(pattern string, f FilterFunc) { +func (p *ControllerRegister) Head(pattern string, f HandleFunc) { p.AddMethod("head", pattern, f) } @@ -351,7 +668,7 @@ func (p *ControllerRegister) Head(pattern string, f FilterFunc) { // Patch("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { +func (p *ControllerRegister) Patch(pattern string, f HandleFunc) { p.AddMethod("patch", pattern, f) } @@ -360,7 +677,7 @@ func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { // Options("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegister) Options(pattern string, f FilterFunc) { +func (p *ControllerRegister) Options(pattern string, f HandleFunc) { p.AddMethod("options", pattern, f) } @@ -369,7 +686,7 @@ func (p *ControllerRegister) Options(pattern string, f FilterFunc) { // Any("/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegister) Any(pattern string, f FilterFunc) { +func (p *ControllerRegister) Any(pattern string, f HandleFunc) { p.AddMethod("*", pattern, f) } @@ -378,41 +695,19 @@ func (p *ControllerRegister) Any(pattern string, f FilterFunc) { // AddMethod("get","/api/:id", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { - method = strings.ToUpper(method) - if method != "*" && !HTTPMETHOD[method] { - panic("not support http method: " + method) - } - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeRESTFul - route.runFunction = f - methods := make(map[string]string) - if method == "*" { - for val := range HTTPMETHOD { - methods[val] = val - } - } else { - methods[method] = method - } +func (p *ControllerRegister) AddMethod(method, pattern string, f HandleFunc) { + method = p.getUpperMethodString(method) + + route := p.createRestfulRouter(f, pattern) + methods := p.getHttpMethodMapMethod(method, "") route.methods = methods - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } + + p.addRouterForMethod(route) } // Handler add user defined Handler func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeHandler - route.handler = h + route := p.createHandlerRouter(h, pattern) if len(options) > 0 { if _, ok := options[0].(bool); ok { pattern = path.Join(pattern, "?:all(.*)") @@ -424,7 +719,7 @@ func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ... } // AddAuto router to ControllerRegister. -// example beego.AddAuto(&MainContorlller{}), +// example beego.AddAuto(&MainController{}), // MainController has method List and Page. // visit the url /main/list to execute List function // /main/page to execute Page function. @@ -433,7 +728,7 @@ func (p *ControllerRegister) AddAuto(c ControllerInterface) { } // AddAutoPrefix Add auto router to ControllerRegister with prefix. -// example beego.AddAutoPrefix("/admin",&MainContorlller{}), +// example beego.AddAutoPrefix("/admin",&MainController{}), // MainController has method List and Page. // visit the url /admin/main/list to execute List function // /admin/main/page to execute Page function. @@ -443,22 +738,30 @@ func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) ct := reflect.Indirect(reflectVal).Type() controllerName := strings.TrimSuffix(ct.Name(), "Controller") for i := 0; i < rt.NumMethod(); i++ { - if !utils.InSlice(rt.Method(i).Name, exceptMethod) { - route := &ControllerInfo{} - route.routerType = routerTypeBeego - route.methods = map[string]string{"*": rt.Method(i).Name} - route.controllerType = ct - pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") - patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") - patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) - patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) - route.pattern = pattern - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - p.addToRouter(m, patternInit, route) - p.addToRouter(m, patternFix, route) - p.addToRouter(m, patternFixInit, route) - } + methodName := rt.Method(i).Name + if !utils.InSlice(methodName, exceptMethod) { + p.addAutoPrefixMethod(prefix, controllerName, methodName, ct) + } + } +} + +func (p *ControllerRegister) addAutoPrefixMethod(prefix, controllerName, methodName string, ctrl reflect.Type) { + pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(methodName), "*") + patternInit := path.Join(prefix, controllerName, methodName, "*") + patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(methodName)) + patternFixInit := path.Join(prefix, controllerName, methodName) + + route := p.createBeegoRouter(ctrl, pattern) + route.methods = map[string]string{"*": methodName} + for m := range HTTPMETHOD { + + p.addToRouter(m, pattern, route) + + // only case sensitive, we add three more routes + if p.cfg.RouterCaseSensitive { + p.addToRouter(m, patternInit, route) + p.addToRouter(m, patternFix, route) + p.addToRouter(m, patternFixInit, route) } } } @@ -485,16 +788,12 @@ func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter Filter // } // } func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { - root := p.chainRoot - //filterFunc := chain(root.filterFunc) - filterFunc := chain(func(ctx *beecontext.Context) { - var preFilterParams map[string]string - root.filter(ctx, p.getUrlPath(ctx), preFilterParams) + opts = append([]FilterOpt{WithCaseSensitive(p.cfg.RouterCaseSensitive)}, opts...) + p.filterChains = append(p.filterChains, filterChainConfig{ + pattern: pattern, + chain: chain, + opts: opts, }) - opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) - p.chainRoot = newFilterRouter(pattern, filterFunc, opts...) - p.chainRoot.next = root - } // add Filter into @@ -559,7 +858,7 @@ func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName str for _, l := range t.leaves { if c, ok := l.runObject.(*ControllerInfo); ok { if c.routerType == routerTypeBeego && - strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) { find := false if HTTPMETHOD[strings.ToUpper(methodName)] { if len(c.methods) == 0 { @@ -665,7 +964,6 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str // Implement http.Handler interface. func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - ctx := p.GetContext() ctx.Reset(rw, r) @@ -681,12 +979,15 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { r := ctx.Request rw := ctx.ResponseWriter.ResponseWriter var ( - runRouter reflect.Type - findRouter bool - runMethod string - methodParams []*param.MethodParam - routerInfo *ControllerInfo - isRunnable bool + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool + currentSessionOn bool + originRouterInfo *ControllerInfo + originFindRouter bool ) if p.cfg.RecoverFunc != nil { @@ -752,7 +1053,12 @@ func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { } // session init - if p.cfg.WebConfig.Session.SessionOn { + currentSessionOn = p.cfg.WebConfig.Session.SessionOn + originRouterInfo, originFindRouter = p.FindRouter(ctx) + if originFindRouter { + currentSessionOn = originRouterInfo.sessionOn + } + if currentSessionOn { ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) if err != nil { logs.Error(err) @@ -995,7 +1301,7 @@ func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, ex // FindRouter Find Router info for URL func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { - var urlPath = context.Input.URL() + urlPath := context.Input.URL() if !p.cfg.RouterCaseSensitive { urlPath = strings.ToLower(urlPath) } diff --git a/server/web/router_test.go b/server/web/router_test.go index 87997322..5009d24e 100644 --- a/server/web/router_test.go +++ b/server/web/router_test.go @@ -16,16 +16,34 @@ package web import ( "bytes" + "fmt" "net/http" "net/http/httptest" "strings" "testing" "github.com/beego/beego/v2/core/logs" - "github.com/beego/beego/v2/server/web/context" ) +type PrefixTestController struct { + Controller +} + +func (ptc *PrefixTestController) PrefixList() { + ptc.Ctx.Output.Body([]byte("i am list in prefix test")) +} + +type TestControllerWithInterface struct{} + +func (m TestControllerWithInterface) Ping() { + fmt.Println("pong") +} + +func (m *TestControllerWithInterface) PingPointer() { + fmt.Println("pong pointer") +} + type TestController struct { Controller } @@ -87,10 +105,24 @@ func (jc *JSONController) Get() { jc.Ctx.Output.Body([]byte("ok")) } +func TestPrefixUrlFor(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList")) + + if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` { + logs.Info(a) + t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list") + } + if a := handler.URLFor(`TestController.PrefixList`); a != `` { + logs.Info(a) + t.Errorf("TestController.PrefixList must equal to empty string") + } +} + func TestUrlFor(t *testing.T) { handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.Add("/person/:last/:first", &TestController{}, "*:Param") + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) if a := handler.URLFor("TestController.List"); a != "/api/list" { logs.Info(a) t.Errorf("TestController.List must equal to /api/list") @@ -103,19 +135,21 @@ func TestUrlFor(t *testing.T) { func TestUrlFor3(t *testing.T) { handler := NewControllerRegister() handler.AddAuto(&TestController{}) - if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" { + a := handler.URLFor("TestController.Myext") + if a != "/test/myext" && a != "/Test/Myext" { t.Errorf("TestController.Myext must equal to /test/myext, but get " + a) } - if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" { + a = handler.URLFor("TestController.GetURL") + if a != "/test/geturl" && a != "/Test/GetURL" { t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a) } } func TestUrlFor2(t *testing.T) { handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") - handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL")) + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { logs.Info(handler.URLFor("TestController.GetURL")) @@ -145,7 +179,7 @@ func TestUserFunc(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) handler.ServeHTTP(w, r) if w.Body.String() != "i am list" { t.Errorf("user define func can't run") @@ -213,7 +247,6 @@ func TestAutoExtFunc(t *testing.T) { } func TestEscape(t *testing.T) { - r, _ := http.NewRequest("GET", "/search/%E4%BD%A0%E5%A5%BD", nil) w := httptest.NewRecorder() @@ -230,12 +263,11 @@ func TestEscape(t *testing.T) { } func TestRouteOk(t *testing.T) { - r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams")) handler.ServeHTTP(w, r) body := w.Body.String() if body != "anderson+thomas+kungfu" { @@ -244,12 +276,11 @@ func TestRouteOk(t *testing.T) { } func TestManyRoute(t *testing.T) { - r, _ := http.NewRequest("GET", "/beego32-12.html", nil) w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter")) handler.ServeHTTP(w, r) body := w.Body.String() @@ -261,12 +292,11 @@ func TestManyRoute(t *testing.T) { // Test for issue #1669 func TestEmptyResponse(t *testing.T) { - r, _ := http.NewRequest("GET", "/beego-empty.html", nil) w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") + handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody")) handler.ServeHTTP(w, r) if body := w.Body.String(); body != "" { @@ -324,7 +354,7 @@ func TestAutoPrefix(t *testing.T) { } } -func TestRouterGet(t *testing.T) { +func TestCtrlGet(t *testing.T) { r, _ := http.NewRequest("GET", "/user", nil) w := httptest.NewRecorder() @@ -334,11 +364,11 @@ func TestRouterGet(t *testing.T) { }) handler.ServeHTTP(w, r) if w.Body.String() != "Get userlist" { - t.Errorf("TestRouterGet can't run") + t.Errorf("TestCtrlGet can't run") } } -func TestRouterPost(t *testing.T) { +func TestCtrlPost(t *testing.T) { r, _ := http.NewRequest("POST", "/user/123", nil) w := httptest.NewRecorder() @@ -348,7 +378,7 @@ func TestRouterPost(t *testing.T) { }) handler.ServeHTTP(w, r) if w.Body.String() != "123" { - t.Errorf("TestRouterPost can't run") + t.Errorf("TestCtrlPost can't run") } } @@ -750,3 +780,320 @@ func TestRouterEntityTooLargeCopyBody(t *testing.T) { t.Errorf("TestRouterRequestEntityTooLarge can't run") } } + +func TestRouterSessionSet(t *testing.T) { + oldGlobalSessionOn := BConfig.WebConfig.Session.SessionOn + defer func() { + BConfig.WebConfig.Session.SessionOn = oldGlobalSessionOn + }() + + // global sessionOn = false, router sessionOn = false + r, _ := http.NewRequest("GET", "/user", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = false, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + BConfig.WebConfig.Session.SessionOn = true + if err := registerSession(); err != nil { + t.Errorf("register session failed, error: %s", err.Error()) + } + // global sessionOn = true, router sessionOn = false + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = true, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") == "" { + t.Errorf("TestRotuerSessionSet failed") + } +} + +func TestRouterCtrlGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlGet("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlGet can't run") + } +} + +func TestRouterCtrlPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPost("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlPost can't run") + } +} + +func TestRouterCtrlHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlHead("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlHead can't run") + } +} + +func TestRouterCtrlPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPut("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlPut can't run") + } +} + +func TestRouterCtrlPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPatch("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlPatch can't run") + } +} + +func TestRouterCtrlDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlDelete("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlDelete can't run") + } +} + +func TestRouterCtrlAny(t *testing.T) { + handler := NewControllerRegister() + handler.CtrlAny("/user", ExampleController.Ping) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/user", nil) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlAny can't run, get the response is " + w.Body.String()) + } + } +} + +func TestRouterCtrlGetPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlGet("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlGetPointerMethod can't run") + } +} + +func TestRouterCtrlPostPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPost("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlPostPointerMethod can't run") + } +} + +func TestRouterCtrlHeadPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlHead("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlHeadPointerMethod can't run") + } +} + +func TestRouterCtrlPutPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPut("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlPutPointerMethod can't run") + } +} + +func TestRouterCtrlPatchPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPatch("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlPatchPointerMethod can't run") + } +} + +func TestRouterCtrlDeletePointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlDelete("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlDeletePointerMethod can't run") + } +} + +func TestRouterCtrlAnyPointerMethod(t *testing.T) { + handler := NewControllerRegister() + handler.CtrlAny("/user", (*ExampleController).PingPointer) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/user", nil) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlAnyPointerMethod can't run, get the response is " + w.Body.String()) + } + } +} + +func TestRouterAddRouterMethodPanicInvalidMethod(t *testing.T) { + method := "some random method" + message := "not support http method: " + strings.ToUpper(method) + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicInvalidMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController.Ping) +} + +func TestRouterAddRouterMethodPanicNotAMethod(t *testing.T) { + method := http.MethodGet + message := "not a method" + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotAMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController{}) +} + +func TestRouterAddRouterMethodPanicNotPublicMethod(t *testing.T) { + method := http.MethodGet + message := "ping is not a public method" + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotPublicMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController.ping) +} + +func TestRouterAddRouterMethodPanicNotImplementInterface(t *testing.T) { + method := http.MethodGet + message := "web.TestControllerWithInterface is not implemented ControllerInterface" + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotImplementInterface failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", TestControllerWithInterface.Ping) +} + +func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) { + method := http.MethodGet + message := "web.TestControllerWithInterface is not implemented ControllerInterface" + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterPointerMethodPanicNotImplementInterface failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer) +} diff --git a/server/web/server.go b/server/web/server.go index f0a4f4ea..ec9b6ef9 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -32,18 +32,15 @@ import ( "golang.org/x/crypto/acme/autocert" "github.com/beego/beego/v2/core/logs" - beecontext "github.com/beego/beego/v2/server/web/context" - "github.com/beego/beego/v2/core/utils" + beecontext "github.com/beego/beego/v2/server/web/context" "github.com/beego/beego/v2/server/web/grace" ) -var ( - // BeeApp is an application instance - // If you are using single server, you could use this - // But if you need multiple servers, do not use this - BeeApp *HttpServer -) +// BeeApp is an application instance +// If you are using single server, you could use this +// But if you need multiple servers, do not use this +var BeeApp *HttpServer func init() { // create beego application @@ -81,10 +78,11 @@ type MiddleWare func(http.Handler) http.Handler // Run beego application. func (app *HttpServer) Run(addr string, mws ...MiddleWare) { - initBeforeHTTPRun() + // init... app.initAddr(addr) + app.Handlers.Init() addr = app.Cfg.Listen.HTTPAddr @@ -233,7 +231,6 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { endRunning <- true } }() - } if app.Cfg.Listen.EnableHTTP { go func() { @@ -267,7 +264,11 @@ func (app *HttpServer) Run(addr string, mws ...MiddleWare) { // Router see HttpServer.Router func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - return BeeApp.Router(rootpath, c, mappingMethods...) + return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...)) +} + +func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { + return BeeApp.RouterWithOpts(rootpath, c, opts...) } // Router adds a patterned controller handler to BeeApp. @@ -287,7 +288,11 @@ func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *H // beego.Router("/api/update",&RestController{},"put:UpdateFood") // beego.Router("/api/delete",&RestController{},"delete:DeleteFood") func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer { - app.Handlers.Add(rootPath, c, mappingMethods...) + return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...)) +} + +func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { + app.Handlers.Add(rootPath, c, opts...) return app } @@ -432,7 +437,7 @@ func AutoRouter(c ControllerInterface) *HttpServer { // AutoRouter adds defined controller handler to BeeApp. // it's same to HttpServer.AutoRouter. -// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, +// if beego.AddAuto(&MainController{}) and MainController has methods List and Page, // visit the url /main/list to exec List function or /main/page to exec Page function. func (app *HttpServer) AutoRouter(c ControllerInterface) *HttpServer { app.Handlers.AddAuto(c) @@ -446,15 +451,175 @@ func AutoPrefix(prefix string, c ControllerInterface) *HttpServer { // AutoPrefix adds controller handler to BeeApp with prefix. // it's same to HttpServer.AutoRouterWithPrefix. -// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, +// if beego.AutoPrefix("/admin",&MainController{}) and MainController has methods List and Page, // visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. func (app *HttpServer) AutoPrefix(prefix string, c ControllerInterface) *HttpServer { app.Handlers.AddAutoPrefix(prefix, c) return app } +// CtrlGet see HttpServer.CtrlGet +func CtrlGet(rootpath string, f interface{}) { + BeeApp.CtrlGet(rootpath, f) +} + +// CtrlGet used to register router for CtrlGet method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlGet("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlGet(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlGet(rootpath, f) + return app +} + +// CtrlPost see HttpServer.CtrlGet +func CtrlPost(rootpath string, f interface{}) { + BeeApp.CtrlPost(rootpath, f) +} + +// CtrlPost used to register router for CtrlPost method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPost("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlPost(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlPost(rootpath, f) + return app +} + +// CtrlHead see HttpServer.CtrlHead +func CtrlHead(rootpath string, f interface{}) { + BeeApp.CtrlHead(rootpath, f) +} + +// CtrlHead used to register router for CtrlHead method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlHead("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlHead(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlHead(rootpath, f) + return app +} + +// CtrlPut see HttpServer.CtrlPut +func CtrlPut(rootpath string, f interface{}) { + BeeApp.CtrlPut(rootpath, f) +} + +// CtrlPut used to register router for CtrlPut method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPut("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlPut(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlPut(rootpath, f) + return app +} + +// CtrlPatch see HttpServer.CtrlPatch +func CtrlPatch(rootpath string, f interface{}) { + BeeApp.CtrlPatch(rootpath, f) +} + +// CtrlPatch used to register router for CtrlPatch method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPatch("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlPatch(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlPatch(rootpath, f) + return app +} + +// CtrlDelete see HttpServer.CtrlDelete +func CtrlDelete(rootpath string, f interface{}) { + BeeApp.CtrlDelete(rootpath, f) +} + +// CtrlDelete used to register router for CtrlDelete method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlDelete("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlDelete(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlDelete(rootpath, f) + return app +} + +// CtrlOptions see HttpServer.CtrlOptions +func CtrlOptions(rootpath string, f interface{}) { + BeeApp.CtrlOptions(rootpath, f) +} + +// CtrlOptions used to register router for CtrlOptions method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlOptions("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlOptions(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlOptions(rootpath, f) + return app +} + +// CtrlAny see HttpServer.CtrlAny +func CtrlAny(rootpath string, f interface{}) { + BeeApp.CtrlAny(rootpath, f) +} + +// CtrlAny used to register router for CtrlAny method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlAny("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlAny(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlAny(rootpath, f) + return app +} + // Get see HttpServer.Get -func Get(rootpath string, f FilterFunc) *HttpServer { +func Get(rootpath string, f HandleFunc) *HttpServer { return BeeApp.Get(rootpath, f) } @@ -463,13 +628,13 @@ func Get(rootpath string, f FilterFunc) *HttpServer { // beego.Get("/", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (app *HttpServer) Get(rootpath string, f FilterFunc) *HttpServer { +func (app *HttpServer) Get(rootpath string, f HandleFunc) *HttpServer { app.Handlers.Get(rootpath, f) return app } // Post see HttpServer.Post -func Post(rootpath string, f FilterFunc) *HttpServer { +func Post(rootpath string, f HandleFunc) *HttpServer { return BeeApp.Post(rootpath, f) } @@ -478,13 +643,13 @@ func Post(rootpath string, f FilterFunc) *HttpServer { // beego.Post("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (app *HttpServer) Post(rootpath string, f FilterFunc) *HttpServer { +func (app *HttpServer) Post(rootpath string, f HandleFunc) *HttpServer { app.Handlers.Post(rootpath, f) return app } // Delete see HttpServer.Delete -func Delete(rootpath string, f FilterFunc) *HttpServer { +func Delete(rootpath string, f HandleFunc) *HttpServer { return BeeApp.Delete(rootpath, f) } @@ -493,13 +658,13 @@ func Delete(rootpath string, f FilterFunc) *HttpServer { // beego.Delete("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (app *HttpServer) Delete(rootpath string, f FilterFunc) *HttpServer { +func (app *HttpServer) Delete(rootpath string, f HandleFunc) *HttpServer { app.Handlers.Delete(rootpath, f) return app } // Put see HttpServer.Put -func Put(rootpath string, f FilterFunc) *HttpServer { +func Put(rootpath string, f HandleFunc) *HttpServer { return BeeApp.Put(rootpath, f) } @@ -508,13 +673,13 @@ func Put(rootpath string, f FilterFunc) *HttpServer { // beego.Put("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (app *HttpServer) Put(rootpath string, f FilterFunc) *HttpServer { +func (app *HttpServer) Put(rootpath string, f HandleFunc) *HttpServer { app.Handlers.Put(rootpath, f) return app } // Head see HttpServer.Head -func Head(rootpath string, f FilterFunc) *HttpServer { +func Head(rootpath string, f HandleFunc) *HttpServer { return BeeApp.Head(rootpath, f) } @@ -523,13 +688,13 @@ func Head(rootpath string, f FilterFunc) *HttpServer { // beego.Head("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (app *HttpServer) Head(rootpath string, f FilterFunc) *HttpServer { +func (app *HttpServer) Head(rootpath string, f HandleFunc) *HttpServer { app.Handlers.Head(rootpath, f) return app } // Options see HttpServer.Options -func Options(rootpath string, f FilterFunc) *HttpServer { +func Options(rootpath string, f HandleFunc) *HttpServer { BeeApp.Handlers.Options(rootpath, f) return BeeApp } @@ -539,13 +704,13 @@ func Options(rootpath string, f FilterFunc) *HttpServer { // beego.Options("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (app *HttpServer) Options(rootpath string, f FilterFunc) *HttpServer { +func (app *HttpServer) Options(rootpath string, f HandleFunc) *HttpServer { app.Handlers.Options(rootpath, f) return app } // Patch see HttpServer.Patch -func Patch(rootpath string, f FilterFunc) *HttpServer { +func Patch(rootpath string, f HandleFunc) *HttpServer { return BeeApp.Patch(rootpath, f) } @@ -554,13 +719,13 @@ func Patch(rootpath string, f FilterFunc) *HttpServer { // beego.Patch("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (app *HttpServer) Patch(rootpath string, f FilterFunc) *HttpServer { +func (app *HttpServer) Patch(rootpath string, f HandleFunc) *HttpServer { app.Handlers.Patch(rootpath, f) return app } // Any see HttpServer.Any -func Any(rootpath string, f FilterFunc) *HttpServer { +func Any(rootpath string, f HandleFunc) *HttpServer { return BeeApp.Any(rootpath, f) } @@ -569,7 +734,7 @@ func Any(rootpath string, f FilterFunc) *HttpServer { // beego.Any("/api", func(ctx *context.Context){ // ctx.Output.Body("hello world") // }) -func (app *HttpServer) Any(rootpath string, f FilterFunc) *HttpServer { +func (app *HttpServer) Any(rootpath string, f HandleFunc) *HttpServer { app.Handlers.Any(rootpath, f) return app } @@ -695,21 +860,21 @@ func printTree(resultList *[][]string, t *Tree) { for _, l := range t.leaves { if v, ok := l.runObject.(*ControllerInfo); ok { if v.routerType == routerTypeBeego { - var result = []string{ + result := []string{ template.HTMLEscapeString(v.pattern), template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), template.HTMLEscapeString(v.controllerType.String()), } *resultList = append(*resultList, result) } else if v.routerType == routerTypeRESTFul { - var result = []string{ + result := []string{ template.HTMLEscapeString(v.pattern), template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), "", } *resultList = append(*resultList, result) } else if v.routerType == routerTypeHandler { - var result = []string{ + result := []string{ template.HTMLEscapeString(v.pattern), "", "", @@ -735,7 +900,7 @@ func (app *HttpServer) reportFilter() M { if bf := app.Handlers.filters[k]; len(bf) > 0 { resultList := new([][]string) for _, f := range bf { - var result = []string{ + result := []string{ // void xss template.HTMLEscapeString(f.pattern), template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), diff --git a/server/web/server_test.go b/server/web/server_test.go index 0b0c601c..1ddf217e 100644 --- a/server/web/server_test.go +++ b/server/web/server_test.go @@ -15,16 +15,95 @@ package web import ( + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" ) func TestNewHttpServerWithCfg(t *testing.T) { - BConfig.AppName = "Before" svr := NewHttpServerWithCfg(BConfig) svr.Cfg.AppName = "hello" assert.Equal(t, "hello", BConfig.AppName) - +} + +func TestServerCtrlGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + CtrlGet("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlGet can't run") + } +} + +func TestServerCtrlPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + CtrlPost("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlPost can't run") + } +} + +func TestServerCtrlHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + CtrlHead("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlHead can't run") + } +} + +func TestServerCtrlPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + CtrlPut("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlPut can't run") + } +} + +func TestServerCtrlPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + CtrlPatch("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlPatch can't run") + } +} + +func TestServerCtrlDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + CtrlDelete("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlDelete can't run") + } +} + +func TestServerCtrlAny(t *testing.T) { + CtrlAny("/user", ExampleController.Ping) + + for method := range HTTPMETHOD { + r, _ := http.NewRequest(method, "/user", nil) + w := httptest.NewRecorder() + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlAny can't run") + } + } } diff --git a/server/web/session/README.md b/server/web/session/README.md index 854fb590..8dd70f67 100644 --- a/server/web/session/README.md +++ b/server/web/session/README.md @@ -6,7 +6,7 @@ and `database/sql/driver`. ## How to install? - go get github.com/beego/beego/v2/session + go get github.com/beego/beego/v2/server/web/session ## What providers are supported? @@ -17,7 +17,7 @@ As of now this session manager support memory, file, Redis and MySQL. First you must import it import ( - "github.com/beego/beego/v2/session" + "github.com/beego/beego/v2/server/web/session" ) Then in you web app init the global session manager diff --git a/server/web/session/couchbase/sess_couchbase.go b/server/web/session/couchbase/sess_couchbase.go index ea94f501..b9075040 100644 --- a/server/web/session/couchbase/sess_couchbase.go +++ b/server/web/session/couchbase/sess_couchbase.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/couchbase" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/couchbase" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/ledis/ledis_session.go b/server/web/session/ledis/ledis_session.go index 8e34388b..80524e20 100644 --- a/server/web/session/ledis/ledis_session.go +++ b/server/web/session/ledis/ledis_session.go @@ -186,6 +186,7 @@ func (lp *Provider) SessionGC(context.Context) { func (lp *Provider) SessionAll(context.Context) int { return 0 } + func init() { session.Register("ledis", ledispder) } diff --git a/server/web/session/memcache/sess_memcache.go b/server/web/session/memcache/sess_memcache.go index 3f4c9842..11aee787 100644 --- a/server/web/session/memcache/sess_memcache.go +++ b/server/web/session/memcache/sess_memcache.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/memcache" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/memcache" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -38,13 +38,15 @@ import ( "strings" "sync" - "github.com/beego/beego/v2/server/web/session" - "github.com/bradfitz/gomemcache/memcache" + + "github.com/beego/beego/v2/server/web/session" ) -var mempder = &MemProvider{} -var client *memcache.Client +var ( + mempder = &MemProvider{} + client *memcache.Client +) // SessionStore memcache session store type SessionStore struct { diff --git a/server/web/session/mysql/sess_mysql.go b/server/web/session/mysql/sess_mysql.go index d76ec287..2755bded 100644 --- a/server/web/session/mysql/sess_mysql.go +++ b/server/web/session/mysql/sess_mysql.go @@ -28,8 +28,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/mysql" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/mysql" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -47,9 +47,10 @@ import ( "sync" "time" - "github.com/beego/beego/v2/server/web/session" // import mysql driver _ "github.com/go-sql-driver/mysql" + + "github.com/beego/beego/v2/server/web/session" ) var ( @@ -150,6 +151,8 @@ func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, if err == sql.ErrNoRows { c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) + } else if err != nil { + return nil, err } var kv map[interface{}]interface{} if len(sessiondata) == 0 { @@ -189,7 +192,10 @@ func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) ( if err == sql.ErrNoRows { c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) } - c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) + _, err = c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) + if err != nil { + return nil, err + } var kv map[interface{}]interface{} if len(sessiondata) == 0 { kv = make(map[interface{}]interface{}) diff --git a/server/web/session/postgres/sess_postgresql.go b/server/web/session/postgres/sess_postgresql.go index 7745ff5f..9914cc0c 100644 --- a/server/web/session/postgres/sess_postgresql.go +++ b/server/web/session/postgres/sess_postgresql.go @@ -38,8 +38,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/postgresql" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/postgresql" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -57,9 +57,10 @@ import ( "sync" "time" - "github.com/beego/beego/v2/server/web/session" // import postgresql Driver _ "github.com/lib/pq" + + "github.com/beego/beego/v2/server/web/session" ) var postgresqlpder = &Provider{} @@ -122,7 +123,6 @@ func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite } st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3", b, time.Now().Format(time.RFC3339), st.sid) - } // Provider postgresql session provider diff --git a/server/web/session/redis/sess_redis.go b/server/web/session/redis/sess_redis.go index e3d38be3..acc25f78 100644 --- a/server/web/session/redis/sess_redis.go +++ b/server/web/session/redis/sess_redis.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/redis/sess_redis_test.go b/server/web/session/redis/sess_redis_test.go index fe5c363b..ef9834ad 100644 --- a/server/web/session/redis/sess_redis_test.go +++ b/server/web/session/redis/sess_redis_test.go @@ -15,21 +15,22 @@ import ( ) func TestRedis(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - } - redisAddr := os.Getenv("REDIS_ADDR") if redisAddr == "" { redisAddr = "127.0.0.1:6379" } + redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr) + + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig(redisConfig), + ) - sessionConfig.ProviderConfig = fmt.Sprintf("%s,100,,0,30", redisAddr) globalSession, err := session.NewManager("redis", sessionConfig) if err != nil { t.Fatal("could not create manager:", err) @@ -100,7 +101,6 @@ func TestRedis(t *testing.T) { } func TestProvider_SessionInit(t *testing.T) { - savePath := ` { "save_path": "my save path", "idle_timeout": "3s"} ` diff --git a/server/web/session/redis_cluster/redis_cluster.go b/server/web/session/redis_cluster/redis_cluster.go index e94dccc3..4db3bbe9 100644 --- a/server/web/session/redis_cluster/redis_cluster.go +++ b/server/web/session/redis_cluster/redis_cluster.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis_cluster" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis_cluster" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/redis_cluster/redis_cluster_test.go b/server/web/session/redis_cluster/redis_cluster_test.go index 0192cd87..e518af54 100644 --- a/server/web/session/redis_cluster/redis_cluster_test.go +++ b/server/web/session/redis_cluster/redis_cluster_test.go @@ -23,7 +23,6 @@ import ( ) func TestProvider_SessionInit(t *testing.T) { - savePath := ` { "save_path": "my save path", "idle_timeout": "3s"} ` diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel.go b/server/web/session/redis_sentinel/sess_redis_sentinel.go index 2d64c6b4..d18a3773 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel.go @@ -20,8 +20,8 @@ // // Usage: // import( -// _ "github.com/beego/beego/v2/session/redis_sentinel" -// "github.com/beego/beego/v2/session" +// _ "github.com/beego/beego/v2/server/web/session/redis_sentinel" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go index 0a8030ce..276fb5d8 100644 --- a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go +++ b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go @@ -13,15 +13,15 @@ import ( ) func TestRedisSentinel(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - ProviderConfig: "127.0.0.1:6379,100,,0,master", - } + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"), + ) globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) if e != nil { t.Log(e) @@ -90,11 +90,9 @@ func TestRedisSentinel(t *testing.T) { } sess.SessionRelease(nil, w) - } func TestProvider_SessionInit(t *testing.T) { - savePath := ` { "save_path": "my save path", "idle_timeout": "3s"} ` diff --git a/server/web/session/sess_cookie.go b/server/web/session/sess_cookie.go index 649f6510..622fb2fe 100644 --- a/server/web/session/sess_cookie.go +++ b/server/web/session/sess_cookie.go @@ -79,12 +79,14 @@ func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.Respons encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) st.lock.Unlock() if err == nil { - cookie := &http.Cookie{Name: cookiepder.config.CookieName, + cookie := &http.Cookie{ + Name: cookiepder.config.CookieName, Value: url.QueryEscape(encodedCookie), Path: "/", HttpOnly: true, Secure: cookiepder.config.Secure, - MaxAge: cookiepder.config.Maxage} + MaxAge: cookiepder.config.Maxage, + } http.SetCookie(w, cookie) } } diff --git a/server/web/session/sess_file.go b/server/web/session/sess_file.go index 90de9a79..14cec1d9 100644 --- a/server/web/session/sess_file.go +++ b/server/web/session/sess_file.go @@ -91,7 +91,7 @@ func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseW _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) var f *os.File if err == nil { - f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) + f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0o777) if err != nil { SLogger.Println(err) return @@ -140,23 +140,32 @@ func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, err filepder.lock.Lock() defer filepder.lock.Unlock() - err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0755) + sessionPath := filepath.Join(fp.savePath, string(sid[0]), string(sid[1])) + sidPath := filepath.Join(sessionPath, sid) + err := os.MkdirAll(sessionPath, 0o755) if err != nil { SLogger.Println(err.Error()) } - _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) var f *os.File - if err == nil { - f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777) - } else if os.IsNotExist(err) { - f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - } else { + _, err = os.Stat(sidPath) + switch { + case err == nil: + f, err = os.OpenFile(sidPath, os.O_RDWR, 0o777) + if err != nil { + return nil, err + } + case os.IsNotExist(err): + f, err = os.Create(sidPath) + if err != nil { + return nil, err + } + default: return nil, err } defer f.Close() - os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) + os.Chtimes(sidPath, time.Now(), time.Now()) var kv map[interface{}]interface{} b, err := ioutil.ReadAll(f) if err != nil { @@ -211,9 +220,7 @@ func (fp *FileProvider) SessionGC(context.Context) { // it walks save path to count files. func (fp *FileProvider) SessionAll(context.Context) int { a := &activeSession{} - err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { - return a.visit(path, f, err) - }) + err := filepath.Walk(fp.savePath, a.visit) if err != nil { SLogger.Printf("filepath.Walk() returned %v\n", err) return 0 @@ -238,7 +245,7 @@ func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid strin return nil, fmt.Errorf("newsid %s exist", newSidFile) } - err = os.MkdirAll(newPath, 0755) + err = os.MkdirAll(newPath, 0o755) if err != nil { SLogger.Println(err.Error()) } @@ -265,7 +272,7 @@ func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid strin } } - ioutil.WriteFile(newSidFile, b, 0777) + ioutil.WriteFile(newSidFile, b, 0o777) os.Remove(oldSidFile) os.Chtimes(newSidFile, time.Now(), time.Now()) ss := &FileSessionStore{sid: sid, values: kv} diff --git a/server/web/session/sess_file_test.go b/server/web/session/sess_file_test.go index f40de69f..84860813 100644 --- a/server/web/session/sess_file_test.go +++ b/server/web/session/sess_file_test.go @@ -23,15 +23,15 @@ import ( "time" ) -const sid = "Session_id" -const sidNew = "Session_id_new" -const sessionPath = "./_session_runtime" - -var ( - mutex sync.Mutex +const ( + sid = "Session_id" + sidNew = "Session_id_new" + sessionPath = "./_session_runtime" ) -func TestFileProvider_SessionInit(t *testing.T) { +var mutex sync.Mutex + +func TestFileProviderSessionInit(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -48,7 +48,7 @@ func TestFileProvider_SessionInit(t *testing.T) { } } -func TestFileProvider_SessionExist(t *testing.T) { +func TestFileProviderSessionExist(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -79,7 +79,7 @@ func TestFileProvider_SessionExist(t *testing.T) { } } -func TestFileProvider_SessionExist2(t *testing.T) { +func TestFileProviderSessionExist2(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -113,7 +113,7 @@ func TestFileProvider_SessionExist2(t *testing.T) { } } -func TestFileProvider_SessionRead(t *testing.T) { +func TestFileProviderSessionRead(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -135,7 +135,7 @@ func TestFileProvider_SessionRead(t *testing.T) { } } -func TestFileProvider_SessionRead1(t *testing.T) { +func TestFileProviderSessionRead1(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -155,7 +155,7 @@ func TestFileProvider_SessionRead1(t *testing.T) { } } -func TestFileProvider_SessionAll(t *testing.T) { +func TestFileProviderSessionAll(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -178,7 +178,7 @@ func TestFileProvider_SessionAll(t *testing.T) { } } -func TestFileProvider_SessionRegenerate(t *testing.T) { +func TestFileProviderSessionRegenerate(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -222,7 +222,7 @@ func TestFileProvider_SessionRegenerate(t *testing.T) { } } -func TestFileProvider_SessionDestroy(t *testing.T) { +func TestFileProviderSessionDestroy(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -258,7 +258,7 @@ func TestFileProvider_SessionDestroy(t *testing.T) { } } -func TestFileProvider_SessionGC(t *testing.T) { +func TestFileProviderSessionGC(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -284,7 +284,7 @@ func TestFileProvider_SessionGC(t *testing.T) { } } -func TestFileSessionStore_Set(t *testing.T) { +func TestFileSessionStoreSet(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -303,7 +303,7 @@ func TestFileSessionStore_Set(t *testing.T) { } } -func TestFileSessionStore_Get(t *testing.T) { +func TestFileSessionStoreGet(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -324,7 +324,7 @@ func TestFileSessionStore_Get(t *testing.T) { } } -func TestFileSessionStore_Delete(t *testing.T) { +func TestFileSessionStoreDelete(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -347,7 +347,7 @@ func TestFileSessionStore_Delete(t *testing.T) { } } -func TestFileSessionStore_Flush(t *testing.T) { +func TestFileSessionStoreFlush(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -371,7 +371,7 @@ func TestFileSessionStore_Flush(t *testing.T) { } } -func TestFileSessionStore_SessionID(t *testing.T) { +func TestFileSessionStoreSessionID(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) @@ -393,7 +393,7 @@ func TestFileSessionStore_SessionID(t *testing.T) { } } -func TestFileSessionStore_SessionRelease(t *testing.T) { +func TestFileSessionStoreSessionRelease(t *testing.T) { mutex.Lock() defer mutex.Unlock() os.RemoveAll(sessionPath) diff --git a/server/web/session/session.go b/server/web/session/session.go index de63ed75..881de8ea 100644 --- a/server/web/session/session.go +++ b/server/web/session/session.go @@ -16,7 +16,7 @@ // // Usage: // import( -// "github.com/beego/beego/v2/session" +// "github.com/beego/beego/v2/server/web/session" // ) // // func init() { @@ -70,15 +70,11 @@ var provides = make(map[string]Provider) var SLogger = NewSessionLog(os.Stderr) // Register makes a session provide available by the provided name. -// If Register is called twice with the same name or if driver is nil, -// it panics. +// If provider is nil, it panic func Register(name string, provide Provider) { if provide == nil { panic("session: Register provide is nil") } - if _, dup := provides[name]; dup { - panic("session: Register called twice for provider " + name) - } provides[name] = provide } @@ -91,24 +87,6 @@ func GetProvider(name string) (Provider, error) { return provider, nil } -// ManagerConfig define the session config -type ManagerConfig struct { - CookieName string `json:"cookieName"` - EnableSetCookie bool `json:"enableSetCookie,omitempty"` - Gclifetime int64 `json:"gclifetime"` - Maxlifetime int64 `json:"maxLifetime"` - DisableHTTPOnly bool `json:"disableHTTPOnly"` - Secure bool `json:"secure"` - CookieLifeTime int `json:"cookieLifeTime"` - ProviderConfig string `json:"providerConfig"` - Domain string `json:"domain"` - SessionIDLength int64 `json:"sessionIDLength"` - EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` - SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` - EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` - SessionIDPrefix string `json:"sessionIDPrefix"` -} - // Manager contains Provider and its configuration. type Manager struct { provider Provider @@ -239,6 +217,7 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se HttpOnly: !manager.config.DisableHTTPOnly, Secure: manager.isSecure(r), Domain: manager.config.Domain, + SameSite: manager.config.CookieSameSite, } if manager.config.CookieLifeTime > 0 { cookie.MaxAge = manager.config.CookieLifeTime @@ -273,12 +252,15 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { manager.provider.SessionDestroy(nil, sid) if manager.config.EnableSetCookie { expiration := time.Now() - cookie = &http.Cookie{Name: manager.config.CookieName, + cookie = &http.Cookie{ + Name: manager.config.CookieName, Path: "/", HttpOnly: !manager.config.DisableHTTPOnly, Expires: expiration, MaxAge: -1, - Domain: manager.config.Domain} + Domain: manager.config.Domain, + SameSite: manager.config.CookieSameSite, + } http.SetCookie(w, cookie) } @@ -313,12 +295,14 @@ func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Reque if err != nil { return nil, err } - cookie = &http.Cookie{Name: manager.config.CookieName, + cookie = &http.Cookie{ + Name: manager.config.CookieName, Value: url.QueryEscape(sid), Path: "/", HttpOnly: !manager.config.DisableHTTPOnly, Secure: manager.isSecure(r), Domain: manager.config.Domain, + SameSite: manager.config.CookieSameSite, } } else { oldsid, err := url.QueryUnescape(cookie.Value) diff --git a/server/web/session/session_config.go b/server/web/session/session_config.go new file mode 100644 index 00000000..d9514003 --- /dev/null +++ b/server/web/session/session_config.go @@ -0,0 +1,143 @@ +package session + +import "net/http" + +// ManagerConfig define the session config +type ManagerConfig struct { + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` + Secure bool `json:"secure"` + EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` + EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` + CookieName string `json:"cookieName"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIDLength int64 `json:"sessionIDLength"` + SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` + SessionIDPrefix string `json:"sessionIDPrefix"` + CookieSameSite http.SameSite `json:"cookieSameSite"` +} + +func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) { + for _, opt := range opts { + opt(c) + } +} + +type ManagerConfigOpt func(config *ManagerConfig) + +func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig { + config := &ManagerConfig{} + for _, opt := range opts { + opt(config) + } + return config +} + +// CfgCookieName set key of session id +func CfgCookieName(cookieName string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieName = cookieName + } +} + +// CfgCookieName set len of session id +func CfgSessionIdLength(length int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDLength = length + } +} + +// CfgSessionIdPrefix set prefix of session id +func CfgSessionIdPrefix(prefix string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDPrefix = prefix + } +} + +// CfgSetCookie whether set `Set-Cookie` header in HTTP response +func CfgSetCookie(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSetCookie = enable + } +} + +// CfgGcLifeTime set session gc lift time +func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Gclifetime = lifeTime + } +} + +// CfgMaxLifeTime set session lift time +func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Maxlifetime = lifeTime + } +} + +// CfgGcLifeTime set session lift time +func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieLifeTime = lifeTime + } +} + +// CfgProviderConfig configure session provider +func CfgProviderConfig(providerConfig string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.ProviderConfig = providerConfig + } +} + +// CfgDomain set cookie domain +func CfgDomain(domain string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Domain = domain + } +} + +// CfgSessionIdInHTTPHeader enable session id in http header +func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInHTTPHeader = enable + } +} + +// CfgSetSessionNameInHTTPHeader set key of session id in http header +func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionNameInHTTPHeader = name + } +} + +// EnableSidInURLQuery enable session id in query string +func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInURLQuery = enable + } +} + +// DisableHTTPOnly set HTTPOnly for http.Cookie +func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.DisableHTTPOnly = !HTTPOnly + } +} + +// CfgSecure set Secure for http.Cookie +func CfgSecure(Enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Secure = Enable + } +} + +// CfgSameSite set http.SameSite +func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieSameSite = sameSite + } +} diff --git a/server/web/session/session_config_test.go b/server/web/session/session_config_test.go new file mode 100644 index 00000000..0ea7d22b --- /dev/null +++ b/server/web/session/session_config_test.go @@ -0,0 +1,222 @@ +package session + +import ( + "net/http" + "testing" +) + +func TestCfgCookieLifeTime(t *testing.T) { + value := 8754 + c := NewManagerConfig( + CfgCookieLifeTime(value), + ) + + if c.CookieLifeTime != value { + t.Error() + } +} + +func TestCfgDomain(t *testing.T) { + value := `http://domain.com` + c := NewManagerConfig( + CfgDomain(value), + ) + + if c.Domain != value { + t.Error() + } +} + +func TestCfgSameSite(t *testing.T) { + value := http.SameSiteLaxMode + c := NewManagerConfig( + CfgSameSite(value), + ) + + if c.CookieSameSite != value { + t.Error() + } +} + +func TestCfgSecure(t *testing.T) { + c := NewManagerConfig( + CfgSecure(true), + ) + + if c.Secure != true { + t.Error() + } +} + +func TestCfgSecure1(t *testing.T) { + c := NewManagerConfig( + CfgSecure(false), + ) + + if c.Secure != false { + t.Error() + } +} + +func TestCfgSessionIdPrefix(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSessionIdPrefix(value), + ) + + if c.SessionIDPrefix != value { + t.Error() + } +} + +func TestCfgSetSessionNameInHTTPHeader(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSetSessionNameInHTTPHeader(value), + ) + + if c.SessionNameInHTTPHeader != value { + t.Error() + } +} + +func TestCfgCookieName(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgCookieName(value), + ) + + if c.CookieName != value { + t.Error() + } +} + +func TestCfgEnableSidInURLQuery(t *testing.T) { + c := NewManagerConfig( + CfgEnableSidInURLQuery(true), + ) + + if c.EnableSidInURLQuery != true { + t.Error() + } +} + +func TestCfgGcLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgGcLifeTime(value), + ) + + if c.Gclifetime != value { + t.Error() + } +} + +func TestCfgHTTPOnly(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(true), + ) + + if c.DisableHTTPOnly != false { + t.Error() + } +} + +func TestCfgHTTPOnly2(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(false), + ) + + if c.DisableHTTPOnly != true { + t.Error() + } +} + +func TestCfgMaxLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgMaxLifeTime(value), + ) + + if c.Maxlifetime != value { + t.Error() + } +} + +func TestCfgProviderConfig(t *testing.T) { + value := `asodiuasldkj12i39809as` + c := NewManagerConfig( + CfgProviderConfig(value), + ) + + if c.ProviderConfig != value { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(true), + ) + + if c.EnableSidInHTTPHeader != true { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader1(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(false), + ) + + if c.EnableSidInHTTPHeader != false { + t.Error() + } +} + +func TestCfgSessionIdLength(t *testing.T) { + value := int64(100) + c := NewManagerConfig( + CfgSessionIdLength(value), + ) + + if c.SessionIDLength != value { + t.Error() + } +} + +func TestCfgSetCookie(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(true), + ) + + if c.EnableSetCookie != true { + t.Error() + } +} + +func TestCfgSetCookie1(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(false), + ) + + if c.EnableSetCookie != false { + t.Error() + } +} + +func TestNewManagerConfig(t *testing.T) { + c := NewManagerConfig() + if c == nil { + t.Error() + } +} + +func TestManagerConfig_Opts(t *testing.T) { + c := NewManagerConfig() + c.Opts(CfgSetCookie(true)) + + if c.EnableSetCookie != true { + t.Error() + } +} diff --git a/server/web/session/session_provider_type.go b/server/web/session/session_provider_type.go new file mode 100644 index 00000000..78dc116d --- /dev/null +++ b/server/web/session/session_provider_type.go @@ -0,0 +1,18 @@ +package session + +type ProviderType string + +const ( + ProviderCookie ProviderType = `cookie` + ProviderFile ProviderType = `file` + ProviderMemory ProviderType = `memory` + ProviderCouchbase ProviderType = `couchbase` + ProviderLedis ProviderType = `ledis` + ProviderMemcache ProviderType = `memcache` + ProviderMysql ProviderType = `mysql` + ProviderPostgresql ProviderType = `postgresql` + ProviderRedis ProviderType = `redis` + ProviderRedisCluster ProviderType = `redis_cluster` + ProviderRedisSentinel ProviderType = `redis_sentinel` + ProviderSsdb ProviderType = `ssdb` +) diff --git a/server/web/staticfile.go b/server/web/staticfile.go index e5d3c3ed..7a120f68 100644 --- a/server/web/staticfile.go +++ b/server/web/staticfile.go @@ -29,7 +29,6 @@ import ( lru "github.com/hashicorp/golang-lru" "github.com/beego/beego/v2/core/logs" - "github.com/beego/beego/v2/server/web/context" ) @@ -76,7 +75,7 @@ func serverStaticRouter(ctx *context.Context) { return } - var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath) + enableCompress := BConfig.EnableGzip && isStaticCompress(filePath) var acceptEncoding string if enableCompress { acceptEncoding = context.ParseEncoding(ctx.Request) diff --git a/server/web/staticfile_test.go b/server/web/staticfile_test.go index 0725a2f8..621b4c17 100644 --- a/server/web/staticfile_test.go +++ b/server/web/staticfile_test.go @@ -12,8 +12,10 @@ import ( "testing" ) -var currentWorkDir, _ = os.Getwd() -var licenseFile = filepath.Join(currentWorkDir, "LICENSE") +var ( + currentWorkDir, _ = os.Getwd() + licenseFile = filepath.Join(currentWorkDir, "LICENSE") +) func testOpenFile(encoding string, content []byte, t *testing.T) { fi, _ := os.Stat(licenseFile) @@ -27,6 +29,7 @@ func testOpenFile(encoding string, content []byte, t *testing.T) { assetOpenFileAndContent(sch, reader, content, t) } + func TestOpenStaticFile_1(t *testing.T) { file, _ := os.Open(licenseFile) content, _ := ioutil.ReadAll(file) @@ -43,6 +46,7 @@ func TestOpenStaticFileGzip_1(t *testing.T) { testOpenFile("gzip", content, t) } + func TestOpenStaticFileDeflate_1(t *testing.T) { file, _ := os.Open(licenseFile) var zipBuf bytes.Buffer diff --git a/server/web/statistics.go b/server/web/statistics.go index 42bf9287..d0f3ebd9 100644 --- a/server/web/statistics.go +++ b/server/web/statistics.go @@ -66,7 +66,6 @@ func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController stri } m.urlmap[requestURL][requestMethod] = nb } - } else { if m.LengthLimit > 0 && m.LengthLimit <= len(m.urlmap) { return @@ -90,7 +89,7 @@ func (m *URLMap) GetMap() map[string]interface{} { m.lock.RLock() defer m.lock.RUnlock() - var fields = []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"} + fields := []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"} var resultLists [][]string content := make(map[string]interface{}) diff --git a/server/web/template.go b/server/web/template.go index 65935ca8..c683c565 100644 --- a/server/web/template.go +++ b/server/web/template.go @@ -202,9 +202,7 @@ func BuildTemplate(dir string, files ...string) error { root: dir, files: make(map[string][]string), } - err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error { - return self.visit(path, f, err) - }) + err = Walk(fs, dir, self.visit) if err != nil { fmt.Printf("Walk() returned %v\n", err) return err @@ -351,7 +349,6 @@ func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMod } } } - } return } diff --git a/server/web/template_test.go b/server/web/template_test.go index 1d82c2e2..2c2585ce 100644 --- a/server/web/template_test.go +++ b/server/web/template_test.go @@ -49,19 +49,18 @@ var block = `{{define "block"}} {{end}}` func TestTemplate(t *testing.T) { - wkdir, err := os.Getwd() - assert.Nil(t, err) - dir := filepath.Join(wkdir, "_beeTmp", "TestTemplate") + tmpDir := os.TempDir() + dir := filepath.Join(tmpDir, "_beeTmp", "TestTemplate") files := []string{ "header.tpl", "index.tpl", "blocks/block.tpl", } - if err := os.MkdirAll(dir, 0777); err != nil { + if err := os.MkdirAll(dir, 0o777); err != nil { t.Fatal(err) } for k, name := range files { - dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0o777) assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) @@ -101,6 +100,7 @@ var menu = ` ` + var user = ` @@ -113,9 +113,8 @@ var user = ` ` func TestRelativeTemplate(t *testing.T) { - wkdir, err := os.Getwd() - assert.Nil(t, err) - dir := filepath.Join(wkdir, "_beeTmp") + tmpDir := os.TempDir() + dir := filepath.Join(tmpDir, "_beeTmp") // Just add dir to known viewPaths if err := AddViewPath(dir); err != nil { @@ -126,11 +125,11 @@ func TestRelativeTemplate(t *testing.T) { "easyui/public/menu.tpl", "easyui/rbac/user.tpl", } - if err := os.MkdirAll(dir, 0777); err != nil { + if err := os.MkdirAll(dir, 0o777); err != nil { t.Fatal(err) } for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0o777) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { @@ -226,20 +225,20 @@ var output = ` ` func TestTemplateLayout(t *testing.T) { - wkdir, err := os.Getwd() + tmpDir, err := os.Getwd() assert.Nil(t, err) - dir := filepath.Join(wkdir, "_beeTmp", "TestTemplateLayout") + dir := filepath.Join(tmpDir, "_beeTmp", "TestTemplateLayout") files := []string{ "add.tpl", "layout_blog.tpl", } - if err := os.MkdirAll(dir, 0777); err != nil { + if err := os.MkdirAll(dir, 0o777); err != nil { t.Fatal(err) } for k, name := range files { - dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0o777) assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) diff --git a/server/web/templatefunc.go b/server/web/templatefunc.go index 53c99018..c0db9990 100644 --- a/server/web/templatefunc.go +++ b/server/web/templatefunc.go @@ -25,13 +25,8 @@ import ( "strconv" "strings" "time" -) -const ( - formatTime = "15:04:05" - formatDate = "2006-01-02" - formatDateTime = "2006-01-02 15:04:05" - formatDateTimeT = "2006-01-02T15:04:05" + "github.com/beego/beego/v2/server/web/context" ) // Substr returns the substr from start to length. @@ -54,7 +49,6 @@ func Substr(s string, start, length int) string { // HTML2str returns escaping text convert from html. func HTML2str(html string) string { - re := regexp.MustCompile(`\<[\S\s]+?\>`) html = re.ReplaceAllStringFunc(html, strings.ToLower) @@ -255,7 +249,6 @@ func URLFor(endpoint string, values ...interface{}) string { // AssetsJs returns script tag with src string. func AssetsJs(text string) template.HTML { - text = "" return template.HTML(text) @@ -263,169 +256,16 @@ func AssetsJs(text string) template.HTML { // AssetsCSS returns stylesheet link tag with src string. func AssetsCSS(text string) template.HTML { - text = "" return template.HTML(text) } -// ParseForm will parse form values to struct via tag. -// Support for anonymous struct. -func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error { - for i := 0; i < objT.NumField(); i++ { - fieldV := objV.Field(i) - if !fieldV.CanSet() { - continue - } - - fieldT := objT.Field(i) - if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct { - err := parseFormToStruct(form, fieldT.Type, fieldV) - if err != nil { - return err - } - continue - } - - tags := strings.Split(fieldT.Tag.Get("form"), ",") - var tag string - if len(tags) == 0 || len(tags[0]) == 0 { - tag = fieldT.Name - } else if tags[0] == "-" { - continue - } else { - tag = tags[0] - } - - formValues := form[tag] - var value string - if len(formValues) == 0 { - defaultValue := fieldT.Tag.Get("default") - if defaultValue != "" { - value = defaultValue - } else { - continue - } - } - if len(formValues) == 1 { - value = formValues[0] - if value == "" { - continue - } - } - - switch fieldT.Type.Kind() { - case reflect.Bool: - if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { - fieldV.SetBool(true) - continue - } - if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { - fieldV.SetBool(false) - continue - } - b, err := strconv.ParseBool(value) - if err != nil { - return err - } - fieldV.SetBool(b) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - x, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return err - } - fieldV.SetInt(x) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - x, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return err - } - fieldV.SetUint(x) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(value, 64) - if err != nil { - return err - } - fieldV.SetFloat(x) - case reflect.Interface: - fieldV.Set(reflect.ValueOf(value)) - case reflect.String: - fieldV.SetString(value) - case reflect.Struct: - switch fieldT.Type.String() { - case "time.Time": - var ( - t time.Time - err error - ) - if len(value) >= 25 { - value = value[:25] - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) - } else if strings.HasSuffix(strings.ToUpper(value), "Z") { - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) - } else if len(value) >= 19 { - if strings.Contains(value, "T") { - value = value[:19] - t, err = time.ParseInLocation(formatDateTimeT, value, time.Local) - } else { - value = value[:19] - t, err = time.ParseInLocation(formatDateTime, value, time.Local) - } - } else if len(value) >= 10 { - if len(value) > 10 { - value = value[:10] - } - t, err = time.ParseInLocation(formatDate, value, time.Local) - } else if len(value) >= 8 { - if len(value) > 8 { - value = value[:8] - } - t, err = time.ParseInLocation(formatTime, value, time.Local) - } - if err != nil { - return err - } - fieldV.Set(reflect.ValueOf(t)) - } - case reflect.Slice: - if fieldT.Type == sliceOfInts { - formVals := form[tag] - fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals))) - for i := 0; i < len(formVals); i++ { - val, err := strconv.Atoi(formVals[i]) - if err != nil { - return err - } - fieldV.Index(i).SetInt(int64(val)) - } - } else if fieldT.Type == sliceOfStrings { - formVals := form[tag] - fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals))) - for i := 0; i < len(formVals); i++ { - fieldV.Index(i).SetString(formVals[i]) - } - } - } - } - return nil -} - // ParseForm will parse form values to struct via tag. func ParseForm(form url.Values, obj interface{}) error { - objT := reflect.TypeOf(obj) - objV := reflect.ValueOf(obj) - if !isStructPtr(objT) { - return fmt.Errorf("%v must be a struct pointer", obj) - } - objT = objT.Elem() - objV = objV.Elem() - - return parseFormToStruct(form, objT, objV) + return context.ParseForm(form, obj) } -var sliceOfInts = reflect.TypeOf([]int(nil)) -var sliceOfStrings = reflect.TypeOf([]string(nil)) - var unKind = map[reflect.Kind]bool{ reflect.Uintptr: true, reflect.Complex64: true, @@ -442,10 +282,11 @@ var unKind = map[reflect.Kind]bool{ // RenderForm will render object to form html. // obj must be a struct pointer. +// nolint func RenderForm(obj interface{}) template.HTML { objT := reflect.TypeOf(obj) objV := reflect.ValueOf(obj) - if !isStructPtr(objT) { + if objT.Kind() != reflect.Ptr || objT.Elem().Kind() != reflect.Struct { return template.HTML("") } objT = objT.Elem() @@ -504,7 +345,7 @@ func isValidForInput(fType string) bool { } // parseFormTag takes the stuct-tag of a StructField and parses the `form` value. -// returned are the form label, name-property, type and wether the field should be ignored. +// returned are the form label, name-property, type and whether the field should be ignored. func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) { tags := strings.Split(fieldT.Tag.Get("form"), ",") label = fieldT.Name + ": " @@ -550,10 +391,6 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str return } -func isStructPtr(t reflect.Type) bool { - return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct -} - // go1.2 added template funcs. begin var ( errBadComparisonType = errors.New("invalid type for comparison") @@ -607,10 +444,23 @@ func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { if err != nil { return false, err } - if k1 != k2 { - return false, errBadComparison - } truth := false + if k1 != k2 { + // Special case: Can compare integer values regardless of type's sign. + switch { + case k1 == intKind && k2 == uintKind: + truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint() + case k1 == uintKind && k2 == intKind: + truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int()) + default: + return false, errBadComparison + } + if truth { + return true, nil + } else { + return false, nil + } + } switch k1 { case boolKind: truth = v1.Bool() == v2.Bool() @@ -653,23 +503,32 @@ func lt(arg1, arg2 interface{}) (bool, error) { if err != nil { return false, err } - if k1 != k2 { - return false, errBadComparison - } truth := false - switch k1 { - case boolKind, complexKind: - return false, errBadComparisonType - case floatKind: - truth = v1.Float() < v2.Float() - case intKind: - truth = v1.Int() < v2.Int() - case stringKind: - truth = v1.String() < v2.String() - case uintKind: - truth = v1.Uint() < v2.Uint() - default: - panic("invalid kind") + if k1 != k2 { + // Special case: Can compare integer values regardless of type's sign. + switch { + case k1 == intKind && k2 == uintKind: + truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint() + case k1 == uintKind && k2 == intKind: + truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int()) + default: + return false, errBadComparison + } + } else { + switch k1 { + case boolKind, complexKind: + return false, errBadComparisonType + case floatKind: + truth = v1.Float() < v2.Float() + case intKind: + truth = v1.Int() < v2.Int() + case stringKind: + truth = v1.String() < v2.String() + case uintKind: + truth = v1.Uint() < v2.Uint() + default: + return false, errBadComparisonType + } } return truth, nil } diff --git a/server/web/templatefunc_test.go b/server/web/templatefunc_test.go index df5cfa40..403a8547 100644 --- a/server/web/templatefunc_test.go +++ b/server/web/templatefunc_test.go @@ -299,7 +299,6 @@ func TestParseFormTag(t *testing.T) { if !(name == "name" && !required) { t.Errorf("Form Tag containing only name and not required was not correctly parsed.") } - } func TestMapGet(t *testing.T) { @@ -378,3 +377,206 @@ func TestMapGet(t *testing.T) { t.Errorf("Error happens %v", err) } } + +func Test_eq(t *testing.T) { + tests := []struct { + a interface{} + b interface{} + result bool + }{ + {uint8(1), int(1), true}, + {uint8(3), int(1), false}, + {uint16(1), int(1), true}, + {uint16(3), int(1), false}, + {uint32(1), int(1), true}, + {uint32(3), int(1), false}, + {uint(1), int(1), true}, + {uint(3), int(1), false}, + {uint64(1), int(1), true}, + {uint64(3), int(1), false}, + {int8(-1), uint(1), false}, + {int16(-2), uint(1), false}, + {int32(-3), uint(1), false}, + {int64(-4), uint(1), false}, + {int8(1), uint(1), true}, + {int16(1), uint(1), true}, + {int32(1), uint(1), true}, + {int64(1), uint(1), true}, + {int8(-1), uint8(1), false}, + {int16(-2), uint8(1), false}, + {int32(-3), uint8(1), false}, + {int64(-4), uint8(1), false}, + {int8(1), uint8(1), true}, + {int16(1), uint8(1), true}, + {int32(1), uint8(1), true}, + {int64(1), uint8(1), true}, + {int8(-1), uint16(1), false}, + {int16(-2), uint16(1), false}, + {int32(-3), uint16(1), false}, + {int64(-4), uint16(1), false}, + {int8(1), uint16(1), true}, + {int16(1), uint16(1), true}, + {int32(1), uint16(1), true}, + {int64(1), uint16(1), true}, + {int8(-1), uint32(1), false}, + {int16(-2), uint32(1), false}, + {int32(-3), uint32(1), false}, + {int64(-4), uint32(1), false}, + {int8(1), uint32(1), true}, + {int16(1), uint32(1), true}, + {int32(1), uint32(1), true}, + {int64(1), uint32(1), true}, + {int8(-1), uint64(1), false}, + {int16(-2), uint64(1), false}, + {int32(-3), uint64(1), false}, + {int64(-4), uint64(1), false}, + {int8(1), uint64(1), true}, + {int16(1), uint64(1), true}, + {int32(1), uint64(1), true}, + {int64(1), uint64(1), true}, + } + + for _, test := range tests { + if res, err := eq(test.a, test.b); err != nil { + if res != test.result { + t.Errorf("a:%v(%T) equals b:%v(%T) should be %v", test.a, test.a, test.b, test.b, test.result) + } + } + } +} + +func Test_lt(t *testing.T) { + tests := []struct { + a interface{} + b interface{} + result bool + }{ + {uint8(1), int(3), true}, + {uint8(1), int(1), false}, + {uint8(3), int(1), false}, + {uint16(1), int(3), true}, + {uint16(1), int(1), false}, + {uint16(3), int(1), false}, + {uint32(1), int(3), true}, + {uint32(1), int(1), false}, + {uint32(3), int(1), false}, + {uint(1), int(3), true}, + {uint(1), int(1), false}, + {uint(3), int(1), false}, + {uint64(1), int(3), true}, + {uint64(1), int(1), false}, + {uint64(3), int(1), false}, + {int(-1), int(1), true}, + {int(1), int(1), false}, + {int(1), int(3), true}, + {int(3), int(1), false}, + {int8(-1), uint(1), true}, + {int8(1), uint(1), false}, + {int8(1), uint(3), true}, + {int8(3), uint(1), false}, + {int16(-1), uint(1), true}, + {int16(1), uint(1), false}, + {int16(1), uint(3), true}, + {int16(3), uint(1), false}, + {int32(-1), uint(1), true}, + {int32(1), uint(1), false}, + {int32(1), uint(3), true}, + {int32(3), uint(1), false}, + {int64(-1), uint(1), true}, + {int64(1), uint(1), false}, + {int64(1), uint(3), true}, + {int64(3), uint(1), false}, + {int(-1), uint(1), true}, + {int(1), uint(1), false}, + {int(1), uint(3), true}, + {int(3), uint(1), false}, + {int8(-1), uint8(1), true}, + {int8(1), uint8(1), false}, + {int8(1), uint8(3), true}, + {int8(3), uint8(1), false}, + {int16(-1), uint8(1), true}, + {int16(1), uint8(1), false}, + {int16(1), uint8(3), true}, + {int16(3), uint8(1), false}, + {int32(-1), uint8(1), true}, + {int32(1), uint8(1), false}, + {int32(1), uint8(3), true}, + {int32(3), uint8(1), false}, + {int64(-1), uint8(1), true}, + {int64(1), uint8(1), false}, + {int64(1), uint8(3), true}, + {int64(3), uint8(1), false}, + {int(-1), uint8(1), true}, + {int(1), uint8(1), false}, + {int(1), uint8(3), true}, + {int(3), uint8(1), false}, + {int8(-1), uint16(1), true}, + {int8(1), uint16(1), false}, + {int8(1), uint16(3), true}, + {int8(3), uint16(1), false}, + {int16(-1), uint16(1), true}, + {int16(1), uint16(1), false}, + {int16(1), uint16(3), true}, + {int16(3), uint16(1), false}, + {int32(-1), uint16(1), true}, + {int32(1), uint16(1), false}, + {int32(1), uint16(3), true}, + {int32(3), uint16(1), false}, + {int64(-1), uint16(1), true}, + {int64(1), uint16(1), false}, + {int64(1), uint16(3), true}, + {int64(3), uint16(1), false}, + {int(-1), uint16(1), true}, + {int(1), uint16(1), false}, + {int(1), uint16(3), true}, + {int(3), uint16(1), false}, + {int8(-1), uint32(1), true}, + {int8(1), uint32(1), false}, + {int8(1), uint32(3), true}, + {int8(3), uint32(1), false}, + {int16(-1), uint32(1), true}, + {int16(1), uint32(1), false}, + {int16(1), uint32(3), true}, + {int16(3), uint32(1), false}, + {int32(-1), uint32(1), true}, + {int32(1), uint32(1), false}, + {int32(1), uint32(3), true}, + {int32(3), uint32(1), false}, + {int64(-1), uint32(1), true}, + {int64(1), uint32(1), false}, + {int64(1), uint32(3), true}, + {int64(3), uint32(1), false}, + {int(-1), uint32(1), true}, + {int(1), uint32(1), false}, + {int(1), uint32(3), true}, + {int(3), uint32(1), false}, + {int8(-1), uint64(1), true}, + {int8(1), uint64(1), false}, + {int8(1), uint64(3), true}, + {int8(3), uint64(1), false}, + {int16(-1), uint64(1), true}, + {int16(1), uint64(1), false}, + {int16(1), uint64(3), true}, + {int16(3), uint64(1), false}, + {int32(-1), uint64(1), true}, + {int32(1), uint64(1), false}, + {int32(1), uint64(3), true}, + {int32(3), uint64(1), false}, + {int64(-1), uint64(1), true}, + {int64(1), uint64(1), false}, + {int64(1), uint64(3), true}, + {int64(3), uint64(1), false}, + {int(-1), uint64(1), true}, + {int(1), uint64(1), false}, + {int(1), uint64(3), true}, + {int(3), uint64(1), false}, + } + + for _, test := range tests { + if res, err := lt(test.a, test.b); err != nil { + if res != test.result { + t.Errorf("a:%v(%T) lt b:%v(%T) should be %v", test.a, test.a, test.b, test.b, test.result) + } + } + } +} diff --git a/server/web/tree.go b/server/web/tree.go index dc459c49..24a58a01 100644 --- a/server/web/tree.go +++ b/server/web/tree.go @@ -20,13 +20,10 @@ import ( "strings" "github.com/beego/beego/v2/core/utils" - "github.com/beego/beego/v2/server/web/context" ) -var ( - allowSuffixExt = []string{".json", ".xml", ".html"} -) +var allowSuffixExt = []string{".json", ".xml", ".html"} // Tree has three elements: FixRouter/wildcard/leaves // fixRouter stores Fixed Router @@ -210,9 +207,9 @@ func (t *Tree) AddRouter(pattern string, runObject interface{}) { func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) { if len(segments) == 0 { if reg != "" { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}) + t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}}, t.leaves...) } else { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards}) + t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards}}, t.leaves...) } } else { seg := segments[0] @@ -285,7 +282,7 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, // Match router to runObject & params func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { - if len(pattern) == 0 || pattern[0] != '/' { + if pattern == "" || pattern[0] != '/' { return nil } w := make([]string, 0, 20) @@ -294,13 +291,14 @@ func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { if len(pattern) > 0 { - i := 0 - for ; i < len(pattern) && pattern[i] == '/'; i++ { + i, l := 0, len(pattern) + for i < l && pattern[i] == '/' { + i++ } pattern = pattern[i:] } // Handle leaf nodes: - if len(pattern) == 0 { + if pattern == "" { for _, l := range t.leaves { if ok := l.match(treePattern, wildcardValues, ctx); ok { return l.runObject @@ -317,7 +315,8 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string } var seg string i, l := 0, len(pattern) - for ; i < l && pattern[i] != '/'; i++ { + for i < l && pattern[i] != '/' { + i++ } if i == 0 { seg = pattern @@ -328,7 +327,7 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string } for _, subTree := range t.fixrouters { if subTree.prefix == seg { - if len(pattern) != 0 && pattern[0] == '/' { + if pattern != "" && pattern[0] == '/' { treePattern = pattern[1:] } else { treePattern = pattern @@ -342,8 +341,9 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string if runObject == nil && len(t.fixrouters) > 0 { // Filter the .json .xml .html extension for _, str := range allowSuffixExt { - if strings.HasSuffix(seg, str) { + if strings.HasSuffix(seg, str) && strings.HasSuffix(treePattern, seg) { for _, subTree := range t.fixrouters { + // strings.HasSuffix(treePattern, seg) avoid cases: /aaa.html/bbb could access /aaa/bbb if subTree.prefix == seg[:len(seg)-len(str)] { runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) if runObject != nil { diff --git a/server/web/tree_test.go b/server/web/tree_test.go index 3cb39c60..2e7fa6ce 100644 --- a/server/web/tree_test.go +++ b/server/web/tree_test.go @@ -17,6 +17,7 @@ package web import ( "strings" "testing" + "time" "github.com/beego/beego/v2/server/web/context" ) @@ -49,61 +50,84 @@ func notMatchTestInfo(pattern, url string) testInfo { } func init() { - routers = make([]testInfo, 0) - // match example - routers = append(routers, matchTestInfo("/topic/?:auth:int", "/topic", nil)) - routers = append(routers, matchTestInfo("/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"})) - routers = append(routers, matchTestInfo("/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"})) - routers = append(routers, matchTestInfo("/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"})) - routers = append(routers, matchTestInfo("/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"})) - routers = append(routers, matchTestInfo("/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"})) - routers = append(routers, matchTestInfo("/:id", "/123", map[string]string{":id": "123"})) - routers = append(routers, matchTestInfo("/hello/?:id", "/hello", map[string]string{":id": ""})) - routers = append(routers, matchTestInfo("/", "/", nil)) - routers = append(routers, matchTestInfo("/customer/login", "/customer/login", nil)) - routers = append(routers, matchTestInfo("/customer/login", "/customer/login.json", map[string]string{":ext": "json"})) - routers = append(routers, matchTestInfo("/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"})) - routers = append(routers, matchTestInfo("/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"})) - routers = append(routers, matchTestInfo("/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"})) - routers = append(routers, matchTestInfo("/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"})) - routers = append(routers, matchTestInfo("/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"})) - routers = append(routers, matchTestInfo("/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"})) - routers = append(routers, matchTestInfo("/thumbnail/:size/uploads/*", "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"})) - routers = append(routers, matchTestInfo("/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"})) - routers = append(routers, matchTestInfo("/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"})) - routers = append(routers, matchTestInfo("/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"})) - routers = append(routers, matchTestInfo("/dl/:width:int/:height:int/*.*", "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"})) - routers = append(routers, matchTestInfo("/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"})) - routers = append(routers, matchTestInfo("/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"})) - routers = append(routers, matchTestInfo("/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"})) - routers = append(routers, matchTestInfo("/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"})) - routers = append(routers, matchTestInfo("/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"})) - routers = append(routers, matchTestInfo("/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"})) - routers = append(routers, matchTestInfo("/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"})) - routers = append(routers, matchTestInfo("/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"})) - routers = append(routers, matchTestInfo("/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"})) - routers = append(routers, matchTestInfo("/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"})) - routers = append(routers, matchTestInfo("/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"})) - routers = append(routers, matchTestInfo("/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"})) - routers = append(routers, matchTestInfo("/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"})) - routers = append(routers, matchTestInfo("/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"})) - routers = append(routers, matchTestInfo("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"})) - routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"})) - routers = append(routers, matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"})) + const ( + abcHTML = "/suffix/abc.html" + abcSuffix = "/abc/suffix/*" + ) - // not match example + routers = []testInfo{ + // match example + matchTestInfo("/topic/?:auth:int", "/topic", nil), + matchTestInfo("/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}), + matchTestInfo("/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}), + matchTestInfo("/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}), + matchTestInfo("/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}), + matchTestInfo("/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}), + matchTestInfo("/:id", "/123", map[string]string{":id": "123"}), + matchTestInfo("/hello/?:id", "/hello", map[string]string{":id": ""}), + matchTestInfo("/", "/", nil), + matchTestInfo("/customer/login", "/customer/login", nil), + matchTestInfo("/customer/login", "/customer/login.json", map[string]string{":ext": "json"}), + matchTestInfo("/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}), + matchTestInfo("/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}), + matchTestInfo("/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}), + matchTestInfo("/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}), + matchTestInfo("/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}), + matchTestInfo("/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}), + matchTestInfo("/thumbnail/:size/uploads/*", "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}), + matchTestInfo("/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}), + matchTestInfo("/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}), + matchTestInfo("/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}), + matchTestInfo("/dl/:width:int/:height:int/*.*", "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}), + matchTestInfo("/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}), + matchTestInfo("/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}), + matchTestInfo("/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}), + matchTestInfo("/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}), + matchTestInfo("/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}), + matchTestInfo("/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}), + matchTestInfo("/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}), + matchTestInfo("/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}), + matchTestInfo("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}), + matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}), + matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}), + matchTestInfo("/?:year/?:month/?:day", "/2020/11/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"}), + matchTestInfo("/?:year/?:month/?:day", "/2020/11", map[string]string{":year": "2020", ":month": "11"}), + matchTestInfo("/?:year", "/2020", map[string]string{":year": "2020"}), + matchTestInfo("/?:year([0-9]+)/?:month([0-9]+)/mid/?:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"}), + matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/mid/10", map[string]string{":year": "2020", ":day": "10"}), + matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/11/mid", map[string]string{":year": "2020", ":month": "11"}), + matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/mid/10/24", map[string]string{":day": "10", ":hour": "24"}), + matchTestInfo("/?:year([0-9]+)/:month([0-9]+)/mid/:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"}), + matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10/24", map[string]string{":month": "11", ":day": "10"}), + matchTestInfo("/?:year/:month/mid/:day/?:hour", "/2020/11/mid/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"}), + matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10", map[string]string{":month": "11", ":day": "10"}), - // https://github.com/beego/beego/v2/issues/3865 - routers = append(routers, notMatchTestInfo("/read_:id:int\\.htm", "/read_222htm")) - routers = append(routers, notMatchTestInfo("/read_:id:int\\.htm", "/read_222_htm")) - routers = append(routers, notMatchTestInfo("/read_:id:int\\.htm", " /read_262shtm")) + // not match example + // https://github.com/beego/beego/v2/issues/3865 + notMatchTestInfo("/read_:id:int\\.htm", "/read_222htm"), + notMatchTestInfo("/read_:id:int\\.htm", "/read_222_htm"), + notMatchTestInfo("/read_:id:int\\.htm", " /read_262shtm"), + // test .html, .json not suffix + notMatchTestInfo(abcHTML, "/suffix.html/abc"), + matchTestInfo("/suffix/abc", abcHTML, nil), + matchTestInfo("/suffix/*", abcHTML, nil), + notMatchTestInfo("/suffix/*", "/suffix.html/a"), + notMatchTestInfo(abcSuffix, "/abc/suffix.html/a"), + matchTestInfo(abcSuffix, "/abc/suffix/a", nil), + notMatchTestInfo(abcSuffix, "/abc.j/suffix/a"), + } } func TestTreeRouters(t *testing.T) { for _, r := range routers { shouldMatch := r.shouldMatchOrNot - tr := NewTree() tr.AddRouter(r.pattern, "astaxie") ctx := context.NewContext() @@ -112,7 +136,7 @@ func TestTreeRouters(t *testing.T) { if obj != nil { t.Fatal("pattern:", r.pattern, ", should not match", r.requestUrl) } else { - return + continue } } if obj == nil || obj.(string) != "astaxie" { @@ -128,6 +152,7 @@ func TestTreeRouters(t *testing.T) { } } } + time.Sleep(time.Second) } func TestStaticPath(t *testing.T) { @@ -280,6 +305,7 @@ func TestAddTree5(t *testing.T) { t.Fatal("url /v1/shop/ need match router /v1/shop/ ") } } + func TestSplitPath(t *testing.T) { a := splitPath("") if len(a) != 0 { @@ -308,7 +334,6 @@ func TestSplitPath(t *testing.T) { } func TestSplitSegment(t *testing.T) { - items := map[string]struct { isReg bool params []string diff --git a/server/web/unregroute_test.go b/server/web/unregroute_test.go index c675ae7d..703497d3 100644 --- a/server/web/unregroute_test.go +++ b/server/web/unregroute_test.go @@ -27,13 +27,14 @@ import ( // that embed parent routers. // -const contentRootOriginal = "ok-original-root" -const contentLevel1Original = "ok-original-level1" -const contentLevel2Original = "ok-original-level2" - -const contentRootReplacement = "ok-replacement-root" -const contentLevel1Replacement = "ok-replacement-level1" -const contentLevel2Replacement = "ok-replacement-level2" +const ( + contentRootOriginal = "ok-original-root" + contentLevel1Original = "ok-original-level1" + contentLevel2Original = "ok-original-level2" + contentRootReplacement = "ok-replacement-root" + contentLevel1Replacement = "ok-replacement-level1" + contentLevel2Replacement = "ok-replacement-level2" +) // TestPreUnregController will supply content for the original routes, // before unregistration @@ -44,9 +45,11 @@ type TestPreUnregController struct { func (tc *TestPreUnregController) GetFixedRoot() { tc.Ctx.Output.Body([]byte(contentRootOriginal)) } + func (tc *TestPreUnregController) GetFixedLevel1() { tc.Ctx.Output.Body([]byte(contentLevel1Original)) } + func (tc *TestPreUnregController) GetFixedLevel2() { tc.Ctx.Output.Body([]byte(contentLevel2Original)) } @@ -60,9 +63,11 @@ type TestPostUnregController struct { func (tc *TestPostUnregController) GetFixedRoot() { tc.Ctx.Output.Body([]byte(contentRootReplacement)) } + func (tc *TestPostUnregController) GetFixedLevel1() { tc.Ctx.Output.Body([]byte(contentLevel1Replacement)) } + func (tc *TestPostUnregController) GetFixedLevel2() { tc.Ctx.Output.Body([]byte(contentLevel2Replacement)) } @@ -71,13 +76,12 @@ func (tc *TestPostUnregController) GetFixedLevel2() { // In this case, for a path like "/level1/level2" or "/level1", those actions // should remain intact, and continue to serve the original content. func TestUnregisterFixedRouteRoot(t *testing.T) { - - var method = "GET" + method := "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, "Test original root", @@ -96,7 +100,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Replace the root path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") + handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) // Test replacement root (expect change) testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) @@ -106,20 +110,18 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Test level 2 (expect no change from the original) testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) - } // TestUnregisterFixedRouteLevel1 replaces just the "/level1" fixed route path. // In this case, for a path like "/level1/level2" or "/", those actions // should remain intact, and continue to serve the original content. func TestUnregisterFixedRouteLevel1(t *testing.T) { - - var method = "GET" + method := "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -146,7 +148,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Replace the "level1" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -156,20 +158,18 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Test level 2 (expect no change from the original) testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) - } // TestUnregisterFixedRouteLevel2 unregisters just the "/level1/level2" fixed // route path. In this case, for a path like "/level1" or "/", those actions // should remain intact, and continue to serve the original content. func TestUnregisterFixedRouteLevel2(t *testing.T) { - - var method = "GET" + method := "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Replace the "/level1/level2" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") + handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -206,7 +206,6 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Test level 2 (expect change) testHelperFnContentCheck(t, handler, "Test level 2 (expect change)", method, "/level1/level2", contentLevel2Replacement) - } func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister, diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 00000000..1a12fb33 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,7 @@ +sonar.organization=beego +sonar.projectKey=beego_beego + +# relative paths to source directories. More details and properties are described +# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/ +sonar.sources=. +sonar.exclusions=**/*_test.go \ No newline at end of file diff --git a/task/govenor_command.go b/task/govenor_command.go index 20238435..5435fdf1 100644 --- a/task/govenor_command.go +++ b/task/govenor_command.go @@ -24,8 +24,7 @@ import ( "github.com/beego/beego/v2/core/admin" ) -type listTaskCommand struct { -} +type listTaskCommand struct{} func (l *listTaskCommand) Execute(params ...interface{}) *admin.Result { resultList := make([][]string, 0, len(globalTaskManager.adminTaskList)) @@ -45,8 +44,7 @@ func (l *listTaskCommand) Execute(params ...interface{}) *admin.Result { } } -type runTaskCommand struct { -} +type runTaskCommand struct{} func (r *runTaskCommand) Execute(params ...interface{}) *admin.Result { if len(params) == 0 { @@ -83,7 +81,6 @@ func (r *runTaskCommand) Execute(params ...interface{}) *admin.Result { Error: errors.New(fmt.Sprintf("task with name %s not found", tn)), } } - } func registerCommands() { diff --git a/task/governor_command_test.go b/task/governor_command_test.go index 00ed37f2..c3547cdf 100644 --- a/task/governor_command_test.go +++ b/task/governor_command_test.go @@ -55,6 +55,10 @@ func (c *countTask) GetPrev(ctx context.Context) time.Time { return time.Now() } +func (c *countTask) GetTimeout(ctx context.Context) time.Duration { + return 0 +} + func TestRunTaskCommand_Execute(t *testing.T) { task := &countTask{} AddTask("count", task) diff --git a/task/task.go b/task/task.go index 2ea34f24..09a08144 100644 --- a/task/task.go +++ b/task/task.go @@ -37,6 +37,7 @@ type taskManager struct { stop chan bool changed chan bool started bool + wait sync.WaitGroup } func newTaskManager() *taskManager { @@ -109,6 +110,7 @@ type Tasker interface { GetNext(ctx context.Context) time.Time SetPrev(context.Context, time.Time) GetPrev(ctx context.Context) time.Time + GetTimeout(ctx context.Context) time.Duration } // task error @@ -127,14 +129,14 @@ type Task struct { DoFunc TaskFunc Prev time.Time Next time.Time - Errlist []*taskerr // like errtime:errinfo - ErrLimit int // max length for the errlist, 0 stand for no limit - errCnt int // records the error count during the execution + Timeout time.Duration // timeout duration + Errlist []*taskerr // like errtime:errinfo + ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func -func NewTask(tname string, spec string, f TaskFunc) *Task { - +func NewTask(tname string, spec string, f TaskFunc, opts ...Option) *Task { task := &Task{ Taskname: tname, DoFunc: f, @@ -144,6 +146,11 @@ func NewTask(tname string, spec string, f TaskFunc) *Task { // we only store the pointer, so it won't use too many space Errlist: make([]*taskerr, 100, 100), } + + for _, opt := range opts { + opt.apply(task) + } + task.SetCron(spec) return task } @@ -196,6 +203,31 @@ func (t *Task) GetPrev(context.Context) time.Time { return t.Prev } +// GetTimeout get timeout duration of this task +func (t *Task) GetTimeout(context.Context) time.Duration { + return t.Timeout +} + +// Option interface +type Option interface { + apply(*Task) +} + +// optionFunc return a function to set task element +type optionFunc func(*Task) + +// apply option to task +func (f optionFunc) apply(t *Task) { + f(t) +} + +// TimeoutOption return a option to set timeout duration for task +func TimeoutOption(timeout time.Duration) Option { + return optionFunc(func(t *Task) { + t.Timeout = timeout + }) +} + // six columns mean: // second:0-59 // minute:0-59 @@ -318,7 +350,6 @@ func (t *Task) parseSpec(spec string) *Schedule { // Next set schedule to next time func (s *Schedule) Next(t time.Time) time.Time { - // Start at the earliest possible time (the upcoming second). t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond) @@ -439,6 +470,11 @@ func ClearTask() { globalTaskManager.ClearTask() } +// GracefulShutdown wait all task done +func GracefulShutdown() <-chan struct{} { + return globalTaskManager.GracefulShutdown() +} + // StartTask start all tasks func (m *taskManager) StartTask() { m.taskLock.Lock() @@ -455,14 +491,12 @@ func (m *taskManager) StartTask() { func (m *taskManager) run() { now := time.Now().Local() - m.taskLock.Lock() - for _, t := range m.adminTaskList { - t.SetNext(nil, now) - } - m.taskLock.Unlock() + // first run the tasks, so set all tasks next run time. + m.setTasksStartTime(now) for { // we only use RLock here because NewMapSorter copy the reference, do not change any thing + // here, we sort all task and get first task running time (effective). m.taskLock.RLock() sortList := NewMapSorter(m.adminTaskList) m.taskLock.RUnlock() @@ -475,37 +509,78 @@ func (m *taskManager) run() { } else { effective = sortList.Vals[0].GetNext(context.Background()) } + select { - case now = <-time.After(effective.Sub(now)): - // Run every entry whose next time was this effective time. - for _, e := range sortList.Vals { - if e.GetNext(context.Background()) != effective { - break - } - go e.Run(nil) - e.SetPrev(context.Background(), e.GetNext(context.Background())) - e.SetNext(nil, effective) - } + case now = <-time.After(effective.Sub(now)): // wait for effective time + m.runNextTasks(sortList, effective) continue - case <-m.changed: + case <-m.changed: // tasks have been changed, set all tasks run again now now = time.Now().Local() - m.taskLock.Lock() - for _, t := range m.adminTaskList { - t.SetNext(nil, now) - } - m.taskLock.Unlock() + m.setTasksStartTime(now) continue - case <-m.stop: - m.taskLock.Lock() - if m.started { - m.started = false - } - m.taskLock.Unlock() + case <-m.stop: // manager is stopped, and mark manager is stopped + m.markManagerStop() return } } } +// setTasksStartTime is set all tasks next running time +func (m *taskManager) setTasksStartTime(now time.Time) { + m.taskLock.Lock() + for _, task := range m.adminTaskList { + task.SetNext(context.Background(), now) + } + m.taskLock.Unlock() +} + +// markManagerStop it sets manager to be stopped +func (m *taskManager) markManagerStop() { + m.taskLock.Lock() + if m.started { + m.started = false + } + m.taskLock.Unlock() +} + +// runNextTasks it runs next task which next run time is equal to effective +func (m *taskManager) runNextTasks(sortList *MapSorter, effective time.Time) { + // Run every entry whose next time was this effective time. + i := 0 + for _, e := range sortList.Vals { + i++ + if e.GetNext(context.Background()) != effective { + break + } + + // check if timeout is on, if yes passing the timeout context + ctx := context.Background() + m.wait.Add(1) + if duration := e.GetTimeout(ctx); duration != 0 { + go func(e Tasker) { + defer m.wait.Done() + ctx, cancelFunc := context.WithTimeout(ctx, duration) + defer cancelFunc() + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } else { + go func(e Tasker) { + defer m.wait.Done() + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } + + e.SetPrev(context.Background(), e.GetNext(context.Background())) + e.SetNext(context.Background(), effective) + } +} + // StopTask stop all tasks func (m *taskManager) StopTask() { go func() { @@ -513,6 +588,17 @@ func (m *taskManager) StopTask() { }() } +// GracefulShutdown wait all task done +func (m *taskManager) GracefulShutdown() <-chan struct{} { + done := make(chan struct{}) + go func() { + m.stop <- true + m.wait.Wait() + close(done) + }() + return done +} + // AddTask add task with name func (m *taskManager) AddTask(taskname string, t Tasker) { isChanged := false @@ -529,7 +615,6 @@ func (m *taskManager) AddTask(taskname string, t Tasker) { m.changed <- true }() } - } // DeleteTask delete task with name @@ -602,6 +687,7 @@ func (ms *MapSorter) Less(i, j int) bool { } return ms.Vals[i].GetNext(context.Background()).Before(ms.Vals[j].GetNext(context.Background())) } + func (ms *MapSorter) Swap(i, j int) { ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] @@ -620,7 +706,6 @@ func getField(field string, r bounds) uint64 { // getRange returns the bits indicated by the given expression: // number | number "-" number [ "/" number ] func getRange(expr string, r bounds) uint64 { - var ( start, end, step uint rangeAndStep = strings.Split(expr, "/") diff --git a/task/task_test.go b/task/task_test.go index 5e117cbd..8d274e8f 100644 --- a/task/task_test.go +++ b/task/task_test.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "testing" "time" @@ -90,6 +91,57 @@ func TestSpec(t *testing.T) { } } +func TestTimeout(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + wg := &sync.WaitGroup{} + wg.Add(2) + once1, once2 := sync.Once{}, sync.Once{} + + tk1 := NewTask("tk1", "0/10 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + once1.Do(func() { + fmt.Println("tk1 done") + wg.Done() + }) + return errors.New("timeout") + default: + } + return nil + }, TimeoutOption(3*time.Second), + ) + + tk2 := NewTask("tk2", "0/11 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + return errors.New("timeout") + default: + once2.Do(func() { + fmt.Println("tk2 done") + wg.Done() + }) + } + return nil + }, + ) + + m.AddTask("tk1", tk1) + m.AddTask("tk2", tk2) + m.StartTask() + defer m.StopTask() + + select { + case <-time.After(19 * time.Second): + t.Error("TestTimeout failed") + case <-wait(wg): + } +} + func TestTask_Run(t *testing.T) { cnt := -1 task := func(ctx context.Context) error { @@ -109,6 +161,43 @@ func TestTask_Run(t *testing.T) { assert.Equal(t, "Hello, world! 101", l[1].errinfo) } +func TestCrudTask(t *testing.T) { + m := newTaskManager() + m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.DeleteTask("my-task1") + assert.Equal(t, 1, len(m.adminTaskList)) + + m.ClearTask() + assert.Equal(t, 0, len(m.adminTaskList)) +} + +func TestGracefulShutdown(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + waitDone := atomic.Value{} + waitDone.Store(false) + tk := NewTask("everySecond", "* * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + time.Sleep(2 * time.Second) + waitDone.Store(true) + return nil + }) + m.AddTask("taska", tk) + m.StartTask() + time.Sleep(1 * time.Second) + shutdown := m.GracefulShutdown() + assert.False(t, waitDone.Load().(bool)) + <-shutdown + assert.True(t, waitDone.Load().(bool)) +} + func wait(wg *sync.WaitGroup) chan bool { ch := make(chan bool) go func() { diff --git a/test/bindata.go b/test/bindata.go index 6dbc08ab..632f5a47 100644 --- a/test/bindata.go +++ b/test/bindata.go @@ -56,18 +56,23 @@ type bindataFileInfo struct { func (fi bindataFileInfo) Name() string { return fi.name } + func (fi bindataFileInfo) Size() int64 { return fi.size } + func (fi bindataFileInfo) Mode() os.FileMode { return fi.mode } + func (fi bindataFileInfo) ModTime() time.Time { return fi.modTime } + func (fi bindataFileInfo) IsDir() bool { return false } + func (fi bindataFileInfo) Sys() interface{} { return nil } @@ -87,7 +92,7 @@ func viewsBlocksBlockTpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(0o664), modTime: time.Unix(1541431067, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -107,7 +112,7 @@ func viewsHeaderTpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(0o664), modTime: time.Unix(1541431067, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -127,7 +132,7 @@ func viewsIndexTpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(436), modTime: time.Unix(1541434906, 0)} + info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(0o664), modTime: time.Unix(1541434906, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -249,7 +254,7 @@ func RestoreAsset(dir, name string) error { if err != nil { return err } - err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755)) + err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0o755)) if err != nil { return err } @@ -287,9 +292,7 @@ func _filePath(dir, name string) string { } func assetFS() *assetfs.AssetFS { - assetInfo := func(path string) (os.FileInfo, error) { - return os.Stat(path) - } + assetInfo := os.Stat for k := range _bintree.Children { return &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: assetInfo, Prefix: k} }