add template functions eq,lt to support uint and int compare.

This commit is contained in:
letu 2021-05-13 16:16:02 +08:00
parent 511714d616
commit 2fdda76882
2 changed files with 72 additions and 34 deletions

View File

@ -607,25 +607,34 @@ func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
if k1 != k2 {
return false, errBadComparison
}
truth := false truth := false
switch k1 { if k1 != k2 {
case boolKind: // Special case: Can compare integer values regardless of type's sign.
truth = v1.Bool() == v2.Bool() switch {
case complexKind: case k1 == intKind && k2 == uintKind:
truth = v1.Complex() == v2.Complex() truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
case floatKind: case k1 == uintKind && k2 == intKind:
truth = v1.Float() == v2.Float() truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
case intKind: default:
truth = v1.Int() == v2.Int() return false, errBadComparison
case stringKind: }
truth = v1.String() == v2.String() } else {
case uintKind: switch k1 {
truth = v1.Uint() == v2.Uint() case boolKind:
default: truth = v1.Bool() == v2.Bool()
panic("invalid kind") case complexKind:
truth = v1.Complex() == v2.Complex()
case floatKind:
truth = v1.Float() == v2.Float()
case intKind:
truth = v1.Int() == v2.Int()
case stringKind:
truth = v1.String() == v2.String()
case uintKind:
truth = v1.Uint() == v2.Uint()
default:
panic("invalid kind")
}
} }
if truth { if truth {
return true, nil return true, nil
@ -653,23 +662,32 @@ func lt(arg1, arg2 interface{}) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
if k1 != k2 {
return false, errBadComparison
}
truth := false truth := false
switch k1 { if k1 != k2 {
case boolKind, complexKind: // Special case: Can compare integer values regardless of type's sign.
return false, errBadComparisonType switch {
case floatKind: case k1 == intKind && k2 == uintKind:
truth = v1.Float() < v2.Float() truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
case intKind: case k1 == uintKind && k2 == intKind:
truth = v1.Int() < v2.Int() truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
case stringKind: default:
truth = v1.String() < v2.String() return false, errBadComparison
case uintKind: }
truth = v1.Uint() < v2.Uint() } else {
default: switch k1 {
panic("invalid kind") case boolKind, complexKind:
return false, errBadComparisonType
case floatKind:
truth = v1.Float() < v2.Float()
case intKind:
truth = v1.Int() < v2.Int()
case stringKind:
truth = v1.String() < v2.String()
case uintKind:
truth = v1.Uint() < v2.Uint()
default:
panic("invalid kind")
}
} }
return truth, nil return truth, nil
} }

View File

@ -378,3 +378,23 @@ func TestMapGet(t *testing.T) {
t.Errorf("Error happens %v", err) t.Errorf("Error happens %v", err)
} }
} }
func Test_eq(t *testing.T) {
var a uint = 1
var b int32 = 1
if res, err := eq(a, b); err != nil {
if !res {
t.Error("uint(1) and int32(1) should not be eq")
}
}
}
func Test_lt(t *testing.T) {
var a uint = 1
var b int32 = 2
if res, err := lt(a, b); err != nil {
if !res {
t.Error("uint(1) not lt int32(2)")
}
}
}