add template functions eq,lt to support uint and int compare.

This commit is contained in:
letu 2021-05-13 16:16:02 +08:00
parent 511714d616
commit 2fdda76882
2 changed files with 72 additions and 34 deletions

View File

@ -607,10 +607,18 @@ func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
truth := false
if k1 != k2 { if k1 != k2 {
// Special case: Can compare integer values regardless of type's sign.
switch {
case k1 == intKind && k2 == uintKind:
truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
case k1 == uintKind && k2 == intKind:
truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
default:
return false, errBadComparison return false, errBadComparison
} }
truth := false } else {
switch k1 { switch k1 {
case boolKind: case boolKind:
truth = v1.Bool() == v2.Bool() truth = v1.Bool() == v2.Bool()
@ -627,6 +635,7 @@ func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
default: default:
panic("invalid kind") panic("invalid kind")
} }
}
if truth { if truth {
return true, nil return true, nil
} }
@ -653,10 +662,18 @@ func lt(arg1, arg2 interface{}) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
truth := false
if k1 != k2 { if k1 != k2 {
// Special case: Can compare integer values regardless of type's sign.
switch {
case k1 == intKind && k2 == uintKind:
truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
case k1 == uintKind && k2 == intKind:
truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
default:
return false, errBadComparison return false, errBadComparison
} }
truth := false } else {
switch k1 { switch k1 {
case boolKind, complexKind: case boolKind, complexKind:
return false, errBadComparisonType return false, errBadComparisonType
@ -671,6 +688,7 @@ func lt(arg1, arg2 interface{}) (bool, error) {
default: default:
panic("invalid kind") panic("invalid kind")
} }
}
return truth, nil return truth, nil
} }

View File

@ -378,3 +378,23 @@ func TestMapGet(t *testing.T) {
t.Errorf("Error happens %v", err) t.Errorf("Error happens %v", err)
} }
} }
func Test_eq(t *testing.T) {
var a uint = 1
var b int32 = 1
if res, err := eq(a, b); err != nil {
if !res {
t.Error("uint(1) and int32(1) should not be eq")
}
}
}
func Test_lt(t *testing.T) {
var a uint = 1
var b int32 = 2
if res, err := lt(a, b); err != nil {
if !res {
t.Error("uint(1) not lt int32(2)")
}
}
}