diff --git a/client/cache/cache_test.go b/client/cache/cache_test.go index 85f83fc4..c02bba69 100644 --- a/client/cache/cache_test.go +++ b/client/cache/cache_test.go @@ -16,6 +16,7 @@ package cache import ( "context" + "math" "os" "sync" "testing" @@ -46,11 +47,11 @@ func TestCacheIncr(t *testing.T) { } func TestCache(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) + bm, err := NewCache("memory", `{"interval":1}`) if err != nil { t.Error("init err") } - timeoutDuration := 10 * time.Second + timeoutDuration := 5 * time.Second if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { t.Error("set Error", err) } @@ -62,7 +63,7 @@ func TestCache(t *testing.T) { t.Error("get err") } - time.Sleep(30 * time.Second) + time.Sleep(7 * time.Second) if res, _ := bm.IsExist(context.Background(), "astaxie"); res { t.Error("check err") @@ -73,7 +74,11 @@ func TestCache(t *testing.T) { } // test different integer type for incr & decr - testMultiIncrDecr(t, bm, timeoutDuration) + testMultiTypeIncrDecr(t, bm, timeoutDuration) + + // test overflow of incr&decr + testIncrOverFlow(t, bm, timeoutDuration) + testDecrOverFlow(t, bm, timeoutDuration) bm.Delete(context.Background(), "astaxie") if res, _ := bm.IsExist(context.Background(), "astaxie"); res { @@ -142,7 +147,11 @@ func TestFileCache(t *testing.T) { } // test different integer type for incr & decr - testMultiIncrDecr(t, bm, timeoutDuration) + testMultiTypeIncrDecr(t, bm, timeoutDuration) + + // test overflow of incr&decr + testIncrOverFlow(t, bm, timeoutDuration) + testDecrOverFlow(t, bm, timeoutDuration) bm.Delete(context.Background(), "astaxie") if res, _ := bm.IsExist(context.Background(), "astaxie"); res { @@ -196,7 +205,7 @@ func TestFileCache(t *testing.T) { os.RemoveAll("cache") } -func testMultiIncrDecr(t *testing.T, c Cache, timeout time.Duration) { +func testMultiTypeIncrDecr(t *testing.T, c Cache, timeout time.Duration) { testIncrDecr(t, c, 1, 2, timeout) testIncrDecr(t, c, int32(1), int32(2), timeout) testIncrDecr(t, c, int64(1), int64(2), timeout) @@ -233,3 +242,45 @@ func testIncrDecr(t *testing.T, c Cache, beforeIncr interface{}, afterIncr inter t.Error("Delete Error") } } + +func testIncrOverFlow(t *testing.T, c Cache, timeout time.Duration) { + var err error + ctx := context.Background() + key := "incKey" + + // int64 + if err = c.Put(ctx, key, int64(math.MaxInt64), timeout); err != nil { + t.Error("Put Error: ", err.Error()) + return + } + defer func() { + if err = c.Delete(ctx, key); err != nil { + t.Errorf("Delete error: %s", err.Error()) + } + }() + if err = c.Incr(ctx, key); err == nil { + t.Error("Incr error") + return + } +} + +func testDecrOverFlow(t *testing.T, c Cache, timeout time.Duration) { + var err error + ctx := context.Background() + key := "decKey" + + // int64 + if err = c.Put(ctx, key, int64(math.MinInt64), timeout); err != nil { + t.Error("Put Error: ", err.Error()) + return + } + defer func() { + if err = c.Delete(ctx, key); err != nil { + t.Errorf("Delete error: %s", err.Error()) + } + }() + if err = c.Decr(ctx, key); err == nil { + t.Error("Decr error") + return + } +} diff --git a/client/cache/calc_utils.go b/client/cache/calc_utils.go new file mode 100644 index 00000000..91d0974b --- /dev/null +++ b/client/cache/calc_utils.go @@ -0,0 +1,83 @@ +package cache + +import ( + "fmt" + "math" +) + +func incr(originVal interface{}) (interface{}, error) { + switch val := originVal.(type) { + case int: + tmp := val + 1 + if val > 0 && tmp < 0 { + return nil, fmt.Errorf("increment would overflow") + } + return tmp, nil + case int32: + if val == math.MaxInt32 { + return nil, fmt.Errorf("increment would overflow") + } + return val + 1, nil + case int64: + if val == math.MaxInt64 { + return nil, fmt.Errorf("increment would overflow") + } + return val + 1, nil + case uint: + tmp := val + 1 + if tmp < val { + return nil, fmt.Errorf("increment would overflow") + } + return tmp, nil + case uint32: + if val == math.MaxUint32 { + return nil, fmt.Errorf("increment would overflow") + } + return val + 1, nil + case uint64: + if val == math.MaxUint64 { + return nil, fmt.Errorf("increment would overflow") + } + return val + 1, nil + default: + return nil, fmt.Errorf("item val is not (u)int (u)int32 (u)int64") + } +} + +func decr(originVal interface{}) (interface{}, error) { + switch val := originVal.(type) { + case int: + tmp := val - 1 + if val < 0 && tmp > 0 { + return nil, fmt.Errorf("decrement would overflow") + } + return tmp, nil + case int32: + if val == math.MinInt32 { + return nil, fmt.Errorf("decrement would overflow") + } + return val - 1, nil + case int64: + if val == math.MinInt64 { + return nil, fmt.Errorf("decrement would overflow") + } + return val - 1, nil + case uint: + if val == 0 { + return nil, fmt.Errorf("decrement would overflow") + } + return val - 1, nil + case uint32: + if val == 0 { + return nil, fmt.Errorf("increment would overflow") + } + return val - 1, nil + case uint64: + if val == 0 { + return nil, fmt.Errorf("increment would overflow") + } + return val - 1, nil + default: + return nil, fmt.Errorf("item val is not (u)int (u)int32 (u)int64") + } +} \ No newline at end of file diff --git a/client/cache/calc_utils_test.go b/client/cache/calc_utils_test.go new file mode 100644 index 00000000..b98e71de --- /dev/null +++ b/client/cache/calc_utils_test.go @@ -0,0 +1,241 @@ +package cache + +import ( + "math" + "strconv" + "testing" +) + +func TestIncr(t *testing.T) { + // int + var originVal interface{} = int(1) + var updateVal interface{} = int(2) + val, err := incr(originVal) + if err != nil { + t.Errorf("incr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("incr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = incr(int(1 << (strconv.IntSize - 1) - 1)) + if err == nil { + t.Error("incr failed") + return + } + + // int32 + originVal = int32(1) + updateVal = int32(2) + val, err = incr(originVal) + if err != nil { + t.Errorf("incr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("incr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = incr(int32(math.MaxInt32)) + if err == nil { + t.Error("incr failed") + return + } + + // int64 + originVal = int64(1) + updateVal = int64(2) + val, err = incr(originVal) + if err != nil { + t.Errorf("incr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("incr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = incr(int64(math.MaxInt64)) + if err == nil { + t.Error("incr failed") + return + } + + // uint + originVal = uint(1) + updateVal = uint(2) + val, err = incr(originVal) + if err != nil { + t.Errorf("incr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("incr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = incr(uint(1 << (strconv.IntSize) - 1)) + if err == nil { + t.Error("incr failed") + return + } + + // uint32 + originVal = uint32(1) + updateVal = uint32(2) + val, err = incr(originVal) + if err != nil { + t.Errorf("incr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("incr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = incr(uint32(math.MaxUint32)) + if err == nil { + t.Error("incr failed") + return + } + + // uint64 + originVal = uint64(1) + updateVal = uint64(2) + val, err = incr(originVal) + if err != nil { + t.Errorf("incr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("incr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = incr(uint64(math.MaxUint64)) + if err == nil { + t.Error("incr failed") + return + } + + // other type + _, err = incr("string") + if err == nil { + t.Error("incr failed") + return + } +} + +func TestDecr(t *testing.T) { + // int + var originVal interface{} = int(2) + var updateVal interface{} = int(1) + val, err := decr(originVal) + if err != nil { + t.Errorf("decr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("decr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = decr(int(-1 << (strconv.IntSize - 1))) + if err == nil { + t.Error("decr failed") + return + } + + // int32 + originVal = int32(2) + updateVal = int32(1) + val, err = decr(originVal) + if err != nil { + t.Errorf("decr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("decr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = decr(int32(math.MinInt32)) + if err == nil { + t.Error("decr failed") + return + } + + // int64 + originVal = int64(2) + updateVal = int64(1) + val, err = decr(originVal) + if err != nil { + t.Errorf("decr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("decr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = decr(int64(math.MinInt64)) + if err == nil { + t.Error("decr failed") + return + } + + // uint + originVal = uint(2) + updateVal = uint(1) + val, err = decr(originVal) + if err != nil { + t.Errorf("decr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("decr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = decr(uint(0)) + if err == nil { + t.Error("decr failed") + return + } + + // uint32 + originVal = uint32(2) + updateVal = uint32(1) + val, err = decr(originVal) + if err != nil { + t.Errorf("decr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("decr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = decr(uint32(0)) + if err == nil { + t.Error("decr failed") + return + } + + // uint64 + originVal = uint64(2) + updateVal = uint64(1) + val, err = decr(originVal) + if err != nil { + t.Errorf("decr failed, err: %s", err.Error()) + return + } + if val != updateVal { + t.Errorf("decr failed, expect %v, but %v actually", updateVal, val) + return + } + _, err = decr(uint64(0)) + if err == nil { + t.Error("decr failed") + return + } + + // other type + _, err = decr("string") + if err == nil { + t.Error("decr failed") + return + } +} \ No newline at end of file diff --git a/client/cache/file.go b/client/cache/file.go index 043c4650..87e14b6c 100644 --- a/client/cache/file.go +++ b/client/cache/file.go @@ -199,25 +199,12 @@ func (fc *FileCache) Incr(ctx context.Context, key string) error { return err } - var res interface{} - switch val := data.(type) { - case int: - res = val + 1 - case int32: - res = val + 1 - case int64: - res = val + 1 - case uint: - res = val + 1 - case uint32: - res = val + 1 - case uint64: - res = val + 1 - default: - return errors.Errorf("data is not (u)int (u)int32 (u)int64") + val, err := incr(data) + if err != nil { + return err } - return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) + return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry)) } // Decr decreases cached int value. @@ -227,37 +214,12 @@ func (fc *FileCache) Decr(ctx context.Context, key string) error { return err } - var res interface{} - switch val := data.(type) { - case int: - res = val - 1 - case int32: - res = val - 1 - case int64: - res = val - 1 - case uint: - if val > 0 { - res = val - 1 - } else { - return errors.New("data val is less than 0") - } - case uint32: - if val > 0 { - res = val - 1 - } else { - return errors.New("data val is less than 0") - } - case uint64: - if val > 0 { - res = val - 1 - } else { - return errors.New("data val is less than 0") - } - default: - return errors.Errorf("data is not (u)int (u)int32 (u)int64") + val, err := decr(data) + if err != nil { + return err } - return fc.Put(context.Background(), key, res, time.Duration(fc.EmbedExpiry)) + return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry)) } // IsExist checks if value exists. diff --git a/client/cache/memory.go b/client/cache/memory.go index 28c7d980..850326ad 100644 --- a/client/cache/memory.go +++ b/client/cache/memory.go @@ -130,22 +130,12 @@ func (bc *MemoryCache) Incr(ctx context.Context, key string) error { if !ok { return errors.New("key not exist") } - switch val := itm.val.(type) { - case int: - itm.val = val + 1 - case int32: - itm.val = val + 1 - case int64: - itm.val = val + 1 - case uint: - itm.val = val + 1 - case uint32: - itm.val = val + 1 - case uint64: - itm.val = val + 1 - default: - return errors.New("item val is not (u)int (u)int32 (u)int64") + + val, err := incr(itm.val) + if err != nil { + return err } + itm.val = val return nil } @@ -157,34 +147,12 @@ func (bc *MemoryCache) Decr(ctx context.Context, key string) error { if !ok { return errors.New("key not exist") } - switch val := itm.val.(type) { - case int: - itm.val = val - 1 - case int64: - itm.val = val - 1 - case int32: - itm.val = val - 1 - case uint: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint32: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint64: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - default: - return errors.New("item val is not int int64 int32") + + val, err := decr(itm.val) + if err != nil { + return err } + itm.val = val return nil }