321 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			321 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2014 beego Author. All Rights Reserved.
 | 
						|
//
 | 
						|
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
// you may not use this file except in compliance with the License.
 | 
						|
// You may obtain a copy of the License at
 | 
						|
//
 | 
						|
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
//
 | 
						|
// Unless required by applicable law or agreed to in writing, software
 | 
						|
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
// See the License for the specific language governing permissions and
 | 
						|
// limitations under the License.
 | 
						|
 | 
						|
package orm
 | 
						|
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"os"
 | 
						|
	"strings"
 | 
						|
)
 | 
						|
 | 
						|
type dbIndex struct {
 | 
						|
	Table string
 | 
						|
	Name  string
 | 
						|
	SQL   string
 | 
						|
}
 | 
						|
 | 
						|
// create database drop sql.
 | 
						|
func getDbDropSQL(al *alias) (sqls []string) {
 | 
						|
	if len(modelCache.cache) == 0 {
 | 
						|
		fmt.Println("no Model found, need register your model")
 | 
						|
		os.Exit(2)
 | 
						|
	}
 | 
						|
 | 
						|
	Q := al.DbBaser.TableQuote()
 | 
						|
 | 
						|
	for _, mi := range modelCache.allOrdered() {
 | 
						|
		sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q))
 | 
						|
	}
 | 
						|
	return sqls
 | 
						|
}
 | 
						|
 | 
						|
// get database column type string.
 | 
						|
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
 | 
						|
	T := al.DbBaser.DbTypes()
 | 
						|
	fieldType := fi.fieldType
 | 
						|
	fieldSize := fi.size
 | 
						|
 | 
						|
checkColumn:
 | 
						|
	switch fieldType {
 | 
						|
	case TypeBooleanField:
 | 
						|
		col = T["bool"]
 | 
						|
	case TypeVarCharField:
 | 
						|
		if al.Driver == DRPostgres && fi.toText {
 | 
						|
			col = T["string-text"]
 | 
						|
		} else {
 | 
						|
			col = fmt.Sprintf(T["string"], fieldSize)
 | 
						|
		}
 | 
						|
	case TypeCharField:
 | 
						|
		col = fmt.Sprintf(T["string-char"], fieldSize)
 | 
						|
	case TypeTextField:
 | 
						|
		col = T["string-text"]
 | 
						|
	case TypeTimeField:
 | 
						|
		col = T["time.Time-clock"]
 | 
						|
	case TypeDateField:
 | 
						|
		col = T["time.Time-date"]
 | 
						|
	case TypeDateTimeField:
 | 
						|
		col = T["time.Time"]
 | 
						|
	case TypeBitField:
 | 
						|
		col = T["int8"]
 | 
						|
	case TypeSmallIntegerField:
 | 
						|
		col = T["int16"]
 | 
						|
	case TypeIntegerField:
 | 
						|
		col = T["int32"]
 | 
						|
	case TypeBigIntegerField:
 | 
						|
		if al.Driver == DRSqlite {
 | 
						|
			fieldType = TypeIntegerField
 | 
						|
			goto checkColumn
 | 
						|
		}
 | 
						|
		col = T["int64"]
 | 
						|
	case TypePositiveBitField:
 | 
						|
		col = T["uint8"]
 | 
						|
	case TypePositiveSmallIntegerField:
 | 
						|
		col = T["uint16"]
 | 
						|
	case TypePositiveIntegerField:
 | 
						|
		col = T["uint32"]
 | 
						|
	case TypePositiveBigIntegerField:
 | 
						|
		col = T["uint64"]
 | 
						|
	case TypeFloatField:
 | 
						|
		col = T["float64"]
 | 
						|
	case TypeDecimalField:
 | 
						|
		s := T["float64-decimal"]
 | 
						|
		if !strings.Contains(s, "%d") {
 | 
						|
			col = s
 | 
						|
		} else {
 | 
						|
			col = fmt.Sprintf(s, fi.digits, fi.decimals)
 | 
						|
		}
 | 
						|
	case TypeJSONField:
 | 
						|
		if al.Driver != DRPostgres {
 | 
						|
			fieldType = TypeVarCharField
 | 
						|
			goto checkColumn
 | 
						|
		}
 | 
						|
		col = T["json"]
 | 
						|
	case TypeJsonbField:
 | 
						|
		if al.Driver != DRPostgres {
 | 
						|
			fieldType = TypeVarCharField
 | 
						|
			goto checkColumn
 | 
						|
		}
 | 
						|
		col = T["jsonb"]
 | 
						|
	case RelForeignKey, RelOneToOne:
 | 
						|
		fieldType = fi.relModelInfo.fields.pk.fieldType
 | 
						|
		fieldSize = fi.relModelInfo.fields.pk.size
 | 
						|
		goto checkColumn
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
// create alter sql string.
 | 
						|
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
 | 
						|
	Q := al.DbBaser.TableQuote()
 | 
						|
	typ := getColumnTyp(al, fi)
 | 
						|
 | 
						|
	if !fi.null {
 | 
						|
		typ += " " + "NOT NULL"
 | 
						|
	}
 | 
						|
 | 
						|
	return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
 | 
						|
		Q, fi.mi.table, Q,
 | 
						|
		Q, fi.column, Q,
 | 
						|
		typ, getColumnDefault(fi),
 | 
						|
	)
 | 
						|
}
 | 
						|
 | 
						|
