Merge pull request #4345 from AllenX2018/fix-issue-4344

Fix issue #4344
This commit is contained in:
Ming Deng 2020-12-16 00:25:55 +08:00 committed by GitHub
commit 884677f3a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 399 additions and 94 deletions

View File

@ -16,6 +16,7 @@ package cache
import ( import (
"context" "context"
"math"
"os" "os"
"sync" "sync"
"testing" "testing"
@ -46,11 +47,11 @@ func TestCacheIncr(t *testing.T) {
} }
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
bm, err := NewCache("memory", `{"interval":20}`) bm, err := NewCache("memory", `{"interval":1}`)
if err != nil { if err != nil {
t.Error("init err") t.Error("init err")
} }
timeoutDuration := 10 * time.Second timeoutDuration := 5 * time.Second
if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil {
t.Error("set Error", err) t.Error("set Error", err)
} }
@ -62,7 +63,7 @@ func TestCache(t *testing.T) {
t.Error("get err") t.Error("get err")
} }
time.Sleep(30 * time.Second) time.Sleep(7 * time.Second)
if res, _ := bm.IsExist(context.Background(), "astaxie"); res { if res, _ := bm.IsExist(context.Background(), "astaxie"); res {
t.Error("check err") t.Error("check err")
@ -73,7 +74,11 @@ func TestCache(t *testing.T) {
} }
// test different integer type for incr & decr // test different integer type for incr & decr
testMultiIncrDecr(t, bm, timeoutDuration) testMultiTypeIncrDecr(t, bm, timeoutDuration)
// test overflow of incr&decr
testIncrOverFlow(t, bm, timeoutDuration)
testDecrOverFlow(t, bm, timeoutDuration)
bm.Delete(context.Background(), "astaxie") bm.Delete(context.Background(), "astaxie")
if res, _ := bm.IsExist(context.Background(), "astaxie"); res { if res, _ := bm.IsExist(context.Background(), "astaxie"); res {
@ -142,7 +147,11 @@ func TestFileCache(t *testing.T) {
} }
// test different integer type for incr & decr // test different integer type for incr & decr
testMultiIncrDecr(t, bm, timeoutDuration) testMultiTypeIncrDecr(t, bm, timeoutDuration)
// test overflow of incr&decr
testIncrOverFlow(t, bm, timeoutDuration)
testDecrOverFlow(t, bm, timeoutDuration)
bm.Delete(context.Background(), "astaxie") bm.Delete(context.Background(), "astaxie")
if res, _ := bm.IsExist(context.Background(), "astaxie"); res { if res, _ := bm.IsExist(context.Background(), "astaxie"); res {
@ -196,7 +205,7 @@ func TestFileCache(t *testing.T) {
os.RemoveAll("cache") os.RemoveAll("cache")
} }
func testMultiIncrDecr(t *testing.T, c Cache, timeout time.Duration) { func testMultiTypeIncrDecr(t *testing.T, c Cache, timeout time.Duration) {
testIncrDecr(t, c, 1, 2, timeout) testIncrDecr(t, c, 1, 2, timeout)
testIncrDecr(t, c, int32(1), int32(2), timeout) testIncrDecr(t, c, int32(1), int32(2), timeout)
testIncrDecr(t, c, int64(1), int64(2), timeout) testIncrDecr(t, c, int64(1), int64(2), timeout)
@ -233,3 +242,45 @@ func testIncrDecr(t *testing.T, c Cache, beforeIncr interface{}, afterIncr inter
t.Error("Delete Error") t.Error("Delete Error")
} }
} }
func testIncrOverFlow(t *testing.T, c Cache, timeout time.Duration) {
var err error
ctx := context.Background()
key := "incKey"
// int64
if err = c.Put(ctx, key, int64(math.MaxInt64), timeout); err != nil {
t.Error("Put Error: ", err.Error())
return
}
defer func() {
if err = c.Delete(ctx, key); err != nil {
t.Errorf("Delete error: %s", err.Error())
}
}()
if err = c.Incr(ctx, key); err == nil {
t.Error("Incr error")
return
}
}
func testDecrOverFlow(t *testing.T, c Cache, timeout time.Duration) {
var err error
ctx := context.Background()
key := "decKey"
// int64
if err = c.Put(ctx, key, int64(math.MinInt64), timeout); err != nil {
t.Error("Put Error: ", err.Error())
return
}
defer func() {
if err = c.Delete(ctx, key); err != nil {
t.Errorf("Delete error: %s", err.Error())
}
}()
if err = c.Decr(ctx, key); err == nil {
t.Error("Decr error")
return
}
}

83
client/cache/calc_utils.go vendored Normal file
View File

@ -0,0 +1,83 @@
package cache
import (
"fmt"
"math"
)
func incr(originVal interface{}) (interface{}, error) {
switch val := originVal.(type) {
case int:
tmp := val + 1
if val > 0 && tmp < 0 {
return nil, fmt.Errorf("increment would overflow")
}
return tmp, nil
case int32:
if val == math.MaxInt32 {
return nil, fmt.Errorf("increment would overflow")
}
return val + 1, nil
case int64:
if val == math.MaxInt64 {
return nil, fmt.Errorf("increment would overflow")
}
return val + 1, nil
case uint:
tmp := val + 1
if tmp < val {
return nil, fmt.Errorf("increment would overflow")
}
return tmp, nil
case uint32:
if val == math.MaxUint32 {
return nil, fmt.Errorf("increment would overflow")
}
return val + 1, nil
case uint64:
if val == math.MaxUint64 {
return nil, fmt.Errorf("increment would overflow")
}
return val + 1, nil
default:
return nil, fmt.Errorf("item val is not (u)int (u)int32 (u)int64")
}
}
func decr(originVal interface{}) (interface{}, error) {
switch val := originVal.(type) {
case int:
tmp := val - 1
if val < 0 && tmp > 0 {
return nil, fmt.Errorf("decrement would overflow")
}
return tmp, nil
case int32:
if val == math.MinInt32 {
return nil, fmt.Errorf("decrement would overflow")
}
return val - 1, nil
case int64:
if val == math.MinInt64 {
return nil, fmt.Errorf("decrement would overflow")
}
return val - 1, nil
case uint:
if val == 0 {
return nil, fmt.Errorf("decrement would overflow")
}
return val - 1, nil
case uint32:
if val == 0 {
return nil, fmt.Errorf("increment would overflow")
}
return val - 1, nil
case uint64:
if val == 0 {
return nil, fmt.Errorf("increment would overflow")
}
return val - 1, nil
default:
return nil, fmt.Errorf("item val is not (u)int (u)int32 (u)int64")
}
}

241
client/cache/calc_utils_test.go vendored Normal file
View File

@ -0,0 +1,241 @@
package cache
import (
"math"
"strconv"
"testing"
)
func TestIncr(t *testing.T) {
// int
var originVal interface{} = int(1)
var updateVal interface{} = int(2)
val, err := incr(originVal)
if err != nil {
t.Errorf("incr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("incr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = incr(int(1 << (strconv.IntSize - 1) - 1))
if err == nil {
t.Error("incr failed")
return
}
// int32
originVal = int32(1)
updateVal = int32(2)
val, err = incr(originVal)
if err != nil {
t.Errorf("incr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("incr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = incr(int32(math.MaxInt32))
if err == nil {
t.Error("incr failed")
return
}
// int64
originVal = int64(1)
updateVal = int64(2)
val, err = incr(originVal)
if err != nil {
t.Errorf("incr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("incr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = incr(int64(math.MaxInt64))
if err == nil {
t.Error("incr failed")
return
}
// uint
originVal = uint(1)
updateVal = uint(2)
val, err = incr(originVal)
if err != nil {
t.Errorf("incr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("incr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = incr(uint(1 << (strconv.IntSize) - 1))
if err == nil {
t.Error("incr failed")
return
}
// uint32
originVal = uint32(1)
updateVal = uint32(2)
val, err = incr(originVal)
if err != nil {
t.Errorf("incr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("incr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = incr(uint32(math.MaxUint32))
if err == nil {
t.Error("incr failed")
return
}
// uint64
originVal = uint64(1)
updateVal = uint64(2)
val, err = incr(originVal)
if err != nil {
t.Errorf("incr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("incr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = incr(uint64(math.MaxUint64))
if err == nil {
t.Error("incr failed")
return
}
// other type
_, err = incr("string")
if err == nil {
t.Error("incr failed")
return
}
}
func TestDecr(t *testing.T) {
// int
var originVal interface{} = int(2)
var updateVal interface{} = int(1)
val, err := decr(originVal)
if err != nil {
t.Errorf("decr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("decr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = decr(int(-1 << (strconv.IntSize - 1)))
if err == nil {
t.Error("decr failed")
return
}
// int32
originVal = int32(2)
updateVal = int32(1)
val, err = decr(originVal)
if err != nil {
t.Errorf("decr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("decr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = decr(int32(math.MinInt32))
if err == nil {
t.Error("decr failed")
return
}
// int64
originVal = int64(2)
updateVal = int64(1)
val, err = decr(originVal)
if err != nil {
t.Errorf("decr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("decr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = decr(int64(math.MinInt64))
if err == nil {
t.Error("decr failed")
return
}
// uint
originVal = uint(2)
updateVal = uint(1)
val, err = decr(originVal)
if err != nil {
t.Errorf("decr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("decr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = decr(uint(0))
if err == nil {
t.Error("decr failed")
return
}
// uint32
originVal = uint32(2)
updateVal = uint32(1)
val, err = decr(originVal)
if err != nil {
t.Errorf("decr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("decr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = decr(uint32(0))
if err == nil {
t.Error("decr failed")
return
}
// uint64
originVal = uint64(2)
updateVal = uint64(1)
val, err = decr(originVal)
if err != nil {
t.Errorf("decr failed, err: %s", err.Error())
return
}
if val != updateVal {
t.Errorf("decr failed, expect %v, but %v actually", updateVal, val)
return
}
_, err = decr(uint64(0))
if err == nil {
t.Error("decr failed")
return
}
// other type
_, err = decr("string")
if err == nil {
t.Error("decr failed")
return
}
}

54
client/cache/file.go vendored
View File

@ -199,25 +199,12 @@ func (fc *FileCache) Incr(ctx context.Context, key string) error {
return err return err
} }
var res interface{} val, err := incr(data)
switch val := data.(type) { if err != nil {
case int: return err
res = val + 1
case int32:
res = val + 1
case int64:
res = val + 1
case uint:
res = val + 1
case uint32:
res = val + 1
case uint64:
res = val + 1
default:
return errors.Errorf("data is not (u)int (u)int32 (u)int64")
} }
return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry))
} }
// Decr decreases cached int value. // Decr decreases cached int value.
@ -227,37 +214,12 @@ func (fc *FileCache) Decr(ctx context.Context, key string) error {
return err return err
} }
var res interface{} val, err := decr(data)
switch val := data.(type) { if err != nil {
case int: return err
res = val - 1
case int32:
res = val - 1
case int64:
res = val - 1
case uint:
if val > 0 {
res = val - 1
} else {
return errors.New("data val is less than 0")
}
case uint32:
if val > 0 {
res = val - 1
} else {
return errors.New("data val is less than 0")
}
case uint64:
if val > 0 {
res = val - 1
} else {
return errors.New("data val is less than 0")
}
default:
return errors.Errorf("data is not (u)int (u)int32 (u)int64")
} }
return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry))
} }
// IsExist checks if value exists. // IsExist checks if value exists.

View File

@ -130,22 +130,12 @@ func (bc *MemoryCache) Incr(ctx context.Context, key string) error {
if !ok { if !ok {
return errors.New("key not exist") return errors.New("key not exist")
} }
switch val := itm.val.(type) {
case int: val, err := incr(itm.val)
itm.val = val + 1 if err != nil {
case int32: return err
itm.val = val + 1
case int64:
itm.val = val + 1
case uint:
itm.val = val + 1
case uint32:
itm.val = val + 1
case uint64:
itm.val = val + 1
default:
return errors.New("item val is not (u)int (u)int32 (u)int64")
} }
itm.val = val
return nil return nil
} }
@ -157,34 +147,12 @@ func (bc *MemoryCache) Decr(ctx context.Context, key string) error {
if !ok { if !ok {
return errors.New("key not exist") return errors.New("key not exist")
} }
switch val := itm.val.(type) {
case int: val, err := decr(itm.val)
itm.val = val - 1 if err != nil {
case int64: return err
itm.val = val - 1
case int32:
itm.val = val - 1
case uint:
if val > 0 {
itm.val = val - 1
} else {
return errors.New("item val is less than 0")
}
case uint32:
if val > 0 {
itm.val = val - 1
} else {
return errors.New("item val is less than 0")
}
case uint64:
if val > 0 {
itm.val = val - 1
} else {
return errors.New("item val is less than 0")
}
default:
return errors.New("item val is not int int64 int32")
} }
itm.val = val
return nil return nil
} }