beego/parser.go
2020-12-22 21:29:23 +08:00

591 lines
14 KiB
Go

// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package beego
import (
"encoding/json"
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"unicode"
"github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils"
)
var globalRouterTemplate = `package {{.routersDir}}
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/context/param"{{.globalimport}}
)
func init() {
{{.globalinfo}}
}
`
var (
lastupdateFilename = "lastupdate.tmp"
pkgLastupdate map[string]int64
genInfoList map[string][]ControllerComments
routerHooks = map[string]int{
"beego.BeforeStatic": BeforeStatic,
"beego.BeforeRouter": BeforeRouter,
"beego.BeforeExec": BeforeExec,
"beego.AfterExec": AfterExec,
"beego.FinishRouter": FinishRouter,
}
routerHooksMapping = map[int]string{
BeforeStatic: "beego.BeforeStatic",
BeforeRouter: "beego.BeforeRouter",
BeforeExec: "beego.BeforeExec",
AfterExec: "beego.AfterExec",
FinishRouter: "beego.FinishRouter",
}
)
const commentFilename = "commentsRouter.go"
func init() {
pkgLastupdate = make(map[string]int64)
}
func parserPkg(pkgRealpath, pkgpath string) error {
if !compareFile(pkgRealpath) {
logs.Info(pkgRealpath + " no changed")
return nil
}
genInfoList = make(map[string][]ControllerComments)
fileSet := token.NewFileSet()
astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool {
name := info.Name()
return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
}, parser.ParseComments)
if err != nil {
return err
}
for _, pkg := range astPkgs {
for _, fl := range pkg.Files {
for _, d := range fl.Decls {
switch specDecl := d.(type) {
case *ast.FuncDecl:
if specDecl.Recv != nil {
exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
if ok {
err = parserComments(specDecl, fmt.Sprint(exp.X), pkgpath)
if err != nil {
return err
}
}
}
}
}
}
}
genRouterCode(pkgRealpath)
savetoFile(pkgRealpath)
return nil
}
type parsedComment struct {
routerPath string
methods []string
params map[string]parsedParam
filters []parsedFilter
imports []parsedImport
}
type parsedImport struct {
importPath string
importAlias string
}
type parsedFilter struct {
pattern string
pos int
filter string
params []bool
}
type parsedParam struct {
name string
datatype string
location string
defValue string
required bool
}
func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
if f.Doc != nil {
parsedComments, err := parseComment(f.Doc.List)
if err != nil {
return err
}
for _, parsedComment := range parsedComments {
if parsedComment.routerPath != "" {
key := pkgpath + ":" + controllerName
cc := ControllerComments{}
cc.Method = f.Name.String()
cc.Router = parsedComment.routerPath
cc.AllowHTTPMethods = parsedComment.methods
cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
cc.FilterComments = buildFilters(parsedComment.filters)
cc.ImportComments = buildImports(parsedComment.imports)
genInfoList[key] = append(genInfoList[key], cc)
}
}
}
return nil
}
func buildImports(pis []parsedImport) []*ControllerImportComments {
var importComments []*ControllerImportComments
for _, pi := range pis {
importComments = append(importComments, &ControllerImportComments{
ImportPath: pi.importPath,
ImportAlias: pi.importAlias,
})
}
return importComments
}
func buildFilters(pfs []parsedFilter) []*ControllerFilterComments {
var filterComments []*ControllerFilterComments
for _, pf := range pfs {
var (
returnOnOutput bool
resetParams bool
)
if len(pf.params) >= 1 {
returnOnOutput = pf.params[0]
}
if len(pf.params) >= 2 {
resetParams = pf.params[1]
}
filterComments = append(filterComments, &ControllerFilterComments{
Filter: pf.filter,
Pattern: pf.pattern,
Pos: pf.pos,
ReturnOnOutput: returnOnOutput,
ResetParams: resetParams,
})
}
return filterComments
}
func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
result := make([]*param.MethodParam, 0, len(funcParams))
for _, fparam := range funcParams {
for _, pName := range fparam.Names {
methodParam := buildMethodParam(fparam, pName.Name, pc)
result = append(result, methodParam)
}
}
return result
}
func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam {
options := []param.MethodParamOption{}
if cparam, ok := pc.params[name]; ok {
//Build param from comment info
name = cparam.name
if cparam.required {
options = append(options, param.IsRequired)
}
switch cparam.location {
case "body":
options = append(options, param.InBody)
case "header":
options = append(options, param.InHeader)
case "path":
options = append(options, param.InPath)
}
if cparam.defValue != "" {
options = append(options, param.Default(cparam.defValue))
}
} else {
if paramInPath(name, pc.routerPath) {
options = append(options, param.InPath)
}
}
return param.New(name, options...)
}
func paramInPath(name, route string) bool {
return strings.HasSuffix(route, ":"+name) ||
strings.Contains(route, ":"+name+"/")
}
var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) {
pcs = []*parsedComment{}
params := map[string]parsedParam{}
filters := []parsedFilter{}
imports := []parsedImport{}
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Param") {
pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
if len(pv) < 4 {
logs.Error("Invalid @Param format. Needs at least 4 parameters")
}
p := parsedParam{}
names := strings.SplitN(pv[0], "=>", 2)
p.name = names[0]
funcParamName := p.name
if len(names) > 1 {
funcParamName = names[1]
}
p.location = pv[1]
p.datatype = pv[2]
switch len(pv) {
case 5:
p.required, _ = strconv.ParseBool(pv[3])
case 6:
p.defValue = pv[3]
p.required, _ = strconv.ParseBool(pv[4])
}
params[funcParamName] = p
}
}
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Import") {
iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import")))
if len(iv) == 0 || len(iv) > 2 {
logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters")
continue
}
p := parsedImport{}
p.importPath = iv[0]
if len(iv) == 2 {
p.importAlias = iv[1]
}
imports = append(imports, p)
}
}
filterLoop:
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@Filter") {
fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter")))
if len(fv) < 3 {
logs.Error("Invalid @Filter format. Needs at least 3 parameters")
continue filterLoop
}
p := parsedFilter{}
p.pattern = fv[0]
posName := fv[1]
if pos, exists := routerHooks[posName]; exists {
p.pos = pos
} else {
logs.Error("Invalid @Filter pos: ", posName)
continue filterLoop
}
p.filter = fv[2]
fvParams := fv[3:]
for _, fvParam := range fvParams {
switch fvParam {
case "true":
p.params = append(p.params, true)
case "false":
p.params = append(p.params, false)
default:
logs.Error("Invalid @Filter param: ", fvParam)
continue filterLoop
}
}
filters = append(filters, p)
}
}
for _, c := range lines {
var pc = &parsedComment{}
pc.params = params
pc.filters = filters
pc.imports = imports
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@router") {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
matches := routeRegex.FindStringSubmatch(t)
if len(matches) == 3 {
pc.routerPath = matches[1]
methods := matches[2]
if methods == "" {
pc.methods = []string{"get"}
//pc.hasGet = true
} else {
pc.methods = strings.Split(methods, ",")
//pc.hasGet = strings.Contains(methods, "get")
}
pcs = append(pcs, pc)
} else {
return nil, errors.New("Router information is missing")
}
}
}
return
}
// direct copy from bee\g_docs.go
// analysis params return []string
// @Param query form string true "The email for login"
// [query form string true "The email for login"]
func getparams(str string) []string {
var s []rune
var j int
var start bool
var r []string
var quoted int8
for _, c := range str {
if unicode.IsSpace(c) && quoted == 0 {
if !start {
continue
} else {
start = false
j++
r = append(r, string(s))
s = make([]rune, 0)
continue
}
}
start = true
if c == '"' {
quoted ^= 1
continue
}
s = append(s, c)
}
if len(s) > 0 {
r = append(r, string(s))
}
return r
}
func genRouterCode(pkgRealpath string) {
os.Mkdir(getRouterDir(pkgRealpath), 0755)
logs.Info("generate router from comments")
var (
globalinfo string
globalimport string
sortKey []string
)
for k := range genInfoList {
sortKey = append(sortKey, k)
}
sort.Strings(sortKey)
for _, k := range sortKey {
cList := genInfoList[k]
sort.Sort(ControllerCommentsSlice(cList))
for _, c := range cList {
allmethod := "nil"
if len(c.AllowHTTPMethods) > 0 {
allmethod = "[]string{"
for _, m := range c.AllowHTTPMethods {
allmethod += `"` + m + `",`
}
allmethod = strings.TrimRight(allmethod, ",") + "}"
}
params := "nil"
if len(c.Params) > 0 {
params = "[]map[string]string{"
for _, p := range c.Params {
for k, v := range p {
params = params + `map[string]string{` + k + `:"` + v + `"},`
}
}
params = strings.TrimRight(params, ",") + "}"
}
methodParams := "param.Make("
if len(c.MethodParams) > 0 {
lines := make([]string, 0, len(c.MethodParams))
for _, m := range c.MethodParams {
lines = append(lines, fmt.Sprint(m))
}
methodParams += "\n " +
strings.Join(lines, ",\n ") +
",\n "
}
methodParams += ")"
imports := ""
if len(c.ImportComments) > 0 {
for _, i := range c.ImportComments {
var s string
if i.ImportAlias != "" {
s = fmt.Sprintf(`
%s "%s"`, i.ImportAlias, i.ImportPath)
} else {
s = fmt.Sprintf(`
"%s"`, i.ImportPath)
}
if !strings.Contains(globalimport, s) {
imports += s
}
}
}
filters := ""
if len(c.FilterComments) > 0 {
for _, f := range c.FilterComments {
filters += fmt.Sprintf(` &beego.ControllerFilter{
Pattern: "%s",
Pos: %s,
Filter: %s,
ReturnOnOutput: %v,
ResetParams: %v,
},`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams)
}
}
if filters == "" {
filters = "nil"
} else {
filters = fmt.Sprintf(`[]*beego.ControllerFilter{
%s
}`, filters)
}
globalimport += imports
globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `",
` + "Router: `" + c.Router + "`" + `,
AllowHTTPMethods: ` + allmethod + `,
MethodParams: ` + methodParams + `,
Filters: ` + filters + `,
Params: ` + params + `})
`
}
}
if globalinfo != "" {
f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename))
if err != nil {
panic(err)
}
defer f.Close()
routersDir := AppConfig.DefaultString("routersdir", "routers")
content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1)
content = strings.Replace(content, "{{.routersDir}}", routersDir, -1)
content = strings.Replace(content, "{{.globalimport}}", globalimport, -1)
f.WriteString(content)
}
}
func compareFile(pkgRealpath string) bool {
if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) {
return true
}
if utils.FileExists(lastupdateFilename) {
content, err := ioutil.ReadFile(lastupdateFilename)
if err != nil {
return true
}
json.Unmarshal(content, &pkgLastupdate)
lastupdate, err := getpathTime(pkgRealpath)
if err != nil {
return true
}
if v, ok := pkgLastupdate[pkgRealpath]; ok {
if lastupdate <= v {
return false
}
}
}
return true
}
func savetoFile(pkgRealpath string) {
lastupdate, err := getpathTime(pkgRealpath)
if err != nil {
return
}
pkgLastupdate[pkgRealpath] = lastupdate
d, err := json.Marshal(pkgLastupdate)
if err != nil {
return
}
ioutil.WriteFile(lastupdateFilename, d, os.ModePerm)
}
func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
fl, err := ioutil.ReadDir(pkgRealpath)
if err != nil {
return lastupdate, err
}
for _, f := range fl {
if lastupdate < f.ModTime().UnixNano() {
lastupdate = f.ModTime().UnixNano()
}
}
return lastupdate, nil
}
func getRouterDir(pkgRealpath string) string {
dir := filepath.Dir(pkgRealpath)
for {
routersDir := AppConfig.DefaultString("routersdir", "routers")
d := filepath.Join(dir, routersDir)
if utils.FileExists(d) {
return d
}
if r, _ := filepath.Rel(dir, AppPath); r == "." {
return d
}
// Parent dir.
dir = filepath.Dir(dir)
}
}