// create database creation string.
 | 
						|
func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
 | 
						|
	if len(modelCache.cache) == 0 {
 | 
						|
		fmt.Println("no Model found, need register your model")
 | 
						|
		os.Exit(2)
 | 
						|
	}
 | 
						|
 | 
						|
	Q := al.DbBaser.TableQuote()
 | 
						|
	T := al.DbBaser.DbTypes()
 | 
						|
	sep := fmt.Sprintf("%s, %s", Q, Q)
 | 
						|
 | 
						|
	tableIndexes = make(map[string][]dbIndex)
 | 
						|
 | 
						|
	for _, mi := range modelCache.allOrdered() {
 | 
						|
		sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
 | 
						|
		sql += fmt.Sprintf("--  Table Structure for `%s`\n", mi.fullName)
 | 
						|
		sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
 | 
						|
 | 
						|
		sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
 | 
						|
 | 
						|
		columns := make([]string, 0, len(mi.fields.fieldsDB))
 | 
						|
 | 
						|
		sqlIndexes := [][]string{}
 | 
						|
 | 
						|
		for _, fi := range mi.fields.fieldsDB {
 | 
						|
 | 
						|
			column := fmt.Sprintf("    %s%s%s ", Q, fi.column, Q)
 | 
						|
			col := getColumnTyp(al, fi)
 | 
						|
 | 
						|
			if fi.auto {
 | 
						|
				switch al.Driver {
 | 
						|
				case DRSqlite, DRPostgres:
 | 
						|
					column += T["auto"]
 | 
						|
				default:
 | 
						|
					column += col + " " + T["auto"]
 | 
						|
				}
 | 
						|
			} else if fi.pk {
 | 
						|
				column += col + " " + T["pk"]
 | 
						|
			} else {
 | 
						|
				column += col
 | 
						|
 | 
						|
				if !fi.null {
 | 
						|
					column += " " + "NOT NULL"
 | 
						|
				}
 | 
						|
 | 
						|
				//if fi.initial.String() != "" {
 | 
						|
				//	column += " DEFAULT " + fi.initial.String()
 | 
						|
				//}
 | 
						|
 | 
						|
				// Append attribute DEFAULT
 | 
						|
				column += getColumnDefault(fi)
 | 
						|
 | 
						|
				if fi.unique {
 | 
						|
					column += " " + "UNIQUE"
 | 
						|
				}
 | 
						|
 | 
						|
				if fi.index {
 | 
						|
					sqlIndexes = append(sqlIndexes, []string{fi.column})
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			if strings.Contains(column, "%COL%") {
 | 
						|
				column = strings.Replace(column, "%COL%", fi.column, -1)
 | 
						|
			}
 | 
						|
 | 
						|
			if fi.description != "" && al.Driver != DRSqlite {
 | 
						|
				column += " " + fmt.Sprintf("COMMENT '%s'", fi.description)
 | 
						|
			}
 | 
						|
 | 
						|
			columns = append(columns, column)
 | 
						|
		}
 | 
						|
 | 
						|
		if mi.model != nil {
 | 
						|
			allnames := getTableUnique(mi.addrField)
 | 
						|
			if !mi.manual && len(mi.uniques) > 0 {
 | 
						|
				allnames = append(allnames, mi.uniques)
 | 
						|
			}
 | 
						|
			for _, names := range allnames {
 | 
						|
				cols := make([]string, 0, len(names))
 | 
						|
				for _, name := range names {
 | 
						|
					if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
 | 
						|
						cols = append(cols, fi.column)
 | 
						|
					} else {
 | 
						|
						panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName))
 | 
						|
					}
 | 
						|
				}
 | 
						|
				column := fmt.Sprintf("    UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q)
 | 
						|
				columns = append(columns, column)
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		sql += strings.Join(columns, ",\n")
 | 
						|
		sql += "\n)"
 | 
						|
 | 
						|
		if al.Driver == DRMySQL {
 | 
						|
			var engine string
 | 
						|
			if mi.model != nil {
 | 
						|
				engine = getTableEngine(mi.addrField)
 | 
						|
			}
 | 
						|
			if engine == "" {
 | 
						|
				engine = al.Engine
 | 
						|
			}
 | 
						|
			sql += " ENGINE=" + engine
 | 
						|
		}
 | 
						|
 | 
						|
		sql += ";"
 | 
						|
		sqls = append(sqls, sql)
 | 
						|
 | 
						|
		if mi.model != nil {
 | 
						|
			for _, names := range getTableIndex(mi.addrField) {
 | 
						|
				cols := make([]string, 0, len(names))
 | 
						|
				for _, name := range names {
 | 
						|
					if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
 | 
						|
						cols = append(cols, fi.column)
 | 
						|
					} else {
 | 
						|
						panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName))
 | 
						|
					}
 | 
						|
				}
 | 
						|
				sqlIndexes = append(sqlIndexes, cols)
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		for _, names := range sqlIndexes {
 | 
						|
			name := mi.table + "_" + strings.Join(names, "_")
 | 
						|
			cols := strings.Join(names, sep)
 | 
						|
			sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
 | 
						|
 | 
						|
			index := dbIndex{}
 | 
						|
			index.Table = mi.table
 | 
						|
			index.Name = name
 | 
						|
			index.SQL = sql
 | 
						|
 | 
						|
			tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
 | 
						|
		}
 | 
						|
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
 | 
						|
func getColumnDefault(fi *fieldInfo) string {
 | 
						|
	var (
 | 
						|
		v, t, d string
 | 
						|
	)
 | 
						|
 | 
						|
	// Skip default attribute if field is in relations
 | 
						|
	if fi.rel || fi.reverse {
 | 
						|
		return v
 | 
						|
	}
 | 
						|
 | 
						|
	t = " DEFAULT '%s' "
 | 
						|
 | 
						|
	// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
 | 
						|
	switch fi.fieldType {
 | 
						|
	case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
 | 
						|
		return v
 | 
						|
 | 
						|
	case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
 | 
						|
		TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
 | 
						|
		TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
 | 
						|
		TypeDecimalField:
 | 
						|
		t = " DEFAULT %s "
 | 
						|
		d = "0"
 | 
						|
	case TypeBooleanField:
 | 
						|
		t = " DEFAULT %s "
 | 
						|
		d = "FALSE"
 | 
						|
	case TypeJSONField, TypeJsonbField:
 | 
						|
		d = "{}"
 | 
						|
	}
 | 
						|
 | 
						|
	if fi.colDefault {
 | 
						|
		if !fi.initial.Exist() {
 | 
						|
			v = fmt.Sprintf(t, "")
 | 
						|
		} else {
 | 
						|
			v = fmt.Sprintf(t, fi.initial.String())
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		if !fi.null {
 | 
						|
			v = fmt.Sprintf(t, d)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return v
 | 
						|
}
 |