THIS IS A TEST INSTANCE ONLY! REPOSITORIES CAN BE DELETED AT ANY TIME!

Browse Source

Fix wrong dbmetas (#1442)

* add tests for db metas

* add more tests

* fix bug on mssql
tags/v0.7.9
Lunny Xiao GitHub 1 month ago
parent
commit
c5ee68faa1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 99 additions and 32 deletions
  1. +4
    -3
      dialect_mssql.go
  2. +39
    -28
      dialect_sqlite3.go
  3. +49
    -0
      tag_test.go
  4. +7
    -1
      xorm_test.go

+ 4
- 3
dialect_mssql.go View File

@@ -340,7 +340,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
replace(replace(isnull(c.text,''),'(',''),')','') as vdefault,
ISNULL(i.is_primary_key, 0)
ISNULL(i.is_primary_key, 0), a.is_identity as is_identity
from sys.columns a
left join sys.types b on a.user_type_id=b.user_type_id
left join sys.syscomments c on a.default_object_id=c.id
@@ -362,8 +362,8 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
for rows.Next() {
var name, ctype, vdefault string
var maxLen, precision, scale int
var nullable, isPK, defaultIsNull bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK)
var nullable, isPK, defaultIsNull, isIncrement bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement)
if err != nil {
return nil, nil, err
}
@@ -377,6 +377,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
col.Default = vdefault
}
col.IsPrimaryKey = isPK
col.IsAutoIncrement = isIncrement
ct := strings.ToUpper(ctype)
if ct == "DECIMAL" {
col.Length = precision


+ 39
- 28
dialect_sqlite3.go View File

@@ -298,6 +298,40 @@ func splitColStr(colStr string) []string {
return results
}

func parseString(colStr string) (*core.Column, error) {
fields := splitColStr(colStr)
col := new(core.Column)
col.Indexes = make(map[string]int)
col.Nullable = true
col.DefaultIsEmpty = true

for idx, field := range fields {
if idx == 0 {
col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`)
continue
} else if idx == 1 {
col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0}
continue
}
switch field {
case "PRIMARY":
col.IsPrimaryKey = true
case "AUTOINCREMENT":
col.IsAutoIncrement = true
case "NULL":
if fields[idx-1] == "NOT" {
col.Nullable = false
} else {
col.Nullable = true
}
case "DEFAULT":
col.Default = fields[idx+1]
col.DefaultIsEmpty = false
}
}
return col, nil
}

func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
@@ -327,6 +361,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
colCreates := reg.FindAllString(name[nStart+1:nEnd], -1)
cols := make(map[string]*core.Column)
colSeq := make([]string, 0)

for _, colStr := range colCreates {
reg = regexp.MustCompile(`,\s`)
colStr = reg.ReplaceAllString(colStr, ",")
@@ -343,35 +378,11 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
continue
}

fields := splitColStr(colStr)
col := new(core.Column)
col.Indexes = make(map[string]int)
col.Nullable = true
col.DefaultIsEmpty = true

for idx, field := range fields {
if idx == 0 {
col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`)
continue
} else if idx == 1 {
col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0}
}
switch field {
case "PRIMARY":
col.IsPrimaryKey = true
case "AUTOINCREMENT":
col.IsAutoIncrement = true
case "NULL":
if fields[idx-1] == "NOT" {
col.Nullable = false
} else {
col.Nullable = true
}
case "DEFAULT":
col.Default = fields[idx+1]
col.DefaultIsEmpty = false
}
col, err := parseString(colStr)
if err != nil {
return colSeq, cols, err
}

cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}


+ 49
- 0
tag_test.go View File

@@ -549,3 +549,52 @@ func TestSplitTag(t *testing.T) {
}
}
}

func TestTagAutoIncr(t *testing.T) {
assert.NoError(t, prepareEngine())

type TagAutoIncr struct {
Id int64
Name string
}

assertSync(t, new(TagAutoIncr))

tables, err := testEngine.DBMetas()
assert.NoError(t, err)
assert.EqualValues(t, 1, len(tables))
assert.EqualValues(t, tableMapper.Obj2Table("TagAutoIncr"), tables[0].Name)
col := tables[0].GetColumn(colMapper.Obj2Table("Id"))
assert.NotNil(t, col)
assert.True(t, col.IsPrimaryKey)
assert.True(t, col.IsAutoIncrement)

col2 := tables[0].GetColumn(colMapper.Obj2Table("Name"))
assert.NotNil(t, col2)
assert.False(t, col2.IsPrimaryKey)
assert.False(t, col2.IsAutoIncrement)
}

func TestTagPrimarykey(t *testing.T) {
assert.NoError(t, prepareEngine())
type TagPrimaryKey struct {
Id int64 `xorm:"pk"`
Name string `xorm:"VARCHAR(20) pk"`
}

assertSync(t, new(TagPrimaryKey))

tables, err := testEngine.DBMetas()
assert.NoError(t, err)
assert.EqualValues(t, 1, len(tables))
assert.EqualValues(t, tableMapper.Obj2Table("TagPrimaryKey"), tables[0].Name)
col := tables[0].GetColumn(colMapper.Obj2Table("Id"))
assert.NotNil(t, col)
assert.True(t, col.IsPrimaryKey)
assert.False(t, col.IsAutoIncrement)

col2 := tables[0].GetColumn(colMapper.Obj2Table("Name"))
assert.NotNil(t, col2)
assert.True(t, col2.IsPrimaryKey)
assert.False(t, col2.IsAutoIncrement)
}

+ 7
- 1
xorm_test.go View File

@@ -15,10 +15,10 @@ import (

_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
"xorm.io/core"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
"xorm.io/core"
)

var (
@@ -35,6 +35,9 @@ var (
splitter = flag.String("splitter", ";", "the splitter on connstr for cluster")
schema = flag.String("schema", "", "specify the schema")
ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb")

tableMapper core.IMapper
colMapper core.IMapper
)

func createEngine(dbType, connStr string) error {
@@ -122,6 +125,9 @@ func createEngine(dbType, connStr string) error {
}
}

tableMapper = testEngine.GetTableMapper()
colMapper = testEngine.GetColumnMapper()

tables, err := testEngine.DBMetas()
if err != nil {
return err


Loading…
Cancel
Save