commit
						9d936c58bf
					
				
							
								
								
									
										159
									
								
								pkg/orm/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								pkg/orm/README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,159 @@ | ||||
| # beego orm | ||||
| 
 | ||||
| [](https://drone.io/github.com/astaxie/beego/latest) | ||||
| 
 | ||||
| A powerful orm framework for go. | ||||
| 
 | ||||
| It is heavily influenced by Django ORM, SQLAlchemy. | ||||
| 
 | ||||
| **Support Database:** | ||||
| 
 | ||||
| * MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) | ||||
| * PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq) | ||||
| * Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) | ||||
| 
 | ||||
| Passed all test, but need more feedback. | ||||
| 
 | ||||
| **Features:** | ||||
| 
 | ||||
| * full go type support | ||||
| * easy for usage, simple CRUD operation | ||||
| * auto join with relation table | ||||
| * cross DataBase compatible query | ||||
| * Raw SQL query / mapper without orm model | ||||
| * full test keep stable and strong | ||||
| 
 | ||||
| more features please read the docs | ||||
| 
 | ||||
| **Install:** | ||||
| 
 | ||||
| 	go get github.com/astaxie/beego/orm | ||||
| 
 | ||||
| ## Changelog | ||||
| 
 | ||||
| * 2013-08-19: support table auto create | ||||
| * 2013-08-13: update test for database types | ||||
| * 2013-08-13: go type support, such as int8, uint8, byte, rune | ||||
| * 2013-08-13: date / datetime timezone support very well | ||||
| 
 | ||||
| ## Quick Start | ||||
| 
 | ||||
| #### Simple Usage | ||||
| 
 | ||||
| ```go | ||||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/astaxie/beego/orm" | ||||
| 	_ "github.com/go-sql-driver/mysql" // import your used driver | ||||
| ) | ||||
| 
 | ||||
| // Model Struct | ||||
| type User struct { | ||||
| 	Id   int    `orm:"auto"` | ||||
| 	Name string `orm:"size(100)"` | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	// register model | ||||
| 	orm.RegisterModel(new(User)) | ||||
| 
 | ||||
| 	// set default database | ||||
| 	orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) | ||||
| 	 | ||||
| 	// create table | ||||
| 	orm.RunSyncdb("default", false, true)	 | ||||
| } | ||||
| 
 | ||||
| func main() { | ||||
| 	o := orm.NewOrm() | ||||
| 
 | ||||
| 	user := User{Name: "slene"} | ||||
| 
 | ||||
| 	// insert | ||||
| 	id, err := o.Insert(&user) | ||||
| 
 | ||||
| 	// update | ||||
| 	user.Name = "astaxie" | ||||
| 	num, err := o.Update(&user) | ||||
| 
 | ||||
| 	// read one | ||||
| 	u := User{Id: user.Id} | ||||
| 	err = o.Read(&u) | ||||
| 
 | ||||
| 	// delete | ||||
| 	num, err = o.Delete(&u)	 | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| #### Next with relation | ||||
| 
 | ||||
| ```go | ||||
| type Post struct { | ||||
| 	Id    int    `orm:"auto"` | ||||
| 	Title string `orm:"size(100)"` | ||||
| 	User  *User  `orm:"rel(fk)"` | ||||
| } | ||||
| 
 | ||||
