beego/client/orm/ddl.go

196 lines
5.1 KiB
Go

// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"errors"
"fmt"
"strings"
imodels "github.com/beego/beego/v2/client/orm/internal/models"
)
// getDbDropSQL Get database scheme drop sql queries
func getDbDropSQL(mc *imodels.ModelCache, al *alias) (queries []string, err error) {
if mc.Empty() {
err = errors.New("no Model found, need Register your model")
return
}
Q := al.DbBaser.TableQuote()
for _, mi := range mc.AllOrdered() {
queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.Table, Q))
}
return queries, nil
}
// getDbCreateSQL Get database scheme creation sql queries
func getDbCreateSQL(mc *imodels.ModelCache, al *alias) (queries []string, tableIndexes map[string][]dbIndex, err error) {
if mc.Empty() {
err = errors.New("no Model found, need Register your model")
return
}
Q := al.DbBaser.TableQuote()
T := al.DbBaser.DbTypes()
sep := fmt.Sprintf("%s, %s", Q, Q)
tableIndexes = make(map[string][]dbIndex)
for _, mi := range mc.AllOrdered() {
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.FullName)
sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.Table, Q)
columns := make([]string, 0, len(mi.Fields.FieldsDB))
sqlIndexes := [][]string{}
var commentIndexes []int // store comment indexes for postgres
for i, fi := range mi.Fields.FieldsDB {
column := fmt.Sprintf(" %s%s%s ", Q, fi.Column, Q)
col := getColumnTyp(al, fi)
if fi.Auto {
switch al.Driver {
case DRSqlite, DRPostgres:
column += T["auto"]
default:
column += col + " " + T["auto"]
}
} else if fi.Pk {
column += col + " " + T["pk"]
} else {
column += col
if !fi.Null {
column += " " + "NOT NULL"
}
// if fi.initial.String() != "" {
// column += " DEFAULT " + fi.initial.String()
// }
// Append attribute DEFAULT
column += getColumnDefault(fi)
if fi.Unique {
column += " " + "UNIQUE"
}
if fi.Index {
sqlIndexes = append(sqlIndexes, []string{fi.Column})
}
}
if strings.Contains(column, "%COL%") {
column = strings.Replace(column, "%COL%", fi.Column, -1)
}
if fi.Description != "" && al.Driver != DRSqlite {
if al.Driver == DRPostgres {
commentIndexes = append(commentIndexes, i)
} else {
column += " " + fmt.Sprintf("COMMENT '%s'", fi.Description)
}
}
columns = append(columns, column)
}
if mi.Model != nil {
allnames := imodels.GetTableUnique(mi.AddrField)
if !mi.Manual && len(mi.Uniques) > 0 {
allnames = append(allnames, mi.Uniques)
}
for _, names := range allnames {
cols := make([]string, 0, len(names))
for _, name := range names {
if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol {
cols = append(cols, fi.Column)
} else {
panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.FullName))
}
}
column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q)
columns = append(columns, column)
}
}
sql += strings.Join(columns, ",\n")
sql += "\n)"
if al.Driver == DRMySQL {
var engine string
if mi.Model != nil {
engine = imodels.GetTableEngine(mi.AddrField)
}
if engine == "" {
engine = al.Engine
}
sql += " ENGINE=" + engine
}
sql += ";"
if al.Driver == DRPostgres && len(commentIndexes) > 0 {
// append comments for postgres only
for _, index := range commentIndexes {
sql += fmt.Sprintf("\nCOMMENT ON COLUMN %s%s%s.%s%s%s is '%s';",
Q,
mi.Table,
Q,
Q,
mi.Fields.FieldsDB[index].Column,
Q,
mi.Fields.FieldsDB[index].Description)
}
}
queries = append(queries, sql)
if mi.Model != nil {
for _, names := range imodels.GetTableIndex(mi.AddrField) {
cols := make([]string, 0, len(names))
for _, name := range names {
if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol {
cols = append(cols, fi.Column)
} else {
panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.FullName))
}
}
sqlIndexes = append(sqlIndexes, cols)
}
}
for _, names := range sqlIndexes {
name := mi.Table + "_" + strings.Join(names, "_")
cols := strings.Join(names, sep)
sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.Table, Q, Q, cols, Q)
index := dbIndex{}
index.Table = mi.Table
index.Name = name
index.SQL = sql
tableIndexes[mi.Table] = append(tableIndexes[mi.Table], index)
}
}
return
}