| var posts []*Post | ||||
| qs := o.QueryTable("post") | ||||
| num, err := qs.Filter("User__Name", "slene").All(&posts) | ||||
| ``` | ||||
| 
 | ||||
| #### Use Raw sql | ||||
| 
 | ||||
| If you don't like ORM,use Raw SQL to query / mapping without ORM setting | ||||
| 
 | ||||
| ```go | ||||
| var maps []Params | ||||
| num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps) | ||||
| if num > 0 { | ||||
| 	fmt.Println(maps[0]["id"]) | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| #### Transaction | ||||
| 
 | ||||
| ```go | ||||
| o.Begin() | ||||
| ... | ||||
| user := User{Name: "slene"} | ||||
| id, err := o.Insert(&user) | ||||
| if err == nil { | ||||
| 	o.Commit() | ||||
| } else { | ||||
| 	o.Rollback() | ||||
| } | ||||
| 
 | ||||
| ``` | ||||
| 
 | ||||
| #### Debug Log Queries | ||||
| 
 | ||||
| In development env, you can simple use | ||||
| 
 | ||||
| ```go | ||||
| func main() { | ||||
| 	orm.Debug = true | ||||
| ... | ||||
| ``` | ||||
| 
 | ||||
| enable log queries. | ||||
| 
 | ||||
| output include all queries, such as exec / prepare / transaction. | ||||
| 
 | ||||
| like this: | ||||
| 
 | ||||
| ```go | ||||
| [ORM] - 2013-08-09 13:18:16 - [Queries/default] - [    db.Exec /     0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene` | ||||
| ... | ||||
| ``` | ||||
| 
 | ||||
| note: not recommend use this in product env. | ||||
| 
 | ||||
| ## Docs | ||||
| 
 | ||||
| more details and examples in docs and test | ||||
| 
 | ||||
| [documents](http://beego.me/docs/mvc/model/overview.md) | ||||
| 
 | ||||
							
								
								
									
										283
									
								
								pkg/orm/cmd.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										283
									
								
								pkg/orm/cmd.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,283 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| type commander interface { | ||||
| 	Parse([]string) | ||||
| 	Run() error | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| 	commands = make(map[string]commander) | ||||
| ) | ||||
| 
 | ||||
| // print help. | ||||
| func printHelp(errs ...string) { | ||||
| 	content := `orm command usage: | ||||
| 
 | ||||
|     syncdb     - auto create tables | ||||
|     sqlall     - print sql of create tables | ||||
|     help       - print this help | ||||
| ` | ||||
| 
 | ||||
| 	if len(errs) > 0 { | ||||
| 		fmt.Println(errs[0]) | ||||
| 	} | ||||
| 	fmt.Println(content) | ||||
| 	os.Exit(2) | ||||
| } | ||||
| 
 | ||||
| // RunCommand listen for orm command and then run it if command arguments passed. | ||||
| func RunCommand() { | ||||
| 	if len(os.Args) < 2 || os.Args[1] != "orm" { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	BootStrap() | ||||
| 
 | ||||
| 	args := argString(os.Args[2:]) | ||||
| 	name := args.Get(0) | ||||
| 
 | ||||
| 	if name == "help" { | ||||
| 		printHelp() | ||||
| 	} | ||||
| 
 | ||||
| 	if cmd, ok := commands[name]; ok { | ||||
| 		cmd.Parse(os.Args[3:]) | ||||
| 		cmd.Run() | ||||
| 		os.Exit(0) | ||||
| 	} else { | ||||
| 		if name == "" { | ||||
| 			printHelp() | ||||
| 		} else { | ||||
| 			printHelp(fmt.Sprintf("unknown command %s", name)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // sync database struct command interface. | ||||
| type commandSyncDb struct { | ||||
| 	al        *alias | ||||
| 	force     bool | ||||
| 	verbose   bool | ||||
| 	noInfo    bool | ||||
| 	rtOnError bool | ||||
| } | ||||
| 
 | ||||
| // parse orm command line arguments. | ||||
| func (d *commandSyncDb) Parse(args []string) { | ||||
| 	var name string | ||||
| 
 | ||||
| 	flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError) | ||||
| 	flagSet.StringVar(&name, "db", "default", "DataBase alias name") | ||||
| 	flagSet.BoolVar(&d.force, "force", false, "drop tables before create") | ||||
| 	flagSet.BoolVar(&d.verbose, "v", false, "verbose info") | ||||
| 	flagSet.Parse(args) | ||||
| 
 | ||||
| 	d.al = getDbAlias(name) | ||||
| } | ||||
| 
 | ||||
| // run orm line command. | ||||
| func (d *commandSyncDb) Run() error { | ||||
| 	var drops []string | ||||
| 	if d.force { | ||||
| 		drops = getDbDropSQL(d.al) | ||||
| 	} | ||||
| 
 | ||||
| 	db := d.al.DB | ||||
| 
 | ||||
| 	if d.force { | ||||
| 		for i, mi := range modelCache.allOrdered() { | ||||
| 			query := drops[i] | ||||
| 			if !d.noInfo { | ||||
| 				fmt.Printf("drop table `%s`\n", mi.table) | ||||
| 			} | ||||
| 			_, err := db.Exec(query) | ||||
| 			if d.verbose { | ||||
| 				fmt.Printf("    %s\n\n", query) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				if d.rtOnError { | ||||
| 					return err | ||||
| 				} | ||||
| 				fmt.Printf("    %s\n", err.Error()) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	sqls, indexes := getDbCreateSQL(d.al) | ||||
| 
 | ||||
| 	tables, err := d.al.DbBaser.GetTables(db) | ||||
| 	if err != nil { | ||||
| 		if d.rtOnError { | ||||
| 			return err | ||||
| 		} | ||||
| 		fmt.Printf("    %s\n", err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	for i, mi := range modelCache.allOrdered() { | ||||
| 		if tables[mi.table] { | ||||
| 			if !d.noInfo { | ||||
| 				fmt.Printf("table `%s` already exists, skip\n", mi.table) | ||||
| 			} | ||||
| 
 | ||||
| 			var fields []*fieldInfo | ||||
| 			columns, err := d.al.DbBaser.GetColumns(db, mi.table) | ||||
| 			if err != nil { | ||||
| 				if d.rtOnError { | ||||
| 					return err | ||||
| 				} | ||||
| 				fmt.Printf("    %s\n", err.Error()) | ||||
| 			} | ||||
| 
 | ||||
| 			for _, fi := range mi.fields.fieldsDB { | ||||
| 				if _, ok := columns[fi.column]; !ok { | ||||
| 					fields = append(fields, fi) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for _, fi := range fields { | ||||
| 				query := getColumnAddQuery(d.al, fi) | ||||
| 
 | ||||
| 				if !d.noInfo { | ||||
| 					fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) | ||||
| 				} | ||||
| 
 | ||||
| 				_, err := db.Exec(query) | ||||
| 				if d.verbose { | ||||
| 					fmt.Printf("    %s\n", query) | ||||
| 				} | ||||
| 				if err != nil { | ||||
| 					if d.rtOnError { | ||||
| 						return err | ||||
| 					} | ||||
| 					fmt.Printf("    %s\n", err.Error()) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for _, idx := range indexes[mi.table] { | ||||
| 				if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { | ||||
| 					if !d.noInfo { | ||||
| 						fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) | ||||
| 					} | ||||
| 
 | ||||
| 					query := idx.SQL | ||||
| 					_, err := db.Exec(query) | ||||
| 					if d.verbose { | ||||
| 						fmt.Printf("    %s\n", query) | ||||
| 					} | ||||
| 					if err != nil { | ||||
| 						if d.rtOnError { | ||||
| 							return err | ||||
| 						} | ||||
| 						fmt.Printf("    %s\n", err.Error()) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		if !d.noInfo { | ||||
| 			fmt.Printf("create table `%s` \n", mi.table) | ||||
| 		} | ||||
| 
 | ||||
| 		queries := []string{sqls[i]} | ||||
| 		for _, idx := range indexes[mi.table] { | ||||
| 			queries = append(queries, idx.SQL) | ||||
| 		} | ||||
| 
 | ||||
| 		for _, query := range queries { | ||||
| 			_, err := db.Exec(query) | ||||
| 			if d.verbose { | ||||
| 				query = "    " + strings.Join(strings.Split(query, "\n"), "\n    ") | ||||
| 				fmt.Println(query) | ||||
| 			} | ||||
| 			if err != nil { | ||||
| 				if d.rtOnError { | ||||
| 					return err | ||||
| 				} | ||||
| 				fmt.Printf("    %s\n", err.Error()) | ||||
| 			} | ||||
| 		} | ||||
| 		if d.verbose { | ||||
| 			fmt.Println("") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // database creation commander interface implement. | ||||
| type commandSQLAll struct { | ||||
| 	al *alias | ||||
| } | ||||
| 
 | ||||
| // parse orm command line arguments. | ||||
| func (d *commandSQLAll) Parse(args []string) { | ||||
| 	var name string | ||||
| 
 | ||||
| 	flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError) | ||||
| 	flagSet.StringVar(&name, "db", "default", "DataBase alias name") | ||||
| 	flagSet.Parse(args) | ||||
| 
 | ||||
| 	d.al = getDbAlias(name) | ||||
| } | ||||
| 
 | ||||
| // run orm line command. | ||||
| func (d *commandSQLAll) Run() error { | ||||
| 	sqls, indexes := getDbCreateSQL(d.al) | ||||
| 	var all []string | ||||
| 	for i, mi := range modelCache.allOrdered() { | ||||
| 		queries := []string{sqls[i]} | ||||
| 		for _, idx := range indexes[mi.table] { | ||||
| 			queries = append(queries, idx.SQL) | ||||
| 		} | ||||
| 		sql := strings.Join(queries, "\n") | ||||
| 		all = append(all, sql) | ||||
| 	} | ||||
| 	fmt.Println(strings.Join(all, "\n\n")) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	commands["syncdb"] = new(commandSyncDb) | ||||
| 	commands["sqlall"] = new(commandSQLAll) | ||||
| } | ||||
| 
 | ||||
| // RunSyncdb run syncdb command line. | ||||
| // name means table's alias name. default is "default". | ||||
| // force means run next sql if the current is error. | ||||
| // verbose means show all info when running command or not. | ||||
| func RunSyncdb(name string, force bool, verbose bool) error { | ||||
| 	BootStrap() | ||||
| 
 | ||||
| 	al := getDbAlias(name) | ||||
| 	cmd := new(commandSyncDb) | ||||
| 	cmd.al = al | ||||
| 	cmd.force = force | ||||
| 	cmd.noInfo = !verbose | ||||
| 	cmd.verbose = verbose | ||||
| 	cmd.rtOnError = true | ||||
| 	return cmd.Run() | ||||
| } | ||||
							
								
								
									
										320
									
								
								pkg/orm/cmd_utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										320
									
								
								pkg/orm/cmd_utils.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,320 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| type dbIndex struct { | ||||
| 	Table string | ||||
| 	Name  string | ||||
| 	SQL   string | ||||
| } | ||||
| 
 | ||||
| // create database drop sql. | ||||
| func getDbDropSQL(al *alias) (sqls []string) { | ||||
| 	if len(modelCache.cache) == 0 { | ||||
| 		fmt.Println("no Model found, need register your model") | ||||
| 		os.Exit(2) | ||||
| 	} | ||||
| 
 | ||||
| 	Q := al.DbBaser.TableQuote() | ||||
| 
 | ||||
| 	for _, mi := range modelCache.allOrdered() { | ||||
| 		sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) | ||||
| 	} | ||||
| 	return sqls | ||||
| } | ||||
| 
 | ||||
| // get database column type string. | ||||
| func getColumnTyp(al *alias, fi *fieldInfo) (col string) { | ||||
| 	T := al.DbBaser.DbTypes() | ||||
| 	fieldType := fi.fieldType | ||||
| 	fieldSize := fi.size | ||||
| 
 | ||||
| checkColumn: | ||||
| 	switch fieldType { | ||||
| 	case TypeBooleanField: | ||||
| 		col = T["bool"] | ||||
| 	case TypeVarCharField: | ||||
| 		if al.Driver == DRPostgres && fi.toText { | ||||
| 			col = T["string-text"] | ||||
| 		} else { | ||||
| 			col = fmt.Sprintf(T["string"], fieldSize) | ||||
| 		} | ||||
| 	case TypeCharField: | ||||
| 		col = fmt.Sprintf(T["string-char"], fieldSize) | ||||
| 	case TypeTextField: | ||||
| 		col = T["string-text"] | ||||
| 	case TypeTimeField: | ||||
| 		col = T["time.Time-clock"] | ||||
| 	case TypeDateField: | ||||
| 		col = T["time.Time-date"] | ||||
| 	case TypeDateTimeField: | ||||
| 		col = T["time.Time"] | ||||
| 	case TypeBitField: | ||||
| 		col = T["int8"] | ||||
| 	case TypeSmallIntegerField: | ||||
| 		col = T["int16"] | ||||
| 	case TypeIntegerField: | ||||
| 		col = T["int32"] | ||||
| 	case TypeBigIntegerField: | ||||
| 		if al.Driver == DRSqlite { | ||||
| 			fieldType = TypeIntegerField | ||||
| 			goto checkColumn | ||||
| 		} | ||||
| 		col = T["int64"] | ||||
| 	case TypePositiveBitField: | ||||
| 		col = T["uint8"] | ||||
| 	case TypePositiveSmallIntegerField: | ||||
| 		col = T["uint16"] | ||||
| 	case TypePositiveIntegerField: | ||||
| 		col = T["uint32"] | ||||
| 	case TypePositiveBigIntegerField: | ||||
| 		col = T["uint64"] | ||||
| 	case TypeFloatField: | ||||
| 		col = T["float64"] | ||||
| 	case TypeDecimalField: | ||||
| 		s := T["float64-decimal"] | ||||
| 		if !strings.Contains(s, "%d") { | ||||
| 			col = s | ||||
| 		} else { | ||||
| 			col = fmt.Sprintf(s, fi.digits, fi.decimals) | ||||
| 		} | ||||
| 	case TypeJSONField: | ||||
| 		if al.Driver != DRPostgres { | ||||
| 			fieldType = TypeVarCharField | ||||
| 			goto checkColumn | ||||
| 		} | ||||
| 		col = T["json"] | ||||
| 	case TypeJsonbField: | ||||
| 		if al.Driver != DRPostgres { | ||||
| 			fieldType = TypeVarCharField | ||||
| 			goto checkColumn | ||||
| 		} | ||||
| 		col = T["jsonb"] | ||||
| 	case RelForeignKey, RelOneToOne: | ||||
| 		fieldType = fi.relModelInfo.fields.pk.fieldType | ||||
| 		fieldSize = fi.relModelInfo.fields.pk.size | ||||
| 		goto checkColumn | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // create alter sql string. | ||||
| func getColumnAddQuery(al *alias, fi *fieldInfo) string { | ||||
| 	Q := al.DbBaser.TableQuote() | ||||
| 	typ := getColumnTyp(al, fi) | ||||
| 
 | ||||
| 	if !fi.null { | ||||
| 		typ += " " + "NOT NULL" | ||||
| 	} | ||||
| 
 | ||||
| 	return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", | ||||
| 		Q, fi.mi.table, Q, | ||||
| 		Q, fi.column, Q, | ||||
| 		typ, getColumnDefault(fi), | ||||
| 	) | ||||
| } | ||||
| 
 | ||||
| // create database creation string. | ||||
| func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { | ||||
| 	if len(modelCache.cache) == 0 { | ||||
| 		fmt.Println("no Model found, need register your model") | ||||
| 		os.Exit(2) | ||||
| 	} | ||||
| 
 | ||||
| 	Q := al.DbBaser.TableQuote() | ||||
| 	T := al.DbBaser.DbTypes() | ||||
| 	sep := fmt.Sprintf("%s, %s", Q, Q) | ||||
| 
 | ||||
| 	tableIndexes = make(map[string][]dbIndex) | ||||
| 
 | ||||
| 	for _, mi := range modelCache.allOrdered() { | ||||
| 		sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) | ||||
| 		sql += fmt.Sprintf("--  Table Structure for `%s`\n", mi.fullName) | ||||
| 		sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) | ||||
| 
 | ||||
| 		sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) | ||||
| 
 | ||||
| 		columns := make([]string, 0, len(mi.fields.fieldsDB)) | ||||
| 
 | ||||
| 		sqlIndexes := [][]string{} | ||||
| 
 | ||||
| 		for _, fi := range mi.fields.fieldsDB { | ||||
| 
 | ||||
| 			column := fmt.Sprintf("    %s%s%s ", Q, fi.column, Q) | ||||
| 			col := getColumnTyp(al, fi) | ||||
| 
 | ||||
| 			if fi.auto { | ||||
| 				switch al.Driver { | ||||
| 				case DRSqlite, DRPostgres: | ||||
| 					column += T["auto"] | ||||
| 				default: | ||||
| 					column += col + " " + T["auto"] | ||||
| 				} | ||||
| 			} else if fi.pk { | ||||
| 				column += col + " " + T["pk"] | ||||
| 			} else { | ||||
| 				column += col | ||||
| 
 | ||||
| 				if !fi.null { | ||||
| 					column += " " + "NOT NULL" | ||||
| 				} | ||||
| 
 | ||||
| 				//if fi.initial.String() != "" { | ||||
| 				//	column += " DEFAULT " + fi.initial.String() | ||||
| 				//} | ||||
| 
 | ||||
| 				// Append attribute DEFAULT | ||||
| 				column += getColumnDefault(fi) | ||||
| 
 | ||||
| 				if fi.unique { | ||||
| 					column += " " + "UNIQUE" | ||||
| 				} | ||||
| 
 | ||||
| 				if fi.index { | ||||
| 					sqlIndexes = append(sqlIndexes, []string{fi.column}) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if strings.Contains(column, "%COL%") { | ||||
| 				column = strings.Replace(column, "%COL%", fi.column, -1) | ||||
| 			} | ||||
| 			 | ||||
| 			if fi.description != "" && al.Driver!=DRSqlite { | ||||
| 				column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) | ||||
| 			} | ||||
| 
 | ||||
| 			columns = append(columns, column) | ||||
| 		} | ||||
| 
 | ||||
| 		if mi.model != nil { | ||||
| 			allnames := getTableUnique(mi.addrField) | ||||
| 			if !mi.manual && len(mi.uniques) > 0 { | ||||
| 				allnames = append(allnames, mi.uniques) | ||||
| 			} | ||||
| 			for _, names := range allnames { | ||||
| 				cols := make([]string, 0, len(names)) | ||||
| 				for _, name := range names { | ||||
| 					if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { | ||||
| 						cols = append(cols, fi.column) | ||||
| 					} else { | ||||
| 						panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) | ||||
| 					} | ||||
| 				} | ||||
| 				column := fmt.Sprintf("    UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) | ||||
| 				columns = append(columns, column) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		sql += strings.Join(columns, ",\n") | ||||
| 		sql += "\n)" | ||||
| 
 | ||||
| 		if al.Driver == DRMySQL { | ||||
| 			var engine string | ||||
| 			if mi.model != nil { | ||||
| 				engine = getTableEngine(mi.addrField) | ||||
| 			} | ||||
| 			if engine == "" { | ||||
| 				engine = al.Engine | ||||
| 			} | ||||
| 			sql += " ENGINE=" + engine | ||||
| 		} | ||||
| 
 | ||||
| 		sql += ";" | ||||
| 		sqls = append(sqls, sql) | ||||
| 
 | ||||
| 		if mi.model != nil { | ||||
| 			for _, names := range getTableIndex(mi.addrField) { | ||||
| 				cols := make([]string, 0, len(names)) | ||||
| 				for _, name := range names { | ||||
| 					if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { | ||||
| 						cols = append(cols, fi.column) | ||||
| 					} else { | ||||
| 						panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) | ||||
| 					} | ||||
| 				} | ||||
| 				sqlIndexes = append(sqlIndexes, cols) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		for _, names := range sqlIndexes { | ||||
| 			name := mi.table + "_" + strings.Join(names, "_") | ||||
| 			cols := strings.Join(names, sep) | ||||
| 			sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) | ||||
| 
 | ||||
| 			index := dbIndex{} | ||||
| 			index.Table = mi.table | ||||
| 			index.Name = name | ||||
| 			index.SQL = sql | ||||
| 
 | ||||
| 			tableIndexes[mi.table] = append(tableIndexes[mi.table], index) | ||||
| 		} | ||||
| 
 | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands | ||||
| func getColumnDefault(fi *fieldInfo) string { | ||||
| 	var ( | ||||
| 		v, t, d string | ||||
| 	) | ||||
| 
 | ||||
| 	// Skip default attribute if field is in relations | ||||
| 	if fi.rel || fi.reverse { | ||||
| 		return v | ||||
| 	} | ||||
| 
 | ||||
| 	t = " DEFAULT '%s' " | ||||
| 
 | ||||
| 	// These defaults will be useful if there no config value orm:"default" and NOT NULL is on | ||||
| 	switch fi.fieldType { | ||||
| 	case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: | ||||
| 		return v | ||||
| 
 | ||||
| 	case TypeBitField, TypeSmallIntegerField, TypeIntegerField, | ||||
| 		TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, | ||||
| 		TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, | ||||
| 		TypeDecimalField: | ||||
| 		t = " DEFAULT %s " | ||||
| 		d = "0" | ||||
| 	case TypeBooleanField: | ||||
| 		t = " DEFAULT %s " | ||||
| 		d = "FALSE" | ||||
| 	case TypeJSONField, TypeJsonbField: | ||||
| 		d = "{}" | ||||
| 	} | ||||
| 
 | ||||
| 	if fi.colDefault { | ||||
| 		if !fi.initial.Exist() { | ||||
| 			v = fmt.Sprintf(t, "") | ||||
| 		} else { | ||||
| 			v = fmt.Sprintf(t, fi.initial.String()) | ||||
| 		} | ||||
| 	} else { | ||||
| 		if !fi.null { | ||||
| 			v = fmt.Sprintf(t, d) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return v | ||||
| } | ||||
							
								
								
									
										1902
									
								
								pkg/orm/db.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1902
									
								
								pkg/orm/db.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										466
									
								
								pkg/orm/db_alias.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										466
									
								
								pkg/orm/db_alias.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,466 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	lru "github.com/hashicorp/golang-lru" | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // DriverType database driver constant int. | ||||
| type DriverType int | ||||
| 
 | ||||
| // Enum the Database driver | ||||
| const ( | ||||
| 	_          DriverType = iota // int enum type | ||||
| 	DRMySQL                      // mysql | ||||
| 	DRSqlite                     // sqlite | ||||
| 	DROracle                     // oracle | ||||
| 	DRPostgres                   // pgsql | ||||
| 	DRTiDB                       // TiDB | ||||
| ) | ||||
| 
 | ||||
| // database driver string. | ||||
| type driver string | ||||
| 
 | ||||
| // get type constant int of current driver.. | ||||
| func (d driver) Type() DriverType { | ||||
| 	a, _ := dataBaseCache.get(string(d)) | ||||
| 	return a.Driver | ||||
| } | ||||
| 
 | ||||
| // get name of current driver | ||||
| func (d driver) Name() string { | ||||
| 	return string(d) | ||||
| } | ||||
| 
 | ||||
| // check driver iis implemented Driver interface or not. | ||||
| var _ Driver = new(driver) | ||||
| 
 | ||||
| var ( | ||||
| 	dataBaseCache = &_dbCache{cache: make(map[string]*alias)} | ||||
| 	drivers       = map[string]DriverType{ | ||||
| 		"mysql":    DRMySQL, | ||||
| 		"postgres": DRPostgres, | ||||
| 		"sqlite3":  DRSqlite, | ||||
| 		"tidb":     DRTiDB, | ||||
| 		"oracle":   DROracle, | ||||
| 		"oci8":     DROracle, // github.com/mattn/go-oci8 | ||||
| 		"ora":      DROracle, //https://github.com/rana/ora | ||||
| 	} | ||||
| 	dbBasers = map[DriverType]dbBaser{ | ||||
| 		DRMySQL:    newdbBaseMysql(), | ||||
| 		DRSqlite:   newdbBaseSqlite(), | ||||
| 		DROracle:   newdbBaseOracle(), | ||||
| 		DRPostgres: newdbBasePostgres(), | ||||
| 		DRTiDB:     newdbBaseTidb(), | ||||
| 	} | ||||
| ) | ||||
| 
 | ||||
| // database alias cacher. | ||||
| type _dbCache struct { | ||||
| 	mux   sync.RWMutex | ||||
| 	cache map[string]*alias | ||||
| } | ||||
| 
 | ||||
| // add database alias with original name. | ||||
| func (ac *_dbCache) add(name string, al *alias) (added bool) { | ||||
| 	ac.mux.Lock() | ||||
| 	defer ac.mux.Unlock() | ||||
| 	if _, ok := ac.cache[name]; !ok { | ||||
| 		ac.cache[name] = al | ||||
| 		added = true | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // get database alias if cached. | ||||
| func (ac *_dbCache) get(name string) (al *alias, ok bool) { | ||||
| 	ac.mux.RLock() | ||||
| 	defer ac.mux.RUnlock() | ||||
| 	al, ok = ac.cache[name] | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // get default alias. | ||||
| func (ac *_dbCache) getDefault() (al *alias) { | ||||
| 	al, _ = ac.get("default") | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| type DB struct { | ||||
| 	*sync.RWMutex | ||||
| 	DB             *sql.DB | ||||
| 	stmtDecorators *lru.Cache | ||||
| } | ||||
| 
 | ||||
| func (d *DB) Begin() (*sql.Tx, error) { | ||||
| 	return d.DB.Begin() | ||||
| } | ||||
| 
 | ||||
| func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { | ||||
| 	return d.DB.BeginTx(ctx, opts) | ||||
| } | ||||
| 
 | ||||
| //su must call release to release *sql.Stmt after using | ||||
| func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { | ||||
| 	d.RLock() | ||||
| 	c, ok := d.stmtDecorators.Get(query) | ||||
| 	if ok { | ||||
| 		c.(*stmtDecorator).acquire() | ||||
| 		d.RUnlock() | ||||
| 		return c.(*stmtDecorator), nil | ||||
| 	} | ||||
| 	d.RUnlock() | ||||
| 
 | ||||
| 	d.Lock() | ||||
| 	c, ok = d.stmtDecorators.Get(query) | ||||
| 	if ok { | ||||
| 		c.(*stmtDecorator).acquire() | ||||
| 		d.Unlock() | ||||
| 		return c.(*stmtDecorator), nil | ||||
| 	} | ||||
| 
 | ||||
| 	stmt, err := d.Prepare(query) | ||||
| 	if err != nil { | ||||
| 		d.Unlock() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	sd := newStmtDecorator(stmt) | ||||
| 	sd.acquire() | ||||
| 	d.stmtDecorators.Add(query, sd) | ||||
| 	d.Unlock() | ||||
| 
 | ||||
| 	return sd, nil | ||||
| } | ||||
| 
 | ||||
| func (d *DB) Prepare(query string) (*sql.Stmt, error) { | ||||
| 	return d.DB.Prepare(query) | ||||
| } | ||||
| 
 | ||||
| func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||
| 	return d.DB.PrepareContext(ctx, query) | ||||
| } | ||||
| 
 | ||||
| func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||
| 	sd, err := d.getStmtDecorator(query) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	stmt := sd.getStmt() | ||||
| 	defer sd.release() | ||||
| 	return stmt.Exec(args...) | ||||
| } | ||||
| 
 | ||||
| func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||
| 	sd, err := d.getStmtDecorator(query) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	stmt := sd.getStmt() | ||||
| 	defer sd.release() | ||||
| 	return stmt.ExecContext(ctx, args...) | ||||
| } | ||||
| 
 | ||||
| func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	sd, err := d.getStmtDecorator(query) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	stmt := sd.getStmt() | ||||
| 	defer sd.release() | ||||
| 	return stmt.Query(args...) | ||||
| } | ||||
| 
 | ||||
| func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	sd, err := d.getStmtDecorator(query) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	stmt := sd.getStmt() | ||||
| 	defer sd.release() | ||||
| 	return stmt.QueryContext(ctx, args...) | ||||
| } | ||||
| 
 | ||||
| func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { | ||||
| 	sd, err := d.getStmtDecorator(query) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	stmt := sd.getStmt() | ||||
| 	defer sd.release() | ||||
| 	return stmt.QueryRow(args...) | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	sd, err := d.getStmtDecorator(query) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	stmt := sd.getStmt() | ||||
| 	defer sd.release() | ||||
| 	return stmt.QueryRowContext(ctx, args) | ||||
| } | ||||
| 
 | ||||
| type alias struct { | ||||
| 	Name         string | ||||
| 	Driver       DriverType | ||||
| 	DriverName   string | ||||
| 	DataSource   string | ||||
| 	MaxIdleConns int | ||||
| 	MaxOpenConns int | ||||
| 	DB           *DB | ||||
| 	DbBaser      dbBaser | ||||
| 	TZ           *time.Location | ||||
| 	Engine       string | ||||
| } | ||||
| 
 | ||||
| func detectTZ(al *alias) { | ||||
| 	// orm timezone system match database | ||||
| 	// default use Local | ||||
| 	al.TZ = DefaultTimeLoc | ||||
| 
 | ||||
| 	if al.DriverName == "sphinx" { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	switch al.Driver { | ||||
| 	case DRMySQL: | ||||
| 		row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") | ||||
| 		var tz string | ||||
| 		row.Scan(&tz) | ||||
| 		if len(tz) >= 8 { | ||||
| 			if tz[0] != '-' { | ||||
| 				tz = "+" + tz | ||||
| 			} | ||||
| 			t, err := time.Parse("-07:00:00", tz) | ||||
| 			if err == nil { | ||||
| 				if t.Location().String() != "" { | ||||
| 					al.TZ = t.Location() | ||||
| 				} | ||||
| 			} else { | ||||
| 				DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		// get default engine from current database | ||||
| 		row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") | ||||
| 		var engine string | ||||
| 		var tx bool | ||||
| 		row.Scan(&engine, &tx) | ||||
| 
 | ||||
| 		if engine != "" { | ||||
| 			al.Engine = engine | ||||
| 		} else { | ||||
| 			al.Engine = "INNODB" | ||||
| 		} | ||||
| 
 | ||||
| 	case DRSqlite, DROracle: | ||||
| 		al.TZ = time.UTC | ||||
| 
 | ||||
| 	case DRPostgres: | ||||
| 		row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") | ||||
| 		var tz string | ||||
| 		row.Scan(&tz) | ||||
| 		loc, err := time.LoadLocation(tz) | ||||
| 		if err == nil { | ||||
| 			al.TZ = loc | ||||
| 		} else { | ||||
| 			DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { | ||||
| 	al := new(alias) | ||||
| 	al.Name = aliasName | ||||
| 	al.DriverName = driverName | ||||
| 	al.DB = &DB{ | ||||
| 		RWMutex:        new(sync.RWMutex), | ||||
| 		DB:             db, | ||||
| 		stmtDecorators: newStmtDecoratorLruWithEvict(), | ||||
| 	} | ||||
| 
 | ||||
| 	if dr, ok := drivers[driverName]; ok { | ||||
| 		al.DbBaser = dbBasers[dr] | ||||
| 		al.Driver = dr | ||||
| 	} else { | ||||
| 		return nil, fmt.Errorf("driver name `%s` have not registered", driverName) | ||||
| 	} | ||||
| 
 | ||||
| 	err := db.Ping() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	if !dataBaseCache.add(aliasName, al) { | ||||
| 		return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) | ||||
| 	} | ||||
| 
 | ||||
| 	return al, nil | ||||
| } | ||||
| 
 | ||||
| // AddAliasWthDB add a aliasName for the drivename | ||||
| func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { | ||||
| 	_, err := addAliasWthDB(aliasName, driverName, db) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. | ||||
| func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { | ||||
| 	var ( | ||||
| 		err error | ||||
| 		db  *sql.DB | ||||
| 		al  *alias | ||||
| 	) | ||||
| 
 | ||||
| 	db, err = sql.Open(driverName, dataSource) | ||||
| 	if err != nil { | ||||
| 		err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) | ||||
| 		goto end | ||||
| 	} | ||||
| 
 | ||||
| 	al, err = addAliasWthDB(aliasName, driverName, db) | ||||
| 	if err != nil { | ||||
| 		goto end | ||||
| 	} | ||||
| 
 | ||||
| 	al.DataSource = dataSource | ||||
| 
 | ||||
| 	detectTZ(al) | ||||
| 
 | ||||
| 	for i, v := range params { | ||||
| 		switch i { | ||||
| 		case 0: | ||||
| 			SetMaxIdleConns(al.Name, v) | ||||
| 		case 1: | ||||
| 			SetMaxOpenConns(al.Name, v) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| end: | ||||
| 	if err != nil { | ||||
| 		if db != nil { | ||||
| 			db.Close() | ||||
| 		} | ||||
| 		DebugLog.Println(err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. | ||||
| func RegisterDriver(driverName string, typ DriverType) error { | ||||
| 	if t, ok := drivers[driverName]; !ok { | ||||
| 		drivers[driverName] = typ | ||||
| 	} else { | ||||
| 		if t != typ { | ||||
| 			return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // SetDataBaseTZ Change the database default used timezone | ||||
| func SetDataBaseTZ(aliasName string, tz *time.Location) error { | ||||
| 	if al, ok := dataBaseCache.get(aliasName); ok { | ||||
| 		al.TZ = tz | ||||
| 	} else { | ||||
| 		return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name | ||||
| func SetMaxIdleConns(aliasName string, maxIdleConns int) { | ||||
| 	al := getDbAlias(aliasName) | ||||
| 	al.MaxIdleConns = maxIdleConns | ||||
| 	al.DB.DB.SetMaxIdleConns(maxIdleConns) | ||||
| } | ||||
| 
 | ||||
| // SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name | ||||
| func SetMaxOpenConns(aliasName string, maxOpenConns int) { | ||||
| 	al := getDbAlias(aliasName) | ||||
| 	al.MaxOpenConns = maxOpenConns | ||||
| 	al.DB.DB.SetMaxOpenConns(maxOpenConns) | ||||
| 	// for tip go 1.2 | ||||
| 	if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { | ||||
| 		fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // GetDB Get *sql.DB from registered database by db alias name. | ||||
| // Use "default" as alias name if you not set. | ||||
| func GetDB(aliasNames ...string) (*sql.DB, error) { | ||||
| 	var name string | ||||
| 	if len(aliasNames) > 0 { | ||||
| 		name = aliasNames[0] | ||||
| 	} else { | ||||
| 		name = "default" | ||||
| 	} | ||||
| 	al, ok := dataBaseCache.get(name) | ||||
| 	if ok { | ||||
| 		return al.DB.DB, nil | ||||
| 	} | ||||
| 	return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) | ||||
| } | ||||
| 
 | ||||
| type stmtDecorator struct { | ||||
| 	wg sync.WaitGroup | ||||
| 	stmt *sql.Stmt | ||||
| } | ||||
| 
 | ||||
| func (s *stmtDecorator) getStmt() *sql.Stmt { | ||||
| 	return s.stmt | ||||
| } | ||||
| 
 | ||||
| // acquire will add one | ||||
| // since this method will be used inside read lock scope, | ||||
| // so we can not do more things here | ||||
| // we should think about refactor this | ||||
| func (s *stmtDecorator) acquire() { | ||||
| 	s.wg.Add(1) | ||||
| } | ||||
| 
 | ||||
| func (s *stmtDecorator) release() { | ||||
| 	s.wg.Done() | ||||
| } | ||||
| 
 | ||||
| //garbage recycle for stmt | ||||
| func (s *stmtDecorator) destroy() { | ||||
| 	go func() { | ||||
| 		s.wg.Wait() | ||||
| 		_ = s.stmt.Close() | ||||
| 	}() | ||||
| } | ||||
| 
 | ||||
| func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { | ||||
| 	return &stmtDecorator{ | ||||
| 		stmt: sqlStmt, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func newStmtDecoratorLruWithEvict() *lru.Cache { | ||||
| 	cache, _ := lru.NewWithEvict(1000, func(key interface{}, value interface{}) { | ||||
| 		value.(*stmtDecorator).destroy() | ||||
| 	}) | ||||
| 	return cache | ||||
| } | ||||
							
								
								
									
										183
									
								
								pkg/orm/db_mysql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								pkg/orm/db_mysql.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,183 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // mysql operators. | ||||
| var mysqlOperators = map[string]string{ | ||||
| 	"exact":     "= ?", | ||||
| 	"iexact":    "LIKE ?", | ||||
| 	"contains":  "LIKE BINARY ?", | ||||
| 	"icontains": "LIKE ?", | ||||
| 	// "regex":       "REGEXP BINARY ?", | ||||
| 	// "iregex":      "REGEXP ?", | ||||
| 	"gt":          "> ?", | ||||
| 	"gte":         ">= ?", | ||||
| 	"lt":          "< ?", | ||||
| 	"lte":         "<= ?", | ||||
| 	"eq":          "= ?", | ||||
| 	"ne":          "!= ?", | ||||
| 	"startswith":  "LIKE BINARY ?", | ||||
| 	"endswith":    "LIKE BINARY ?", | ||||
| 	"istartswith": "LIKE ?", | ||||
| 	"iendswith":   "LIKE ?", | ||||
| } | ||||
| 
 | ||||
| // mysql column field types. | ||||
| var mysqlTypes = map[string]string{ | ||||
| 	"auto":            "AUTO_INCREMENT NOT NULL PRIMARY KEY", | ||||
| 	"pk":              "NOT NULL PRIMARY KEY", | ||||
| 	"bool":            "bool", | ||||
| 	"string":          "varchar(%d)", | ||||
| 	"string-char":     "char(%d)", | ||||
| 	"string-text":     "longtext", | ||||
| 	"time.Time-date":  "date", | ||||
| 	"time.Time":       "datetime", | ||||
| 	"int8":            "tinyint", | ||||
| 	"int16":           "smallint", | ||||
| 	"int32":           "integer", | ||||
| 	"int64":           "bigint", | ||||
| 	"uint8":           "tinyint unsigned", | ||||
| 	"uint16":          "smallint unsigned", | ||||
| 	"uint32":          "integer unsigned", | ||||
| 	"uint64":          "bigint unsigned", | ||||
| 	"float64":         "double precision", | ||||
| 	"float64-decimal": "numeric(%d, %d)", | ||||
| } | ||||
| 
 | ||||
| // mysql dbBaser implementation. | ||||
| type dbBaseMysql struct { | ||||
| 	dbBase | ||||
| } | ||||
| 
 | ||||
| var _ dbBaser = new(dbBaseMysql) | ||||
| 
 | ||||
| // get mysql operator. | ||||
| func (d *dbBaseMysql) OperatorSQL(operator string) string { | ||||
| 	return mysqlOperators[operator] | ||||
| } | ||||
| 
 | ||||
| // get mysql table field types. | ||||
| func (d *dbBaseMysql) DbTypes() map[string]string { | ||||
| 	return mysqlTypes | ||||
| } | ||||
| 
 | ||||
| // show table sql for mysql. | ||||
| func (d *dbBaseMysql) ShowTablesQuery() string { | ||||
| 	return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" | ||||
| } | ||||
| 
 | ||||
| // show columns sql of table for mysql. | ||||
| func (d *dbBaseMysql) ShowColumnsQuery(table string) string { | ||||
| 	return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ | ||||
| 		"WHERE table_schema = DATABASE() AND table_name = '%s'", table) | ||||
| } | ||||
| 
 | ||||
| // execute sql to check index exist. | ||||
| func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| 	row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ | ||||
| 		"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) | ||||
| 	var cnt int | ||||
| 	row.Scan(&cnt) | ||||
| 	return cnt > 0 | ||||
| } | ||||
| 
 | ||||
| // InsertOrUpdate a row | ||||
| // If your primary key or unique column conflict will update | ||||
| // If no will insert | ||||
| // Add "`" for mysql sql building | ||||
| func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { | ||||
| 	var iouStr string | ||||
| 	argsMap := map[string]string{} | ||||
| 
 | ||||
| 	iouStr = "ON DUPLICATE KEY UPDATE" | ||||
| 
 | ||||
| 	//Get on the key-value pairs | ||||
| 	for _, v := range args { | ||||
| 		kv := strings.Split(v, "=") | ||||
| 		if len(kv) == 2 { | ||||
| 			argsMap[strings.ToLower(kv[0])] = kv[1] | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	isMulti := false | ||||
| 	names := make([]string, 0, len(mi.fields.dbcols)-1) | ||||
| 	Q := d.ins.TableQuote() | ||||
| 	values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	marks := make([]string, len(names)) | ||||
| 	updateValues := make([]interface{}, 0) | ||||
| 	updates := make([]string, len(names)) | ||||
| 
 | ||||
| 	for i, v := range names { | ||||
| 		marks[i] = "?" | ||||
| 		valueStr := argsMap[strings.ToLower(v)] | ||||
| 		if valueStr != "" { | ||||
| 			updates[i] = "`" + v + "`" + "=" + valueStr | ||||
| 		} else { | ||||
| 			updates[i] = "`" + v + "`" + "=?" | ||||
| 			updateValues = append(updateValues, values[i]) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	values = append(values, updateValues...) | ||||
| 
 | ||||
| 	sep := fmt.Sprintf("%s, %s", Q, Q) | ||||
| 	qmarks := strings.Join(marks, ", ") | ||||
| 	qupdates := strings.Join(updates, ", ") | ||||
| 	columns := strings.Join(names, sep) | ||||
| 
 | ||||
| 	multi := len(values) / len(names) | ||||
| 
 | ||||
| 	if isMulti { | ||||
| 		qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks | ||||
| 	} | ||||
| 	//conflitValue maybe is a int,can`t use fmt.Sprintf | ||||
| 	query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	if isMulti || !d.ins.HasReturningID(mi, &query) { | ||||
| 		res, err := q.Exec(query, values...) | ||||
| 		if err == nil { | ||||
| 			if isMulti { | ||||
| 				return res.RowsAffected() | ||||
| 			} | ||||
| 			return res.LastInsertId() | ||||
| 		} | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	row := q.QueryRow(query, values...) | ||||
| 	var id int64 | ||||
| 	err = row.Scan(&id) | ||||
| 	return id, err | ||||
| } | ||||
| 
 | ||||
| // create new mysql dbBaser. | ||||
| func newdbBaseMysql() dbBaser { | ||||
| 	b := new(dbBaseMysql) | ||||
| 	b.ins = b | ||||
| 	return b | ||||
| } | ||||
							
								
								
									
										137
									
								
								pkg/orm/db_oracle.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								pkg/orm/db_oracle.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,137 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // oracle operators. | ||||
| var oracleOperators = map[string]string{ | ||||
| 	"exact":       "= ?", | ||||
| 	"gt":          "> ?", | ||||
| 	"gte":         ">= ?", | ||||
| 	"lt":          "< ?", | ||||
| 	"lte":         "<= ?", | ||||
| 	"//iendswith": "LIKE ?", | ||||
| } | ||||
| 
 | ||||
| // oracle column field types. | ||||
| var oracleTypes = map[string]string{ | ||||
| 	"pk":              "NOT NULL PRIMARY KEY", | ||||
| 	"bool":            "bool", | ||||
| 	"string":          "VARCHAR2(%d)", | ||||
| 	"string-char":     "CHAR(%d)", | ||||
| 	"string-text":     "VARCHAR2(%d)", | ||||
| 	"time.Time-date":  "DATE", | ||||
| 	"time.Time":       "TIMESTAMP", | ||||
| 	"int8":            "INTEGER", | ||||
| 	"int16":           "INTEGER", | ||||
| 	"int32":           "INTEGER", | ||||
| 	"int64":           "INTEGER", | ||||
| 	"uint8":           "INTEGER", | ||||
| 	"uint16":          "INTEGER", | ||||
| 	"uint32":          "INTEGER", | ||||
| 	"uint64":          "INTEGER", | ||||
| 	"float64":         "NUMBER", | ||||
| 	"float64-decimal": "NUMBER(%d, %d)", | ||||
| } | ||||
| 
 | ||||
| // oracle dbBaser | ||||
| type dbBaseOracle struct { | ||||
| 	dbBase | ||||
| } | ||||
| 
 | ||||
| var _ dbBaser = new(dbBaseOracle) | ||||
| 
 | ||||
| // create oracle dbBaser. | ||||
| func newdbBaseOracle() dbBaser { | ||||
| 	b := new(dbBaseOracle) | ||||
| 	b.ins = b | ||||
| 	return b | ||||
| } | ||||
| 
 | ||||
| // OperatorSQL get oracle operator. | ||||
| func (d *dbBaseOracle) OperatorSQL(operator string) string { | ||||
| 	return oracleOperators[operator] | ||||
| } | ||||
| 
 | ||||
| // DbTypes get oracle table field types. | ||||
| func (d *dbBaseOracle) DbTypes() map[string]string { | ||||
| 	return oracleTypes | ||||
| } | ||||
| 
 | ||||
| //ShowTablesQuery show all the tables in database | ||||
| func (d *dbBaseOracle) ShowTablesQuery() string { | ||||
| 	return "SELECT TABLE_NAME FROM USER_TABLES" | ||||
| } | ||||
| 
 | ||||
| // Oracle | ||||
| func (d *dbBaseOracle) ShowColumnsQuery(table string) string { | ||||
| 	return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+ | ||||
| 		"WHERE TABLE_NAME ='%s'", strings.ToUpper(table)) | ||||
| } | ||||
| 
 | ||||
| // check index is exist | ||||
| func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| 	row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ | ||||
| 		"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ | ||||
| 		"AND  USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) | ||||
| 
 | ||||
| 	var cnt int | ||||
| 	row.Scan(&cnt) | ||||
| 	return cnt > 0 | ||||
| } | ||||
| 
 | ||||
| // execute insert sql with given struct and given values. | ||||
| // insert the given values, not the field values in struct. | ||||
| func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { | ||||
| 	Q := d.ins.TableQuote() | ||||
| 
 | ||||
| 	marks := make([]string, len(names)) | ||||
| 	for i := range marks { | ||||
| 		marks[i] = ":" + names[i] | ||||
| 	} | ||||
| 
 | ||||
| 	sep := fmt.Sprintf("%s, %s", Q, Q) | ||||
| 	qmarks := strings.Join(marks, ", ") | ||||
| 	columns := strings.Join(names, sep) | ||||
| 
 | ||||
| 	multi := len(values) / len(names) | ||||
| 
 | ||||
| 	if isMulti { | ||||
| 		qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks | ||||
| 	} | ||||
| 
 | ||||
| 	query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) | ||||
| 
 | ||||
| 	d.ins.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	if isMulti || !d.ins.HasReturningID(mi, &query) { | ||||
| 		res, err := q.Exec(query, values...) | ||||
| 		if err == nil { | ||||
| 			if isMulti { | ||||
| 				return res.RowsAffected() | ||||
| 			} | ||||
| 			return res.LastInsertId() | ||||
| 		} | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	row := q.QueryRow(query, values...) | ||||
| 	var id int64 | ||||
| 	err := row.Scan(&id) | ||||
| 	return id, err | ||||
| } | ||||
							
								
								
									
										189
									
								
								pkg/orm/db_postgres.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								pkg/orm/db_postgres.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,189 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| ) | ||||
| 
 | ||||
| // postgresql operators. | ||||
| var postgresOperators = map[string]string{ | ||||
| 	"exact":       "= ?", | ||||
| 	"iexact":      "= UPPER(?)", | ||||
| 	"contains":    "LIKE ?", | ||||
| 	"icontains":   "LIKE UPPER(?)", | ||||
| 	"gt":          "> ?", | ||||
| 	"gte":         ">= ?", | ||||
| 	"lt":          "< ?", | ||||
| 	"lte":         "<= ?", | ||||
| 	"eq":          "= ?", | ||||
| 	"ne":          "!= ?", | ||||
| 	"startswith":  "LIKE ?", | ||||
| 	"endswith":    "LIKE ?", | ||||
| 	"istartswith": "LIKE UPPER(?)", | ||||
| 	"iendswith":   "LIKE UPPER(?)", | ||||
| } | ||||
| 
 | ||||
| // postgresql column field types. | ||||
| var postgresTypes = map[string]string{ | ||||
| 	"auto":            "serial NOT NULL PRIMARY KEY", | ||||
| 	"pk":              "NOT NULL PRIMARY KEY", | ||||
| 	"bool":            "bool", | ||||
| 	"string":          "varchar(%d)", | ||||
| 	"string-char":     "char(%d)", | ||||
| 	"string-text":     "text", | ||||
| 	"time.Time-date":  "date", | ||||
| 	"time.Time":       "timestamp with time zone", | ||||
| 	"int8":            `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, | ||||
| 	"int16":           "smallint", | ||||
| 	"int32":           "integer", | ||||
| 	"int64":           "bigint", | ||||
| 	"uint8":           `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, | ||||
| 	"uint16":          `integer CHECK("%COL%" >= 0)`, | ||||
| 	"uint32":          `bigint CHECK("%COL%" >= 0)`, | ||||
| 	"uint64":          `bigint CHECK("%COL%" >= 0)`, | ||||
| 	"float64":         "double precision", | ||||
| 	"float64-decimal": "numeric(%d, %d)", | ||||
| 	"json":            "json", | ||||
| 	"jsonb":           "jsonb", | ||||
| } | ||||
| 
 | ||||
| // postgresql dbBaser. | ||||
| type dbBasePostgres struct { | ||||
| 	dbBase | ||||
| } | ||||
| 
 | ||||
| var _ dbBaser = new(dbBasePostgres) | ||||
| 
 | ||||
| // get postgresql operator. | ||||
| func (d *dbBasePostgres) OperatorSQL(operator string) string { | ||||
| 	return postgresOperators[operator] | ||||
| } | ||||
| 
 | ||||
| // generate functioned sql string, such as contains(text). | ||||
| func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { | ||||
| 	switch operator { | ||||
| 	case "contains", "startswith", "endswith": | ||||
| 		*leftCol = fmt.Sprintf("%s::text", *leftCol) | ||||
| 	case "iexact", "icontains", "istartswith", "iendswith": | ||||
| 		*leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // postgresql unsupports updating joined record. | ||||
| func (d *dbBasePostgres) SupportUpdateJoin() bool { | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func (d *dbBasePostgres) MaxLimit() uint64 { | ||||
| 	return 0 | ||||
| } | ||||
| 
 | ||||
| // postgresql quote is ". | ||||
| func (d *dbBasePostgres) TableQuote() string { | ||||
| 	return `"` | ||||
| } | ||||
| 
 | ||||
| // postgresql value placeholder is $n. | ||||
| // replace default ? to $n. | ||||
| func (d *dbBasePostgres) ReplaceMarks(query *string) { | ||||
| 	q := *query | ||||
| 	num := 0 | ||||
| 	for _, c := range q { | ||||
| 		if c == '?' { | ||||
| 			num++ | ||||
| 		} | ||||
| 	} | ||||
| 	if num == 0 { | ||||
| 		return | ||||
| 	} | ||||
| 	data := make([]byte, 0, len(q)+num) | ||||
| 	num = 1 | ||||
| 	for i := 0; i < len(q); i++ { | ||||
| 		c := q[i] | ||||
| 		if c == '?' { | ||||
| 			data = append(data, '$') | ||||
| 			data = append(data, []byte(strconv.Itoa(num))...) | ||||
| 			num++ | ||||
| 		} else { | ||||
| 			data = append(data, c) | ||||
| 		} | ||||
| 	} | ||||
| 	*query = string(data) | ||||
| } | ||||
| 
 | ||||
| // make returning sql support for postgresql. | ||||
| func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { | ||||
| 	fi := mi.fields.pk | ||||
| 	if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	if query != nil { | ||||
| 		*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| // sync auto key | ||||
| func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { | ||||
| 	if len(autoFields) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	Q := d.ins.TableQuote() | ||||
| 	for _, name := range autoFields { | ||||
| 		query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", | ||||
| 			mi.table, name, | ||||
| 			Q, name, Q, | ||||
| 			Q, mi.table, Q) | ||||
| 		if _, err := db.Exec(query); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // show table sql for postgresql. | ||||
| func (d *dbBasePostgres) ShowTablesQuery() string { | ||||
| 	return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" | ||||
| } | ||||
| 
 | ||||
| // show table columns sql for postgresql. | ||||
| func (d *dbBasePostgres) ShowColumnsQuery(table string) string { | ||||
| 	return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) | ||||
| } | ||||
| 
 | ||||
| // get column types of postgresql. | ||||
| func (d *dbBasePostgres) DbTypes() map[string]string { | ||||
| 	return postgresTypes | ||||
| } | ||||
| 
 | ||||
| // check index exist in postgresql. | ||||
| func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| 	query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) | ||||
| 	row := db.QueryRow(query) | ||||
| 	var cnt int | ||||
| 	row.Scan(&cnt) | ||||
| 	return cnt > 0 | ||||
| } | ||||
| 
 | ||||
| // create new postgresql dbBaser. | ||||
| func newdbBasePostgres() dbBaser { | ||||
| 	b := new(dbBasePostgres) | ||||
| 	b.ins = b | ||||
| 	return b | ||||
| } | ||||
							
								
								
									
										161
									
								
								pkg/orm/db_sqlite.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								pkg/orm/db_sqlite.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,161 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // sqlite operators. | ||||
| var sqliteOperators = map[string]string{ | ||||
| 	"exact":       "= ?", | ||||
| 	"iexact":      "LIKE ? ESCAPE '\\'", | ||||
| 	"contains":    "LIKE ? ESCAPE '\\'", | ||||
| 	"icontains":   "LIKE ? ESCAPE '\\'", | ||||
| 	"gt":          "> ?", | ||||
| 	"gte":         ">= ?", | ||||
| 	"lt":          "< ?", | ||||
| 	"lte":         "<= ?", | ||||
| 	"eq":          "= ?", | ||||
| 	"ne":          "!= ?", | ||||
| 	"startswith":  "LIKE ? ESCAPE '\\'", | ||||
| 	"endswith":    "LIKE ? ESCAPE '\\'", | ||||
| 	"istartswith": "LIKE ? ESCAPE '\\'", | ||||
| 	"iendswith":   "LIKE ? ESCAPE '\\'", | ||||
| } | ||||
| 
 | ||||
| // sqlite column types. | ||||
| var sqliteTypes = map[string]string{ | ||||
| 	"auto":            "integer NOT NULL PRIMARY KEY AUTOINCREMENT", | ||||
| 	"pk":              "NOT NULL PRIMARY KEY", | ||||
| 	"bool":            "bool", | ||||
| 	"string":          "varchar(%d)", | ||||
| 	"string-char":     "character(%d)", | ||||
| 	"string-text":     "text", | ||||
| 	"time.Time-date":  "date", | ||||
| 	"time.Time":       "datetime", | ||||
| 	"int8":            "tinyint", | ||||
| 	"int16":           "smallint", | ||||
| 	"int32":           "integer", | ||||
| 	"int64":           "bigint", | ||||
| 	"uint8":           "tinyint unsigned", | ||||
| 	"uint16":          "smallint unsigned", | ||||
| 	"uint32":          "integer unsigned", | ||||
| 	"uint64":          "bigint unsigned", | ||||
| 	"float64":         "real", | ||||
| 	"float64-decimal": "decimal", | ||||
| } | ||||
| 
 | ||||
| // sqlite dbBaser. | ||||
| type dbBaseSqlite struct { | ||||
| 	dbBase | ||||
| } | ||||
| 
 | ||||
| var _ dbBaser = new(dbBaseSqlite) | ||||
| 
 | ||||
| // override base db read for update behavior as SQlite does not support syntax | ||||
| func (d *dbBaseSqlite) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { | ||||
| 	if isForUpdate { | ||||
| 		DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") | ||||
| 	} | ||||
| 	return d.dbBase.Read(q, mi, ind, tz, cols, false) | ||||
| } | ||||
| 
 | ||||
| // get sqlite operator. | ||||
| func (d *dbBaseSqlite) OperatorSQL(operator string) string { | ||||
| 	return sqliteOperators[operator] | ||||
| } | ||||
| 
 | ||||
| // generate functioned sql for sqlite. | ||||
| // only support DATE(text). | ||||
| func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { | ||||
| 	if fi.fieldType == TypeDateField { | ||||
| 		*leftCol = fmt.Sprintf("DATE(%s)", *leftCol) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // unable updating joined record in sqlite. | ||||
| func (d *dbBaseSqlite) SupportUpdateJoin() bool { | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // max int in sqlite. | ||||
| func (d *dbBaseSqlite) MaxLimit() uint64 { | ||||
| 	return 9223372036854775807 | ||||
| } | ||||
| 
 | ||||
| // get column types in sqlite. | ||||
| func (d *dbBaseSqlite) DbTypes() map[string]string { | ||||
| 	return sqliteTypes | ||||
| } | ||||
| 
 | ||||
| // get show tables sql in sqlite. | ||||
| func (d *dbBaseSqlite) ShowTablesQuery() string { | ||||
| 	return "SELECT name FROM sqlite_master WHERE type = 'table'" | ||||
| } | ||||
| 
 | ||||
| // get columns in sqlite. | ||||
| func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { | ||||
| 	query := d.ins.ShowColumnsQuery(table) | ||||
| 	rows, err := db.Query(query) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	columns := make(map[string][3]string) | ||||
| 	for rows.Next() { | ||||
| 		var tmp, name, typ, null sql.NullString | ||||
| 		err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		columns[name.String] = [3]string{name.String, typ.String, null.String} | ||||
| 	} | ||||
| 
 | ||||
| 	return columns, nil | ||||
| } | ||||
| 
 | ||||
| // get show columns sql in sqlite. | ||||
| func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { | ||||
| 	return fmt.Sprintf("pragma table_info('%s')", table) | ||||
| } | ||||
| 
 | ||||
| // check index exist in sqlite. | ||||
| func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| 	query := fmt.Sprintf("PRAGMA index_list('%s')", table) | ||||
| 	rows, err := db.Query(query) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	defer rows.Close() | ||||
| 	for rows.Next() { | ||||
| 		var tmp, index sql.NullString | ||||
| 		rows.Scan(&tmp, &index, &tmp, &tmp, &tmp) | ||||
| 		if name == index.String { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| // create new sqlite dbBaser. | ||||
| func newdbBaseSqlite() dbBaser { | ||||
| 	b := new(dbBaseSqlite) | ||||
| 	b.ins = b | ||||
| 	return b | ||||
| } | ||||
							
								
								
									
										482
									
								
								pkg/orm/db_tables.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										482
									
								
								pkg/orm/db_tables.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,482 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // table info struct. | ||||
| type dbTable struct { | ||||
| 	id    int | ||||
| 	index string | ||||
| 	name  string | ||||
| 	names []string | ||||
| 	sel   bool | ||||
| 	inner bool | ||||
| 	mi    *modelInfo | ||||
| 	fi    *fieldInfo | ||||
| 	jtl   *dbTable | ||||
| } | ||||
| 
 | ||||
| // tables collection struct, contains some tables. | ||||
| type dbTables struct { | ||||
| 	tablesM map[string]*dbTable | ||||
| 	tables  []*dbTable | ||||
| 	mi      *modelInfo | ||||
| 	base    dbBaser | ||||
| 	skipEnd bool | ||||
| } | ||||
| 
 | ||||
| // set table info to collection. | ||||
| // if not exist, create new. | ||||
| func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { | ||||
| 	name := strings.Join(names, ExprSep) | ||||
| 	if j, ok := t.tablesM[name]; ok { | ||||
| 		j.name = name | ||||
| 		j.mi = mi | ||||
| 		j.fi = fi | ||||
| 		j.inner = inner | ||||
| 	} else { | ||||
| 		i := len(t.tables) + 1 | ||||
| 		jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} | ||||
| 		t.tablesM[name] = jt | ||||
| 		t.tables = append(t.tables, jt) | ||||
| 	} | ||||
| 	return t.tablesM[name] | ||||
| } | ||||
| 
 | ||||
| // add table info to collection. | ||||
| func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { | ||||
| 	name := strings.Join(names, ExprSep) | ||||
| 	if _, ok := t.tablesM[name]; !ok { | ||||
| 		i := len(t.tables) + 1 | ||||
| 		jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil} | ||||
| 		t.tablesM[name] = jt | ||||
| 		t.tables = append(t.tables, jt) | ||||
| 		return jt, true | ||||
| 	} | ||||
| 	return t.tablesM[name], false | ||||
| } | ||||
| 
 | ||||
| // get table info in collection. | ||||
| func (t *dbTables) get(name string) (*dbTable, bool) { | ||||
| 	j, ok := t.tablesM[name] | ||||
| 	return j, ok | ||||
| } | ||||
| 
 | ||||
| // get related fields info in recursive depth loop. | ||||
| // loop once, depth decreases one. | ||||
| func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { | ||||
| 	if depth < 0 || fi.fieldType == RelManyToMany { | ||||
| 		return related | ||||
| 	} | ||||
| 
 | ||||
| 	if prefix == "" { | ||||
| 		prefix = fi.name | ||||
| 	} else { | ||||
| 		prefix = prefix + ExprSep + fi.name | ||||
| 	} | ||||
| 	related = append(related, prefix) | ||||
| 
 | ||||
| 	depth-- | ||||
| 	for _, fi := range fi.relModelInfo.fields.fieldsRel { | ||||
| 		related = t.loopDepth(depth, prefix, fi, related) | ||||
| 	} | ||||
| 
 | ||||
| 	return related | ||||
| } | ||||
| 
 | ||||
| // parse related fields. | ||||
| func (t *dbTables) parseRelated(rels []string, depth int) { | ||||
| 
 | ||||
| 	relsNum := len(rels) | ||||
| 	related := make([]string, relsNum) | ||||
| 	copy(related, rels) | ||||
| 
 | ||||
| 	relDepth := depth | ||||
| 
 | ||||
| 	if relsNum != 0 { | ||||
| 		relDepth = 0 | ||||
| 	} | ||||
| 
 | ||||
| 	relDepth-- | ||||
| 	for _, fi := range t.mi.fields.fieldsRel { | ||||
| 		related = t.loopDepth(relDepth, "", fi, related) | ||||
| 	} | ||||
| 
 | ||||
| 	for i, s := range related { | ||||
| 		var ( | ||||
| 			exs    = strings.Split(s, ExprSep) | ||||
| 			names  = make([]string, 0, len(exs)) | ||||
| 			mmi    = t.mi | ||||
| 			cancel = true | ||||
| 			jtl    *dbTable | ||||
| 		) | ||||
| 
 | ||||
| 		inner := true | ||||
| 
 | ||||
| 		for _, ex := range exs { | ||||
| 			if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { | ||||
| 				names = append(names, fi.name) | ||||
| 				mmi = fi.relModelInfo | ||||
| 
 | ||||
| 				if fi.null || t.skipEnd { | ||||
| 					inner = false | ||||
| 				} | ||||
| 
 | ||||
| 				jt := t.set(names, mmi, fi, inner) | ||||
| 				jt.jtl = jtl | ||||
| 
 | ||||
| 				if fi.reverse { | ||||
| 					cancel = false | ||||
| 				} | ||||
| 
 | ||||
| 				if cancel { | ||||
| 					jt.sel = depth > 0 | ||||
| 
 | ||||
| 					if i < relsNum { | ||||
| 						jt.sel = true | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				jtl = jt | ||||
| 
 | ||||
| 			} else { | ||||
| 				panic(fmt.Errorf("unknown model/table name `%s`", ex)) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // generate join string. | ||||
| func (t *dbTables) getJoinSQL() (join string) { | ||||
| 	Q := t.base.TableQuote() | ||||
| 
 | ||||
| 	for _, jt := range t.tables { | ||||
| 		if jt.inner { | ||||
| 			join += "INNER JOIN " | ||||
| 		} else { | ||||
| 			join += "LEFT OUTER JOIN " | ||||
| 		} | ||||
| 		var ( | ||||
| 			table  string | ||||
| 			t1, t2 string | ||||
| 			c1, c2 string | ||||
| 		) | ||||
| 		t1 = "T0" | ||||
| 		if jt.jtl != nil { | ||||
| 			t1 = jt.jtl.index | ||||
| 		} | ||||
| 		t2 = jt.index | ||||
| 		table = jt.mi.table | ||||
| 
 | ||||
| 		switch { | ||||
| 		case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: | ||||
| 			c1 = jt.fi.mi.fields.pk.column | ||||
| 			for _, ffi := range jt.mi.fields.fieldsRel { | ||||
| 				if jt.fi.mi == ffi.relModelInfo { | ||||
| 					c2 = ffi.column | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 		default: | ||||
| 			c1 = jt.fi.column | ||||
| 			c2 = jt.fi.relModelInfo.fields.pk.column | ||||
| 
 | ||||
| 			if jt.fi.reverse { | ||||
| 				c1 = jt.mi.fields.pk.column | ||||
| 				c2 = jt.fi.reverseFieldInfo.column | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2, | ||||
| 			t2, Q, c2, Q, t1, Q, c1, Q) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // parse orm model struct field tag expression. | ||||
| func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { | ||||
| 	var ( | ||||
| 		jtl *dbTable | ||||
| 		fi  *fieldInfo | ||||
| 		fiN *fieldInfo | ||||
| 		mmi = mi | ||||
| 	) | ||||
| 
 | ||||
| 	num := len(exprs) - 1 | ||||
| 	var names []string | ||||
| 
 | ||||
| 	inner := true | ||||
| 
 | ||||
| loopFor: | ||||
| 	for i, ex := range exprs { | ||||
| 
 | ||||
| 		var ok, okN bool | ||||
| 
 | ||||
| 		if fiN != nil { | ||||
| 			fi = fiN | ||||
| 			ok = true | ||||
| 			fiN = nil | ||||
| 		} | ||||
| 
 | ||||
| 		if i == 0 { | ||||
| 			fi, ok = mmi.fields.GetByAny(ex) | ||||
| 		} | ||||
| 
 | ||||
| 		_ = okN | ||||
| 
 | ||||
| 		if ok { | ||||
| 
 | ||||
| 			isRel := fi.rel || fi.reverse | ||||
| 
 | ||||
| 			names = append(names, fi.name) | ||||
| 
 | ||||
| 			switch { | ||||
| 			case fi.rel: | ||||
| 				mmi = fi.relModelInfo | ||||
| 				if fi.fieldType == RelManyToMany { | ||||
| 					mmi = fi.relThroughModelInfo | ||||
| 				} | ||||
| 			case fi.reverse: | ||||
| 				mmi = fi.reverseFieldInfo.mi | ||||
| 			} | ||||
| 
 | ||||
| 			if i < num { | ||||
| 				fiN, okN = mmi.fields.GetByAny(exprs[i+1]) | ||||
| 			} | ||||
| 
 | ||||
| 			if isRel && (!fi.mi.isThrough || num != i) { | ||||
| 				if fi.null || t.skipEnd { | ||||
| 					inner = false | ||||
| 				} | ||||
| 
 | ||||
| 				if t.skipEnd && okN || !t.skipEnd { | ||||
| 					if t.skipEnd && okN && fiN.pk { | ||||
| 						goto loopEnd | ||||
| 					} | ||||
| 
 | ||||
| 					jt, _ := t.add(names, mmi, fi, inner) | ||||
| 					jt.jtl = jtl | ||||
| 					jtl = jt | ||||
| 				} | ||||
| 
 | ||||
| 			} | ||||
| 
 | ||||
| 			if num != i { | ||||
| 				continue | ||||
| 			} | ||||
| 
 | ||||
| 		loopEnd: | ||||
| 
 | ||||
| 			if i == 0 || jtl == nil { | ||||
| 				index = "T0" | ||||
| 			} else { | ||||
| 				index = jtl.index | ||||
| 			} | ||||
| 
 | ||||
| 			info = fi | ||||
| 
 | ||||
| 			if jtl == nil { | ||||
| 				name = fi.name | ||||
| 			} else { | ||||
| 				name = jtl.name + ExprSep + fi.name | ||||
| 			} | ||||
| 
 | ||||
| 			switch { | ||||
| 			case fi.rel: | ||||
| 
 | ||||
| 			case fi.reverse: | ||||
| 				switch fi.reverseFieldInfo.fieldType { | ||||
| 				case RelOneToOne, RelForeignKey: | ||||
| 					index = jtl.index | ||||
| 					info = fi.reverseFieldInfo.mi.fields.pk | ||||
| 					name = info.name | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			break loopFor | ||||
| 
 | ||||
| 		} else { | ||||
| 			index = "" | ||||
| 			name = "" | ||||
| 			info = nil | ||||
| 			success = false | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	success = index != "" && info != nil | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // generate condition sql. | ||||
| func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { | ||||
| 	if cond == nil || cond.IsEmpty() { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	Q := t.base.TableQuote() | ||||
| 
 | ||||
| 	mi := t.mi | ||||
| 
 | ||||
| 	for i, p := range cond.params { | ||||
| 		if i > 0 { | ||||
| 			if p.isOr { | ||||
| 				where += "OR " | ||||
| 			} else { | ||||
| 				where += "AND " | ||||
| 			} | ||||
| 		} | ||||
| 		if p.isNot { | ||||
| 			where += "NOT " | ||||
| 		} | ||||
| 		if p.isCond { | ||||
| 			w, ps := t.getCondSQL(p.cond, true, tz) | ||||
| 			if w != "" { | ||||
| 				w = fmt.Sprintf("( %s) ", w) | ||||
| 			} | ||||
| 			where += w | ||||
| 			params = append(params, ps...) | ||||
| 		} else { | ||||
| 			exprs := p.exprs | ||||
| 
 | ||||
| 			num := len(exprs) - 1 | ||||
| 			operator := "" | ||||
| 			if operators[exprs[num]] { | ||||
| 				operator = exprs[num] | ||||
| 				exprs = exprs[:num] | ||||
| 			} | ||||
| 
 | ||||
| 			index, _, fi, suc := t.parseExprs(mi, exprs) | ||||
| 			if !suc { | ||||
| 				panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) | ||||
| 			} | ||||
| 
 | ||||
| 			if operator == "" { | ||||
| 				operator = "exact" | ||||
| 			} | ||||
| 
 | ||||
| 			var operSQL string | ||||
| 			var args []interface{} | ||||
| 			if p.isRaw { | ||||
| 				operSQL = p.sql | ||||
| 			} else { | ||||
| 				operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) | ||||
| 			} | ||||
| 
 | ||||
| 			leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) | ||||
| 			t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) | ||||
| 
 | ||||
| 			where += fmt.Sprintf("%s %s ", leftCol, operSQL) | ||||
| 			params = append(params, args...) | ||||
| 
 | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if !sub && where != "" { | ||||
| 		where = "WHERE " + where | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // generate group sql. | ||||
| func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { | ||||
| 	if len(groups) == 0 { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	Q := t.base.TableQuote() | ||||
| 
 | ||||
| 	groupSqls := make([]string, 0, len(groups)) | ||||
| 	for _, group := range groups { | ||||
| 		exprs := strings.Split(group, ExprSep) | ||||
| 
 | ||||
| 		index, _, fi, suc := t.parseExprs(t.mi, exprs) | ||||
| 		if !suc { | ||||
| 			panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) | ||||
| 		} | ||||
| 
 | ||||
| 		groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) | ||||
| 	} | ||||
| 
 | ||||
| 	groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // generate order sql. | ||||
| func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { | ||||
| 	if len(orders) == 0 { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	Q := t.base.TableQuote() | ||||
| 
 | ||||
| 	orderSqls := make([]string, 0, len(orders)) | ||||
| 	for _, order := range orders { | ||||
| 		asc := "ASC" | ||||
| 		if order[0] == '-' { | ||||
| 			asc = "DESC" | ||||
| 			order = order[1:] | ||||
| 		} | ||||
| 		exprs := strings.Split(order, ExprSep) | ||||
| 
 | ||||
| 		index, _, fi, suc := t.parseExprs(t.mi, exprs) | ||||
| 		if !suc { | ||||
| 			panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) | ||||
| 		} | ||||
| 
 | ||||
| 		orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) | ||||
| 	} | ||||
| 
 | ||||
| 	orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // generate limit sql. | ||||
| func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { | ||||
| 	if limit == 0 { | ||||
| 		limit = int64(DefaultRowsLimit) | ||||
| 	} | ||||
| 	if limit < 0 { | ||||
| 		// no limit | ||||
| 		if offset > 0 { | ||||
| 			maxLimit := t.base.MaxLimit() | ||||
| 			if maxLimit == 0 { | ||||
| 				limits = fmt.Sprintf("OFFSET %d", offset) | ||||
| 			} else { | ||||
| 				limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) | ||||
| 			} | ||||
| 		} | ||||
| 	} else if offset <= 0 { | ||||
| 		limits = fmt.Sprintf("LIMIT %d", limit) | ||||
| 	} else { | ||||
| 		limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // crete new tables collection. | ||||
| func newDbTables(mi *modelInfo, base dbBaser) *dbTables { | ||||
| 	tables := &dbTables{} | ||||
| 	tables.tablesM = make(map[string]*dbTable) | ||||
| 	tables.mi = mi | ||||
| 	tables.base = base | ||||
| 	return tables | ||||
| } | ||||
							
								
								
									
										63
									
								
								pkg/orm/db_tidb.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								pkg/orm/db_tidb.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,63 @@ | ||||
| // Copyright 2015 TiDB Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| ) | ||||
| 
 | ||||
| // mysql dbBaser implementation. | ||||
| type dbBaseTidb struct { | ||||
| 	dbBase | ||||
| } | ||||
| 
 | ||||
| var _ dbBaser = new(dbBaseTidb) | ||||
| 
 | ||||
| // get mysql operator. | ||||
| func (d *dbBaseTidb) OperatorSQL(operator string) string { | ||||
| 	return mysqlOperators[operator] | ||||
| } | ||||
| 
 | ||||
| // get mysql table field types. | ||||
| func (d *dbBaseTidb) DbTypes() map[string]string { | ||||
| 	return mysqlTypes | ||||
| } | ||||
| 
 | ||||
| // show table sql for mysql. | ||||
| func (d *dbBaseTidb) ShowTablesQuery() string { | ||||
| 	return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" | ||||
| } | ||||
| 
 | ||||
| // show columns sql of table for mysql. | ||||
| func (d *dbBaseTidb) ShowColumnsQuery(table string) string { | ||||
| 	return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ | ||||
| 		"WHERE table_schema = DATABASE() AND table_name = '%s'", table) | ||||
| } | ||||
| 
 | ||||
| // execute sql to check index exist. | ||||
| func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { | ||||
| 	row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ | ||||
| 		"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) | ||||
| 	var cnt int | ||||
| 	row.Scan(&cnt) | ||||
| 	return cnt > 0 | ||||
| } | ||||
| 
 | ||||
| // create new mysql dbBaser. | ||||
| func newdbBaseTidb() dbBaser { | ||||
| 	b := new(dbBaseTidb) | ||||
| 	b.ins = b | ||||
| 	return b | ||||
| } | ||||
							
								
								
									
										177
									
								
								pkg/orm/db_utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								pkg/orm/db_utils.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,177 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // get table alias. | ||||
| func getDbAlias(name string) *alias { | ||||
| 	if al, ok := dataBaseCache.get(name); ok { | ||||
| 		return al | ||||
| 	} | ||||
| 	panic(fmt.Errorf("unknown DataBase alias name %s", name)) | ||||
| } | ||||
| 
 | ||||
| // get pk column info. | ||||
| func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { | ||||
| 	fi := mi.fields.pk | ||||
| 
 | ||||
| 	v := ind.FieldByIndex(fi.fieldIndex) | ||||
| 	if fi.fieldType&IsPositiveIntegerField > 0 { | ||||
| 		vu := v.Uint() | ||||
| 		exist = vu > 0 | ||||
| 		value = vu | ||||
| 	} else if fi.fieldType&IsIntegerField > 0 { | ||||
| 		vu := v.Int() | ||||
| 		exist = true | ||||
| 		value = vu | ||||
| 	} else if fi.fieldType&IsRelField > 0 { | ||||
| 		_, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) | ||||
| 	} else { | ||||
| 		vu := v.String() | ||||
| 		exist = vu != "" | ||||
| 		value = vu | ||||
| 	} | ||||
| 
 | ||||
| 	column = fi.column | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // get fields description as flatted string. | ||||
| func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { | ||||
| 
 | ||||
| outFor: | ||||
| 	for _, arg := range args { | ||||
| 		val := reflect.ValueOf(arg) | ||||
| 
 | ||||
| 		if arg == nil { | ||||
| 			params = append(params, arg) | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		kind := val.Kind() | ||||
| 		if kind == reflect.Ptr { | ||||
| 			val = val.Elem() | ||||
| 			kind = val.Kind() | ||||
| 			arg = val.Interface() | ||||
| 		} | ||||
| 
 | ||||
| 		switch kind { | ||||
| 		case reflect.String: | ||||
| 			v := val.String() | ||||
| 			if fi != nil { | ||||
| 				if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { | ||||
| 					var t time.Time | ||||
| 					var err error | ||||
| 					if len(v) >= 19 { | ||||
| 						s := v[:19] | ||||
| 						t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) | ||||
| 					} else if len(v) >= 10 { | ||||
| 						s := v | ||||
| 						if len(v) > 10 { | ||||
| 							s = v[:10] | ||||
| 						} | ||||
| 						t, err = time.ParseInLocation(formatDate, s, tz) | ||||
| 					} else { | ||||
| 						s := v | ||||
| 						if len(s) > 8 { | ||||
| 							s = v[:8] | ||||
| 						} | ||||
| 						t, err = time.ParseInLocation(formatTime, s, tz) | ||||
| 					} | ||||
| 					if err == nil { | ||||
| 						if fi.fieldType == TypeDateField { | ||||
| 							v = t.In(tz).Format(formatDate) | ||||
| 						} else if fi.fieldType == TypeDateTimeField { | ||||
| 							v = t.In(tz).Format(formatDateTime) | ||||
| 						} else { | ||||
| 							v = t.In(tz).Format(formatTime) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			arg = v | ||||
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 			arg = val.Int() | ||||
| 		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||
| 			arg = val.Uint() | ||||
| 		case reflect.Float32: | ||||
| 			arg, _ = StrTo(ToStr(arg)).Float64() | ||||
| 		case reflect.Float64: | ||||
| 			arg = val.Float() | ||||
| 		case reflect.Bool: | ||||
| 			arg = val.Bool() | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			if _, ok := arg.([]byte); ok { | ||||
| 				continue outFor | ||||
| 			} | ||||
| 
 | ||||
| 			var args []interface{} | ||||
| 			for i := 0; i < val.Len(); i++ { | ||||
| 				v := val.Index(i) | ||||
| 
 | ||||
| 				var vu interface{} | ||||
| 				if v.CanInterface() { | ||||
| 					vu = v.Interface() | ||||
| 				} | ||||
| 
 | ||||
| 				if vu == nil { | ||||
| 					continue | ||||
| 				} | ||||
| 
 | ||||
| 				args = append(args, vu) | ||||
| 			} | ||||
| 
 | ||||
| 			if len(args) > 0 { | ||||
| 				p := getFlatParams(fi, args, tz) | ||||
| 				params = append(params, p...) | ||||
| 			} | ||||
| 			continue outFor | ||||
| 		case reflect.Struct: | ||||
| 			if v, ok := arg.(time.Time); ok { | ||||
| 				if fi != nil && fi.fieldType == TypeDateField { | ||||
| 					arg = v.In(tz).Format(formatDate) | ||||
| 				} else if fi != nil && fi.fieldType == TypeDateTimeField { | ||||
| 					arg = v.In(tz).Format(formatDateTime) | ||||
| 				} else if fi != nil && fi.fieldType == TypeTimeField { | ||||
| 					arg = v.In(tz).Format(formatTime) | ||||
| 				} else { | ||||
| 					arg = v.In(tz).Format(formatDateTime) | ||||
| 				} | ||||
| 			} else { | ||||
| 				typ := val.Type() | ||||
| 				name := getFullName(typ) | ||||
| 				var value interface{} | ||||
| 				if mmi, ok := modelCache.getByFullName(name); ok { | ||||
| 					if _, vu, exist := getExistPk(mmi, val); exist { | ||||
| 						value = vu | ||||
| 					} | ||||
| 				} | ||||
| 				arg = value | ||||
| 
 | ||||
| 				if arg == nil { | ||||
| 					panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		params = append(params, arg) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										99
									
								
								pkg/orm/models.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								pkg/orm/models.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,99 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"sync" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	odCascade             = "cascade" | ||||
| 	odSetNULL             = "set_null" | ||||
| 	odSetDefault          = "set_default" | ||||
| 	odDoNothing           = "do_nothing" | ||||
| 	defaultStructTagName  = "orm" | ||||
| 	defaultStructTagDelim = ";" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	modelCache = &_modelCache{ | ||||
| 		cache:           make(map[string]*modelInfo), | ||||
| 		cacheByFullName: make(map[string]*modelInfo), | ||||
| 	} | ||||
| ) | ||||
| 
 | ||||
| // model info collection | ||||
| type _modelCache struct { | ||||
| 	sync.RWMutex    // only used outsite for bootStrap | ||||
| 	orders          []string | ||||
| 	cache           map[string]*modelInfo | ||||
| 	cacheByFullName map[string]*modelInfo | ||||
| 	done            bool | ||||
| } | ||||
| 
 | ||||
| // get all model info | ||||
| func (mc *_modelCache) all() map[string]*modelInfo { | ||||
| 	m := make(map[string]*modelInfo, len(mc.cache)) | ||||
| 	for k, v := range mc.cache { | ||||
| 		m[k] = v | ||||
| 	} | ||||
| 	return m | ||||
| } | ||||
| 
 | ||||
| // get ordered model info | ||||
| func (mc *_modelCache) allOrdered() []*modelInfo { | ||||
| 	m := make([]*modelInfo, 0, len(mc.orders)) | ||||
| 	for _, table := range mc.orders { | ||||
| 		m = append(m, mc.cache[table]) | ||||
| 	} | ||||
| 	return m | ||||
| } | ||||
| 
 | ||||
| // get model info by table name | ||||
| func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { | ||||
| 	mi, ok = mc.cache[table] | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // get model info by full name | ||||
| func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { | ||||
| 	mi, ok = mc.cacheByFullName[name] | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // set model info to collection | ||||
| func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { | ||||
| 	mii := mc.cache[table] | ||||
| 	mc.cache[table] = mi | ||||
| 	mc.cacheByFullName[mi.fullName] = mi | ||||
| 	if mii == nil { | ||||
| 		mc.orders = append(mc.orders, table) | ||||
| 	} | ||||
| 	return mii | ||||
| } | ||||
| 
 | ||||
| // clean all model info. | ||||
| func (mc *_modelCache) clean() { | ||||
| 	mc.orders = make([]string, 0) | ||||
| 	mc.cache = make(map[string]*modelInfo) | ||||
| 	mc.cacheByFullName = make(map[string]*modelInfo) | ||||
| 	mc.done = false | ||||
| } | ||||
| 
 | ||||
| // ResetModelCache Clean model cache. Then you can re-RegisterModel. | ||||
| // Common use this api for test case. | ||||
| func ResetModelCache() { | ||||
| 	modelCache.clean() | ||||
| } | ||||
							
								
								
									
										347
									
								
								pkg/orm/models_boot.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										347
									
								
								pkg/orm/models_boot.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,347 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"runtime/debug" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // register models. | ||||
| // PrefixOrSuffix means table name prefix or suffix. | ||||
| // isPrefix whether the prefix is prefix or suffix | ||||
| func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { | ||||
| 	val := reflect.ValueOf(model) | ||||
| 	typ := reflect.Indirect(val).Type() | ||||
| 
 | ||||
| 	if val.Kind() != reflect.Ptr { | ||||
| 		panic(fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ))) | ||||
| 	} | ||||
| 	// For this case: | ||||
| 	// u := &User{} | ||||
| 	// registerModel(&u) | ||||
| 	if typ.Kind() == reflect.Ptr { | ||||
| 		panic(fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) | ||||
| 	} | ||||
| 
 | ||||
| 	table := getTableName(val) | ||||
| 
 | ||||
| 	if PrefixOrSuffix != "" { | ||||
| 		if isPrefix { | ||||
| 			table = PrefixOrSuffix + table | ||||
| 		} else { | ||||
| 			table = table + PrefixOrSuffix | ||||
| 		} | ||||
| 	} | ||||
| 	// models's fullname is pkgpath + struct name | ||||
| 	name := getFullName(typ) | ||||
| 	if _, ok := modelCache.getByFullName(name); ok { | ||||
| 		fmt.Printf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name) | ||||
| 		os.Exit(2) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := modelCache.get(table); ok { | ||||
| 		fmt.Printf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table) | ||||
| 		os.Exit(2) | ||||
| 	} | ||||
| 
 | ||||
| 	mi := newModelInfo(val) | ||||
| 	if mi.fields.pk == nil { | ||||
| 	outFor: | ||||
| 		for _, fi := range mi.fields.fieldsDB { | ||||
| 			if strings.ToLower(fi.name) == "id" { | ||||
| 				switch fi.addrValue.Elem().Kind() { | ||||
| 				case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: | ||||
| 					fi.auto = true | ||||
| 					fi.pk = true | ||||
| 					mi.fields.pk = fi | ||||
| 					break outFor | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if mi.fields.pk == nil { | ||||
| 			fmt.Printf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name) | ||||
| 			os.Exit(2) | ||||
| 		} | ||||
| 
 | ||||
| 	} | ||||
| 
 | ||||
| 	mi.table = table | ||||
| 	mi.pkg = typ.PkgPath() | ||||
| 	mi.model = model | ||||
| 	mi.manual = true | ||||
| 
 | ||||
| 	modelCache.set(table, mi) | ||||
| } | ||||
| 
 | ||||
| // bootstrap models | ||||
| func bootStrap() { | ||||
| 	if modelCache.done { | ||||
| 		return | ||||
| 	} | ||||
| 	var ( | ||||
| 		err    error | ||||
| 		models map[string]*modelInfo | ||||
| 	) | ||||
| 	if dataBaseCache.getDefault() == nil { | ||||
| 		err = fmt.Errorf("must have one register DataBase alias named `default`") | ||||
| 		goto end | ||||
| 	} | ||||
| 
 | ||||
| 	// set rel and reverse model | ||||
| 	// RelManyToMany set the relTable | ||||
| 	models = modelCache.all() | ||||
| 	for _, mi := range models { | ||||
| 		for _, fi := range mi.fields.columns { | ||||
| 			if fi.rel || fi.reverse { | ||||
| 				elm := fi.addrValue.Type().Elem() | ||||
| 				if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { | ||||
| 					elm = elm.Elem() | ||||
| 				} | ||||
| 				// check the rel or reverse model already register | ||||
| 				name := getFullName(elm) | ||||
| 				mii, ok := modelCache.getByFullName(name) | ||||
| 				if !ok || mii.pkg != elm.PkgPath() { | ||||
| 					err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) | ||||
| 					goto end | ||||
| 				} | ||||
| 				fi.relModelInfo = mii | ||||
| 
 | ||||
| 				switch fi.fieldType { | ||||
| 				case RelManyToMany: | ||||
| 					if fi.relThrough != "" { | ||||
| 						if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { | ||||
| 							pn := fi.relThrough[:i] | ||||
| 							rmi, ok := modelCache.getByFullName(fi.relThrough) | ||||
| 							if !ok || pn != rmi.pkg { | ||||
| 								err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) | ||||
| 								goto end | ||||
| 							} | ||||
| 							fi.relThroughModelInfo = rmi | ||||
| 							fi.relTable = rmi.table | ||||
| 						} else { | ||||
| 							err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) | ||||
| 							goto end | ||||
| 						} | ||||
| 					} else { | ||||
| 						i := newM2MModelInfo(mi, mii) | ||||
| 						if fi.relTable != "" { | ||||
| 							i.table = fi.relTable | ||||
| 						} | ||||
| 						if v := modelCache.set(i.table, i); v != nil { | ||||
| 							err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) | ||||
| 							goto end | ||||
| 						} | ||||
| 						fi.relTable = i.table | ||||
| 						fi.relThroughModelInfo = i | ||||
| 					} | ||||
| 
 | ||||
| 					fi.relThroughModelInfo.isThrough = true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// check the rel filed while the relModelInfo also has filed point to current model | ||||
| 	// if not exist, add a new field to the relModelInfo | ||||
| 	models = modelCache.all() | ||||
| 	for _, mi := range models { | ||||
| 		for _, fi := range mi.fields.fieldsRel { | ||||
| 			switch fi.fieldType { | ||||
| 			case RelForeignKey, RelOneToOne, RelManyToMany: | ||||
| 				inModel := false | ||||
| 				for _, ffi := range fi.relModelInfo.fields.fieldsReverse { | ||||
| 					if ffi.relModelInfo == mi { | ||||
| 						inModel = true | ||||
| 						break | ||||
| 					} | ||||
| 				} | ||||
| 				if !inModel { | ||||
| 					rmi := fi.relModelInfo | ||||
| 					ffi := new(fieldInfo) | ||||
| 					ffi.name = mi.name | ||||
| 					ffi.column = ffi.name | ||||
| 					ffi.fullName = rmi.fullName + "." + ffi.name | ||||
| 					ffi.reverse = true | ||||
| 					ffi.relModelInfo = mi | ||||
| 					ffi.mi = rmi | ||||
| 					if fi.fieldType == RelOneToOne { | ||||
| 						ffi.fieldType = RelReverseOne | ||||
| 					} else { | ||||
| 						ffi.fieldType = RelReverseMany | ||||
| 					} | ||||
| 					if !rmi.fields.Add(ffi) { | ||||
| 						added := false | ||||
| 						for cnt := 0; cnt < 5; cnt++ { | ||||
| 							ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) | ||||
| 							ffi.column = ffi.name | ||||
| 							ffi.fullName = rmi.fullName + "." + ffi.name | ||||
| 							if added = rmi.fields.Add(ffi); added { | ||||
| 								break | ||||
| 							} | ||||
| 						} | ||||
| 						if !added { | ||||
| 							panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	models = modelCache.all() | ||||
| 	for _, mi := range models { | ||||
| 		for _, fi := range mi.fields.fieldsRel { | ||||
| 			switch fi.fieldType { | ||||
| 			case RelManyToMany: | ||||
| 				for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { | ||||
| 					switch ffi.fieldType { | ||||
| 					case RelOneToOne, RelForeignKey: | ||||
| 						if ffi.relModelInfo == fi.relModelInfo { | ||||
| 							fi.reverseFieldInfoTwo = ffi | ||||
| 						} | ||||
| 						if ffi.relModelInfo == mi { | ||||
| 							fi.reverseField = ffi.name | ||||
| 							fi.reverseFieldInfo = ffi | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 				if fi.reverseFieldInfoTwo == nil { | ||||
| 					err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", | ||||
| 						fi.relThroughModelInfo.fullName) | ||||
| 					goto end | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	models = modelCache.all() | ||||
| 	for _, mi := range models { | ||||
| 		for _, fi := range mi.fields.fieldsReverse { | ||||
| 			switch fi.fieldType { | ||||
| 			case RelReverseOne: | ||||
| 				found := false | ||||
| 			mForA: | ||||
| 				for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { | ||||
| 					if ffi.relModelInfo == mi { | ||||
| 						found = true | ||||
| 						fi.reverseField = ffi.name | ||||
| 						fi.reverseFieldInfo = ffi | ||||
| 
 | ||||
| 						ffi.reverseField = fi.name | ||||
| 						ffi.reverseFieldInfo = fi | ||||
| 						break mForA | ||||
| 					} | ||||
| 				} | ||||
| 				if !found { | ||||
| 					err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) | ||||
| 					goto end | ||||
| 				} | ||||
| 			case RelReverseMany: | ||||
| 				found := false | ||||
| 			mForB: | ||||
| 				for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { | ||||
| 					if ffi.relModelInfo == mi { | ||||
| 						found = true | ||||
| 						fi.reverseField = ffi.name | ||||
| 						fi.reverseFieldInfo = ffi | ||||
| 
 | ||||
| 						ffi.reverseField = fi.name | ||||
| 						ffi.reverseFieldInfo = fi | ||||
| 
 | ||||
| 						break mForB | ||||
| 					} | ||||
| 				} | ||||
| 				if !found { | ||||
| 				mForC: | ||||
| 					for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { | ||||
| 						conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || | ||||
| 							fi.relTable != "" && fi.relTable == ffi.relTable || | ||||
| 							fi.relThrough == "" && fi.relTable == "" | ||||
| 						if ffi.relModelInfo == mi && conditions { | ||||
| 							found = true | ||||
| 
 | ||||
| 							fi.reverseField = ffi.reverseFieldInfoTwo.name | ||||
| 							fi.reverseFieldInfo = ffi.reverseFieldInfoTwo | ||||
| 							fi.relThroughModelInfo = ffi.relThroughModelInfo | ||||
| 							fi.reverseFieldInfoTwo = ffi.reverseFieldInfo | ||||
| 							fi.reverseFieldInfoM2M = ffi | ||||
| 							ffi.reverseFieldInfoM2M = fi | ||||
| 
 | ||||
| 							break mForC | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 				if !found { | ||||
| 					err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) | ||||
| 					goto end | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| end: | ||||
| 	if err != nil { | ||||
| 		fmt.Println(err) | ||||
| 		debug.PrintStack() | ||||
| 		os.Exit(2) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // RegisterModel register models | ||||
| func RegisterModel(models ...interface{}) { | ||||
| 	if modelCache.done { | ||||
| 		panic(fmt.Errorf("RegisterModel must be run before BootStrap")) | ||||
| 	} | ||||
| 	RegisterModelWithPrefix("", models...) | ||||
| } | ||||
| 
 | ||||
| // RegisterModelWithPrefix register models with a prefix | ||||
| func RegisterModelWithPrefix(prefix string, models ...interface{}) { | ||||
| 	if modelCache.done { | ||||
| 		panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, model := range models { | ||||
| 		registerModel(prefix, model, true) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // RegisterModelWithSuffix register models with a suffix | ||||
| func RegisterModelWithSuffix(suffix string, models ...interface{}) { | ||||
| 	if modelCache.done { | ||||
| 		panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, model := range models { | ||||
| 		registerModel(suffix, model, false) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // BootStrap bootstrap models. | ||||
| // make all model parsed and can not add more models | ||||
| func BootStrap() { | ||||
| 	modelCache.Lock() | ||||
| 	defer modelCache.Unlock() | ||||
| 	if modelCache.done { | ||||
| 		return | ||||
| 	} | ||||
| 	bootStrap() | ||||
| 	modelCache.done = true | ||||
| } | ||||
							
								
								
									
										783
									
								
								pkg/orm/models_fields.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										783
									
								
								pkg/orm/models_fields.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,783 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // Define the Type enum | ||||
| const ( | ||||
| 	TypeBooleanField = 1 << iota | ||||
| 	TypeVarCharField | ||||
| 	TypeCharField | ||||
| 	TypeTextField | ||||
| 	TypeTimeField | ||||
| 	TypeDateField | ||||
| 	TypeDateTimeField | ||||
| 	TypeBitField | ||||
| 	TypeSmallIntegerField | ||||
| 	TypeIntegerField | ||||
| 	TypeBigIntegerField | ||||
| 	TypePositiveBitField | ||||
| 	TypePositiveSmallIntegerField | ||||
| 	TypePositiveIntegerField | ||||
| 	TypePositiveBigIntegerField | ||||
| 	TypeFloatField | ||||
| 	TypeDecimalField | ||||
| 	TypeJSONField | ||||
| 	TypeJsonbField | ||||
| 	RelForeignKey | ||||
| 	RelOneToOne | ||||
| 	RelManyToMany | ||||
| 	RelReverseOne | ||||
| 	RelReverseMany | ||||
| ) | ||||
| 
 | ||||
| // Define some logic enum | ||||
| const ( | ||||
| 	IsIntegerField         = ^-TypePositiveBigIntegerField >> 6 << 7 | ||||
| 	IsPositiveIntegerField = ^-TypePositiveBigIntegerField >> 10 << 11 | ||||
| 	IsRelField             = ^-RelReverseMany >> 18 << 19 | ||||
| 	IsFieldType            = ^-RelReverseMany<<1 + 1 | ||||
| ) | ||||
| 
 | ||||
| // BooleanField A true/false field. | ||||
| type BooleanField bool | ||||
| 
 | ||||
| // Value return the BooleanField | ||||
| func (e BooleanField) Value() bool { | ||||
| 	return bool(e) | ||||
| } | ||||
| 
 | ||||
| // Set will set the BooleanField | ||||
| func (e *BooleanField) Set(d bool) { | ||||
| 	*e = BooleanField(d) | ||||
| } | ||||
| 
 | ||||
| // String format the Bool to string | ||||
| func (e *BooleanField) String() string { | ||||
| 	return strconv.FormatBool(e.Value()) | ||||
| } | ||||
| 
 | ||||
| // FieldType return BooleanField the type | ||||
| func (e *BooleanField) FieldType() int { | ||||
| 	return TypeBooleanField | ||||
| } | ||||
| 
 | ||||
| // SetRaw set the interface to bool | ||||
| func (e *BooleanField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case bool: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := StrTo(d).Bool() | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<BooleanField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return the current value | ||||
| func (e *BooleanField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify the BooleanField implement the Fielder interface | ||||
| var _ Fielder = new(BooleanField) | ||||
| 
 | ||||
| // CharField A string field | ||||
| // required values tag: size | ||||
| // The size is enforced at the database level and in models’s validation. | ||||
| // eg: `orm:"size(120)"` | ||||
| type CharField string | ||||
| 
 | ||||
| // Value return the CharField's Value | ||||
| func (e CharField) Value() string { | ||||
| 	return string(e) | ||||
| } | ||||
| 
 | ||||
| // Set CharField value | ||||
| func (e *CharField) Set(d string) { | ||||
| 	*e = CharField(d) | ||||
| } | ||||
| 
 | ||||
| // String return the CharField | ||||
| func (e *CharField) String() string { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // FieldType return the enum type | ||||
| func (e *CharField) FieldType() int { | ||||
| 	return TypeVarCharField | ||||
| } | ||||
| 
 | ||||
| // SetRaw set the interface to string | ||||
| func (e *CharField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case string: | ||||
| 		e.Set(d) | ||||
| 	default: | ||||
| 		return fmt.Errorf("<CharField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return the CharField value | ||||
| func (e *CharField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify CharField implement Fielder | ||||
| var _ Fielder = new(CharField) | ||||
| 
 | ||||
| // TimeField A time, represented in go by a time.Time instance. | ||||
| // only time values like 10:00:00 | ||||
| // Has a few extra, optional attr tag: | ||||
| // | ||||
| // auto_now: | ||||
| // Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. | ||||
| // Note that the current date is always used; it’s not just a default value that you can override. | ||||
| // | ||||
| // auto_now_add: | ||||
| // Automatically set the field to now when the object is first created. Useful for creation of timestamps. | ||||
| // Note that the current date is always used; it’s not just a default value that you can override. | ||||
| // | ||||
| // eg: `orm:"auto_now"` or `orm:"auto_now_add"` | ||||
| type TimeField time.Time | ||||
| 
 | ||||
| // Value return the time.Time | ||||
| func (e TimeField) Value() time.Time { | ||||
| 	return time.Time(e) | ||||
| } | ||||
| 
 | ||||
| // Set set the TimeField's value | ||||
| func (e *TimeField) Set(d time.Time) { | ||||
| 	*e = TimeField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert time to string | ||||
| func (e *TimeField) String() string { | ||||
| 	return e.Value().String() | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type Date | ||||
| func (e *TimeField) FieldType() int { | ||||
| 	return TypeDateField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert the interface to time.Time. Allow string and time.Time | ||||
| func (e *TimeField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case time.Time: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := timeParse(d, formatTime) | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<TimeField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return time value | ||||
| func (e *TimeField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| var _ Fielder = new(TimeField) | ||||
| 
 | ||||
| // DateField A date, represented in go by a time.Time instance. | ||||
| // only date values like 2006-01-02 | ||||
| // Has a few extra, optional attr tag: | ||||
| // | ||||
| // auto_now: | ||||
| // Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. | ||||
| // Note that the current date is always used; it’s not just a default value that you can override. | ||||
| // | ||||
| // auto_now_add: | ||||
| // Automatically set the field to now when the object is first created. Useful for creation of timestamps. | ||||
| // Note that the current date is always used; it’s not just a default value that you can override. | ||||
| // | ||||
| // eg: `orm:"auto_now"` or `orm:"auto_now_add"` | ||||
| type DateField time.Time | ||||
| 
 | ||||
| // Value return the time.Time | ||||
| func (e DateField) Value() time.Time { | ||||
| 	return time.Time(e) | ||||
| } | ||||
| 
 | ||||
| // Set set the DateField's value | ||||
| func (e *DateField) Set(d time.Time) { | ||||
| 	*e = DateField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert datetime to string | ||||
| func (e *DateField) String() string { | ||||
| 	return e.Value().String() | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type Date | ||||
| func (e *DateField) FieldType() int { | ||||
| 	return TypeDateField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert the interface to time.Time. Allow string and time.Time | ||||
| func (e *DateField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case time.Time: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := timeParse(d, formatDate) | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<DateField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return Date value | ||||
| func (e *DateField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify DateField implement fielder interface | ||||
| var _ Fielder = new(DateField) | ||||
| 
 | ||||
| // DateTimeField A date, represented in go by a time.Time instance. | ||||
| // datetime values like 2006-01-02 15:04:05 | ||||
| // Takes the same extra arguments as DateField. | ||||
| type DateTimeField time.Time | ||||
| 
 | ||||
| // Value return the datetime value | ||||
| func (e DateTimeField) Value() time.Time { | ||||
| 	return time.Time(e) | ||||
| } | ||||
| 
 | ||||
| // Set set the time.Time to datetime | ||||
| func (e *DateTimeField) Set(d time.Time) { | ||||
| 	*e = DateTimeField(d) | ||||
| } | ||||
| 
 | ||||
| // String return the time's String | ||||
| func (e *DateTimeField) String() string { | ||||
| 	return e.Value().String() | ||||
| } | ||||
| 
 | ||||
| // FieldType return the enum TypeDateTimeField | ||||
| func (e *DateTimeField) FieldType() int { | ||||
| 	return TypeDateTimeField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert the string or time.Time to DateTimeField | ||||
| func (e *DateTimeField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case time.Time: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := timeParse(d, formatDateTime) | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<DateTimeField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return the datetime value | ||||
| func (e *DateTimeField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify datetime implement fielder | ||||
| var _ Fielder = new(DateTimeField) | ||||
| 
 | ||||
| // FloatField A floating-point number represented in go by a float32 value. | ||||
| type FloatField float64 | ||||
| 
 | ||||
| // Value return the FloatField value | ||||
| func (e FloatField) Value() float64 { | ||||
| 	return float64(e) | ||||
| } | ||||
| 
 | ||||
| // Set the Float64 | ||||
| func (e *FloatField) Set(d float64) { | ||||
| 	*e = FloatField(d) | ||||
| } | ||||
| 
 | ||||
| // String return the string | ||||
| func (e *FloatField) String() string { | ||||
| 	return ToStr(e.Value(), -1, 32) | ||||
| } | ||||
| 
 | ||||
| // FieldType return the enum type | ||||
| func (e *FloatField) FieldType() int { | ||||
| 	return TypeFloatField | ||||
| } | ||||
| 
 | ||||
| // SetRaw converter interface Float64 float32 or string to FloatField | ||||
| func (e *FloatField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case float32: | ||||
| 		e.Set(float64(d)) | ||||
| 	case float64: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := StrTo(d).Float64() | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<FloatField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return the FloatField value | ||||
| func (e *FloatField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify FloatField implement Fielder | ||||
| var _ Fielder = new(FloatField) | ||||
| 
 | ||||
| // SmallIntegerField -32768 to 32767 | ||||
| type SmallIntegerField int16 | ||||
| 
 | ||||
| // Value return int16 value | ||||
| func (e SmallIntegerField) Value() int16 { | ||||
| 	return int16(e) | ||||
| } | ||||
| 
 | ||||
| // Set the SmallIntegerField value | ||||
| func (e *SmallIntegerField) Set(d int16) { | ||||
| 	*e = SmallIntegerField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert smallint to string | ||||
| func (e *SmallIntegerField) String() string { | ||||
| 	return ToStr(e.Value()) | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type SmallIntegerField | ||||
| func (e *SmallIntegerField) FieldType() int { | ||||
| 	return TypeSmallIntegerField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert interface int16/string to int16 | ||||
| func (e *SmallIntegerField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case int16: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := StrTo(d).Int16() | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<SmallIntegerField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return smallint value | ||||
| func (e *SmallIntegerField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify SmallIntegerField implement Fielder | ||||
| var _ Fielder = new(SmallIntegerField) | ||||
| 
 | ||||
| // IntegerField -2147483648 to 2147483647 | ||||
| type IntegerField int32 | ||||
| 
 | ||||
| // Value return the int32 | ||||
| func (e IntegerField) Value() int32 { | ||||
| 	return int32(e) | ||||
| } | ||||
| 
 | ||||
| // Set IntegerField value | ||||
| func (e *IntegerField) Set(d int32) { | ||||
| 	*e = IntegerField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert Int32 to string | ||||
| func (e *IntegerField) String() string { | ||||
| 	return ToStr(e.Value()) | ||||
| } | ||||
| 
 | ||||
| // FieldType return the enum type | ||||
| func (e *IntegerField) FieldType() int { | ||||
| 	return TypeIntegerField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert interface int32/string to int32 | ||||
| func (e *IntegerField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case int32: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := StrTo(d).Int32() | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<IntegerField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return IntegerField value | ||||
| func (e *IntegerField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify IntegerField implement Fielder | ||||
| var _ Fielder = new(IntegerField) | ||||
| 
 | ||||
| // BigIntegerField -9223372036854775808 to 9223372036854775807. | ||||
| type BigIntegerField int64 | ||||
| 
 | ||||
| // Value return int64 | ||||
| func (e BigIntegerField) Value() int64 { | ||||
| 	return int64(e) | ||||
| } | ||||
| 
 | ||||
| // Set the BigIntegerField value | ||||
| func (e *BigIntegerField) Set(d int64) { | ||||
| 	*e = BigIntegerField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert BigIntegerField to string | ||||
| func (e *BigIntegerField) String() string { | ||||
| 	return ToStr(e.Value()) | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type | ||||
| func (e *BigIntegerField) FieldType() int { | ||||
| 	return TypeBigIntegerField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert interface int64/string to int64 | ||||
| func (e *BigIntegerField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case int64: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := StrTo(d).Int64() | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<BigIntegerField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return BigIntegerField value | ||||
| func (e *BigIntegerField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify BigIntegerField implement Fielder | ||||
| var _ Fielder = new(BigIntegerField) | ||||
| 
 | ||||
| // PositiveSmallIntegerField 0 to 65535 | ||||
| type PositiveSmallIntegerField uint16 | ||||
| 
 | ||||
| // Value return uint16 | ||||
| func (e PositiveSmallIntegerField) Value() uint16 { | ||||
| 	return uint16(e) | ||||
| } | ||||
| 
 | ||||
| // Set PositiveSmallIntegerField value | ||||
| func (e *PositiveSmallIntegerField) Set(d uint16) { | ||||
| 	*e = PositiveSmallIntegerField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert uint16 to string | ||||
| func (e *PositiveSmallIntegerField) String() string { | ||||
| 	return ToStr(e.Value()) | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type | ||||
| func (e *PositiveSmallIntegerField) FieldType() int { | ||||
| 	return TypePositiveSmallIntegerField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert Interface uint16/string to uint16 | ||||
| func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case uint16: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := StrTo(d).Uint16() | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue returns PositiveSmallIntegerField value | ||||
| func (e *PositiveSmallIntegerField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify PositiveSmallIntegerField implement Fielder | ||||
| var _ Fielder = new(PositiveSmallIntegerField) | ||||
| 
 | ||||
| // PositiveIntegerField 0 to 4294967295 | ||||
| type PositiveIntegerField uint32 | ||||
| 
 | ||||
| // Value return PositiveIntegerField value. Uint32 | ||||
| func (e PositiveIntegerField) Value() uint32 { | ||||
| 	return uint32(e) | ||||
| } | ||||
| 
 | ||||
| // Set the PositiveIntegerField value | ||||
| func (e *PositiveIntegerField) Set(d uint32) { | ||||
| 	*e = PositiveIntegerField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert PositiveIntegerField to string | ||||
| func (e *PositiveIntegerField) String() string { | ||||
| 	return ToStr(e.Value()) | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type | ||||
| func (e *PositiveIntegerField) FieldType() int { | ||||
| 	return TypePositiveIntegerField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert interface uint32/string to Uint32 | ||||
| func (e *PositiveIntegerField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case uint32: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := StrTo(d).Uint32() | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<PositiveIntegerField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return the PositiveIntegerField Value | ||||
| func (e *PositiveIntegerField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify PositiveIntegerField implement Fielder | ||||
| var _ Fielder = new(PositiveIntegerField) | ||||
| 
 | ||||
| // PositiveBigIntegerField 0 to 18446744073709551615 | ||||
| type PositiveBigIntegerField uint64 | ||||
| 
 | ||||
| // Value return uint64 | ||||
| func (e PositiveBigIntegerField) Value() uint64 { | ||||
| 	return uint64(e) | ||||
| } | ||||
| 
 | ||||
| // Set PositiveBigIntegerField value | ||||
| func (e *PositiveBigIntegerField) Set(d uint64) { | ||||
| 	*e = PositiveBigIntegerField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert PositiveBigIntegerField to string | ||||
| func (e *PositiveBigIntegerField) String() string { | ||||
| 	return ToStr(e.Value()) | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type | ||||
| func (e *PositiveBigIntegerField) FieldType() int { | ||||
| 	return TypePositiveIntegerField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert interface uint64/string to Uint64 | ||||
| func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case uint64: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		v, err := StrTo(d).Uint64() | ||||
| 		if err == nil { | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 		return err | ||||
| 	default: | ||||
| 		return fmt.Errorf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return PositiveBigIntegerField value | ||||
| func (e *PositiveBigIntegerField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify PositiveBigIntegerField implement Fielder | ||||
| var _ Fielder = new(PositiveBigIntegerField) | ||||
| 
 | ||||
| // TextField A large text field. | ||||
| type TextField string | ||||
| 
 | ||||
| // Value return TextField value | ||||
| func (e TextField) Value() string { | ||||
| 	return string(e) | ||||
| } | ||||
| 
 | ||||
| // Set the TextField value | ||||
| func (e *TextField) Set(d string) { | ||||
| 	*e = TextField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert TextField to string | ||||
| func (e *TextField) String() string { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type | ||||
| func (e *TextField) FieldType() int { | ||||
| 	return TypeTextField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert interface string to string | ||||
| func (e *TextField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case string: | ||||
| 		e.Set(d) | ||||
| 	default: | ||||
| 		return fmt.Errorf("<TextField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return TextField value | ||||
| func (e *TextField) RawValue() interface{} { | ||||
| 	return e.Value() | ||||
| } | ||||
| 
 | ||||
| // verify TextField implement Fielder | ||||
| var _ Fielder = new(TextField) | ||||
| 
 | ||||
| // JSONField postgres json field. | ||||
| type JSONField string | ||||
| 
 | ||||
| // Value return JSONField value | ||||
| func (j JSONField) Value() string { | ||||
| 	return string(j) | ||||
| } | ||||
| 
 | ||||
| // Set the JSONField value | ||||
| func (j *JSONField) Set(d string) { | ||||
| 	*j = JSONField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert JSONField to string | ||||
| func (j *JSONField) String() string { | ||||
| 	return j.Value() | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type | ||||
| func (j *JSONField) FieldType() int { | ||||
| 	return TypeJSONField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert interface string to string | ||||
| func (j *JSONField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case string: | ||||
| 		j.Set(d) | ||||
| 	default: | ||||
| 		return fmt.Errorf("<JSONField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return JSONField value | ||||
| func (j *JSONField) RawValue() interface{} { | ||||
| 	return j.Value() | ||||
| } | ||||
| 
 | ||||
| // verify JSONField implement Fielder | ||||
| var _ Fielder = new(JSONField) | ||||
| 
 | ||||
| // JsonbField postgres json field. | ||||
| type JsonbField string | ||||
| 
 | ||||
| // Value return JsonbField value | ||||
| func (j JsonbField) Value() string { | ||||
| 	return string(j) | ||||
| } | ||||
| 
 | ||||
| // Set the JsonbField value | ||||
| func (j *JsonbField) Set(d string) { | ||||
| 	*j = JsonbField(d) | ||||
| } | ||||
| 
 | ||||
| // String convert JsonbField to string | ||||
| func (j *JsonbField) String() string { | ||||
| 	return j.Value() | ||||
| } | ||||
| 
 | ||||
| // FieldType return enum type | ||||
| func (j *JsonbField) FieldType() int { | ||||
| 	return TypeJsonbField | ||||
| } | ||||
| 
 | ||||
| // SetRaw convert interface string to string | ||||
| func (j *JsonbField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case string: | ||||
| 		j.Set(d) | ||||
| 	default: | ||||
| 		return fmt.Errorf("<JsonbField.SetRaw> unknown value `%s`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // RawValue return JsonbField value | ||||
| func (j *JsonbField) RawValue() interface{} { | ||||
| 	return j.Value() | ||||
| } | ||||
| 
 | ||||
| // verify JsonbField implement Fielder | ||||
| var _ Fielder = new(JsonbField) | ||||
							
								
								
									
										473
									
								
								pkg/orm/models_info_f.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										473
									
								
								pkg/orm/models_info_f.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,473 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| var errSkipField = errors.New("skip field") | ||||
| 
 | ||||
| // field info collection | ||||
| type fields struct { | ||||
| 	pk            *fieldInfo | ||||
| 	columns       map[string]*fieldInfo | ||||
| 	fields        map[string]*fieldInfo | ||||
| 	fieldsLow     map[string]*fieldInfo | ||||
| 	fieldsByType  map[int][]*fieldInfo | ||||
| 	fieldsRel     []*fieldInfo | ||||
| 	fieldsReverse []*fieldInfo | ||||
| 	fieldsDB      []*fieldInfo | ||||
| 	rels          []*fieldInfo | ||||
| 	orders        []string | ||||
| 	dbcols        []string | ||||
| } | ||||
| 
 | ||||
| // add field info | ||||
| func (f *fields) Add(fi *fieldInfo) (added bool) { | ||||
| 	if f.fields[fi.name] == nil && f.columns[fi.column] == nil { | ||||
| 		f.columns[fi.column] = fi | ||||
| 		f.fields[fi.name] = fi | ||||
| 		f.fieldsLow[strings.ToLower(fi.name)] = fi | ||||
| 	} else { | ||||
| 		return | ||||
| 	} | ||||
| 	if _, ok := f.fieldsByType[fi.fieldType]; !ok { | ||||
| 		f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) | ||||
| 	} | ||||
| 	f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) | ||||
| 	f.orders = append(f.orders, fi.column) | ||||
| 	if fi.dbcol { | ||||
| 		f.dbcols = append(f.dbcols, fi.column) | ||||
| 		f.fieldsDB = append(f.fieldsDB, fi) | ||||
| 	} | ||||
| 	if fi.rel { | ||||
| 		f.fieldsRel = append(f.fieldsRel, fi) | ||||
| 	} | ||||
| 	if fi.reverse { | ||||
| 		f.fieldsReverse = append(f.fieldsReverse, fi) | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| // get field info by name | ||||
| func (f *fields) GetByName(name string) *fieldInfo { | ||||
| 	return f.fields[name] | ||||
| } | ||||
| 
 | ||||
| // get field info by column name | ||||
| func (f *fields) GetByColumn(column string) *fieldInfo { | ||||
| 	return f.columns[column] | ||||
| } | ||||
| 
 | ||||
| // get field info by string, name is prior | ||||
| func (f *fields) GetByAny(name string) (*fieldInfo, bool) { | ||||
| 	if fi, ok := f.fields[name]; ok { | ||||
| 		return fi, ok | ||||
| 	} | ||||
| 	if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { | ||||
| 		return fi, ok | ||||
| 	} | ||||
| 	if fi, ok := f.columns[name]; ok { | ||||
| 		return fi, ok | ||||
| 	} | ||||
| 	return nil, false | ||||
| } | ||||
| 
 | ||||
| // create new field info collection | ||||
| func newFields() *fields { | ||||
| 	f := new(fields) | ||||
| 	f.fields = make(map[string]*fieldInfo) | ||||
| 	f.fieldsLow = make(map[string]*fieldInfo) | ||||
| 	f.columns = make(map[string]*fieldInfo) | ||||
| 	f.fieldsByType = make(map[int][]*fieldInfo) | ||||
| 	return f | ||||
| } | ||||
| 
 | ||||
| // single field info | ||||
| type fieldInfo struct { | ||||
| 	mi                  *modelInfo | ||||
| 	fieldIndex          []int | ||||
| 	fieldType           int | ||||
| 	dbcol               bool // table column fk and onetoone | ||||
| 	inModel             bool | ||||
| 	name                string | ||||
| 	fullName            string | ||||
| 	column              string | ||||
| 	addrValue           reflect.Value | ||||
| 	sf                  reflect.StructField | ||||
| 	auto                bool | ||||
| 	pk                  bool | ||||
| 	null                bool | ||||
| 	index               bool | ||||
| 	unique              bool | ||||
| 	colDefault          bool  // whether has default tag | ||||
| 	initial             StrTo // store the default value | ||||
| 	size                int | ||||
| 	toText              bool | ||||
| 	autoNow             bool | ||||
| 	autoNowAdd          bool | ||||
| 	rel                 bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true | ||||
| 	reverse             bool | ||||
| 	reverseField        string | ||||
| 	reverseFieldInfo    *fieldInfo | ||||
| 	reverseFieldInfoTwo *fieldInfo | ||||
| 	reverseFieldInfoM2M *fieldInfo | ||||
| 	relTable            string | ||||
| 	relThrough          string | ||||
| 	relThroughModelInfo *modelInfo | ||||
| 	relModelInfo        *modelInfo | ||||
| 	digits              int | ||||
| 	decimals            int | ||||
| 	isFielder           bool // implement Fielder interface | ||||
| 	onDelete            string | ||||
| 	description         string | ||||
| } | ||||
| 
 | ||||
| // new field info | ||||
| func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) { | ||||
| 	var ( | ||||
| 		tag       string | ||||
| 		tagValue  string | ||||
| 		initial   StrTo // store the default value | ||||
| 		fieldType int | ||||
| 		attrs     map[string]bool | ||||
| 		tags      map[string]string | ||||
| 		addrField reflect.Value | ||||
| 	) | ||||
| 
 | ||||
| 	fi = new(fieldInfo) | ||||
| 
 | ||||
| 	// if field which CanAddr is the follow type | ||||
| 	//  A value is addressable if it is an element of a slice, | ||||
| 	//  an element of an addressable array, a field of an | ||||
| 	//  addressable struct, or the result of dereferencing a pointer. | ||||
| 	addrField = field | ||||
| 	if field.CanAddr() && field.Kind() != reflect.Ptr { | ||||
| 		addrField = field.Addr() | ||||
| 		if _, ok := addrField.Interface().(Fielder); !ok { | ||||
| 			if field.Kind() == reflect.Slice { | ||||
| 				addrField = field | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) | ||||
| 
 | ||||
| 	if _, ok := attrs["-"]; ok { | ||||
| 		return nil, errSkipField | ||||
| 	} | ||||
| 
 | ||||
| 	digits := tags["digits"] | ||||
| 	decimals := tags["decimals"] | ||||
| 	size := tags["size"] | ||||
| 	onDelete := tags["on_delete"] | ||||
| 
 | ||||
| 	initial.Clear() | ||||
| 	if v, ok := tags["default"]; ok { | ||||
| 		initial.Set(v) | ||||
| 	} | ||||
| 
 | ||||
| checkType: | ||||
| 	switch f := addrField.Interface().(type) { | ||||
| 	case Fielder: | ||||
| 		fi.isFielder = true | ||||
| 		if field.Kind() == reflect.Ptr { | ||||
| 			err = fmt.Errorf("the model Fielder can not be use ptr") | ||||
| 			goto end | ||||
| 		} | ||||
| 		fieldType = f.FieldType() | ||||
| 		if fieldType&IsRelField > 0 { | ||||
| 			err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42") | ||||
| 			goto end | ||||
| 		} | ||||
| 	default: | ||||
| 		tag = "rel" | ||||
| 		tagValue = tags[tag] | ||||
| 		if tagValue != "" { | ||||
| 			switch tagValue { | ||||
| 			case "fk": | ||||
| 				fieldType = RelForeignKey | ||||
| 				break checkType | ||||
| 			case "one": | ||||
| 				fieldType = RelOneToOne | ||||
| 				break checkType | ||||
| 			case "m2m": | ||||
| 				fieldType = RelManyToMany | ||||
| 				if tv := tags["rel_table"]; tv != "" { | ||||
| 					fi.relTable = tv | ||||
| 				} else if tv := tags["rel_through"]; tv != "" { | ||||
| 					fi.relThrough = tv | ||||
| 				} | ||||
| 				break checkType | ||||
| 			default: | ||||
| 				err = fmt.Errorf("rel only allow these value: fk, one, m2m") | ||||
| 				goto wrongTag | ||||
| 			} | ||||
| 		} | ||||
| 		tag = "reverse" | ||||
| 		tagValue = tags[tag] | ||||
| 		if tagValue != "" { | ||||
| 			switch tagValue { | ||||
| 			case "one": | ||||
| 				fieldType = RelReverseOne | ||||
| 				break checkType | ||||
| 			case "many": | ||||
| 				fieldType = RelReverseMany | ||||
| 				if tv := tags["rel_table"]; tv != "" { | ||||
| 					fi.relTable = tv | ||||
| 				} else if tv := tags["rel_through"]; tv != "" { | ||||
| 					fi.relThrough = tv | ||||
| 				} | ||||
| 				break checkType | ||||
| 			default: | ||||
| 				err = fmt.Errorf("reverse only allow these value: one, many") | ||||
| 				goto wrongTag | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		fieldType, err = getFieldType(addrField) | ||||
| 		if err != nil { | ||||
| 			goto end | ||||
| 		} | ||||
| 		if fieldType == TypeVarCharField { | ||||
| 			switch tags["type"] { | ||||
| 			case "char": | ||||
| 				fieldType = TypeCharField | ||||
| 			case "text": | ||||
| 				fieldType = TypeTextField | ||||
| 			case "json": | ||||
| 				fieldType = TypeJSONField | ||||
| 			case "jsonb": | ||||
| 				fieldType = TypeJsonbField | ||||
| 			} | ||||
| 		} | ||||
| 		if fieldType == TypeFloatField && (digits != "" || decimals != "") { | ||||
| 			fieldType = TypeDecimalField | ||||
| 		} | ||||
| 		if fieldType == TypeDateTimeField && tags["type"] == "date" { | ||||
| 			fieldType = TypeDateField | ||||
| 		} | ||||
| 		if fieldType == TypeTimeField && tags["type"] == "time" { | ||||
| 			fieldType = TypeTimeField | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// check the rel and reverse type | ||||
| 	// rel should Ptr | ||||
| 	// reverse should slice []*struct | ||||
| 	switch fieldType { | ||||
| 	case RelForeignKey, RelOneToOne, RelReverseOne: | ||||
| 		if field.Kind() != reflect.Ptr { | ||||
| 			err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name()) | ||||
| 			goto end | ||||
| 		} | ||||
| 	case RelManyToMany, RelReverseMany: | ||||
| 		if field.Kind() != reflect.Slice { | ||||
| 			err = fmt.Errorf("rel/reverse:many field must be slice") | ||||
| 			goto end | ||||
| 		} else { | ||||
| 			if field.Type().Elem().Kind() != reflect.Ptr { | ||||
| 				err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name()) | ||||
| 				goto end | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if fieldType&IsFieldType == 0 { | ||||
| 		err = fmt.Errorf("wrong field type") | ||||
| 		goto end | ||||
| 	} | ||||
| 
 | ||||
| 	fi.fieldType = fieldType | ||||
| 	fi.name = sf.Name | ||||
| 	fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) | ||||
| 	fi.addrValue = addrField | ||||
| 	fi.sf = sf | ||||
| 	fi.fullName = mi.fullName + mName + "." + sf.Name | ||||
| 
 | ||||
| 	fi.description = tags["description"] | ||||
| 	fi.null = attrs["null"] | ||||
| 	fi.index = attrs["index"] | ||||
| 	fi.auto = attrs["auto"] | ||||
| 	fi.pk = attrs["pk"] | ||||
| 	fi.unique = attrs["unique"] | ||||
| 
 | ||||
| 	// Mark object property if there is attribute "default" in the orm configuration | ||||
| 	if _, ok := tags["default"]; ok { | ||||
| 		fi.colDefault = true | ||||
| 	} | ||||
| 
 | ||||
| 	switch fieldType { | ||||
| 	case RelManyToMany, RelReverseMany, RelReverseOne: | ||||
| 		fi.null = false | ||||
| 		fi.index = false | ||||
| 		fi.auto = false | ||||
| 		fi.pk = false | ||||
| 		fi.unique = false | ||||
| 	default: | ||||
| 		fi.dbcol = true | ||||
| 	} | ||||
| 
 | ||||
| 	switch fieldType { | ||||
| 	case RelForeignKey, RelOneToOne, RelManyToMany: | ||||
| 		fi.rel = true | ||||
| 		if fieldType == RelOneToOne { | ||||
| 			fi.unique = true | ||||
| 		} | ||||
| 	case RelReverseMany, RelReverseOne: | ||||
| 		fi.reverse = true | ||||
| 	} | ||||
| 
 | ||||
| 	if fi.rel && fi.dbcol { | ||||
| 		switch onDelete { | ||||
| 		case odCascade, odDoNothing: | ||||
| 		case odSetDefault: | ||||
| 			if !initial.Exist() { | ||||
| 				err = errors.New("on_delete: set_default need set field a default value") | ||||
| 				goto end | ||||
| 			} | ||||
| 		case odSetNULL: | ||||
| 			if !fi.null { | ||||
| 				err = errors.New("on_delete: set_null need set field null") | ||||
| 				goto end | ||||
| 			} | ||||
| 		default: | ||||
| 			if onDelete == "" { | ||||
| 				onDelete = odCascade | ||||
| 			} else { | ||||
| 				err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) | ||||
| 				goto end | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		fi.onDelete = onDelete | ||||
| 	} | ||||
| 
 | ||||
| 	switch fieldType { | ||||
| 	case TypeBooleanField: | ||||
| 	case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField: | ||||
| 		if size != "" { | ||||
| 			v, e := StrTo(size).Int32() | ||||
| 			if e != nil { | ||||
| 				err = fmt.Errorf("wrong size value `%s`", size) | ||||
| 			} else { | ||||
| 				fi.size = int(v) | ||||
| 			} | ||||
| 		} else { | ||||
| 			fi.size = 255 | ||||
| 			fi.toText = true | ||||
| 		} | ||||
| 	case TypeTextField: | ||||
| 		fi.index = false | ||||
| 		fi.unique = false | ||||
| 	case TypeTimeField, TypeDateField, TypeDateTimeField: | ||||
| 		if attrs["auto_now"] { | ||||
| 			fi.autoNow = true | ||||
| 		} else if attrs["auto_now_add"] { | ||||
| 			fi.autoNowAdd = true | ||||
| 		} | ||||
| 	case TypeFloatField: | ||||
| 	case TypeDecimalField: | ||||
| 		d1 := digits | ||||
| 		d2 := decimals | ||||
| 		v1, er1 := StrTo(d1).Int8() | ||||
| 		v2, er2 := StrTo(d2).Int8() | ||||
| 		if er1 != nil || er2 != nil { | ||||
| 			err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) | ||||
| 			goto end | ||||
| 		} | ||||
| 		fi.digits = int(v1) | ||||
| 		fi.decimals = int(v2) | ||||
| 	default: | ||||
| 		switch { | ||||
| 		case fieldType&IsIntegerField > 0: | ||||
| 		case fieldType&IsRelField > 0: | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if fieldType&IsIntegerField == 0 { | ||||
| 		if fi.auto { | ||||
| 			err = fmt.Errorf("non-integer type cannot set auto") | ||||
| 			goto end | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if fi.auto || fi.pk { | ||||
| 		if fi.auto { | ||||
| 			switch addrField.Elem().Kind() { | ||||
| 			case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: | ||||
| 			default: | ||||
| 				err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind()) | ||||
| 				goto end | ||||
| 			} | ||||
| 			fi.pk = true | ||||
| 		} | ||||
| 		fi.null = false | ||||
| 		fi.index = false | ||||
| 		fi.unique = false | ||||
| 	} | ||||
| 
 | ||||
| 	if fi.unique { | ||||
| 		fi.index = false | ||||
| 	} | ||||
| 
 | ||||
| 	// can not set default for these type | ||||
| 	if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { | ||||
| 		initial.Clear() | ||||
| 	} | ||||
| 
 | ||||
| 	if initial.Exist() { | ||||
| 		v := initial | ||||
| 		switch fieldType { | ||||
| 		case TypeBooleanField: | ||||
| 			_, err = v.Bool() | ||||
| 		case TypeFloatField, TypeDecimalField: | ||||
| 			_, err = v.Float64() | ||||
| 		case TypeBitField: | ||||
| 			_, err = v.Int8() | ||||
| 		case TypeSmallIntegerField: | ||||
| 			_, err = v.Int16() | ||||
| 		case TypeIntegerField: | ||||
| 			_, err = v.Int32() | ||||
| 		case TypeBigIntegerField: | ||||
| 			_, err = v.Int64() | ||||
| 		case TypePositiveBitField: | ||||
| 			_, err = v.Uint8() | ||||
| 		case TypePositiveSmallIntegerField: | ||||
| 			_, err = v.Uint16() | ||||
| 		case TypePositiveIntegerField: | ||||
| 			_, err = v.Uint32() | ||||
| 		case TypePositiveBigIntegerField: | ||||
| 			_, err = v.Uint64() | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			tag, tagValue = "default", tags["default"] | ||||
| 			goto wrongTag | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	fi.initial = initial | ||||
| end: | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return | ||||
| wrongTag: | ||||
| 	return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err) | ||||
| } | ||||
							
								
								
									
										148
									
								
								pkg/orm/models_info_m.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								pkg/orm/models_info_m.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,148 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| ) | ||||
| 
 | ||||
| // single model info | ||||
| type modelInfo struct { | ||||
| 	pkg       string | ||||
| 	name      string | ||||
| 	fullName  string | ||||
| 	table     string | ||||
| 	model     interface{} | ||||
| 	fields    *fields | ||||
| 	manual    bool | ||||
| 	addrField reflect.Value //store the original struct value | ||||
| 	uniques   []string | ||||
| 	isThrough bool | ||||
| } | ||||
| 
 | ||||
| // new model info | ||||
| func newModelInfo(val reflect.Value) (mi *modelInfo) { | ||||
| 	mi = &modelInfo{} | ||||
| 	mi.fields = newFields() | ||||
| 	ind := reflect.Indirect(val) | ||||
| 	mi.addrField = val | ||||
| 	mi.name = ind.Type().Name() | ||||
| 	mi.fullName = getFullName(ind.Type()) | ||||
| 	addModelFields(mi, ind, "", []int{}) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // index: FieldByIndex returns the nested field corresponding to index | ||||
| func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) { | ||||
| 	var ( | ||||
| 		err error | ||||
| 		fi  *fieldInfo | ||||
| 		sf  reflect.StructField | ||||
| 	) | ||||
| 
 | ||||
| 	for i := 0; i < ind.NumField(); i++ { | ||||
| 		field := ind.Field(i) | ||||
| 		sf = ind.Type().Field(i) | ||||
| 		// if the field is unexported skip | ||||
| 		if sf.PkgPath != "" { | ||||
| 			continue | ||||
| 		} | ||||
| 		// add anonymous struct fields | ||||
| 		if sf.Anonymous { | ||||
| 			addModelFields(mi, field, mName+"."+sf.Name, append(index, i)) | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		fi, err = newFieldInfo(mi, field, sf, mName) | ||||
| 		if err == errSkipField { | ||||
| 			err = nil | ||||
| 			continue | ||||
| 		} else if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		//record current field index | ||||
| 		fi.fieldIndex = append(fi.fieldIndex, index...) | ||||
| 		fi.fieldIndex = append(fi.fieldIndex, i) | ||||
| 		fi.mi = mi | ||||
| 		fi.inModel = true | ||||
| 		if !mi.fields.Add(fi) { | ||||
| 			err = fmt.Errorf("duplicate column name: %s", fi.column) | ||||
| 			break | ||||
| 		} | ||||
| 		if fi.pk { | ||||
| 			if mi.fields.pk != nil { | ||||
| 				err = fmt.Errorf("one model must have one pk field only") | ||||
| 				break | ||||
| 			} else { | ||||
| 				mi.fields.pk = fi | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) | ||||
| 		os.Exit(2) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // combine related model info to new model info. | ||||
| // prepare for relation models query. | ||||
| func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) { | ||||
| 	mi = new(modelInfo) | ||||
| 	mi.fields = newFields() | ||||
| 	mi.table = m1.table + "_" + m2.table + "s" | ||||
| 	mi.name = camelString(mi.table) | ||||
| 	mi.fullName = m1.pkg + "." + mi.name | ||||
| 
 | ||||
| 	fa := new(fieldInfo) // pk | ||||
| 	f1 := new(fieldInfo) // m1 table RelForeignKey | ||||
| 	f2 := new(fieldInfo) // m2 table RelForeignKey | ||||
| 	fa.fieldType = TypeBigIntegerField | ||||
| 	fa.auto = true | ||||
| 	fa.pk = true | ||||
| 	fa.dbcol = true | ||||
| 	fa.name = "Id" | ||||
| 	fa.column = "id" | ||||
| 	fa.fullName = mi.fullName + "." + fa.name | ||||
| 
 | ||||
| 	f1.dbcol = true | ||||
| 	f2.dbcol = true | ||||
| 	f1.fieldType = RelForeignKey | ||||
| 	f2.fieldType = RelForeignKey | ||||
| 	f1.name = camelString(m1.table) | ||||
| 	f2.name = camelString(m2.table) | ||||
| 	f1.fullName = mi.fullName + "." + f1.name | ||||
| 	f2.fullName = mi.fullName + "." + f2.name | ||||
| 	f1.column = m1.table + "_id" | ||||
| 	f2.column = m2.table + "_id" | ||||
| 	f1.rel = true | ||||
| 	f2.rel = true | ||||
| 	f1.relTable = m1.table | ||||
| 	f2.relTable = m2.table | ||||
| 	f1.relModelInfo = m1 | ||||
| 	f2.relModelInfo = m2 | ||||
| 	f1.mi = mi | ||||
| 	f2.mi = mi | ||||
| 
 | ||||
| 	mi.fields.Add(fa) | ||||
| 	mi.fields.Add(f1) | ||||
| 	mi.fields.Add(f2) | ||||
| 	mi.fields.pk = fa | ||||
| 
 | ||||
| 	mi.uniques = []string{f1.column, f2.column} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										497
									
								
								pkg/orm/models_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										497
									
								
								pkg/orm/models_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,497 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	_ "github.com/go-sql-driver/mysql" | ||||
| 	_ "github.com/lib/pq" | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| 	// As tidb can't use go get, so disable the tidb testing now | ||||
| 	// _ "github.com/pingcap/tidb" | ||||
| ) | ||||
| 
 | ||||
| // A slice string field. | ||||
| type SliceStringField []string | ||||
| 
 | ||||
| func (e SliceStringField) Value() []string { | ||||
| 	return []string(e) | ||||
| } | ||||
| 
 | ||||
| func (e *SliceStringField) Set(d []string) { | ||||
| 	*e = SliceStringField(d) | ||||
| } | ||||
| 
 | ||||
| func (e *SliceStringField) Add(v string) { | ||||
| 	*e = append(*e, v) | ||||
| } | ||||
| 
 | ||||
| func (e *SliceStringField) String() string { | ||||
| 	return strings.Join(e.Value(), ",") | ||||
| } | ||||
| 
 | ||||
| func (e *SliceStringField) FieldType() int { | ||||
| 	return TypeVarCharField | ||||
| } | ||||
| 
 | ||||
| func (e *SliceStringField) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case []string: | ||||
| 		e.Set(d) | ||||
| 	case string: | ||||
| 		if len(d) > 0 { | ||||
| 			parts := strings.Split(d, ",") | ||||
| 			v := make([]string, 0, len(parts)) | ||||
| 			for _, p := range parts { | ||||
| 				v = append(v, strings.TrimSpace(p)) | ||||
| 			} | ||||
| 			e.Set(v) | ||||
| 		} | ||||
| 	default: | ||||
| 		return fmt.Errorf("<SliceStringField.SetRaw> unknown value `%v`", value) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (e *SliceStringField) RawValue() interface{} { | ||||
| 	return e.String() | ||||
| } | ||||
| 
 | ||||
| var _ Fielder = new(SliceStringField) | ||||
| 
 | ||||
| // A json field. | ||||
| type JSONFieldTest struct { | ||||
| 	Name string | ||||
| 	Data string | ||||
| } | ||||
| 
 | ||||
| func (e *JSONFieldTest) String() string { | ||||
| 	data, _ := json.Marshal(e) | ||||
| 	return string(data) | ||||
| } | ||||
| 
 | ||||
| func (e *JSONFieldTest) FieldType() int { | ||||
| 	return TypeTextField | ||||
| } | ||||
| 
 | ||||
| func (e *JSONFieldTest) SetRaw(value interface{}) error { | ||||
| 	switch d := value.(type) { | ||||
| 	case string: | ||||
| 		return json.Unmarshal([]byte(d), e) | ||||
| 	default: | ||||
| 		return fmt.Errorf("<JSONField.SetRaw> unknown value `%v`", value) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (e *JSONFieldTest) RawValue() interface{} { | ||||
| 	return e.String() | ||||
| } | ||||
| 
 | ||||
| var _ Fielder = new(JSONFieldTest) | ||||
| 
 | ||||
| type Data struct { | ||||
| 	ID       int `orm:"column(id)"` | ||||
| 	Boolean  bool | ||||
| 	Char     string    `orm:"size(50)"` | ||||
| 	Text     string    `orm:"type(text)"` | ||||
| 	JSON     string    `orm:"type(json);default({\"name\":\"json\"})"` | ||||
| 	Jsonb    string    `orm:"type(jsonb)"` | ||||
| 	Time     time.Time `orm:"type(time)"` | ||||
| 	Date     time.Time `orm:"type(date)"` | ||||
| 	DateTime time.Time `orm:"column(datetime)"` | ||||
| 	Byte     byte | ||||
| 	Rune     rune | ||||
| 	Int      int | ||||
| 	Int8     int8 | ||||
| 	Int16    int16 | ||||
| 	Int32    int32 | ||||
| 	Int64    int64 | ||||
| 	Uint     uint | ||||
| 	Uint8    uint8 | ||||
| 	Uint16   uint16 | ||||
| 	Uint32   uint32 | ||||
| 	Uint64   uint64 | ||||
| 	Float32  float32 | ||||
| 	Float64  float64 | ||||
| 	Decimal  float64 `orm:"digits(8);decimals(4)"` | ||||
| } | ||||
| 
 | ||||
| type DataNull struct { | ||||
| 	ID          int             `orm:"column(id)"` | ||||
| 	Boolean     bool            `orm:"null"` | ||||
| 	Char        string          `orm:"null;size(50)"` | ||||
| 	Text        string          `orm:"null;type(text)"` | ||||
| 	JSON        string          `orm:"type(json);null"` | ||||
| 	Jsonb       string          `orm:"type(jsonb);null"` | ||||
| 	Time        time.Time       `orm:"null;type(time)"` | ||||
| 	Date        time.Time       `orm:"null;type(date)"` | ||||
| 	DateTime    time.Time       `orm:"null;column(datetime)"` | ||||
| 	Byte        byte            `orm:"null"` | ||||
| 	Rune        rune            `orm:"null"` | ||||
| 	Int         int             `orm:"null"` | ||||
| 	Int8        int8            `orm:"null"` | ||||
| 	Int16       int16           `orm:"null"` | ||||
| 	Int32       int32           `orm:"null"` | ||||
| 	Int64       int64           `orm:"null"` | ||||
| 	Uint        uint            `orm:"null"` | ||||
| 	Uint8       uint8           `orm:"null"` | ||||
| 	Uint16      uint16          `orm:"null"` | ||||
| 	Uint32      uint32          `orm:"null"` | ||||
| 	Uint64      uint64          `orm:"null"` | ||||
| 	Float32     float32         `orm:"null"` | ||||
| 	Float64     float64         `orm:"null"` | ||||
| 	Decimal     float64         `orm:"digits(8);decimals(4);null"` | ||||
| 	NullString  sql.NullString  `orm:"null"` | ||||
| 	NullBool    sql.NullBool    `orm:"null"` | ||||
| 	NullFloat64 sql.NullFloat64 `orm:"null"` | ||||
| 	NullInt64   sql.NullInt64   `orm:"null"` | ||||
| 	BooleanPtr  *bool           `orm:"null"` | ||||
| 	CharPtr     *string         `orm:"null;size(50)"` | ||||
| 	TextPtr     *string         `orm:"null;type(text)"` | ||||
| 	BytePtr     *byte           `orm:"null"` | ||||
| 	RunePtr     *rune           `orm:"null"` | ||||
| 	IntPtr      *int            `orm:"null"` | ||||
| 	Int8Ptr     *int8           `orm:"null"` | ||||
| 	Int16Ptr    *int16          `orm:"null"` | ||||
| 	Int32Ptr    *int32          `orm:"null"` | ||||
| 	Int64Ptr    *int64          `orm:"null"` | ||||
| 	UintPtr     *uint           `orm:"null"` | ||||
| 	Uint8Ptr    *uint8          `orm:"null"` | ||||
| 	Uint16Ptr   *uint16         `orm:"null"` | ||||
| 	Uint32Ptr   *uint32         `orm:"null"` | ||||
| 	Uint64Ptr   *uint64         `orm:"null"` | ||||
| 	Float32Ptr  *float32        `orm:"null"` | ||||
| 	Float64Ptr  *float64        `orm:"null"` | ||||
| 	DecimalPtr  *float64        `orm:"digits(8);decimals(4);null"` | ||||
| 	TimePtr     *time.Time      `orm:"null;type(time)"` | ||||
| 	DatePtr     *time.Time      `orm:"null;type(date)"` | ||||
| 	DateTimePtr *time.Time      `orm:"null"` | ||||
| } | ||||
| 
 | ||||
| type String string | ||||
| type Boolean bool | ||||
| type Byte byte | ||||
| type Rune rune | ||||
| type Int int | ||||
| type Int8 int8 | ||||
| type Int16 int16 | ||||
| type Int32 int32 | ||||
| type Int64 int64 | ||||
| type Uint uint | ||||
| type Uint8 uint8 | ||||
| type Uint16 uint16 | ||||
| type Uint32 uint32 | ||||
| type Uint64 uint64 | ||||
| type Float32 float64 | ||||
| type Float64 float64 | ||||
| 
 | ||||
| type DataCustom struct { | ||||
| 	ID      int `orm:"column(id)"` | ||||
| 	Boolean Boolean | ||||
| 	Char    string `orm:"size(50)"` | ||||
| 	Text    string `orm:"type(text)"` | ||||
| 	Byte    Byte | ||||
| 	Rune    Rune | ||||
| 	Int     Int | ||||
| 	Int8    Int8 | ||||
| 	Int16   Int16 | ||||
| 	Int32   Int32 | ||||
| 	Int64   Int64 | ||||
| 	Uint    Uint | ||||
| 	Uint8   Uint8 | ||||
| 	Uint16  Uint16 | ||||
| 	Uint32  Uint32 | ||||
| 	Uint64  Uint64 | ||||
| 	Float32 Float32 | ||||
| 	Float64 Float64 | ||||
| 	Decimal Float64 `orm:"digits(8);decimals(4)"` | ||||
| } | ||||
| 
 | ||||
| // only for mysql | ||||
| type UserBig struct { | ||||
| 	ID   uint64 `orm:"column(id)"` | ||||
| 	Name string | ||||
| } | ||||
| 
 | ||||
| type User struct { | ||||
| 	ID           int    `orm:"column(id)"` | ||||
| 	UserName     string `orm:"size(30);unique"` | ||||
| 	Email        string `orm:"size(100)"` | ||||
| 	Password     string `orm:"size(100)"` | ||||
| 	Status       int16  `orm:"column(Status)"` | ||||
| 	IsStaff      bool | ||||
| 	IsActive     bool      `orm:"default(true)"` | ||||
| 	Created      time.Time `orm:"auto_now_add;type(date)"` | ||||
| 	Updated      time.Time `orm:"auto_now"` | ||||
| 	Profile      *Profile  `orm:"null;rel(one);on_delete(set_null)"` | ||||
| 	Posts        []*Post   `orm:"reverse(many)" json:"-"` | ||||
| 	ShouldSkip   string    `orm:"-"` | ||||
| 	Nums         int | ||||
| 	Langs        SliceStringField `orm:"size(100)"` | ||||
| 	Extra        JSONFieldTest    `orm:"type(text)"` | ||||
| 	unexport     bool             `orm:"-"` | ||||
| 	unexportBool bool | ||||
| } | ||||
| 
 | ||||
| func (u *User) TableIndex() [][]string { | ||||
| 	return [][]string{ | ||||
| 		{"Id", "UserName"}, | ||||
| 		{"Id", "Created"}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (u *User) TableUnique() [][]string { | ||||
| 	return [][]string{ | ||||
| 		{"UserName", "Email"}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func NewUser() *User { | ||||
| 	obj := new(User) | ||||
| 	return obj | ||||
| } | ||||
| 
 | ||||
| type Profile struct { | ||||
| 	ID       int `orm:"column(id)"` | ||||
| 	Age      int16 | ||||
| 	Money    float64 | ||||
| 	User     *User `orm:"reverse(one)" json:"-"` | ||||
| 	BestPost *Post `orm:"rel(one);null"` | ||||
| } | ||||
| 
 | ||||
| func (u *Profile) TableName() string { | ||||
| 	return "user_profile" | ||||
| } | ||||
| 
 | ||||
| func NewProfile() *Profile { | ||||
| 	obj := new(Profile) | ||||
| 	return obj | ||||
| } | ||||
| 
 | ||||
| type Post struct { | ||||
| 	ID      int       `orm:"column(id)"` | ||||
| 	User    *User     `orm:"rel(fk)"` | ||||
| 	Title   string    `orm:"size(60)"` | ||||
| 	Content string    `orm:"type(text)"` | ||||
| 	Created time.Time `orm:"auto_now_add"` | ||||
| 	Updated time.Time `orm:"auto_now"` | ||||
| 	Tags    []*Tag    `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` | ||||
| } | ||||
| 
 | ||||
| func (u *Post) TableIndex() [][]string { | ||||
| 	return [][]string{ | ||||
| 		{"Id", "Created"}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func NewPost() *Post { | ||||
| 	obj := new(Post) | ||||
| 	return obj | ||||
| } | ||||
| 
 | ||||
| type Tag struct { | ||||
| 	ID       int     `orm:"column(id)"` | ||||
| 	Name     string  `orm:"size(30)"` | ||||
| 	BestPost *Post   `orm:"rel(one);null"` | ||||
| 	Posts    []*Post `orm:"reverse(many)" json:"-"` | ||||
| } | ||||
| 
 | ||||
| func NewTag() *Tag { | ||||
| 	obj := new(Tag) | ||||
| 	return obj | ||||
| } | ||||
| 
 | ||||
| type PostTags struct { | ||||
| 	ID   int   `orm:"column(id)"` | ||||
| 	Post *Post `orm:"rel(fk)"` | ||||
| 	Tag  *Tag  `orm:"rel(fk)"` | ||||
| } | ||||
| 
 | ||||
| func (m *PostTags) TableName() string { | ||||
| 	return "prefix_post_tags" | ||||
| } | ||||
| 
 | ||||
| type Comment struct { | ||||
| 	ID      int       `orm:"column(id)"` | ||||
| 	Post    *Post     `orm:"rel(fk);column(post)"` | ||||
| 	Content string    `orm:"type(text)"` | ||||
| 	Parent  *Comment  `orm:"null;rel(fk)"` | ||||
| 	Created time.Time `orm:"auto_now_add"` | ||||
| } | ||||
| 
 | ||||
| func NewComment() *Comment { | ||||
| 	obj := new(Comment) | ||||
| 	return obj | ||||
| } | ||||
| 
 | ||||
| type Group struct { | ||||
| 	ID          int `orm:"column(gid);size(32)"` | ||||
| 	Name        string | ||||
| 	Permissions []*Permission `orm:"reverse(many)" json:"-"` | ||||
| } | ||||
| 
 | ||||
| type Permission struct { | ||||
| 	ID     int `orm:"column(id)"` | ||||
| 	Name   string | ||||
| 	Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` | ||||
| } | ||||
| 
 | ||||
| type GroupPermissions struct { | ||||
| 	ID         int         `orm:"column(id)"` | ||||
| 	Group      *Group      `orm:"rel(fk)"` | ||||
| 	Permission *Permission `orm:"rel(fk)"` | ||||
| } | ||||
| 
 | ||||
| type ModelID struct { | ||||
| 	ID int64 | ||||
| } | ||||
| 
 | ||||
| type ModelBase struct { | ||||
| 	ModelID | ||||
| 
 | ||||
| 	Created time.Time `orm:"auto_now_add;type(datetime)"` | ||||
| 	Updated time.Time `orm:"auto_now;type(datetime)"` | ||||
| } | ||||
| 
 | ||||
| type InLine struct { | ||||
| 	// Common Fields | ||||
| 	ModelBase | ||||
| 
 | ||||
| 	// Other Fields | ||||
| 	Name  string `orm:"unique"` | ||||
| 	Email string | ||||
| } | ||||
| 
 | ||||
| func NewInLine() *InLine { | ||||
| 	return new(InLine) | ||||
| } | ||||
| 
 | ||||
| type InLineOneToOne struct { | ||||
| 	// Common Fields | ||||
| 	ModelBase | ||||
| 
 | ||||
| 	Note   string | ||||
| 	InLine *InLine `orm:"rel(fk);column(inline)"` | ||||
| } | ||||
| 
 | ||||
| func NewInLineOneToOne() *InLineOneToOne { | ||||
| 	return new(InLineOneToOne) | ||||
| } | ||||
| 
 | ||||
| type IntegerPk struct { | ||||
| 	ID    int64 `orm:"pk"` | ||||
| 	Value string | ||||
| } | ||||
| 
 | ||||
| type UintPk struct { | ||||
| 	ID   uint32 `orm:"pk"` | ||||
| 	Name string | ||||
| } | ||||
| 
 | ||||
| type PtrPk struct { | ||||
| 	ID       *IntegerPk `orm:"pk;rel(one)"` | ||||
| 	Positive bool | ||||
| } | ||||
| 
 | ||||
| var DBARGS = struct { | ||||
| 	Driver string | ||||
| 	Source string | ||||
| 	Debug  string | ||||
| }{ | ||||
| 	os.Getenv("ORM_DRIVER"), | ||||
| 	os.Getenv("ORM_SOURCE"), | ||||
| 	os.Getenv("ORM_DEBUG"), | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| 	IsMysql    = DBARGS.Driver == "mysql" | ||||
| 	IsSqlite   = DBARGS.Driver == "sqlite3" | ||||
| 	IsPostgres = DBARGS.Driver == "postgres" | ||||
| 	IsTidb     = DBARGS.Driver == "tidb" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	dORM     Ormer | ||||
| 	dDbBaser dbBaser | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	helpinfo = `need driver and source! | ||||
| 
 | ||||
| 	Default DB Drivers. | ||||
| 	 | ||||
| 	  driver: url | ||||
| 	   mysql: https://github.com/go-sql-driver/mysql | ||||
| 	 sqlite3: https://github.com/mattn/go-sqlite3 | ||||
| 	postgres: https://github.com/lib/pq | ||||
| 	tidb: https://github.com/pingcap/tidb | ||||
| 	 | ||||
| 	usage: | ||||
| 	 | ||||
| 	go get -u github.com/astaxie/beego/orm | ||||
| 	go get -u github.com/go-sql-driver/mysql | ||||
| 	go get -u github.com/mattn/go-sqlite3 | ||||
| 	go get -u github.com/lib/pq | ||||
| 	go get -u github.com/pingcap/tidb | ||||
| 	 | ||||
| 	#### MySQL | ||||
| 	mysql -u root -e 'create database orm_test;' | ||||
| 	export ORM_DRIVER=mysql | ||||
| 	export ORM_SOURCE="root:@/orm_test?charset=utf8" | ||||
| 	go test -v github.com/astaxie/beego/orm | ||||
| 	 | ||||
| 	 | ||||
| 	#### Sqlite3 | ||||
| 	export ORM_DRIVER=sqlite3 | ||||
| 	export ORM_SOURCE='file:memory_test?mode=memory' | ||||
| 	go test -v github.com/astaxie/beego/orm | ||||
| 	 | ||||
| 	 | ||||
| 	#### PostgreSQL | ||||
| 	psql -c 'create database orm_test;' -U postgres | ||||
| 	export ORM_DRIVER=postgres | ||||
| 	export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" | ||||
| 	go test -v github.com/astaxie/beego/orm | ||||
| 	 | ||||
| 	#### TiDB | ||||
| 	export ORM_DRIVER=tidb | ||||
| 	export ORM_SOURCE='memory://test/test' | ||||
| 	go test -v github.com/astaxie/beego/orm | ||||
| 	 | ||||
| 	` | ||||
| ) | ||||
| 
 | ||||
| func init() { | ||||
| 	Debug, _ = StrTo(DBARGS.Debug).Bool() | ||||
| 
 | ||||
| 	if DBARGS.Driver == "" || DBARGS.Source == "" { | ||||
| 		fmt.Println(helpinfo) | ||||
| 		os.Exit(2) | ||||
| 	} | ||||
| 
 | ||||
| 	RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) | ||||
| 
 | ||||
| 	alias := getDbAlias("default") | ||||
| 	if alias.Driver == DRMySQL { | ||||
| 		alias.Engine = "INNODB" | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
							
								
								
									
										227
									
								
								pkg/orm/models_utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								pkg/orm/models_utils.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,227 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // 1 is attr | ||||
| // 2 is tag | ||||
| var supportTag = map[string]int{ | ||||
| 	"-":            1, | ||||
| 	"null":         1, | ||||
| 	"index":        1, | ||||
| 	"unique":       1, | ||||
| 	"pk":           1, | ||||
| 	"auto":         1, | ||||
| 	"auto_now":     1, | ||||
| 	"auto_now_add": 1, | ||||
| 	"size":         2, | ||||
| 	"column":       2, | ||||
| 	"default":      2, | ||||
| 	"rel":          2, | ||||
| 	"reverse":      2, | ||||
| 	"rel_table":    2, | ||||
| 	"rel_through":  2, | ||||
| 	"digits":       2, | ||||
| 	"decimals":     2, | ||||
| 	"on_delete":    2, | ||||
| 	"type":         2, | ||||
| 	"description":  2, | ||||
| } | ||||
| 
 | ||||
| // get reflect.Type name with package path. | ||||
| func getFullName(typ reflect.Type) string { | ||||
| 	return typ.PkgPath() + "." + typ.Name() | ||||
| } | ||||
| 
 | ||||
| // getTableName get struct table name. | ||||
| // If the struct implement the TableName, then get the result as tablename | ||||
| // else use the struct name which will apply snakeString. | ||||
| func getTableName(val reflect.Value) string { | ||||
| 	if fun := val.MethodByName("TableName"); fun.IsValid() { | ||||
| 		vals := fun.Call([]reflect.Value{}) | ||||
| 		// has return and the first val is string | ||||
| 		if len(vals) > 0 && vals[0].Kind() == reflect.String { | ||||
| 			return vals[0].String() | ||||
| 		} | ||||
| 	} | ||||
| 	return snakeString(reflect.Indirect(val).Type().Name()) | ||||
| } | ||||
| 
 | ||||
| // get table engine, myisam or innodb. | ||||
| func getTableEngine(val reflect.Value) string { | ||||
| 	fun := val.MethodByName("TableEngine") | ||||
| 	if fun.IsValid() { | ||||
| 		vals := fun.Call([]reflect.Value{}) | ||||
| 		if len(vals) > 0 && vals[0].Kind() == reflect.String { | ||||
| 			return vals[0].String() | ||||
| 		} | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // get table index from method. | ||||
| func getTableIndex(val reflect.Value) [][]string { | ||||
| 	fun := val.MethodByName("TableIndex") | ||||
| 	if fun.IsValid() { | ||||
| 		vals := fun.Call([]reflect.Value{}) | ||||
| 		if len(vals) > 0 && vals[0].CanInterface() { | ||||
| 			if d, ok := vals[0].Interface().([][]string); ok { | ||||
| 				return d | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // get table unique from method | ||||
| func getTableUnique(val reflect.Value) [][]string { | ||||
| 	fun := val.MethodByName("TableUnique") | ||||
| 	if fun.IsValid() { | ||||
| 		vals := fun.Call([]reflect.Value{}) | ||||
| 		if len(vals) > 0 && vals[0].CanInterface() { | ||||
| 			if d, ok := vals[0].Interface().([][]string); ok { | ||||
| 				return d | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // get snaked column name | ||||
| func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { | ||||
| 	column := col | ||||
| 	if col == "" { | ||||
| 		column = nameStrategyMap[nameStrategy](sf.Name) | ||||
| 	} | ||||
| 	switch ft { | ||||
| 	case RelForeignKey, RelOneToOne: | ||||
| 		if len(col) == 0 { | ||||
| 			column = column + "_id" | ||||
| 		} | ||||
| 	case RelManyToMany, RelReverseMany, RelReverseOne: | ||||
| 		column = sf.Name | ||||
| 	} | ||||
| 	return column | ||||
| } | ||||
| 
 | ||||
| // return field type as type constant from reflect.Value | ||||
| func getFieldType(val reflect.Value) (ft int, err error) { | ||||
| 	switch val.Type() { | ||||
| 	case reflect.TypeOf(new(int8)): | ||||
| 		ft = TypeBitField | ||||
| 	case reflect.TypeOf(new(int16)): | ||||
| 		ft = TypeSmallIntegerField | ||||
| 	case reflect.TypeOf(new(int32)), | ||||
| 		reflect.TypeOf(new(int)): | ||||
| 		ft = TypeIntegerField | ||||
| 	case reflect.TypeOf(new(int64)): | ||||
| 		ft = TypeBigIntegerField | ||||
| 	case reflect.TypeOf(new(uint8)): | ||||
| 		ft = TypePositiveBitField | ||||
| 	case reflect.TypeOf(new(uint16)): | ||||
| 		ft = TypePositiveSmallIntegerField | ||||
| 	case reflect.TypeOf(new(uint32)), | ||||
| 		reflect.TypeOf(new(uint)): | ||||
| 		ft = TypePositiveIntegerField | ||||
| 	case reflect.TypeOf(new(uint64)): | ||||
| 		ft = TypePositiveBigIntegerField | ||||
| 	case reflect.TypeOf(new(float32)), | ||||
| 		reflect.TypeOf(new(float64)): | ||||
| 		ft = TypeFloatField | ||||
| 	case reflect.TypeOf(new(bool)): | ||||
| 		ft = TypeBooleanField | ||||
| 	case reflect.TypeOf(new(string)): | ||||
| 		ft = TypeVarCharField | ||||
| 	case reflect.TypeOf(new(time.Time)): | ||||
| 		ft = TypeDateTimeField | ||||
| 	default: | ||||
| 		elm := reflect.Indirect(val) | ||||
| 		switch elm.Kind() { | ||||
| 		case reflect.Int8: | ||||
| 			ft = TypeBitField | ||||
| 		case reflect.Int16: | ||||
| 			ft = TypeSmallIntegerField | ||||
| 		case reflect.Int32, reflect.Int: | ||||
| 			ft = TypeIntegerField | ||||
| 		case reflect.Int64: | ||||
| 			ft = TypeBigIntegerField | ||||
| 		case reflect.Uint8: | ||||
| 			ft = TypePositiveBitField | ||||
| 		case reflect.Uint16: | ||||
| 			ft = TypePositiveSmallIntegerField | ||||
| 		case reflect.Uint32, reflect.Uint: | ||||
| 			ft = TypePositiveIntegerField | ||||
| 		case reflect.Uint64: | ||||
| 			ft = TypePositiveBigIntegerField | ||||
| 		case reflect.Float32, reflect.Float64: | ||||
| 			ft = TypeFloatField | ||||
| 		case reflect.Bool: | ||||
| 			ft = TypeBooleanField | ||||
| 		case reflect.String: | ||||
| 			ft = TypeVarCharField | ||||
| 		default: | ||||
| 			if elm.Interface() == nil { | ||||
| 				panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) | ||||
| 			} | ||||
| 			switch elm.Interface().(type) { | ||||
| 			case sql.NullInt64: | ||||
| 				ft = TypeBigIntegerField | ||||
| 			case sql.NullFloat64: | ||||
| 				ft = TypeFloatField | ||||
| 			case sql.NullBool: | ||||
| 				ft = TypeBooleanField | ||||
| 			case sql.NullString: | ||||
| 				ft = TypeVarCharField | ||||
| 			case time.Time: | ||||
| 				ft = TypeDateTimeField | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	if ft&IsFieldType == 0 { | ||||
| 		err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // parse struct tag string | ||||
| func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { | ||||
| 	attrs = make(map[string]bool) | ||||
| 	tags = make(map[string]string) | ||||
| 	for _, v := range strings.Split(data, defaultStructTagDelim) { | ||||
| 		if v == "" { | ||||
| 			continue | ||||
| 		} | ||||
| 		v = strings.TrimSpace(v) | ||||
| 		if t := strings.ToLower(v); supportTag[t] == 1 { | ||||
| 			attrs[t] = true | ||||
| 		} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 { | ||||
| 			name := t[:i] | ||||
| 			if supportTag[name] == 2 { | ||||
| 				v = v[i+1 : len(v)-1] | ||||
| 				tags[name] = v | ||||
| 			} | ||||
| 		} else { | ||||
| 			DebugLog.Println("unsupport orm tag", v) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										579
									
								
								pkg/orm/orm.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										579
									
								
								pkg/orm/orm.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,579 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| // +build go1.8 | ||||
| 
 | ||||
| // Package orm provide ORM for MySQL/PostgreSQL/sqlite | ||||
| // Simple Usage | ||||
| // | ||||
| //	package main | ||||
| // | ||||
| //	import ( | ||||
| //		"fmt" | ||||
| //		"github.com/astaxie/beego/orm" | ||||
| //		_ "github.com/go-sql-driver/mysql" // import your used driver | ||||
| //	) | ||||
| // | ||||
| //	// Model Struct | ||||
| //	type User struct { | ||||
| //		Id   int    `orm:"auto"` | ||||
| //		Name string `orm:"size(100)"` | ||||
| //	} | ||||
| // | ||||
| //	func init() { | ||||
| //		orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) | ||||
| //	} | ||||
| // | ||||
| //	func main() { | ||||
| //		o := orm.NewOrm() | ||||
| //		user := User{Name: "slene"} | ||||
| //		// insert | ||||
| //		id, err := o.Insert(&user) | ||||
| //		// update | ||||
| //		user.Name = "astaxie" | ||||
| //		num, err := o.Update(&user) | ||||
| //		// read one | ||||
| //		u := User{Id: user.Id} | ||||
| //		err = o.Read(&u) | ||||
| //		// delete | ||||
| //		num, err = o.Delete(&u) | ||||
| //	} | ||||
| // | ||||
| // more docs: http://beego.me/docs/mvc/model/overview.md | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // DebugQueries define the debug | ||||
| const ( | ||||
| 	DebugQueries = iota | ||||
| ) | ||||
| 
 | ||||
| // Define common vars | ||||
| var ( | ||||
| 	Debug            = false | ||||
| 	DebugLog         = NewLog(os.Stdout) | ||||
| 	DefaultRowsLimit = -1 | ||||
| 	DefaultRelsDepth = 2 | ||||
| 	DefaultTimeLoc   = time.Local | ||||
| 	ErrTxHasBegan    = errors.New("<Ormer.Begin> transaction already begin") | ||||
| 	ErrTxDone        = errors.New("<Ormer.Commit/Rollback> transaction not begin") | ||||
| 	ErrMultiRows     = errors.New("<QuerySeter> return multi rows") | ||||
| 	ErrNoRows        = errors.New("<QuerySeter> no row found") | ||||
| 	ErrStmtClosed    = errors.New("<QuerySeter> stmt already closed") | ||||
| 	ErrArgs          = errors.New("<Ormer> args error may be empty") | ||||
| 	ErrNotImplement  = errors.New("have not implement") | ||||
| ) | ||||
| 
 | ||||
| // Params stores the Params | ||||
| type Params map[string]interface{} | ||||
| 
 | ||||
| // ParamsList stores paramslist | ||||
| type ParamsList []interface{} | ||||
| 
 | ||||
| type orm struct { | ||||
| 	alias *alias | ||||
| 	db    dbQuerier | ||||
| 	isTx  bool | ||||
| } | ||||
| 
 | ||||
| var _ Ormer = new(orm) | ||||
| 
 | ||||
| // get model info and model reflect value | ||||
| func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { | ||||
| 	val := reflect.ValueOf(md) | ||||
| 	ind = reflect.Indirect(val) | ||||
| 	typ := ind.Type() | ||||
| 	if needPtr && val.Kind() != reflect.Ptr { | ||||
| 		panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ))) | ||||
| 	} | ||||
| 	name := getFullName(typ) | ||||
| 	if mi, ok := modelCache.getByFullName(name); ok { | ||||
| 		return mi, ind | ||||
| 	} | ||||
| 	panic(fmt.Errorf("<Ormer> table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) | ||||
| } | ||||
| 
 | ||||
| // get field info from model info by given field name | ||||
| func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { | ||||
| 	fi, ok := mi.fields.GetByAny(name) | ||||
| 	if !ok { | ||||
| 		panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName)) | ||||
| 	} | ||||
| 	return fi | ||||
| } | ||||
| 
 | ||||
| // read data to model | ||||
| func (o *orm) Read(md interface{}, cols ...string) error { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) | ||||
| } | ||||
| 
 | ||||
| // read data to model, like Read(), but use "SELECT FOR UPDATE" form | ||||
| func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) | ||||
| } | ||||
| 
 | ||||
| // Try to read a row from the database, or insert one if it doesn't exist | ||||
| func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { | ||||
| 	cols = append([]string{col1}, cols...) | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) | ||||
| 	if err == ErrNoRows { | ||||
| 		// Create | ||||
| 		id, err := o.Insert(md) | ||||
| 		return (err == nil), id, err | ||||
| 	} | ||||
| 
 | ||||
| 	id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) | ||||
| 	if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { | ||||
| 		id = int64(vid.Uint()) | ||||
| 	} else if mi.fields.pk.rel { | ||||
| 		return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) | ||||
| 	} else { | ||||
| 		id = vid.Int() | ||||
| 	} | ||||
| 
 | ||||
| 	return false, id, err | ||||
| } | ||||
| 
 | ||||
| // insert model data to database | ||||
| func (o *orm) Insert(md interface{}) (int64, error) { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) | ||||
| 	if err != nil { | ||||
| 		return id, err | ||||
| 	} | ||||
| 
 | ||||
| 	o.setPk(mi, ind, id) | ||||
| 
 | ||||
| 	return id, nil | ||||
| } | ||||
| 
 | ||||
| // set auto pk field | ||||
| func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { | ||||
| 	if mi.fields.pk.auto { | ||||
| 		if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { | ||||
| 			ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) | ||||
| 		} else { | ||||
| 			ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // insert some models to database | ||||
| func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { | ||||
| 	var cnt int64 | ||||
| 
 | ||||
| 	sind := reflect.Indirect(reflect.ValueOf(mds)) | ||||
| 
 | ||||
| 	switch sind.Kind() { | ||||
| 	case reflect.Array, reflect.Slice: | ||||
| 		if sind.Len() == 0 { | ||||
| 			return cnt, ErrArgs | ||||
| 		} | ||||
| 	default: | ||||
| 		return cnt, ErrArgs | ||||
| 	} | ||||
| 
 | ||||
| 	if bulk <= 1 { | ||||
| 		for i := 0; i < sind.Len(); i++ { | ||||
| 			ind := reflect.Indirect(sind.Index(i)) | ||||
| 			mi, _ := o.getMiInd(ind.Interface(), false) | ||||
| 			id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) | ||||
| 			if err != nil { | ||||
| 				return cnt, err | ||||
| 			} | ||||
| 
 | ||||
| 			o.setPk(mi, ind, id) | ||||
| 
 | ||||
| 			cnt++ | ||||
| 		} | ||||
| 	} else { | ||||
| 		mi, _ := o.getMiInd(sind.Index(0).Interface(), false) | ||||
| 		return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) | ||||
| 	} | ||||
| 	return cnt, nil | ||||
| } | ||||
| 
 | ||||
| // InsertOrUpdate data to database | ||||
| func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) | ||||
| 	if err != nil { | ||||
| 		return id, err | ||||
| 	} | ||||
| 
 | ||||
| 	o.setPk(mi, ind, id) | ||||
| 
 | ||||
| 	return id, nil | ||||
| } | ||||
| 
 | ||||
| // update model to database. | ||||
| // cols set the columns those want to update. | ||||
| func (o *orm) Update(md interface{}, cols ...string) (int64, error) { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) | ||||
| } | ||||
| 
 | ||||
| // delete model in database | ||||
| // cols shows the delete conditions values read from. default is pk | ||||
| func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) | ||||
| 	if err != nil { | ||||
| 		return num, err | ||||
| 	} | ||||
| 	if num > 0 { | ||||
| 		o.setPk(mi, ind, 0) | ||||
| 	} | ||||
| 	return num, nil | ||||
| } | ||||
| 
 | ||||
| // create a models to models queryer | ||||
| func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	fi := o.getFieldInfo(mi, name) | ||||
| 
 | ||||
| 	switch { | ||||
| 	case fi.fieldType == RelManyToMany: | ||||
| 	case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: | ||||
| 	default: | ||||
| 		panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) | ||||
| 	} | ||||
| 
 | ||||
| 	return newQueryM2M(md, o, mi, fi, ind) | ||||
| } | ||||
| 
 | ||||
| // load related models to md model. | ||||
| // args are limit, offset int and order string. | ||||
| // | ||||
| // example: | ||||
| // 	orm.LoadRelated(post,"Tags") | ||||
| // 	for _,tag := range post.Tags{...} | ||||
| // | ||||
| // make sure the relation is defined in model struct tags. | ||||
| func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { | ||||
| 	_, fi, ind, qseter := o.queryRelated(md, name) | ||||
| 
 | ||||
| 	qs := qseter.(*querySet) | ||||
| 
 | ||||
| 	var relDepth int | ||||
| 	var limit, offset int64 | ||||
| 	var order string | ||||
| 	for i, arg := range args { | ||||
| 		switch i { | ||||
| 		case 0: | ||||
| 			if v, ok := arg.(bool); ok { | ||||
| 				if v { | ||||
| 					relDepth = DefaultRelsDepth | ||||
| 				} | ||||
| 			} else if v, ok := arg.(int); ok { | ||||
| 				relDepth = v | ||||
| 			} | ||||
| 		case 1: | ||||
| 			limit = ToInt64(arg) | ||||
| 		case 2: | ||||
| 			offset = ToInt64(arg) | ||||
| 		case 3: | ||||
| 			order, _ = arg.(string) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	switch fi.fieldType { | ||||
| 	case RelOneToOne, RelForeignKey, RelReverseOne: | ||||
| 		limit = 1 | ||||
| 		offset = 0 | ||||
| 	} | ||||
| 
 | ||||
| 	qs.limit = limit | ||||
| 	qs.offset = offset | ||||
| 	qs.relDepth = relDepth | ||||
| 
 | ||||
| 	if len(order) > 0 { | ||||
| 		qs.orders = []string{order} | ||||
| 	} | ||||
| 
 | ||||
| 	find := ind.FieldByIndex(fi.fieldIndex) | ||||
| 
 | ||||
| 	var nums int64 | ||||
| 	var err error | ||||
| 	switch fi.fieldType { | ||||
| 	case RelOneToOne, RelForeignKey, RelReverseOne: | ||||
| 		val := reflect.New(find.Type().Elem()) | ||||
| 		container := val.Interface() | ||||
| 		err = qs.One(container) | ||||
| 		if err == nil { | ||||
| 			find.Set(val) | ||||
| 			nums = 1 | ||||
| 		} | ||||
| 	default: | ||||
| 		nums, err = qs.All(find.Addr().Interface()) | ||||
| 	} | ||||
| 
 | ||||
| 	return nums, err | ||||
| } | ||||
| 
 | ||||
| // return a QuerySeter for related models to md model. | ||||
| // it can do all, update, delete in QuerySeter. | ||||
| // example: | ||||
| // 	qs := orm.QueryRelated(post,"Tag") | ||||
| //  qs.All(&[]*Tag{}) | ||||
| // | ||||
| func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { | ||||
| 	// is this api needed ? | ||||
| 	_, _, _, qs := o.queryRelated(md, name) | ||||
| 	return qs | ||||
| } | ||||
| 
 | ||||
| // get QuerySeter for related models to md model | ||||
| func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { | ||||
| 	mi, ind := o.getMiInd(md, true) | ||||
| 	fi := o.getFieldInfo(mi, name) | ||||
| 
 | ||||
| 	_, _, exist := getExistPk(mi, ind) | ||||
| 	if !exist { | ||||
| 		panic(ErrMissPK) | ||||
| 	} | ||||
| 
 | ||||
| 	var qs *querySet | ||||
| 
 | ||||
| 	switch fi.fieldType { | ||||
| 	case RelOneToOne, RelForeignKey, RelManyToMany: | ||||
| 		if !fi.inModel { | ||||
| 			break | ||||
| 		} | ||||
| 		qs = o.getRelQs(md, mi, fi) | ||||
| 	case RelReverseOne, RelReverseMany: | ||||
| 		if !fi.inModel { | ||||
| 			break | ||||
| 		} | ||||
| 		qs = o.getReverseQs(md, mi, fi) | ||||
| 	} | ||||
| 
 | ||||
| 	if qs == nil { | ||||
| 		panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field", md, name)) | ||||
| 	} | ||||
| 
 | ||||
| 	return mi, fi, ind, qs | ||||
| } | ||||
| 
 | ||||
| // get reverse relation QuerySeter | ||||
| func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { | ||||
| 	switch fi.fieldType { | ||||
| 	case RelReverseOne, RelReverseMany: | ||||
| 	default: | ||||
| 		panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) | ||||
| 	} | ||||
| 
 | ||||
| 	var q *querySet | ||||
| 
 | ||||
| 	if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { | ||||
| 		q = newQuerySet(o, fi.relModelInfo).(*querySet) | ||||
| 		q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) | ||||
| 	} else { | ||||
| 		q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) | ||||
| 		q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) | ||||
| 	} | ||||
| 
 | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| // get relation QuerySeter | ||||
| func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { | ||||
| 	switch fi.fieldType { | ||||
| 	case RelOneToOne, RelForeignKey, RelManyToMany: | ||||
| 	default: | ||||
| 		panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) | ||||
| 	} | ||||
| 
 | ||||
| 	q := newQuerySet(o, fi.relModelInfo).(*querySet) | ||||
| 	q.cond = NewCondition() | ||||
| 
 | ||||
| 	if fi.fieldType == RelManyToMany { | ||||
| 		q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) | ||||
| 	} else { | ||||
| 		q.cond = q.cond.And(fi.reverseFieldInfo.column, md) | ||||
| 	} | ||||
| 
 | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| // return a QuerySeter for table operations. | ||||
| // table name can be string or struct. | ||||
| // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), | ||||
| func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { | ||||
| 	var name string | ||||
| 	if table, ok := ptrStructOrTableName.(string); ok { | ||||
| 		name = nameStrategyMap[defaultNameStrategy](table) | ||||
| 		if mi, ok := modelCache.get(name); ok { | ||||
| 			qs = newQuerySet(o, mi) | ||||
| 		} | ||||
| 	} else { | ||||
| 		name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) | ||||
| 		if mi, ok := modelCache.getByFullName(name); ok { | ||||
| 			qs = newQuerySet(o, mi) | ||||
| 		} | ||||
| 	} | ||||
| 	if qs == nil { | ||||
| 		panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name)) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // switch to another registered database driver by given name. | ||||
| func (o *orm) Using(name string) error { | ||||
| 	if o.isTx { | ||||
| 		panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db")) | ||||
| 	} | ||||
| 	if al, ok := dataBaseCache.get(name); ok { | ||||
| 		o.alias = al | ||||
| 		if Debug { | ||||
| 			o.db = newDbQueryLog(al, al.DB) | ||||
| 		} else { | ||||
| 			o.db = al.DB | ||||
| 		} | ||||
| 	} else { | ||||
| 		return fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", name) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // begin transaction | ||||
| func (o *orm) Begin() error { | ||||
| 	return o.BeginTx(context.Background(), nil) | ||||
| } | ||||
| 
 | ||||
| func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { | ||||
| 	if o.isTx { | ||||
| 		return ErrTxHasBegan | ||||
| 	} | ||||
| 	var tx *sql.Tx | ||||
| 	tx, err := o.db.(txer).BeginTx(ctx, opts) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	o.isTx = true | ||||
| 	if Debug { | ||||
| 		o.db.(*dbQueryLog).SetDB(tx) | ||||
| 	} else { | ||||
| 		o.db = tx | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // commit transaction | ||||
| func (o *orm) Commit() error { | ||||
| 	if !o.isTx { | ||||
| 		return ErrTxDone | ||||
| 	} | ||||
| 	err := o.db.(txEnder).Commit() | ||||
| 	if err == nil { | ||||
| 		o.isTx = false | ||||
| 		o.Using(o.alias.Name) | ||||
| 	} else if err == sql.ErrTxDone { | ||||
| 		return ErrTxDone | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // rollback transaction | ||||
| func (o *orm) Rollback() error { | ||||
| 	if !o.isTx { | ||||
| 		return ErrTxDone | ||||
| 	} | ||||
| 	err := o.db.(txEnder).Rollback() | ||||
| 	if err == nil { | ||||
| 		o.isTx = false | ||||
| 		o.Using(o.alias.Name) | ||||
| 	} else if err == sql.ErrTxDone { | ||||
| 		return ErrTxDone | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| // return a raw query seter for raw sql string. | ||||
| func (o *orm) Raw(query string, args ...interface{}) RawSeter { | ||||
| 	return newRawSet(o, query, args) | ||||
| } | ||||
| 
 | ||||
| // return current using database Driver | ||||
| func (o *orm) Driver() Driver { | ||||
| 	return driver(o.alias.Name) | ||||
| } | ||||
| 
 | ||||
| // return sql.DBStats for current database | ||||
| func (o *orm) DBStats() *sql.DBStats { | ||||
| 	if o.alias != nil && o.alias.DB != nil { | ||||
| 		stats := o.alias.DB.DB.Stats() | ||||
| 		return &stats | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // NewOrm create new orm | ||||
| func NewOrm() Ormer { | ||||
| 	BootStrap() // execute only once | ||||
| 
 | ||||
| 	o := new(orm) | ||||
| 	err := o.Using("default") | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return o | ||||
| } | ||||
| 
 | ||||
| // NewOrmWithDB create a new ormer object with specify *sql.DB for query | ||||
| func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { | ||||
| 	var al *alias | ||||
| 
 | ||||
| 	if dr, ok := drivers[driverName]; ok { | ||||
| 		al = new(alias) | ||||
| 		al.DbBaser = dbBasers[dr] | ||||
| 		al.Driver = dr | ||||
| 	} else { | ||||
| 		return nil, fmt.Errorf("driver name `%s` have not registered", driverName) | ||||
| 	} | ||||
| 
 | ||||
| 	al.Name = aliasName | ||||
| 	al.DriverName = driverName | ||||
| 	al.DB = &DB{ | ||||
| 		RWMutex:        new(sync.RWMutex), | ||||
| 		DB:             db, | ||||
| 		stmtDecorators: newStmtDecoratorLruWithEvict(), | ||||
| 	} | ||||
| 
 | ||||
| 	detectTZ(al) | ||||
| 
 | ||||
| 	o := new(orm) | ||||
| 	o.alias = al | ||||
| 
 | ||||
| 	if Debug { | ||||
| 		o.db = newDbQueryLog(o.alias, db) | ||||
| 	} else { | ||||
| 		o.db = db | ||||
| 	} | ||||
| 
 | ||||
| 	return o, nil | ||||
| } | ||||
							
								
								
									
										153
									
								
								pkg/orm/orm_conds.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								pkg/orm/orm_conds.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,153 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // ExprSep define the expression separation | ||||
| const ( | ||||
| 	ExprSep = "__" | ||||
| ) | ||||
| 
 | ||||
| type condValue struct { | ||||
| 	exprs  []string | ||||
| 	args   []interface{} | ||||
| 	cond   *Condition | ||||
| 	isOr   bool | ||||
| 	isNot  bool | ||||
| 	isCond bool | ||||
| 	isRaw  bool | ||||
| 	sql    string | ||||
| } | ||||
| 
 | ||||
| // Condition struct. | ||||
| // work for WHERE conditions. | ||||
| type Condition struct { | ||||
| 	params []condValue | ||||
| } | ||||
| 
 | ||||
| // NewCondition return new condition struct | ||||
| func NewCondition() *Condition { | ||||
| 	c := &Condition{} | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| // Raw add raw sql to condition | ||||
| func (c Condition) Raw(expr string, sql string) *Condition { | ||||
| 	if len(sql) == 0 { | ||||
| 		panic(fmt.Errorf("<Condition.Raw> sql cannot empty")) | ||||
| 	} | ||||
| 	c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), sql: sql, isRaw: true}) | ||||
| 	return &c | ||||
| } | ||||
| 
 | ||||
| // And add expression to condition | ||||
| func (c Condition) And(expr string, args ...interface{}) *Condition { | ||||
| 	if expr == "" || len(args) == 0 { | ||||
| 		panic(fmt.Errorf("<Condition.And> args cannot empty")) | ||||
| 	} | ||||
| 	c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) | ||||
| 	return &c | ||||
| } | ||||
| 
 | ||||
| // AndNot add NOT expression to condition | ||||
| func (c Condition) AndNot(expr string, args ...interface{}) *Condition { | ||||
| 	if expr == "" || len(args) == 0 { | ||||
| 		panic(fmt.Errorf("<Condition.AndNot> args cannot empty")) | ||||
| 	} | ||||
| 	c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) | ||||
| 	return &c | ||||
| } | ||||
| 
 | ||||
| // AndCond combine a condition to current condition | ||||
| func (c *Condition) AndCond(cond *Condition) *Condition { | ||||
| 	c = c.clone() | ||||
| 	if c == cond { | ||||
| 		panic(fmt.Errorf("<Condition.AndCond> cannot use self as sub cond")) | ||||
| 	} | ||||
| 	if cond != nil { | ||||
| 		c.params = append(c.params, condValue{cond: cond, isCond: true}) | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| // AndNotCond combine a AND NOT condition to current condition | ||||
| func (c *Condition) AndNotCond(cond *Condition) *Condition { | ||||
| 	c = c.clone() | ||||
| 	if c == cond { | ||||
| 		panic(fmt.Errorf("<Condition.AndNotCond> cannot use self as sub cond")) | ||||
| 	} | ||||
| 
 | ||||
| 	if cond != nil { | ||||
| 		c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true}) | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| // Or add OR expression to condition | ||||
| func (c Condition) Or(expr string, args ...interface{}) *Condition { | ||||
| 	if expr == "" || len(args) == 0 { | ||||
| 		panic(fmt.Errorf("<Condition.Or> args cannot empty")) | ||||
| 	} | ||||
| 	c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) | ||||
| 	return &c | ||||
| } | ||||
| 
 | ||||
| // OrNot add OR NOT expression to condition | ||||
| func (c Condition) OrNot(expr string, args ...interface{}) *Condition { | ||||
| 	if expr == "" || len(args) == 0 { | ||||
| 		panic(fmt.Errorf("<Condition.OrNot> args cannot empty")) | ||||
| 	} | ||||
| 	c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) | ||||
| 	return &c | ||||
| } | ||||
| 
 | ||||
| // OrCond combine a OR condition to current condition | ||||
| func (c *Condition) OrCond(cond *Condition) *Condition { | ||||
| 	c = c.clone() | ||||
| 	if c == cond { | ||||
| 		panic(fmt.Errorf("<Condition.OrCond> cannot use self as sub cond")) | ||||
| 	} | ||||
| 	if cond != nil { | ||||
| 		c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true}) | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| // OrNotCond combine a OR NOT condition to current condition | ||||
| func (c *Condition) OrNotCond(cond *Condition) *Condition { | ||||
| 	c = c.clone() | ||||
| 	if c == cond { | ||||
| 		panic(fmt.Errorf("<Condition.OrNotCond> cannot use self as sub cond")) | ||||
| 	} | ||||
| 
 | ||||
| 	if cond != nil { | ||||
| 		c.params = append(c.params, condValue{cond: cond, isCond: true, isNot: true, isOr: true}) | ||||
| 	} | ||||
| 	return c | ||||
| } | ||||
| 
 | ||||
| // IsEmpty check the condition arguments are empty or not. | ||||
| func (c *Condition) IsEmpty() bool { | ||||
| 	return len(c.params) == 0 | ||||
| } | ||||
| 
 | ||||
| // clone clone a condition | ||||
| func (c Condition) clone() *Condition { | ||||
| 	return &c | ||||
| } | ||||
							
								
								
									
										222
									
								
								pkg/orm/orm_log.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										222
									
								
								pkg/orm/orm_log.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,222 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // Log implement the log.Logger | ||||
| type Log struct { | ||||
| 	*log.Logger | ||||
| } | ||||
| 
 | ||||
| //costomer log func | ||||
| var LogFunc func(query map[string]interface{}) | ||||
| 
 | ||||
| // NewLog set io.Writer to create a Logger. | ||||
| func NewLog(out io.Writer) *Log { | ||||
| 	d := new(Log) | ||||
| 	d.Logger = log.New(out, "[ORM]", log.LstdFlags) | ||||
| 	return d | ||||
| } | ||||
| 
 | ||||
| func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { | ||||
| 	var logMap = make(map[string]interface{}) | ||||
| 	sub := time.Now().Sub(t) / 1e5 | ||||
| 	elsp := float64(int(sub)) / 10.0 | ||||
| 	logMap["cost_time"] = elsp | ||||
| 	flag := "  OK" | ||||
| 	if err != nil { | ||||
| 		flag = "FAIL" | ||||
| 	} | ||||
| 	logMap["flag"] = flag | ||||
| 	con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) | ||||
| 	cons := make([]string, 0, len(args)) | ||||
| 	for _, arg := range args { | ||||
| 		cons = append(cons, fmt.Sprintf("%v", arg)) | ||||
| 	} | ||||
| 	if len(cons) > 0 { | ||||
| 		con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `")) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		con += " - " + err.Error() | ||||
| 	} | ||||
| 	logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) | ||||
| 	if LogFunc != nil{ | ||||
| 		LogFunc(logMap) | ||||
| 	} | ||||
| 	DebugLog.Println(con) | ||||
| } | ||||
| 
 | ||||
| // statement query logger struct. | ||||
| // if dev mode, use stmtQueryLog, or use stmtQuerier. | ||||
| type stmtQueryLog struct { | ||||
| 	alias *alias | ||||
| 	query string | ||||
| 	stmt  stmtQuerier | ||||
| } | ||||
| 
 | ||||
| var _ stmtQuerier = new(stmtQueryLog) | ||||
| 
 | ||||
| func (d *stmtQueryLog) Close() error { | ||||
| 	a := time.Now() | ||||
| 	err := d.stmt.Close() | ||||
| 	debugLogQueies(d.alias, "st.Close", d.query, a, err) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { | ||||
| 	a := time.Now() | ||||
| 	res, err := d.stmt.Exec(args...) | ||||
| 	debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) | ||||
| 	return res, err | ||||
| } | ||||
| 
 | ||||
| func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { | ||||
| 	a := time.Now() | ||||
| 	res, err := d.stmt.Query(args...) | ||||
| 	debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) | ||||
| 	return res, err | ||||
| } | ||||
| 
 | ||||
| func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { | ||||
| 	a := time.Now() | ||||
| 	res := d.stmt.QueryRow(args...) | ||||
| 	debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) | ||||
| 	return res | ||||
| } | ||||
| 
 | ||||
| func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { | ||||
| 	d := new(stmtQueryLog) | ||||
| 	d.stmt = stmt | ||||
| 	d.alias = alias | ||||
| 	d.query = query | ||||
| 	return d | ||||
| } | ||||
| 
 | ||||
| // database query logger struct. | ||||
| // if dev mode, use dbQueryLog, or use dbQuerier. | ||||
| type dbQueryLog struct { | ||||
| 	alias *alias | ||||
| 	db    dbQuerier | ||||
| 	tx    txer | ||||
| 	txe   txEnder | ||||
| } | ||||
| 
 | ||||
| var _ dbQuerier = new(dbQueryLog) | ||||
| var _ txer = new(dbQueryLog) | ||||
| var _ txEnder = new(dbQueryLog) | ||||
| 
 | ||||
| func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { | ||||
| 	a := time.Now() | ||||
| 	stmt, err := d.db.Prepare(query) | ||||
| 	debugLogQueies(d.alias, "db.Prepare", query, a, err) | ||||
| 	return stmt, err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { | ||||
| 	a := time.Now() | ||||
| 	stmt, err := d.db.PrepareContext(ctx, query) | ||||
| 	debugLogQueies(d.alias, "db.Prepare", query, a, err) | ||||
| 	return stmt, err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||
| 	a := time.Now() | ||||
| 	res, err := d.db.Exec(query, args...) | ||||
| 	debugLogQueies(d.alias, "db.Exec", query, a, err, args...) | ||||
| 	return res, err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { | ||||
| 	a := time.Now() | ||||
| 	res, err := d.db.ExecContext(ctx, query, args...) | ||||
| 	debugLogQueies(d.alias, "db.Exec", query, a, err, args...) | ||||
| 	return res, err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	a := time.Now() | ||||
| 	res, err := d.db.Query(query, args...) | ||||
| 	debugLogQueies(d.alias, "db.Query", query, a, err, args...) | ||||
| 	return res, err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	a := time.Now() | ||||
| 	res, err := d.db.QueryContext(ctx, query, args...) | ||||
| 	debugLogQueies(d.alias, "db.Query", query, a, err, args...) | ||||
| 	return res, err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { | ||||
| 	a := time.Now() | ||||
| 	res := d.db.QueryRow(query, args...) | ||||
| 	debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) | ||||
| 	return res | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { | ||||
| 	a := time.Now() | ||||
| 	res := d.db.QueryRowContext(ctx, query, args...) | ||||
| 	debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) | ||||
| 	return res | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) Begin() (*sql.Tx, error) { | ||||
| 	a := time.Now() | ||||
| 	tx, err := d.db.(txer).Begin() | ||||
| 	debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err) | ||||
| 	return tx, err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { | ||||
| 	a := time.Now() | ||||
| 	tx, err := d.db.(txer).BeginTx(ctx, opts) | ||||
| 	debugLogQueies(d.alias, "db.BeginTx", "START TRANSACTION", a, err) | ||||
| 	return tx, err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) Commit() error { | ||||
| 	a := time.Now() | ||||
| 	err := d.db.(txEnder).Commit() | ||||
| 	debugLogQueies(d.alias, "tx.Commit", "COMMIT", a, err) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) Rollback() error { | ||||
| 	a := time.Now() | ||||
| 	err := d.db.(txEnder).Rollback() | ||||
| 	debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err) | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (d *dbQueryLog) SetDB(db dbQuerier) { | ||||
| 	d.db = db | ||||
| } | ||||
| 
 | ||||
| func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier { | ||||
| 	d := new(dbQueryLog) | ||||
| 	d.alias = alias | ||||
| 	d.db = db | ||||
| 	return d | ||||
| } | ||||
							
								
								
									
										87
									
								
								pkg/orm/orm_object.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								pkg/orm/orm_object.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,87 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| ) | ||||
| 
 | ||||
| // an insert queryer struct | ||||
| type insertSet struct { | ||||
| 	mi     *modelInfo | ||||
| 	orm    *orm | ||||
| 	stmt   stmtQuerier | ||||
| 	closed bool | ||||
| } | ||||
| 
 | ||||
| var _ Inserter = new(insertSet) | ||||
| 
 | ||||
| // insert model ignore it's registered or not. | ||||
| func (o *insertSet) Insert(md interface{}) (int64, error) { | ||||
| 	if o.closed { | ||||
| 		return 0, ErrStmtClosed | ||||
| 	} | ||||
| 	val := reflect.ValueOf(md) | ||||
| 	ind := reflect.Indirect(val) | ||||
| 	typ := ind.Type() | ||||
| 	name := getFullName(typ) | ||||
| 	if val.Kind() != reflect.Ptr { | ||||
| 		panic(fmt.Errorf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name)) | ||||
| 	} | ||||
| 	if name != o.mi.fullName { | ||||
| 		panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name)) | ||||
| 	} | ||||
| 	id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) | ||||
| 	if err != nil { | ||||
| 		return id, err | ||||
| 	} | ||||
| 	if id > 0 { | ||||
| 		if o.mi.fields.pk.auto { | ||||
| 			if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { | ||||
| 				ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) | ||||
| 			} else { | ||||
| 				ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return id, nil | ||||
| } | ||||
| 
 | ||||
| // close insert queryer statement | ||||
| func (o *insertSet) Close() error { | ||||
| 	if o.closed { | ||||
| 		return ErrStmtClosed | ||||
| 	} | ||||
| 	o.closed = true | ||||
| 	return o.stmt.Close() | ||||
| } | ||||
| 
 | ||||
| // create new insert queryer. | ||||
| func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { | ||||
| 	bi := new(insertSet) | ||||
| 	bi.orm = orm | ||||
| 	bi.mi = mi | ||||
| 	st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if Debug { | ||||
| 		bi.stmt = newStmtQueryLog(orm.alias, st, query) | ||||
| 	} else { | ||||
| 		bi.stmt = st | ||||
| 	} | ||||
| 	return bi, nil | ||||
| } | ||||
							
								
								
									
										140
									
								
								pkg/orm/orm_querym2m.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										140
									
								
								pkg/orm/orm_querym2m.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,140 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import "reflect" | ||||
| 
 | ||||
| // model to model struct | ||||
| type queryM2M struct { | ||||
| 	md  interface{} | ||||
| 	mi  *modelInfo | ||||
| 	fi  *fieldInfo | ||||
| 	qs  *querySet | ||||
| 	ind reflect.Value | ||||
| } | ||||
| 
 | ||||
| // add models to origin models when creating queryM2M. | ||||
| // example: | ||||
| // 	m2m := orm.QueryM2M(post,"Tag") | ||||
| // 	m2m.Add(&Tag1{},&Tag2{}) | ||||
| //  for _,tag := range post.Tags{} | ||||
| // | ||||
| // make sure the relation is defined in post model struct tag. | ||||
| func (o *queryM2M) Add(mds ...interface{}) (int64, error) { | ||||
| 	fi := o.fi | ||||
| 	mi := fi.relThroughModelInfo | ||||
| 	mfi := fi.reverseFieldInfo | ||||
| 	rfi := fi.reverseFieldInfoTwo | ||||
| 
 | ||||
| 	orm := o.qs.orm | ||||
| 	dbase := orm.alias.DbBaser | ||||
| 
 | ||||
| 	var models []interface{} | ||||
| 	var otherValues []interface{} | ||||
| 	var otherNames []string | ||||
| 
 | ||||
| 	for _, colname := range mi.fields.dbcols { | ||||
| 		if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && | ||||
| 			mi.fields.columns[colname] != mi.fields.pk { | ||||
| 			otherNames = append(otherNames, colname) | ||||
| 		} | ||||
| 	} | ||||
| 	for i, md := range mds { | ||||
| 		if reflect.Indirect(reflect.ValueOf(md)).Kind() != reflect.Struct && i > 0 { | ||||
| 			otherValues = append(otherValues, md) | ||||
| 			mds = append(mds[:i], mds[i+1:]...) | ||||
| 		} | ||||
| 	} | ||||
| 	for _, md := range mds { | ||||
| 		val := reflect.ValueOf(md) | ||||
| 		if val.Kind() == reflect.Slice || val.Kind() == reflect.Array { | ||||
| 			for i := 0; i < val.Len(); i++ { | ||||
| 				v := val.Index(i) | ||||
| 				if v.CanInterface() { | ||||
| 					models = append(models, v.Interface()) | ||||
| 				} | ||||
| 			} | ||||
| 		} else { | ||||
| 			models = append(models, md) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	_, v1, exist := getExistPk(o.mi, o.ind) | ||||
| 	if !exist { | ||||
| 		panic(ErrMissPK) | ||||
| 	} | ||||
| 
 | ||||
| 	names := []string{mfi.column, rfi.column} | ||||
| 
 | ||||
| 	values := make([]interface{}, 0, len(models)*2) | ||||
| 	for _, md := range models { | ||||
| 
 | ||||
| 		ind := reflect.Indirect(reflect.ValueOf(md)) | ||||
| 		var v2 interface{} | ||||
| 		if ind.Kind() != reflect.Struct { | ||||
| 			v2 = ind.Interface() | ||||
| 		} else { | ||||
| 			_, v2, exist = getExistPk(fi.relModelInfo, ind) | ||||
| 			if !exist { | ||||
| 				panic(ErrMissPK) | ||||
| 			} | ||||
| 		} | ||||
| 		values = append(values, v1, v2) | ||||
| 
 | ||||
| 	} | ||||
| 	names = append(names, otherNames...) | ||||
| 	values = append(values, otherValues...) | ||||
| 	return dbase.InsertValue(orm.db, mi, true, names, values) | ||||
| } | ||||
| 
 | ||||
| // remove models following the origin model relationship | ||||
| func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { | ||||
| 	fi := o.fi | ||||
| 	qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) | ||||
| 
 | ||||
| 	return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() | ||||
| } | ||||
| 
 | ||||
| // check model is existed in relationship of origin model | ||||
| func (o *queryM2M) Exist(md interface{}) bool { | ||||
| 	fi := o.fi | ||||
| 	return o.qs.Filter(fi.reverseFieldInfo.name, o.md). | ||||
| 		Filter(fi.reverseFieldInfoTwo.name, md).Exist() | ||||
| } | ||||
| 
 | ||||
| // clean all models in related of origin model | ||||
| func (o *queryM2M) Clear() (int64, error) { | ||||
| 	fi := o.fi | ||||
| 	return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() | ||||
| } | ||||
| 
 | ||||
| // count all related models of origin model | ||||
| func (o *queryM2M) Count() (int64, error) { | ||||
| 	fi := o.fi | ||||
| 	return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() | ||||
| } | ||||
| 
 | ||||
| var _ QueryM2Mer = new(queryM2M) | ||||
| 
 | ||||
| // create new M2M queryer. | ||||
| func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { | ||||
| 	qm2m := new(queryM2M) | ||||
| 	qm2m.md = md | ||||
| 	qm2m.mi = mi | ||||
| 	qm2m.fi = fi | ||||
| 	qm2m.ind = ind | ||||
| 	qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) | ||||
| 	return qm2m | ||||
| } | ||||
							
								
								
									
										300
									
								
								pkg/orm/orm_queryset.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										300
									
								
								pkg/orm/orm_queryset.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,300 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| ) | ||||
| 
 | ||||
| type colValue struct { | ||||
| 	value int64 | ||||
| 	opt   operator | ||||
| } | ||||
| 
 | ||||
| type operator int | ||||
| 
 | ||||
| // define Col operations | ||||
| const ( | ||||
| 	ColAdd operator = iota | ||||
| 	ColMinus | ||||
| 	ColMultiply | ||||
| 	ColExcept | ||||
| 	ColBitAnd | ||||
| 	ColBitRShift | ||||
| 	ColBitLShift | ||||
| 	ColBitXOR | ||||
| 	ColBitOr | ||||
| ) | ||||
| 
 | ||||
| // ColValue do the field raw changes. e.g Nums = Nums + 10. usage: | ||||
| // 	Params{ | ||||
| // 		"Nums": ColValue(Col_Add, 10), | ||||
| // 	} | ||||
| func ColValue(opt operator, value interface{}) interface{} { | ||||
| 	switch opt { | ||||
| 	case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift, | ||||
| 		ColBitLShift, ColBitXOR, ColBitOr: | ||||
| 	default: | ||||
| 		panic(fmt.Errorf("orm.ColValue wrong operator")) | ||||
| 	} | ||||
| 	v, err := StrTo(ToStr(value)).Int64() | ||||
| 	if err != nil { | ||||
| 		panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) | ||||
| 	} | ||||
| 	var val colValue | ||||
| 	val.value = v | ||||
| 	val.opt = opt | ||||
| 	return val | ||||
| } | ||||
| 
 | ||||
| // real query struct | ||||
| type querySet struct { | ||||
| 	mi         *modelInfo | ||||
| 	cond       *Condition | ||||
| 	related    []string | ||||
| 	relDepth   int | ||||
| 	limit      int64 | ||||
| 	offset     int64 | ||||
| 	groups     []string | ||||
| 	orders     []string | ||||
| 	distinct   bool | ||||
| 	forupdate  bool | ||||
| 	orm        *orm | ||||
| 	ctx        context.Context | ||||
| 	forContext bool | ||||
| } | ||||
| 
 | ||||
| var _ QuerySeter = new(querySet) | ||||
| 
 | ||||
| // add condition expression to QuerySeter. | ||||
| func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { | ||||
| 	if o.cond == nil { | ||||
| 		o.cond = NewCondition() | ||||
| 	} | ||||
| 	o.cond = o.cond.And(expr, args...) | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // add raw sql to querySeter. | ||||
| func (o querySet) FilterRaw(expr string, sql string) QuerySeter { | ||||
| 	if o.cond == nil { | ||||
| 		o.cond = NewCondition() | ||||
| 	} | ||||
| 	o.cond = o.cond.Raw(expr, sql) | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // add NOT condition to querySeter. | ||||
| func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { | ||||
| 	if o.cond == nil { | ||||
| 		o.cond = NewCondition() | ||||
| 	} | ||||
| 	o.cond = o.cond.AndNot(expr, args...) | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // set offset number | ||||
| func (o *querySet) setOffset(num interface{}) { | ||||
| 	o.offset = ToInt64(num) | ||||
| } | ||||
| 
 | ||||
| // add LIMIT value. | ||||
| // args[0] means offset, e.g. LIMIT num,offset. | ||||
| func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { | ||||
| 	o.limit = ToInt64(limit) | ||||
| 	if len(args) > 0 { | ||||
| 		o.setOffset(args[0]) | ||||
| 	} | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // add OFFSET value | ||||
| func (o querySet) Offset(offset interface{}) QuerySeter { | ||||
| 	o.setOffset(offset) | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // add GROUP expression | ||||
| func (o querySet) GroupBy(exprs ...string) QuerySeter { | ||||
| 	o.groups = exprs | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // add ORDER expression. | ||||
| // "column" means ASC, "-column" means DESC. | ||||
| func (o querySet) OrderBy(exprs ...string) QuerySeter { | ||||
| 	o.orders = exprs | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // add DISTINCT to SELECT | ||||
| func (o querySet) Distinct() QuerySeter { | ||||
| 	o.distinct = true | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // add FOR UPDATE to SELECT | ||||
| func (o querySet) ForUpdate() QuerySeter { | ||||
| 	o.forupdate = true | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // set relation model to query together. | ||||
| // it will query relation models and assign to parent model. | ||||
| func (o querySet) RelatedSel(params ...interface{}) QuerySeter { | ||||
| 	if len(params) == 0 { | ||||
| 		o.relDepth = DefaultRelsDepth | ||||
| 	} else { | ||||
| 		for _, p := range params { | ||||
| 			switch val := p.(type) { | ||||
| 			case string: | ||||
| 				o.related = append(o.related, val) | ||||
| 			case int: | ||||
| 				o.relDepth = val | ||||
| 			default: | ||||
| 				panic(fmt.Errorf("<QuerySeter.RelatedSel> wrong param kind: %v", val)) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // set condition to QuerySeter. | ||||
| func (o querySet) SetCond(cond *Condition) QuerySeter { | ||||
| 	o.cond = cond | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // get condition from QuerySeter | ||||
| func (o querySet) GetCond() *Condition { | ||||
| 	return o.cond | ||||
| } | ||||
| 
 | ||||
| // return QuerySeter execution result number | ||||
| func (o *querySet) Count() (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // check result empty or not after QuerySeter executed | ||||
| func (o *querySet) Exist() bool { | ||||
| 	cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) | ||||
| 	return cnt > 0 | ||||
| } | ||||
| 
 | ||||
| // execute update with parameters | ||||
| func (o *querySet) Update(values Params) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // execute delete | ||||
| func (o *querySet) Delete() (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // return a insert queryer. | ||||
| // it can be used in times. | ||||
| // example: | ||||
| // 	i,err := sq.PrepareInsert() | ||||
| // 	i.Add(&user1{},&user2{}) | ||||
| func (o *querySet) PrepareInsert() (Inserter, error) { | ||||
| 	return newInsertSet(o.orm, o.mi) | ||||
| } | ||||
| 
 | ||||
| // query all data and map to containers. | ||||
| // cols means the columns when querying. | ||||
| func (o *querySet) All(container interface{}, cols ...string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) | ||||
| } | ||||
| 
 | ||||
| // query one row data and map to containers. | ||||
| // cols means the columns when querying. | ||||
| func (o *querySet) One(container interface{}, cols ...string) error { | ||||
| 	o.limit = 1 | ||||
| 	num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if num == 0 { | ||||
| 		return ErrNoRows | ||||
| 	} | ||||
| 
 | ||||
| 	if num > 1 { | ||||
| 		return ErrMultiRows | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // query all data and map to []map[string]interface. | ||||
| // expres means condition expression. | ||||
| // it converts data to []map[column]value. | ||||
| func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // query all data and map to [][]interface | ||||
| // it converts data to [][column_index]value | ||||
| func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // query all data and map to []interface. | ||||
| // it's designed for one row record set, auto change to []value, not [][column]value. | ||||
| func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { | ||||
| 	return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) | ||||
| } | ||||
| 
 | ||||
| // query all rows into map[string]interface with specify key and value column name. | ||||
| // keyCol = "name", valueCol = "value" | ||||
| // table data | ||||
| // name  | value | ||||
| // total | 100 | ||||
| // found | 200 | ||||
| // to map[string]interface{}{ | ||||
| // 	"total": 100, | ||||
| // 	"found": 200, | ||||
| // } | ||||
| func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { | ||||
| 	panic(ErrNotImplement) | ||||
| } | ||||
| 
 | ||||
| // query all rows into struct with specify key and value column name. | ||||
| // keyCol = "name", valueCol = "value" | ||||
| // table data | ||||
| // name  | value | ||||
| // total | 100 | ||||
| // found | 200 | ||||
| // to struct { | ||||
| // 	Total int | ||||
| // 	Found int | ||||
| // } | ||||
| func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { | ||||
| 	panic(ErrNotImplement) | ||||
| } | ||||
| 
 | ||||
| // set context to QuerySeter. | ||||
| func (o querySet) WithContext(ctx context.Context) QuerySeter { | ||||
| 	o.ctx = ctx | ||||
| 	o.forContext = true | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // create new QuerySeter. | ||||
| func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { | ||||
| 	o := new(querySet) | ||||
| 	o.mi = mi | ||||
| 	o.orm = orm | ||||
| 	return o | ||||
| } | ||||
							
								
								
									
										867
									
								
								pkg/orm/orm_raw.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										867
									
								
								pkg/orm/orm_raw.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,867 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // raw sql string prepared statement | ||||
| type rawPrepare struct { | ||||
| 	rs     *rawSet | ||||
| 	stmt   stmtQuerier | ||||
| 	closed bool | ||||
| } | ||||
| 
 | ||||
| func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { | ||||
| 	if o.closed { | ||||
| 		return nil, ErrStmtClosed | ||||
| 	} | ||||
| 	return o.stmt.Exec(args...) | ||||
| } | ||||
| 
 | ||||
| func (o *rawPrepare) Close() error { | ||||
| 	o.closed = true | ||||
| 	return o.stmt.Close() | ||||
| } | ||||
| 
 | ||||
| func newRawPreparer(rs *rawSet) (RawPreparer, error) { | ||||
| 	o := new(rawPrepare) | ||||
| 	o.rs = rs | ||||
| 
 | ||||
| 	query := rs.query | ||||
| 	rs.orm.alias.DbBaser.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	st, err := rs.orm.db.Prepare(query) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if Debug { | ||||
| 		o.stmt = newStmtQueryLog(rs.orm.alias, st, query) | ||||
| 	} else { | ||||
| 		o.stmt = st | ||||
| 	} | ||||
| 	return o, nil | ||||
| } | ||||
| 
 | ||||
| // raw query seter | ||||
| type rawSet struct { | ||||
| 	query string | ||||
| 	args  []interface{} | ||||
| 	orm   *orm | ||||
| } | ||||
| 
 | ||||
| var _ RawSeter = new(rawSet) | ||||
| 
 | ||||
| // set args for every query | ||||
| func (o rawSet) SetArgs(args ...interface{}) RawSeter { | ||||
| 	o.args = args | ||||
| 	return &o | ||||
| } | ||||
| 
 | ||||
| // execute raw sql and return sql.Result | ||||
| func (o *rawSet) Exec() (sql.Result, error) { | ||||
| 	query := o.query | ||||
| 	o.orm.alias.DbBaser.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	args := getFlatParams(nil, o.args, o.orm.alias.TZ) | ||||
| 	return o.orm.db.Exec(query, args...) | ||||
| } | ||||
| 
 | ||||
| // set field value to row container | ||||
| func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { | ||||
| 	switch ind.Kind() { | ||||
| 	case reflect.Bool: | ||||
| 		if value == nil { | ||||
| 			ind.SetBool(false) | ||||
| 		} else if v, ok := value.(bool); ok { | ||||
| 			ind.SetBool(v) | ||||
| 		} else { | ||||
| 			v, _ := StrTo(ToStr(value)).Bool() | ||||
| 			ind.SetBool(v) | ||||
| 		} | ||||
| 
 | ||||
| 	case reflect.String: | ||||
| 		if value == nil { | ||||
| 			ind.SetString("") | ||||
| 		} else { | ||||
| 			ind.SetString(ToStr(value)) | ||||
| 		} | ||||
| 
 | ||||
| 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 		if value == nil { | ||||
| 			ind.SetInt(0) | ||||
| 		} else { | ||||
| 			val := reflect.ValueOf(value) | ||||
| 			switch val.Kind() { | ||||
| 			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 				ind.SetInt(val.Int()) | ||||
| 			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||
| 				ind.SetInt(int64(val.Uint())) | ||||
| 			default: | ||||
| 				v, _ := StrTo(ToStr(value)).Int64() | ||||
| 				ind.SetInt(v) | ||||
| 			} | ||||
| 		} | ||||
| 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||
| 		if value == nil { | ||||
| 			ind.SetUint(0) | ||||
| 		} else { | ||||
| 			val := reflect.ValueOf(value) | ||||
| 			switch val.Kind() { | ||||
| 			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 				ind.SetUint(uint64(val.Int())) | ||||
| 			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||
| 				ind.SetUint(val.Uint()) | ||||
| 			default: | ||||
| 				v, _ := StrTo(ToStr(value)).Uint64() | ||||
| 				ind.SetUint(v) | ||||
| 			} | ||||
| 		} | ||||
| 	case reflect.Float64, reflect.Float32: | ||||
| 		if value == nil { | ||||
| 			ind.SetFloat(0) | ||||
| 		} else { | ||||
| 			val := reflect.ValueOf(value) | ||||
| 			switch val.Kind() { | ||||
| 			case reflect.Float64: | ||||
| 				ind.SetFloat(val.Float()) | ||||
| 			default: | ||||
| 				v, _ := StrTo(ToStr(value)).Float64() | ||||
| 				ind.SetFloat(v) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 	case reflect.Struct: | ||||
| 		if value == nil { | ||||
| 			ind.Set(reflect.Zero(ind.Type())) | ||||
| 			return | ||||
| 		} | ||||
| 		switch ind.Interface().(type) { | ||||
| 		case time.Time: | ||||
| 			var str string | ||||
| 			switch d := value.(type) { | ||||
| 			case time.Time: | ||||
| 				o.orm.alias.DbBaser.TimeFromDB(&d, o.orm.alias.TZ) | ||||
| 				ind.Set(reflect.ValueOf(d)) | ||||
| 			case []byte: | ||||
| 				str = string(d) | ||||
| 			case string: | ||||
| 				str = d | ||||
| 			} | ||||
| 			if str != "" { | ||||
| 				if len(str) >= 19 { | ||||
| 					str = str[:19] | ||||
| 					t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) | ||||
| 					if err == nil { | ||||
| 						t = t.In(DefaultTimeLoc) | ||||
| 						ind.Set(reflect.ValueOf(t)) | ||||
| 					} | ||||
| 				} else if len(str) >= 10 { | ||||
| 					str = str[:10] | ||||
| 					t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) | ||||
| 					if err == nil { | ||||
| 						ind.Set(reflect.ValueOf(t)) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		case sql.NullString, sql.NullInt64, sql.NullFloat64, sql.NullBool: | ||||
| 			indi := reflect.New(ind.Type()).Interface() | ||||
| 			sc, ok := indi.(sql.Scanner) | ||||
| 			if !ok { | ||||
| 				return | ||||
| 			} | ||||
| 			err := sc.Scan(value) | ||||
| 			if err == nil { | ||||
| 				ind.Set(reflect.Indirect(reflect.ValueOf(sc))) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 	case reflect.Ptr: | ||||
| 		if value == nil { | ||||
| 			ind.Set(reflect.Zero(ind.Type())) | ||||
| 			break | ||||
| 		} | ||||
| 		ind.Set(reflect.New(ind.Type().Elem())) | ||||
| 		o.setFieldValue(reflect.Indirect(ind), value) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // set field value in loop for slice container | ||||
| func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { | ||||
| 	nInds := *nIndsPtr | ||||
| 
 | ||||
| 	cur := 0 | ||||
| 	for i := 0; i < len(sInds); i++ { | ||||
| 		sInd := sInds[i] | ||||
| 		eTyp := eTyps[i] | ||||
| 
 | ||||
| 		typ := eTyp | ||||
| 		isPtr := false | ||||
| 		if typ.Kind() == reflect.Ptr { | ||||
| 			isPtr = true | ||||
| 			typ = typ.Elem() | ||||
| 		} | ||||
| 		if typ.Kind() == reflect.Ptr { | ||||
| 			isPtr = true | ||||
| 			typ = typ.Elem() | ||||
| 		} | ||||
| 
 | ||||
| 		var nInd reflect.Value | ||||
| 		if init { | ||||
| 			nInd = reflect.New(sInd.Type()).Elem() | ||||
| 		} else { | ||||
| 			nInd = nInds[i] | ||||
| 		} | ||||
| 
 | ||||
| 		val := reflect.New(typ) | ||||
| 		ind := val.Elem() | ||||
| 
 | ||||
| 		tpName := ind.Type().String() | ||||
| 
 | ||||
| 		if ind.Kind() == reflect.Struct { | ||||
| 			if tpName == "time.Time" { | ||||
| 				value := reflect.ValueOf(refs[cur]).Elem().Interface() | ||||
| 				if isPtr && value == nil { | ||||
| 					val = reflect.New(val.Type()).Elem() | ||||
| 				} else { | ||||
| 					o.setFieldValue(ind, value) | ||||
| 				} | ||||
| 				cur++ | ||||
| 			} | ||||
| 
 | ||||
| 		} else { | ||||
| 			value := reflect.ValueOf(refs[cur]).Elem().Interface() | ||||
| 			if isPtr && value == nil { | ||||
| 				val = reflect.New(val.Type()).Elem() | ||||
| 			} else { | ||||
| 				o.setFieldValue(ind, value) | ||||
| 			} | ||||
| 			cur++ | ||||
| 		} | ||||
| 
 | ||||
| 		if nInd.Kind() == reflect.Slice { | ||||
| 			if isPtr { | ||||
| 				nInd = reflect.Append(nInd, val) | ||||
| 			} else { | ||||
| 				nInd = reflect.Append(nInd, ind) | ||||
| 			} | ||||
| 		} else { | ||||
| 			if isPtr { | ||||
| 				nInd.Set(val) | ||||
| 			} else { | ||||
| 				nInd.Set(ind) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		nInds[i] = nInd | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // query data and map to container | ||||
| func (o *rawSet) QueryRow(containers ...interface{}) error { | ||||
| 	var ( | ||||
| 		refs  = make([]interface{}, 0, len(containers)) | ||||
| 		sInds []reflect.Value | ||||
| 		eTyps []reflect.Type | ||||
| 		sMi   *modelInfo | ||||
| 	) | ||||
| 	structMode := false | ||||
| 	for _, container := range containers { | ||||
| 		val := reflect.ValueOf(container) | ||||
| 		ind := reflect.Indirect(val) | ||||
| 
 | ||||
| 		if val.Kind() != reflect.Ptr { | ||||
| 			panic(fmt.Errorf("<RawSeter.QueryRow> all args must be use ptr")) | ||||
| 		} | ||||
| 
 | ||||
| 		etyp := ind.Type() | ||||
| 		typ := etyp | ||||
| 		if typ.Kind() == reflect.Ptr { | ||||
| 			typ = typ.Elem() | ||||
| 		} | ||||
| 
 | ||||
| 		sInds = append(sInds, ind) | ||||
| 		eTyps = append(eTyps, etyp) | ||||
| 
 | ||||
| 		if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { | ||||
| 			if len(containers) > 1 { | ||||
| 				panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384")) | ||||
| 			} | ||||
| 
 | ||||
| 			structMode = true | ||||
| 			fn := getFullName(typ) | ||||
| 			if mi, ok := modelCache.getByFullName(fn); ok { | ||||
| 				sMi = mi | ||||
| 			} | ||||
| 		} else { | ||||
| 			var ref interface{} | ||||
| 			refs = append(refs, &ref) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	query := o.query | ||||
| 	o.orm.alias.DbBaser.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	args := getFlatParams(nil, o.args, o.orm.alias.TZ) | ||||
| 	rows, err := o.orm.db.Query(query, args...) | ||||
| 	if err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
| 			return ErrNoRows | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	defer rows.Close() | ||||
| 
 | ||||
| 	if rows.Next() { | ||||
| 		if structMode { | ||||
| 			columns, err := rows.Columns() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			columnsMp := make(map[string]interface{}, len(columns)) | ||||
| 
 | ||||
| 			refs = make([]interface{}, 0, len(columns)) | ||||
| 			for _, col := range columns { | ||||
| 				var ref interface{} | ||||
| 				columnsMp[col] = &ref | ||||
| 				refs = append(refs, &ref) | ||||
| 			} | ||||
| 
 | ||||
| 			if err := rows.Scan(refs...); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			ind := sInds[0] | ||||
| 
 | ||||
| 			if ind.Kind() == reflect.Ptr { | ||||
| 				if ind.IsNil() || !ind.IsValid() { | ||||
| 					ind.Set(reflect.New(eTyps[0].Elem())) | ||||
| 				} | ||||
| 				ind = ind.Elem() | ||||
| 			} | ||||
| 
 | ||||
| 			if sMi != nil { | ||||
| 				for _, col := range columns { | ||||
| 					if fi := sMi.fields.GetByColumn(col); fi != nil { | ||||
| 						value := reflect.ValueOf(columnsMp[col]).Elem().Interface() | ||||
| 						field := ind.FieldByIndex(fi.fieldIndex) | ||||
| 						if fi.fieldType&IsRelField > 0 { | ||||
| 							mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) | ||||
| 							field.Set(mf) | ||||
| 							field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) | ||||
| 						} | ||||
| 						o.setFieldValue(field, value) | ||||
| 					} | ||||
| 				} | ||||
| 			} else { | ||||
| 				for i := 0; i < ind.NumField(); i++ { | ||||
| 					f := ind.Field(i) | ||||
| 					fe := ind.Type().Field(i) | ||||
| 					_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) | ||||
| 					var col string | ||||
| 					if col = tags["column"]; col == "" { | ||||
| 						col = nameStrategyMap[nameStrategy](fe.Name) | ||||
| 					} | ||||
| 					if v, ok := columnsMp[col]; ok { | ||||
| 						value := reflect.ValueOf(v).Elem().Interface() | ||||
| 						o.setFieldValue(f, value) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 		} else { | ||||
| 			if err := rows.Scan(refs...); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			nInds := make([]reflect.Value, len(sInds)) | ||||
| 			o.loopSetRefs(refs, sInds, &nInds, eTyps, true) | ||||
| 			for i, sInd := range sInds { | ||||
| 				nInd := nInds[i] | ||||
| 				sInd.Set(nInd) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 	} else { | ||||
| 		return ErrNoRows | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // query data rows and map to container | ||||
| func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { | ||||
| 	var ( | ||||
| 		refs  = make([]interface{}, 0, len(containers)) | ||||
| 		sInds []reflect.Value | ||||
| 		eTyps []reflect.Type | ||||
| 		sMi   *modelInfo | ||||
| 	) | ||||
| 	structMode := false | ||||
| 	for _, container := range containers { | ||||
| 		val := reflect.ValueOf(container) | ||||
| 		sInd := reflect.Indirect(val) | ||||
| 		if val.Kind() != reflect.Ptr || sInd.Kind() != reflect.Slice { | ||||
| 			panic(fmt.Errorf("<RawSeter.QueryRows> all args must be use ptr slice")) | ||||
| 		} | ||||
| 
 | ||||
| 		etyp := sInd.Type().Elem() | ||||
| 		typ := etyp | ||||
| 		if typ.Kind() == reflect.Ptr { | ||||
| 			typ = typ.Elem() | ||||
| 		} | ||||
| 
 | ||||
| 		sInds = append(sInds, sInd) | ||||
| 		eTyps = append(eTyps, etyp) | ||||
| 
 | ||||
| 		if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { | ||||
| 			if len(containers) > 1 { | ||||
| 				panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384")) | ||||
| 			} | ||||
| 
 | ||||
| 			structMode = true | ||||
| 			fn := getFullName(typ) | ||||
| 			if mi, ok := modelCache.getByFullName(fn); ok { | ||||
| 				sMi = mi | ||||
| 			} | ||||
| 		} else { | ||||
| 			var ref interface{} | ||||
| 			refs = append(refs, &ref) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	query := o.query | ||||
| 	o.orm.alias.DbBaser.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	args := getFlatParams(nil, o.args, o.orm.alias.TZ) | ||||
| 	rows, err := o.orm.db.Query(query, args...) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	defer rows.Close() | ||||
| 
 | ||||
| 	var cnt int64 | ||||
| 	nInds := make([]reflect.Value, len(sInds)) | ||||
| 	sInd := sInds[0] | ||||
| 
 | ||||
| 	for rows.Next() { | ||||
| 
 | ||||
| 		if structMode { | ||||
| 			columns, err := rows.Columns() | ||||
| 			if err != nil { | ||||
| 				return 0, err | ||||
| 			} | ||||
| 
 | ||||
| 			columnsMp := make(map[string]interface{}, len(columns)) | ||||
| 
 | ||||
| 			refs = make([]interface{}, 0, len(columns)) | ||||
| 			for _, col := range columns { | ||||
| 				var ref interface{} | ||||
| 				columnsMp[col] = &ref | ||||
| 				refs = append(refs, &ref) | ||||
| 			} | ||||
| 
 | ||||
| 			if err := rows.Scan(refs...); err != nil { | ||||
| 				return 0, err | ||||
| 			} | ||||
| 
 | ||||
| 			if cnt == 0 && !sInd.IsNil() { | ||||
| 				sInd.Set(reflect.New(sInd.Type()).Elem()) | ||||
| 			} | ||||
| 
 | ||||
| 			var ind reflect.Value | ||||
| 			if eTyps[0].Kind() == reflect.Ptr { | ||||
| 				ind = reflect.New(eTyps[0].Elem()) | ||||
| 			} else { | ||||
| 				ind = reflect.New(eTyps[0]) | ||||
| 			} | ||||
| 
 | ||||
| 			if ind.Kind() == reflect.Ptr { | ||||
| 				ind = ind.Elem() | ||||
| 			} | ||||
| 
 | ||||
| 			if sMi != nil { | ||||
| 				for _, col := range columns { | ||||
| 					if fi := sMi.fields.GetByColumn(col); fi != nil { | ||||
| 						value := reflect.ValueOf(columnsMp[col]).Elem().Interface() | ||||
| 						field := ind.FieldByIndex(fi.fieldIndex) | ||||
| 						if fi.fieldType&IsRelField > 0 { | ||||
| 							mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) | ||||
| 							field.Set(mf) | ||||
| 							field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) | ||||
| 						} | ||||
| 						o.setFieldValue(field, value) | ||||
| 					} | ||||
| 				} | ||||
| 			} else { | ||||
| 				// define recursive function | ||||
| 				var recursiveSetField func(rv reflect.Value) | ||||
| 				recursiveSetField = func(rv reflect.Value) { | ||||
| 					for i := 0; i < rv.NumField(); i++ { | ||||
| 						f := rv.Field(i) | ||||
| 						fe := rv.Type().Field(i) | ||||
| 
 | ||||
| 						// check if the field is a Struct | ||||
| 						// recursive the Struct type | ||||
| 						if fe.Type.Kind() == reflect.Struct { | ||||
| 							recursiveSetField(f) | ||||
| 						} | ||||
| 
 | ||||
| 						_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) | ||||
| 						var col string | ||||
| 						if col = tags["column"]; col == "" { | ||||
| 							col = nameStrategyMap[nameStrategy](fe.Name) | ||||
| 						} | ||||
| 						if v, ok := columnsMp[col]; ok { | ||||
| 							value := reflect.ValueOf(v).Elem().Interface() | ||||
| 							o.setFieldValue(f, value) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				// init call the recursive function | ||||
| 				recursiveSetField(ind) | ||||
| 			} | ||||
| 
 | ||||
| 			if eTyps[0].Kind() == reflect.Ptr { | ||||
| 				ind = ind.Addr() | ||||
| 			} | ||||
| 
 | ||||
| 			sInd = reflect.Append(sInd, ind) | ||||
| 
 | ||||
| 		} else { | ||||
| 			if err := rows.Scan(refs...); err != nil { | ||||
| 				return 0, err | ||||
| 			} | ||||
| 
 | ||||
| 			o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0) | ||||
| 		} | ||||
| 
 | ||||
| 		cnt++ | ||||
| 	} | ||||
| 
 | ||||
| 	if cnt > 0 { | ||||
| 
 | ||||
| 		if structMode { | ||||
| 			sInds[0].Set(sInd) | ||||
| 		} else { | ||||
| 			for i, sInd := range sInds { | ||||
| 				nInd := nInds[i] | ||||
| 				sInd.Set(nInd) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return cnt, nil | ||||
| } | ||||
| 
 | ||||
| func (o *rawSet) readValues(container interface{}, needCols []string) (int64, error) { | ||||
| 	var ( | ||||
| 		maps  []Params | ||||
| 		lists []ParamsList | ||||
| 		list  ParamsList | ||||
| 	) | ||||
| 
 | ||||
| 	typ := 0 | ||||
| 	switch container.(type) { | ||||
| 	case *[]Params: | ||||
| 		typ = 1 | ||||
| 	case *[]ParamsList: | ||||
| 		typ = 2 | ||||
| 	case *ParamsList: | ||||
| 		typ = 3 | ||||
| 	default: | ||||
| 		panic(fmt.Errorf("<RawSeter> unsupport read values type `%T`", container)) | ||||
| 	} | ||||
| 
 | ||||
| 	query := o.query | ||||
| 	o.orm.alias.DbBaser.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	args := getFlatParams(nil, o.args, o.orm.alias.TZ) | ||||
| 
 | ||||
| 	var rs *sql.Rows | ||||
| 	rs, err := o.orm.db.Query(query, args...) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	defer rs.Close() | ||||
| 
 | ||||
| 	var ( | ||||
| 		refs   []interface{} | ||||
| 		cnt    int64 | ||||
| 		cols   []string | ||||
| 		indexs []int | ||||
| 	) | ||||
| 
 | ||||
| 	for rs.Next() { | ||||
| 		if cnt == 0 { | ||||
| 			columns, err := rs.Columns() | ||||
| 			if err != nil { | ||||
| 				return 0, err | ||||
| 			} | ||||
| 			if len(needCols) > 0 { | ||||
| 				indexs = make([]int, 0, len(needCols)) | ||||
| 			} else { | ||||
| 				indexs = make([]int, 0, len(columns)) | ||||
| 			} | ||||
| 
 | ||||
| 			cols = columns | ||||
| 			refs = make([]interface{}, len(cols)) | ||||
| 			for i := range refs { | ||||
| 				var ref sql.NullString | ||||
| 				refs[i] = &ref | ||||
| 
 | ||||
| 				if len(needCols) > 0 { | ||||
| 					for _, c := range needCols { | ||||
| 						if c == cols[i] { | ||||
| 							indexs = append(indexs, i) | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| 					indexs = append(indexs, i) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if err := rs.Scan(refs...); err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 
 | ||||
| 		switch typ { | ||||
| 		case 1: | ||||
| 			params := make(Params, len(cols)) | ||||
| 			for _, i := range indexs { | ||||
| 				ref := refs[i] | ||||
| 				value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) | ||||
| 				if value.Valid { | ||||
| 					params[cols[i]] = value.String | ||||
| 				} else { | ||||
| 					params[cols[i]] = nil | ||||
| 				} | ||||
| 			} | ||||
| 			maps = append(maps, params) | ||||
| 		case 2: | ||||
| 			params := make(ParamsList, 0, len(cols)) | ||||
| 			for _, i := range indexs { | ||||
| 				ref := refs[i] | ||||
| 				value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) | ||||
| 				if value.Valid { | ||||
| 					params = append(params, value.String) | ||||
| 				} else { | ||||
| 					params = append(params, nil) | ||||
| 				} | ||||
| 			} | ||||
| 			lists = append(lists, params) | ||||
| 		case 3: | ||||
| 			for _, i := range indexs { | ||||
| 				ref := refs[i] | ||||
| 				value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString) | ||||
| 				if value.Valid { | ||||
| 					list = append(list, value.String) | ||||
| 				} else { | ||||
| 					list = append(list, nil) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		cnt++ | ||||
| 	} | ||||
| 
 | ||||
| 	switch v := container.(type) { | ||||
| 	case *[]Params: | ||||
| 		*v = maps | ||||
| 	case *[]ParamsList: | ||||
| 		*v = lists | ||||
| 	case *ParamsList: | ||||
| 		*v = list | ||||
| 	} | ||||
| 
 | ||||
| 	return cnt, nil | ||||
| } | ||||
| 
 | ||||
| func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (int64, error) { | ||||
| 	var ( | ||||
| 		maps Params | ||||
| 		ind  *reflect.Value | ||||
| 	) | ||||
| 
 | ||||
| 	var typ int | ||||
| 	switch container.(type) { | ||||
| 	case *Params: | ||||
| 		typ = 1 | ||||
| 	default: | ||||
| 		typ = 2 | ||||
| 		vl := reflect.ValueOf(container) | ||||
| 		id := reflect.Indirect(vl) | ||||
| 		if vl.Kind() != reflect.Ptr || id.Kind() != reflect.Struct { | ||||
| 			panic(fmt.Errorf("<RawSeter> RowsTo unsupport type `%T` need ptr struct", container)) | ||||
| 		} | ||||
| 
 | ||||
| 		ind = &id | ||||
| 	} | ||||
| 
 | ||||
| 	query := o.query | ||||
| 	o.orm.alias.DbBaser.ReplaceMarks(&query) | ||||
| 
 | ||||
| 	args := getFlatParams(nil, o.args, o.orm.alias.TZ) | ||||
| 
 | ||||
| 	rs, err := o.orm.db.Query(query, args...) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	defer rs.Close() | ||||
| 
 | ||||
| 	var ( | ||||
| 		refs []interface{} | ||||
| 		cnt  int64 | ||||
| 		cols []string | ||||
| 	) | ||||
| 
 | ||||
| 	var ( | ||||
| 		keyIndex   = -1 | ||||
| 		valueIndex = -1 | ||||
| 	) | ||||
| 
 | ||||
| 	for rs.Next() { | ||||
| 		if cnt == 0 { | ||||
| 			columns, err := rs.Columns() | ||||
| 			if err != nil { | ||||
| 				return 0, err | ||||
| 			} | ||||
| 			cols = columns | ||||
| 			refs = make([]interface{}, len(cols)) | ||||
| 			for i := range refs { | ||||
| 				if keyCol == cols[i] { | ||||
| 					keyIndex = i | ||||
| 				} | ||||
| 				if typ == 1 || keyIndex == i { | ||||
| 					var ref sql.NullString | ||||
| 					refs[i] = &ref | ||||
| 				} else { | ||||
| 					var ref interface{} | ||||
| 					refs[i] = &ref | ||||
| 				} | ||||
| 				if valueCol == cols[i] { | ||||
| 					valueIndex = i | ||||
| 				} | ||||
| 			} | ||||
| 			if keyIndex == -1 || valueIndex == -1 { | ||||
| 				panic(fmt.Errorf("<RawSeter> RowsTo unknown key, value column name `%s: %s`", keyCol, valueCol)) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if err := rs.Scan(refs...); err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 
 | ||||
| 		if cnt == 0 { | ||||
| 			switch typ { | ||||
| 			case 1: | ||||
| 				maps = make(Params) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		key := reflect.Indirect(reflect.ValueOf(refs[keyIndex])).Interface().(sql.NullString).String | ||||
| 
 | ||||
| 		switch typ { | ||||
| 		case 1: | ||||
| 			value := reflect.Indirect(reflect.ValueOf(refs[valueIndex])).Interface().(sql.NullString) | ||||
| 			if value.Valid { | ||||
| 				maps[key] = value.String | ||||
| 			} else { | ||||
| 				maps[key] = nil | ||||
| 			} | ||||
| 
 | ||||
| 		default: | ||||
| 			if id := ind.FieldByName(camelString(key)); id.IsValid() { | ||||
| 				o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		cnt++ | ||||
| 	} | ||||
| 
 | ||||
| 	if typ == 1 { | ||||
| 		v, _ := container.(*Params) | ||||
| 		*v = maps | ||||
| 	} | ||||
| 
 | ||||
| 	return cnt, nil | ||||
| } | ||||
| 
 | ||||
| // query data to []map[string]interface | ||||
| func (o *rawSet) Values(container *[]Params, cols ...string) (int64, error) { | ||||
| 	return o.readValues(container, cols) | ||||
| } | ||||
| 
 | ||||
| // query data to [][]interface | ||||
| func (o *rawSet) ValuesList(container *[]ParamsList, cols ...string) (int64, error) { | ||||
| 	return o.readValues(container, cols) | ||||
| } | ||||
| 
 | ||||
| // query data to []interface | ||||
| func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error) { | ||||
| 	return o.readValues(container, cols) | ||||
| } | ||||
| 
 | ||||
| // query all rows into map[string]interface with specify key and value column name. | ||||
| // keyCol = "name", valueCol = "value" | ||||
| // table data | ||||
| // name  | value | ||||
| // total | 100 | ||||
| // found | 200 | ||||
| // to map[string]interface{}{ | ||||
| // 	"total": 100, | ||||
| // 	"found": 200, | ||||
| // } | ||||
| func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { | ||||
| 	return o.queryRowsTo(result, keyCol, valueCol) | ||||
| } | ||||
| 
 | ||||
| // query all rows into struct with specify key and value column name. | ||||
| // keyCol = "name", valueCol = "value" | ||||
| // table data | ||||
| // name  | value | ||||
| // total | 100 | ||||
| // found | 200 | ||||
| // to struct { | ||||
| // 	Total int | ||||
| // 	Found int | ||||
| // } | ||||
| func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { | ||||
| 	return o.queryRowsTo(ptrStruct, keyCol, valueCol) | ||||
| } | ||||
| 
 | ||||
| // return prepared raw statement for used in times. | ||||
| func (o *rawSet) Prepare() (RawPreparer, error) { | ||||
| 	return newRawPreparer(o) | ||||
| } | ||||
| 
 | ||||
| func newRawSet(orm *orm, query string, args []interface{}) RawSeter { | ||||
| 	o := new(rawSet) | ||||
| 	o.query = query | ||||
| 	o.args = args | ||||
| 	o.orm = orm | ||||
| 	return o | ||||
| } | ||||
							
								
								
									
										2494
									
								
								pkg/orm/orm_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2494
									
								
								pkg/orm/orm_test.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										62
									
								
								pkg/orm/qb.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								pkg/orm/qb.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,62 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import "errors" | ||||
| 
 | ||||
| // QueryBuilder is the Query builder interface | ||||
| type QueryBuilder interface { | ||||
| 	Select(fields ...string) QueryBuilder | ||||
| 	ForUpdate() QueryBuilder | ||||
| 	From(tables ...string) QueryBuilder | ||||
| 	InnerJoin(table string) QueryBuilder | ||||
| 	LeftJoin(table string) QueryBuilder | ||||
| 	RightJoin(table string) QueryBuilder | ||||
| 	On(cond string) QueryBuilder | ||||
| 	Where(cond string) QueryBuilder | ||||
| 	And(cond string) QueryBuilder | ||||
| 	Or(cond string) QueryBuilder | ||||
| 	In(vals ...string) QueryBuilder | ||||
| 	OrderBy(fields ...string) QueryBuilder | ||||
| 	Asc() QueryBuilder | ||||
| 	Desc() QueryBuilder | ||||
| 	Limit(limit int) QueryBuilder | ||||
| 	Offset(offset int) QueryBuilder | ||||
| 	GroupBy(fields ...string) QueryBuilder | ||||
| 	Having(cond string) QueryBuilder | ||||
| 	Update(tables ...string) QueryBuilder | ||||
| 	Set(kv ...string) QueryBuilder | ||||
| 	Delete(tables ...string) QueryBuilder | ||||
| 	InsertInto(table string, fields ...string) QueryBuilder | ||||
| 	Values(vals ...string) QueryBuilder | ||||
| 	Subquery(sub string, alias string) string | ||||
| 	String() string | ||||
| } | ||||
| 
 | ||||
| // NewQueryBuilder return the QueryBuilder | ||||
| func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { | ||||
| 	if driver == "mysql" { | ||||
| 		qb = new(MySQLQueryBuilder) | ||||
| 	} else if driver == "tidb" { | ||||
| 		qb = new(TiDBQueryBuilder) | ||||
| 	} else if driver == "postgres" { | ||||
| 		err = errors.New("postgres query builder is not supported yet") | ||||
| 	} else if driver == "sqlite" { | ||||
| 		err = errors.New("sqlite query builder is not supported yet") | ||||
| 	} else { | ||||
| 		err = errors.New("unknown driver for query builder") | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										185
									
								
								pkg/orm/qb_mysql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								pkg/orm/qb_mysql.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,185 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // CommaSpace is the separation | ||||
| const CommaSpace = ", " | ||||
| 
 | ||||
| // MySQLQueryBuilder is the SQL build | ||||
| type MySQLQueryBuilder struct { | ||||
| 	Tokens []string | ||||
| } | ||||
| 
 | ||||
| // Select will join the fields | ||||
| func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // ForUpdate add the FOR UPDATE clause | ||||
| func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "FOR UPDATE") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // From join the tables | ||||
| func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // InnerJoin INNER JOIN the table | ||||
| func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "INNER JOIN", table) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // LeftJoin LEFT JOIN the table | ||||
| func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // RightJoin RIGHT JOIN the table | ||||
| func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // On join with on cond | ||||
| func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "ON", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Where join the Where cond | ||||
| func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "WHERE", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // And join the and cond | ||||
| func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "AND", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Or join the or cond | ||||
| func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "OR", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // In join the IN (vals) | ||||
| func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // OrderBy join the Order by fields | ||||
| func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Asc join the asc | ||||
| func (qb *MySQLQueryBuilder) Asc() QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "ASC") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Desc join the desc | ||||
| func (qb *MySQLQueryBuilder) Desc() QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "DESC") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Limit join the limit num | ||||
| func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Offset join the offset num | ||||
| func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // GroupBy join the Group by fields | ||||
| func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Having join the Having cond | ||||
| func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "HAVING", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Update join the update table | ||||
| func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Set join the set kv | ||||
| func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Delete join the Delete tables | ||||
| func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "DELETE") | ||||
| 	if len(tables) != 0 { | ||||
| 		qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) | ||||
| 	} | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // InsertInto join the insert SQL | ||||
| func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "INSERT INTO", table) | ||||
| 	if len(fields) != 0 { | ||||
| 		fieldsStr := strings.Join(fields, CommaSpace) | ||||
| 		qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") | ||||
| 	} | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Values join the Values(vals) | ||||
| func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { | ||||
| 	valsStr := strings.Join(vals, CommaSpace) | ||||
| 	qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Subquery join the sub as alias | ||||
| func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { | ||||
| 	return fmt.Sprintf("(%s) AS %s", sub, alias) | ||||
| } | ||||
| 
 | ||||
| // String join all Tokens | ||||
| func (qb *MySQLQueryBuilder) String() string { | ||||
| 	return strings.Join(qb.Tokens, " ") | ||||
| } | ||||
							
								
								
									
										182
									
								
								pkg/orm/qb_tidb.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										182
									
								
								pkg/orm/qb_tidb.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,182 @@ | ||||
| // Copyright 2015 TiDB Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // TiDBQueryBuilder is the SQL build | ||||
| type TiDBQueryBuilder struct { | ||||
| 	Tokens []string | ||||
| } | ||||
| 
 | ||||
| // Select will join the fields | ||||
| func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // ForUpdate add the FOR UPDATE clause | ||||
| func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "FOR UPDATE") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // From join the tables | ||||
| func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // InnerJoin INNER JOIN the table | ||||
| func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "INNER JOIN", table) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // LeftJoin LEFT JOIN the table | ||||
| func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // RightJoin RIGHT JOIN the table | ||||
| func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // On join with on cond | ||||
| func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "ON", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Where join the Where cond | ||||
| func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "WHERE", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // And join the and cond | ||||
| func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "AND", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Or join the or cond | ||||
| func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "OR", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // In join the IN (vals) | ||||
| func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // OrderBy join the Order by fields | ||||
| func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Asc join the asc | ||||
| func (qb *TiDBQueryBuilder) Asc() QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "ASC") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Desc join the desc | ||||
| func (qb *TiDBQueryBuilder) Desc() QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "DESC") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Limit join the limit num | ||||
| func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Offset join the offset num | ||||
| func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // GroupBy join the Group by fields | ||||
| func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Having join the Having cond | ||||
| func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "HAVING", cond) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Update join the update table | ||||
| func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Set join the set kv | ||||
| func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Delete join the Delete tables | ||||
| func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "DELETE") | ||||
| 	if len(tables) != 0 { | ||||
| 		qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) | ||||
| 	} | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // InsertInto join the insert SQL | ||||
| func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { | ||||
| 	qb.Tokens = append(qb.Tokens, "INSERT INTO", table) | ||||
| 	if len(fields) != 0 { | ||||
| 		fieldsStr := strings.Join(fields, CommaSpace) | ||||
| 		qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") | ||||
| 	} | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Values join the Values(vals) | ||||
| func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { | ||||
| 	valsStr := strings.Join(vals, CommaSpace) | ||||
| 	qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") | ||||
| 	return qb | ||||
| } | ||||
| 
 | ||||
| // Subquery join the sub as alias | ||||
| func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { | ||||
| 	return fmt.Sprintf("(%s) AS %s", sub, alias) | ||||
| } | ||||
| 
 | ||||
| // String join all Tokens | ||||
| func (qb *TiDBQueryBuilder) String() string { | ||||
| 	return strings.Join(qb.Tokens, " ") | ||||
| } | ||||
							
								
								
									
										473
									
								
								pkg/orm/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										473
									
								
								pkg/orm/types.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,473 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // Driver define database driver | ||||
| type Driver interface { | ||||
| 	Name() string | ||||
| 	Type() DriverType | ||||
| } | ||||
| 
 | ||||
| // Fielder define field info | ||||
| type Fielder interface { | ||||
| 	String() string | ||||
| 	FieldType() int | ||||
| 	SetRaw(interface{}) error | ||||
| 	RawValue() interface{} | ||||
| } | ||||
| 
 | ||||
| // Ormer define the orm interface | ||||
| type Ormer interface { | ||||
| 	// read data to model | ||||
| 	// for example: | ||||
| 	//	this will find User by Id field | ||||
| 	// 	u = &User{Id: user.Id} | ||||
| 	// 	err = Ormer.Read(u) | ||||
| 	//	this will find User by UserName field | ||||
| 	// 	u = &User{UserName: "astaxie", Password: "pass"} | ||||
| 	//	err = Ormer.Read(u, "UserName") | ||||
| 	Read(md interface{}, cols ...string) error | ||||
| 	// Like Read(), but with "FOR UPDATE" clause, useful in transaction. | ||||
| 	// Some databases are not support this feature. | ||||
| 	ReadForUpdate(md interface{}, cols ...string) error | ||||
| 	// Try to read a row from the database, or insert one if it doesn't exist | ||||
| 	ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) | ||||
| 	// insert model data to database | ||||
| 	// for example: | ||||
| 	//  user := new(User) | ||||
| 	//  id, err = Ormer.Insert(user) | ||||
| 	//  user must be a pointer and Insert will set user's pk field | ||||
| 	Insert(interface{}) (int64, error) | ||||
| 	// mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") | ||||
| 	// if colu type is integer : can use(+-*/), string : convert(colu,"value") | ||||
| 	// postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") | ||||
| 	// if colu type is integer : can use(+-*/), string : colu || "value" | ||||
| 	InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) | ||||
| 	// insert some models to database | ||||
| 	InsertMulti(bulk int, mds interface{}) (int64, error) | ||||
| 	// update model to database. | ||||
| 	// cols set the columns those want to update. | ||||
| 	// find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns | ||||
| 	// for example: | ||||
| 	// user := User{Id: 2} | ||||
| 	//	user.Langs = append(user.Langs, "zh-CN", "en-US") | ||||
| 	//	user.Extra.Name = "beego" | ||||
| 	//	user.Extra.Data = "orm" | ||||
| 	//	num, err = Ormer.Update(&user, "Langs", "Extra") | ||||
| 	Update(md interface{}, cols ...string) (int64, error) | ||||
| 	// delete model in database | ||||
| 	Delete(md interface{}, cols ...string) (int64, error) | ||||
| 	// load related models to md model. | ||||
| 	// args are limit, offset int and order string. | ||||
| 	// | ||||
| 	// example: | ||||
| 	// 	Ormer.LoadRelated(post,"Tags") | ||||
| 	// 	for _,tag := range post.Tags{...} | ||||
| 	//args[0] bool true useDefaultRelsDepth ; false  depth 0 | ||||
| 	//args[0] int  loadRelationDepth | ||||
| 	//args[1] int limit default limit 1000 | ||||
| 	//args[2] int offset default offset 0 | ||||
| 	//args[3] string order  for example : "-Id" | ||||
| 	// make sure the relation is defined in model struct tags. | ||||
| 	LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) | ||||
| 	// create a models to models queryer | ||||
| 	// for example: | ||||
| 	// 	post := Post{Id: 4} | ||||
| 	// 	m2m := Ormer.QueryM2M(&post, "Tags") | ||||
| 	QueryM2M(md interface{}, name string) QueryM2Mer | ||||
| 	// return a QuerySeter for table operations. | ||||
| 	// table name can be string or struct. | ||||
| 	// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), | ||||
| 	QueryTable(ptrStructOrTableName interface{}) QuerySeter | ||||
| 	// switch to another registered database driver by given name. | ||||
| 	Using(name string) error | ||||
| 	// begin transaction | ||||
| 	// for example: | ||||
| 	// 	o := NewOrm() | ||||
| 	// 	err := o.Begin() | ||||
| 	// 	... | ||||
| 	// 	err = o.Rollback() | ||||
| 	Begin() error | ||||
| 	// begin transaction with provided context and option | ||||
| 	// the provided context is used until the transaction is committed or rolled back. | ||||
| 	// if the context is canceled, the transaction will be rolled back. | ||||
| 	// the provided TxOptions is optional and may be nil if defaults should be used. | ||||
| 	// if a non-default isolation level is used that the driver doesn't support, an error will be returned. | ||||
| 	// for example: | ||||
| 	//  o := NewOrm() | ||||
| 	// 	err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) | ||||
| 	//  ... | ||||
| 	//  err = o.Rollback() | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) error | ||||
| 	// commit transaction | ||||
| 	Commit() error | ||||
| 	// rollback transaction | ||||
| 	Rollback() error | ||||
| 	// return a raw query seter for raw sql string. | ||||
| 	// for example: | ||||
| 	//	 ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() | ||||
| 	//	// update user testing's name to slene | ||||
| 	Raw(query string, args ...interface{}) RawSeter | ||||
| 	Driver() Driver | ||||
| 	DBStats() *sql.DBStats | ||||
| } | ||||
| 
 | ||||
| // Inserter insert prepared statement | ||||
| type Inserter interface { | ||||
| 	Insert(interface{}) (int64, error) | ||||
| 	Close() error | ||||
| } | ||||
| 
 | ||||
| // QuerySeter query seter | ||||
| type QuerySeter interface { | ||||
| 	// add condition expression to QuerySeter. | ||||
| 	// for example: | ||||
| 	//	filter by UserName == 'slene' | ||||
| 	//	qs.Filter("UserName", "slene") | ||||
| 	//	sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28 | ||||
| 	//	Filter("profile__Age", 28) | ||||
| 	// 	 // time compare | ||||
| 	//	qs.Filter("created", time.Now()) | ||||
| 	Filter(string, ...interface{}) QuerySeter | ||||
| 	// add raw sql to querySeter. | ||||
| 	// for example: | ||||
| 	// qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") | ||||
| 	// //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) | ||||
| 	FilterRaw(string, string) QuerySeter | ||||
| 	// add NOT condition to querySeter. | ||||
| 	// have the same usage as Filter | ||||
| 	Exclude(string, ...interface{}) QuerySeter | ||||
| 	// set condition to QuerySeter. | ||||
| 	// sql's where condition | ||||
| 	//	cond := orm.NewCondition() | ||||
| 	//	cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) | ||||
| 	//	//sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` >  2000 | ||||
| 	//	num, err := qs.SetCond(cond1).Count() | ||||
| 	SetCond(*Condition) QuerySeter | ||||
| 	// get condition from QuerySeter. | ||||
| 	// sql's where condition | ||||
| 	//  cond := orm.NewCondition() | ||||
| 	//  cond = cond.And("profile__isnull", false).AndNot("status__in", 1) | ||||
| 	//  qs = qs.SetCond(cond) | ||||
| 	//  cond = qs.GetCond() | ||||
| 	//  cond := cond.Or("profile__age__gt", 2000) | ||||
| 	//  //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` >  2000 | ||||
| 	//  num, err := qs.SetCond(cond).Count() | ||||
| 	GetCond() *Condition | ||||
| 	// add LIMIT value. | ||||
| 	// args[0] means offset, e.g. LIMIT num,offset. | ||||
| 	// if Limit <= 0 then Limit will be set to default limit ,eg 1000 | ||||
| 	// if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 | ||||
| 	//  for example: | ||||
| 	//	qs.Limit(10, 2) | ||||
| 	//	// sql-> limit 10 offset 2 | ||||
| 	Limit(limit interface{}, args ...interface{}) QuerySeter | ||||
| 	// add OFFSET value | ||||
| 	// same as Limit function's args[0] | ||||
| 	Offset(offset interface{}) QuerySeter | ||||
| 	// add GROUP BY expression | ||||
| 	// for example: | ||||
| 	//	qs.GroupBy("id") | ||||
| 	GroupBy(exprs ...string) QuerySeter | ||||
| 	// add ORDER expression. | ||||
| 	// "column" means ASC, "-column" means DESC. | ||||
| 	// for example: | ||||
| 	//	qs.OrderBy("-status") | ||||
| 	OrderBy(exprs ...string) QuerySeter | ||||
| 	// set relation model to query together. | ||||
| 	// it will query relation models and assign to parent model. | ||||
| 	// for example: | ||||
| 	//	// will load all related fields use left join . | ||||
| 	// 	qs.RelatedSel().One(&user) | ||||
| 	//	// will  load related field only profile | ||||
| 	//	qs.RelatedSel("profile").One(&user) | ||||
| 	//	user.Profile.Age = 32 | ||||
| 	RelatedSel(params ...interface{}) QuerySeter | ||||
| 	// Set Distinct | ||||
| 	// for example: | ||||
| 	//  o.QueryTable("policy").Filter("Groups__Group__Users__User", user). | ||||
| 	//    Distinct(). | ||||
| 	//    All(&permissions) | ||||
| 	Distinct() QuerySeter | ||||
| 	// set FOR UPDATE to query. | ||||
| 	// for example: | ||||
| 	//  o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) | ||||
| 	ForUpdate() QuerySeter | ||||
| 	// return QuerySeter execution result number | ||||
| 	// for example: | ||||
| 	//	num, err = qs.Filter("profile__age__gt", 28).Count() | ||||
| 	Count() (int64, error) | ||||
| 	// check result empty or not after QuerySeter executed | ||||
| 	// the same as QuerySeter.Count > 0 | ||||
| 	Exist() bool | ||||
| 	// execute update with parameters | ||||
| 	// for example: | ||||
| 	//	num, err = qs.Filter("user_name", "slene").Update(Params{ | ||||
| 	//		"Nums": ColValue(Col_Minus, 50), | ||||
| 	//	}) // user slene's Nums will minus 50 | ||||
| 	//	num, err = qs.Filter("UserName", "slene").Update(Params{ | ||||
| 	//		"user_name": "slene2" | ||||
| 	//	}) // user slene's  name will change to slene2 | ||||
| 	Update(values Params) (int64, error) | ||||
| 	// delete from table | ||||
| 	//for example: | ||||
| 	//	num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() | ||||
| 	// 	//delete two user  who's name is testing1 or testing2 | ||||
| 	Delete() (int64, error) | ||||
| 	// return a insert queryer. | ||||
| 	// it can be used in times. | ||||
| 	// example: | ||||
| 	// 	i,err := sq.PrepareInsert() | ||||
| 	// 	num, err = i.Insert(&user1) // user table will add one record user1 at once | ||||
| 	//	num, err = i.Insert(&user2) // user table will add one record user2 at once | ||||
| 	//	err = i.Close() //don't forget call Close | ||||
| 	PrepareInsert() (Inserter, error) | ||||
| 	// query all data and map to containers. | ||||
| 	// cols means the columns when querying. | ||||
| 	// for example: | ||||
| 	//	var users []*User | ||||
| 	//	qs.All(&users) // users[0],users[1],users[2] ... | ||||
| 	All(container interface{}, cols ...string) (int64, error) | ||||
| 	// query one row data and map to containers. | ||||
| 	// cols means the columns when querying. | ||||
| 	// for example: | ||||
| 	//	var user User | ||||
| 	//	qs.One(&user) //user.UserName == "slene" | ||||
| 	One(container interface{}, cols ...string) error | ||||
| 	// query all data and map to []map[string]interface. | ||||
| 	// expres means condition expression. | ||||
| 	// it converts data to []map[column]value. | ||||
| 	// for example: | ||||
| 	//	var maps []Params | ||||
| 	//	qs.Values(&maps) //maps[0]["UserName"]=="slene" | ||||
| 	Values(results *[]Params, exprs ...string) (int64, error) | ||||
| 	// query all data and map to [][]interface | ||||
| 	// it converts data to [][column_index]value | ||||
| 	// for example: | ||||
| 	//	var list []ParamsList | ||||
| 	//	qs.ValuesList(&list) // list[0][1] == "slene" | ||||
| 	ValuesList(results *[]ParamsList, exprs ...string) (int64, error) | ||||
| 	// query all data and map to []interface. | ||||
| 	// it's designed for one column record set, auto change to []value, not [][column]value. | ||||
| 	// for example: | ||||
| 	//	var list ParamsList | ||||
| 	//	qs.ValuesFlat(&list, "UserName") // list[0] == "slene" | ||||
| 	ValuesFlat(result *ParamsList, expr string) (int64, error) | ||||
| 	// query all rows into map[string]interface with specify key and value column name. | ||||
| 	// keyCol = "name", valueCol = "value" | ||||
| 	// table data | ||||
| 	// name  | value | ||||
| 	// total | 100 | ||||
| 	// found | 200 | ||||
| 	// to map[string]interface{}{ | ||||
| 	// 	"total": 100, | ||||
| 	// 	"found": 200, | ||||
| 	// } | ||||
| 	RowsToMap(result *Params, keyCol, valueCol string) (int64, error) | ||||
| 	// query all rows into struct with specify key and value column name. | ||||
| 	// keyCol = "name", valueCol = "value" | ||||
| 	// table data | ||||
| 	// name  | value | ||||
| 	// total | 100 | ||||
| 	// found | 200 | ||||
| 	// to struct { | ||||
| 	// 	Total int | ||||
| 	// 	Found int | ||||
| 	// } | ||||
| 	RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) | ||||
| } | ||||
| 
 | ||||
| // QueryM2Mer model to model query struct | ||||
| // all operations are on the m2m table only, will not affect the origin model table | ||||
| type QueryM2Mer interface { | ||||
| 	// add models to origin models when creating queryM2M. | ||||
| 	// example: | ||||
| 	// 	m2m := orm.QueryM2M(post,"Tag") | ||||
| 	// 	m2m.Add(&Tag1{},&Tag2{}) | ||||
| 	//  	for _,tag := range post.Tags{}{ ... } | ||||
| 	// param could also be any of the follow | ||||
| 	// 	[]*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}} | ||||
| 	//	&Tag{Id:5,Name: "TestTag3"} | ||||
| 	//	[]interface{}{&Tag{Id:6,Name: "TestTag4"}} | ||||
| 	// insert one or more rows to m2m table | ||||
| 	// make sure the relation is defined in post model struct tag. | ||||
| 	Add(...interface{}) (int64, error) | ||||
| 	// remove models following the origin model relationship | ||||
| 	// only delete rows from m2m table | ||||
| 	// for example: | ||||
| 	//tag3 := &Tag{Id:5,Name: "TestTag3"} | ||||
| 	//num, err = m2m.Remove(tag3) | ||||
| 	Remove(...interface{}) (int64, error) | ||||
| 	// check model is existed in relationship of origin model | ||||
| 	Exist(interface{}) bool | ||||
| 	// clean all models in related of origin model | ||||
| 	Clear() (int64, error) | ||||
| 	// count all related models of origin model | ||||
| 	Count() (int64, error) | ||||
| } | ||||
| 
 | ||||
| // RawPreparer raw query statement | ||||
| type RawPreparer interface { | ||||
| 	Exec(...interface{}) (sql.Result, error) | ||||
| 	Close() error | ||||
| } | ||||
| 
 | ||||
| // RawSeter raw query seter | ||||
| // create From Ormer.Raw | ||||
| // for example: | ||||
| //  sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) | ||||
| //  rs := Ormer.Raw(sql, 1) | ||||
| type RawSeter interface { | ||||
| 	//execute sql and get result | ||||
| 	Exec() (sql.Result, error) | ||||
| 	//query data and map to container | ||||
| 	//for example: | ||||
| 	//	var name string | ||||
| 	//	var id int | ||||
| 	//	rs.QueryRow(&id,&name) // id==2 name=="slene" | ||||
| 	QueryRow(containers ...interface{}) error | ||||
| 
 | ||||
| 	// query data rows and map to container | ||||
| 	//	var ids []int | ||||
| 	//	var names []int | ||||
| 	//	query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) | ||||
| 	//	num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} | ||||
| 	QueryRows(containers ...interface{}) (int64, error) | ||||
| 	SetArgs(...interface{}) RawSeter | ||||
| 	// query data to []map[string]interface | ||||
| 	// see QuerySeter's Values | ||||
| 	Values(container *[]Params, cols ...string) (int64, error) | ||||
| 	// query data to [][]interface | ||||
| 	// see QuerySeter's ValuesList | ||||
| 	ValuesList(container *[]ParamsList, cols ...string) (int64, error) | ||||
| 	// query data to []interface | ||||
| 	// see QuerySeter's ValuesFlat | ||||
| 	ValuesFlat(container *ParamsList, cols ...string) (int64, error) | ||||
| 	// query all rows into map[string]interface with specify key and value column name. | ||||
| 	// keyCol = "name", valueCol = "value" | ||||
| 	// table data | ||||
| 	// name  | value | ||||
| 	// total | 100 | ||||
| 	// found | 200 | ||||
| 	// to map[string]interface{}{ | ||||
| 	// 	"total": 100, | ||||
| 	// 	"found": 200, | ||||
| 	// } | ||||
| 	RowsToMap(result *Params, keyCol, valueCol string) (int64, error) | ||||
| 	// query all rows into struct with specify key and value column name. | ||||
| 	// keyCol = "name", valueCol = "value" | ||||
| 	// table data | ||||
| 	// name  | value | ||||
| 	// total | 100 | ||||
| 	// found | 200 | ||||
| 	// to struct { | ||||
| 	// 	Total int | ||||
| 	// 	Found int | ||||
| 	// } | ||||
| 	RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) | ||||
| 
 | ||||
| 	// return prepared raw statement for used in times. | ||||
| 	// for example: | ||||
| 	// 	pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() | ||||
| 	// 	r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) | ||||
| 	Prepare() (RawPreparer, error) | ||||
| } | ||||
| 
 | ||||
| // stmtQuerier statement querier | ||||
| type stmtQuerier interface { | ||||
| 	Close() error | ||||
| 	Exec(args ...interface{}) (sql.Result, error) | ||||
| 	//ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) | ||||
| 	Query(args ...interface{}) (*sql.Rows, error) | ||||
| 	//QueryContext(args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryRow(args ...interface{}) *sql.Row | ||||
| 	//QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row | ||||
| } | ||||
| 
 | ||||
| // db querier | ||||
| type dbQuerier interface { | ||||
| 	Prepare(query string) (*sql.Stmt, error) | ||||
| 	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) | ||||
| 	Exec(query string, args ...interface{}) (sql.Result, error) | ||||
| 	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) | ||||
| 	Query(query string, args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) | ||||
| 	QueryRow(query string, args ...interface{}) *sql.Row | ||||
| 	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row | ||||
| } | ||||
| 
 | ||||
| // type DB interface { | ||||
| // 	Begin() (*sql.Tx, error) | ||||
| // 	Prepare(query string) (stmtQuerier, error) | ||||
| // 	Exec(query string, args ...interface{}) (sql.Result, error) | ||||
| // 	Query(query string, args ...interface{}) (*sql.Rows, error) | ||||
| // 	QueryRow(query string, args ...interface{}) *sql.Row | ||||
| // } | ||||
| 
 | ||||
| // transaction beginner | ||||
| type txer interface { | ||||
| 	Begin() (*sql.Tx, error) | ||||
| 	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) | ||||
| } | ||||
| 
 | ||||
| // transaction ending | ||||
| type txEnder interface { | ||||
| 	Commit() error | ||||
| 	Rollback() error | ||||
| } | ||||
| 
 | ||||
| // base database struct | ||||
| type dbBaser interface { | ||||
| 	Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error | ||||
| 	Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) | ||||
| 	InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) | ||||
| 	InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) | ||||
| 	InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) | ||||
| 	InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) | ||||
| 	Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) | ||||
| 	Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) | ||||
| 	ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) | ||||
| 	SupportUpdateJoin() bool | ||||
| 	UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) | ||||
| 	DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) | ||||
| 	Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) | ||||
| 	OperatorSQL(string) string | ||||
| 	GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) | ||||
| 	GenerateOperatorLeftCol(*fieldInfo, string, *string) | ||||
| 	PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) | ||||
| 	ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) | ||||
| 	RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) | ||||
| 	MaxLimit() uint64 | ||||
| 	TableQuote() string | ||||
| 	ReplaceMarks(*string) | ||||
| 	HasReturningID(*modelInfo, *string) bool | ||||
| 	TimeFromDB(*time.Time, *time.Location) | ||||
| 	TimeToDB(*time.Time, *time.Location) | ||||
| 	DbTypes() map[string]string | ||||
| 	GetTables(dbQuerier) (map[string]bool, error) | ||||
| 	GetColumns(dbQuerier, string) (map[string][3]string, error) | ||||
| 	ShowTablesQuery() string | ||||
| 	ShowColumnsQuery(string) string | ||||
| 	IndexExists(dbQuerier, string, string) bool | ||||
| 	collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) | ||||
| 	setval(dbQuerier, *modelInfo, []string) error | ||||
| } | ||||
							
								
								
									
										319
									
								
								pkg/orm/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										319
									
								
								pkg/orm/utils.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,319 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"math/big" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| type fn func(string) string | ||||
| 
 | ||||
| var ( | ||||
| 	nameStrategyMap = map[string]fn{ | ||||
| 		defaultNameStrategy:      snakeString, | ||||
| 		SnakeAcronymNameStrategy: snakeStringWithAcronym, | ||||
| 	} | ||||
| 	defaultNameStrategy      = "snakeString" | ||||
| 	SnakeAcronymNameStrategy = "snakeStringWithAcronym" | ||||
| 	nameStrategy             = defaultNameStrategy | ||||
| ) | ||||
| 
 | ||||
| // StrTo is the target string | ||||
| type StrTo string | ||||
| 
 | ||||
| // Set string | ||||
| func (f *StrTo) Set(v string) { | ||||
| 	if v != "" { | ||||
| 		*f = StrTo(v) | ||||
| 	} else { | ||||
| 		f.Clear() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Clear string | ||||
| func (f *StrTo) Clear() { | ||||
| 	*f = StrTo(0x1E) | ||||
| } | ||||
| 
 | ||||
| // Exist check string exist | ||||
| func (f StrTo) Exist() bool { | ||||
| 	return string(f) != string(0x1E) | ||||
| } | ||||
| 
 | ||||
| // Bool string to bool | ||||
| func (f StrTo) Bool() (bool, error) { | ||||
| 	return strconv.ParseBool(f.String()) | ||||
| } | ||||
| 
 | ||||
| // Float32 string to float32 | ||||
| func (f StrTo) Float32() (float32, error) { | ||||
| 	v, err := strconv.ParseFloat(f.String(), 32) | ||||
| 	return float32(v), err | ||||
| } | ||||
| 
 | ||||
| // Float64 string to float64 | ||||
| func (f StrTo) Float64() (float64, error) { | ||||
| 	return strconv.ParseFloat(f.String(), 64) | ||||
| } | ||||
| 
 | ||||
| // Int string to int | ||||
| func (f StrTo) Int() (int, error) { | ||||
| 	v, err := strconv.ParseInt(f.String(), 10, 32) | ||||
| 	return int(v), err | ||||
| } | ||||
| 
 | ||||
| // Int8 string to int8 | ||||
| func (f StrTo) Int8() (int8, error) { | ||||
| 	v, err := strconv.ParseInt(f.String(), 10, 8) | ||||
| 	return int8(v), err | ||||
| } | ||||
| 
 | ||||
| // Int16 string to int16 | ||||
| func (f StrTo) Int16() (int16, error) { | ||||
| 	v, err := strconv.ParseInt(f.String(), 10, 16) | ||||
| 	return int16(v), err | ||||
| } | ||||
| 
 | ||||
| // Int32 string to int32 | ||||
| func (f StrTo) Int32() (int32, error) { | ||||
| 	v, err := strconv.ParseInt(f.String(), 10, 32) | ||||
| 	return int32(v), err | ||||
| } | ||||
| 
 | ||||
| // Int64 string to int64 | ||||
| func (f StrTo) Int64() (int64, error) { | ||||
| 	v, err := strconv.ParseInt(f.String(), 10, 64) | ||||
| 	if err != nil { | ||||
| 		i := new(big.Int) | ||||
| 		ni, ok := i.SetString(f.String(), 10) // octal | ||||
| 		if !ok { | ||||
| 			return v, err | ||||
| 		} | ||||
| 		return ni.Int64(), nil | ||||
| 	} | ||||
| 	return v, err | ||||
| } | ||||
| 
 | ||||
| // Uint string to uint | ||||
| func (f StrTo) Uint() (uint, error) { | ||||
| 	v, err := strconv.ParseUint(f.String(), 10, 32) | ||||
| 	return uint(v), err | ||||
| } | ||||
| 
 | ||||
| // Uint8 string to uint8 | ||||
| func (f StrTo) Uint8() (uint8, error) { | ||||
| 	v, err := strconv.ParseUint(f.String(), 10, 8) | ||||
| 	return uint8(v), err | ||||
| } | ||||
| 
 | ||||
| // Uint16 string to uint16 | ||||
| func (f StrTo) Uint16() (uint16, error) { | ||||
| 	v, err := strconv.ParseUint(f.String(), 10, 16) | ||||
| 	return uint16(v), err | ||||
| } | ||||
| 
 | ||||
| // Uint32 string to uint32 | ||||
| func (f StrTo) Uint32() (uint32, error) { | ||||
| 	v, err := strconv.ParseUint(f.String(), 10, 32) | ||||
| 	return uint32(v), err | ||||
| } | ||||
| 
 | ||||
| // Uint64 string to uint64 | ||||
| func (f StrTo) Uint64() (uint64, error) { | ||||
| 	v, err := strconv.ParseUint(f.String(), 10, 64) | ||||
| 	if err != nil { | ||||
| 		i := new(big.Int) | ||||
| 		ni, ok := i.SetString(f.String(), 10) | ||||
| 		if !ok { | ||||
| 			return v, err | ||||
| 		} | ||||
| 		return ni.Uint64(), nil | ||||
| 	} | ||||
| 	return v, err | ||||
| } | ||||
| 
 | ||||
| // String string to string | ||||
| func (f StrTo) String() string { | ||||
| 	if f.Exist() { | ||||
| 		return string(f) | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // ToStr interface to string | ||||
| func ToStr(value interface{}, args ...int) (s string) { | ||||
| 	switch v := value.(type) { | ||||
| 	case bool: | ||||
| 		s = strconv.FormatBool(v) | ||||
| 	case float32: | ||||
| 		s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) | ||||
| 	case float64: | ||||
| 		s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) | ||||
| 	case int: | ||||
| 		s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) | ||||
| 	case int8: | ||||
| 		s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) | ||||
| 	case int16: | ||||
| 		s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) | ||||
| 	case int32: | ||||
| 		s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) | ||||
| 	case int64: | ||||
| 		s = strconv.FormatInt(v, argInt(args).Get(0, 10)) | ||||
| 	case uint: | ||||
| 		s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) | ||||
| 	case uint8: | ||||
| 		s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) | ||||
| 	case uint16: | ||||
| 		s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) | ||||
| 	case uint32: | ||||
| 		s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) | ||||
| 	case uint64: | ||||
| 		s = strconv.FormatUint(v, argInt(args).Get(0, 10)) | ||||
| 	case string: | ||||
| 		s = v | ||||
| 	case []byte: | ||||
| 		s = string(v) | ||||
| 	default: | ||||
| 		s = fmt.Sprintf("%v", v) | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
| 
 | ||||
| // ToInt64 interface to int64 | ||||
| func ToInt64(value interface{}) (d int64) { | ||||
| 	val := reflect.ValueOf(value) | ||||
| 	switch value.(type) { | ||||
| 	case int, int8, int16, int32, int64: | ||||
| 		d = val.Int() | ||||
| 	case uint, uint8, uint16, uint32, uint64: | ||||
| 		d = int64(val.Uint()) | ||||
| 	default: | ||||
| 		panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func snakeStringWithAcronym(s string) string { | ||||
| 	data := make([]byte, 0, len(s)*2) | ||||
| 	num := len(s) | ||||
| 	for i := 0; i < num; i++ { | ||||
| 		d := s[i] | ||||
| 		before := false | ||||
| 		after := false | ||||
| 		if i > 0 { | ||||
| 			before = s[i-1] >= 'a' && s[i-1] <= 'z' | ||||
| 		} | ||||
| 		if i+1 < num { | ||||
| 			after = s[i+1] >= 'a' && s[i+1] <= 'z' | ||||
| 		} | ||||
| 		if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { | ||||
| 			data = append(data, '_') | ||||
| 		} | ||||
| 		data = append(data, d) | ||||
| 	} | ||||
| 	return strings.ToLower(string(data[:])) | ||||
| } | ||||
| 
 | ||||
| // snake string, XxYy to xx_yy , XxYY to xx_y_y | ||||
| func snakeString(s string) string { | ||||
| 	data := make([]byte, 0, len(s)*2) | ||||
| 	j := false | ||||
| 	num := len(s) | ||||
| 	for i := 0; i < num; i++ { | ||||
| 		d := s[i] | ||||
| 		if i > 0 && d >= 'A' && d <= 'Z' && j { | ||||
| 			data = append(data, '_') | ||||
| 		} | ||||
| 		if d != '_' { | ||||
| 			j = true | ||||
| 		} | ||||
| 		data = append(data, d) | ||||
| 	} | ||||
| 	return strings.ToLower(string(data[:])) | ||||
| } | ||||
| 
 | ||||
| // SetNameStrategy set different name strategy | ||||
| func SetNameStrategy(s string) { | ||||
| 	if SnakeAcronymNameStrategy != s { | ||||
| 		nameStrategy = defaultNameStrategy | ||||
| 	} | ||||
| 	nameStrategy = s | ||||
| } | ||||
| 
 | ||||
| // camel string, xx_yy to XxYy | ||||
| func camelString(s string) string { | ||||
| 	data := make([]byte, 0, len(s)) | ||||
| 	flag, num := true, len(s)-1 | ||||
| 	for i := 0; i <= num; i++ { | ||||
| 		d := s[i] | ||||
| 		if d == '_' { | ||||
| 			flag = true | ||||
| 			continue | ||||
| 		} else if flag { | ||||
| 			if d >= 'a' && d <= 'z' { | ||||
| 				d = d - 32 | ||||
| 			} | ||||
| 			flag = false | ||||
| 		} | ||||
| 		data = append(data, d) | ||||
| 	} | ||||
| 	return string(data[:]) | ||||
| } | ||||
| 
 | ||||
| type argString []string | ||||
| 
 | ||||
| // get string by index from string slice | ||||
| func (a argString) Get(i int, args ...string) (r string) { | ||||
| 	if i >= 0 && i < len(a) { | ||||
| 		r = a[i] | ||||
| 	} else if len(args) > 0 { | ||||
| 		r = args[0] | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| type argInt []int | ||||
| 
 | ||||
| // get int by index from int slice | ||||
| func (a argInt) Get(i int, args ...int) (r int) { | ||||
| 	if i >= 0 && i < len(a) { | ||||
| 		r = a[i] | ||||
| 	} | ||||
| 	if len(args) > 0 { | ||||
| 		r = args[0] | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // parse time to string with location | ||||
| func timeParse(dateString, format string) (time.Time, error) { | ||||
| 	tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) | ||||
| 	return tp, err | ||||
| } | ||||
| 
 | ||||
| // get pointer indirect type | ||||
| func indirectType(v reflect.Type) reflect.Type { | ||||
| 	switch v.Kind() { | ||||
| 	case reflect.Ptr: | ||||
| 		return indirectType(v.Elem()) | ||||
| 	default: | ||||
| 		return v | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										70
									
								
								pkg/orm/utils_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								pkg/orm/utils_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,70 @@ | ||||
| // Copyright 2014 beego Author. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package orm | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestCamelString(t *testing.T) { | ||||
| 	snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} | ||||
| 	camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} | ||||
| 
 | ||||
| 	answer := make(map[string]string) | ||||
| 	for i, v := range snake { | ||||
| 		answer[v] = camel[i] | ||||
| 	} | ||||
| 
 | ||||
| 	for _, v := range snake { | ||||
| 		res := camelString(v) | ||||
| 		if res != answer[v] { | ||||
| 			t.Error("Unit Test Fail:", v, res, answer[v]) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSnakeString(t *testing.T) { | ||||
| 	camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} | ||||
| 	snake := []string{"pic_url", "hello_world", "hello_world", "hel_l_o_word", "pic_url1", "xy_x_x"} | ||||
| 
 | ||||
| 	answer := make(map[string]string) | ||||
| 	for i, v := range camel { | ||||
| 		answer[v] = snake[i] | ||||
| 	} | ||||
| 
 | ||||
| 	for _, v := range camel { | ||||
| 		res := snakeString(v) | ||||
| 		if res != answer[v] { | ||||
| 			t.Error("Unit Test Fail:", v, res, answer[v]) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSnakeStringWithAcronym(t *testing.T) { | ||||
| 	camel := []string{"ID", "PicURL", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "XyXX"} | ||||
| 	snake := []string{"id", "pic_url", "hello_world", "hello_world", "hel_lo_word", "pic_url1", "xy_xx"} | ||||
| 
 | ||||
| 	answer := make(map[string]string) | ||||
| 	for i, v := range camel { | ||||
| 		answer[v] = snake[i] | ||||
| 	} | ||||
| 
 | ||||
| 	for _, v := range camel { | ||||
| 		res := snakeStringWithAcronym(v) | ||||
| 		if res != answer[v] { | ||||
| 			t.Error("Unit Test Fail:", v, res, answer[v]) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user