session: Support SessionReleaseIfPresent to avoid concurrent problem (#5685)

This commit is contained in:
Alan Xu 2024-07-31 21:44:47 +08:00 committed by GitHub
parent bdb7e7a904
commit 06d869664a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 346 additions and 165 deletions

View File

@ -51,7 +51,7 @@ func TestConn(t *testing.T) {
func TestReconnect(t *testing.T) { func TestReconnect(t *testing.T) {
// Setup connection listener // Setup connection listener
newConns := make(chan net.Conn) newConns := make(chan net.Conn)
connNum := 2 connNum := 3
ln, err := net.Listen("tcp", ":6002") ln, err := net.Listen("tcp", ":6002")
if err != nil { if err != nil {
t.Log("Error listening:", err.Error()) t.Log("Error listening:", err.Error())

View File

@ -16,13 +16,14 @@ package mock
import ( import (
"net/http" "net/http"
"net/http/httptest"
beegoCtx "github.com/beego/beego/v2/server/web/context" beegoCtx "github.com/beego/beego/v2/server/web/context"
) )
func NewMockContext(req *http.Request) (*beegoCtx.Context, *HttpResponse) { func NewMockContext(req *http.Request) (*beegoCtx.Context, *httptest.ResponseRecorder) {
ctx := beegoCtx.NewContext() ctx := beegoCtx.NewContext()
resp := NewMockHttpResponse() resp := httptest.NewRecorder()
ctx.Reset(resp, req) ctx.Reset(resp, req)
return ctx, resp return ctx, resp
} }

View File

@ -31,7 +31,7 @@ type TestController struct {
} }
func TestMockContext(t *testing.T) { func TestMockContext(t *testing.T) {
req, err := http.NewRequest("GET", "http://localhost:8080/hello?name=tom", bytes.NewReader([]byte{})) req, err := http.NewRequest("GET", "https://localhost:8080/hello?name=tom", bytes.NewReader([]byte{}))
assert.Nil(t, err) assert.Nil(t, err)
ctx, resp := NewMockContext(req) ctx, resp := NewMockContext(req)
ctrl := &TestController{ ctrl := &TestController{
@ -40,7 +40,7 @@ func TestMockContext(t *testing.T) {
}, },
} }
ctrl.HelloWorld() ctrl.HelloWorld()
result := resp.BodyToString() result := resp.Body.String()
assert.Equal(t, "name=tom", result) assert.Equal(t, "name=tom", result)
} }

View File

@ -1,69 +0,0 @@
// Copyright 2021 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mock
import (
"encoding/json"
"net/http"
)
// HttpResponse mock response, which should be used in tests
type HttpResponse struct {
body []byte
header http.Header
StatusCode int
}
// NewMockHttpResponse you should only use this in your test code
func NewMockHttpResponse() *HttpResponse {
return &HttpResponse{
body: make([]byte, 0),
header: make(http.Header),
}
}
// Header return headers
func (m *HttpResponse) Header() http.Header {
return m.header
}
// Write append the body
func (m *HttpResponse) Write(bytes []byte) (int, error) {
m.body = append(m.body, bytes...)
return len(bytes), nil
}
// WriteHeader set the status code
func (m *HttpResponse) WriteHeader(statusCode int) {
m.StatusCode = statusCode
}
// JsonUnmarshal convert the body to object
func (m *HttpResponse) JsonUnmarshal(value interface{}) error {
return json.Unmarshal(m.body, value)
}
// BodyToString return the body as the string
func (m *HttpResponse) BodyToString() string {
return string(m.body)
}
// Reset will reset the status to init status
// Usually, you want to reuse this instance you may need to call Reset
func (m *HttpResponse) Reset() {
m.body = make([]byte, 0)
m.header = make(http.Header)
m.StatusCode = 0
}

View File

@ -106,7 +106,12 @@ func (s *SessionStore) SessionID(ctx context.Context) string {
} }
// SessionRelease do nothing // SessionRelease do nothing
func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (s *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) {
// Support in the future if necessary, now I think we don't need to implement this
}
// SessionReleaseIfPresent do nothing
func (*SessionStore) SessionReleaseIfPresent(_ context.Context, _ http.ResponseWriter) {
// Support in the future if necessary, now I think we don't need to implement this // Support in the future if necessary, now I think we don't need to implement this
} }

View File

@ -37,12 +37,12 @@ func TestSessionProvider(t *testing.T) {
}, },
} }
ctrl.HelloSession() ctrl.HelloSession()
result := resp.BodyToString() result := resp.Body.String()
assert.Equal(t, "set", result) assert.Equal(t, "set", result)
resp.Reset() resp.Body.Reset()
ctrl.HelloSessionName() ctrl.HelloSessionName()
result = resp.BodyToString() result = resp.Body.String()
assert.Equal(t, "Tom", result) assert.Equal(t, "Tom", result)
} }

View File

@ -64,7 +64,7 @@ type Provider struct {
b *couchbase.Bucket b *couchbase.Bucket
} }
// Set value to couchabse session // Set value to couchbase session
func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error { func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
cs.lock.Lock() cs.lock.Lock()
defer cs.lock.Unlock() defer cs.lock.Unlock()
@ -72,7 +72,7 @@ func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error {
return nil return nil
} }
// Get value from couchabse session // Get value from couchbase session
func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} {
cs.lock.RLock() cs.lock.RLock()
defer cs.lock.RUnlock() defer cs.lock.RUnlock()
@ -104,7 +104,7 @@ func (cs *SessionStore) SessionID(context.Context) string {
} }
// SessionRelease Write couchbase session with Gob string // SessionRelease Write couchbase session with Gob string
func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (cs *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) {
defer cs.b.Close() defer cs.b.Close()
cs.lock.RLock() cs.lock.RLock()
values := cs.values values := cs.values
@ -117,6 +117,12 @@ func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
cs.b.Set(cs.sid, int(cs.maxlifetime), bo) cs.b.Set(cs.sid, int(cs.maxlifetime), bo)
} }
// SessionReleaseIfPresent is not supported now.
// If we want to use couchbase, we may refactor the code to use couchbase collection.
func (cs *SessionStore) SessionReleaseIfPresent(c context.Context, w http.ResponseWriter) {
cs.SessionRelease(c, w)
}
func (cp *Provider) getBucket() *couchbase.Bucket { func (cp *Provider) getBucket() *couchbase.Bucket {
c, err := couchbase.Connect(cp.SavePath) c, err := couchbase.Connect(cp.SavePath)
if err != nil { if err != nil {
@ -195,7 +201,7 @@ func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store,
} }
// SessionExist Check couchbase session exist. // SessionExist Check couchbase session exist.
// it checkes sid exist or not. // it checks sid exist or not.
func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
cp.b = cp.getBucket() cp.b = cp.getBucket()
defer cp.b.Close() defer cp.b.Close()

View File

@ -68,7 +68,7 @@ func (ls *SessionStore) SessionID(context.Context) string {
} }
// SessionRelease save session values to ledis // SessionRelease save session values to ledis
func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (ls *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) {
ls.lock.RLock() ls.lock.RLock()
values := ls.values values := ls.values
ls.lock.RUnlock() ls.lock.RUnlock()
@ -80,6 +80,13 @@ func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
c.Expire([]byte(ls.sid), ls.maxlifetime) c.Expire([]byte(ls.sid), ls.maxlifetime)
} }
// SessionReleaseIfPresent is not supported now, because ledis has no this feature like SETXX or atomic operation.
// https://github.com/ledisdb/ledisdb/issues/251
// https://github.com/ledisdb/ledisdb/issues/351
func (ls *SessionStore) SessionReleaseIfPresent(c context.Context, w http.ResponseWriter) {
ls.SessionRelease(c, w)
}
// Provider ledis session provider // Provider ledis session provider
type Provider struct { type Provider struct {
maxlifetime int64 maxlifetime int64
@ -162,8 +169,8 @@ func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error)
func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) {
count, _ := c.Exists([]byte(sid)) count, _ := c.Exists([]byte(sid))
if count == 0 { if count == 0 {
// oldsid doesn't exists, set the new sid directly // oldsid doesn't exist, set the new sid directly
// ignore error here, since if it return error // ignore error here, since if it returns error
// the existed value will be 0 // the existed value will be 0
c.Set([]byte(sid), []byte("")) c.Set([]byte(sid), []byte(""))
c.Expire([]byte(sid), lp.maxlifetime) c.Expire([]byte(sid), lp.maxlifetime)
@ -181,7 +188,7 @@ func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error {
return nil return nil
} }
// SessionGC Impelment method, no used. // SessionGC Implement method, no used.
func (lp *Provider) SessionGC(context.Context) { func (lp *Provider) SessionGC(context.Context) {
} }

View File

@ -97,6 +97,15 @@ func (rs *SessionStore) SessionID(context.Context) string {
// SessionRelease save session values to memcache // SessionRelease save session values to memcache
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
rs.releaseSession(ctx, w, false)
}
// SessionReleaseIfPresent save session values to memcache when key is present
func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) {
rs.releaseSession(ctx, w, true)
}
func (rs *SessionStore) releaseSession(_ context.Context, _ http.ResponseWriter, requirePresent bool) {
rs.lock.RLock() rs.lock.RLock()
values := rs.values values := rs.values
rs.lock.RUnlock() rs.lock.RUnlock()
@ -105,7 +114,11 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
return return
} }
item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)} item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)}
client.Set(&item) if requirePresent {
client.Replace(&item)
} else {
client.Set(&item)
}
} }
// MemProvider memcache session provider // MemProvider memcache session provider
@ -176,8 +189,8 @@ func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string
} }
var contain []byte var contain []byte
if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { if item, err := client.Get(sid); err != nil || len(item.Value) == 0 {
// oldsid doesn't exists, set the new sid directly // oldsid doesn't exist, set the new sid directly
// ignore error here, since if it return error // ignore error here, since if it returns error
// the existed value will be 0 // the existed value will be 0
item.Key = sid item.Key = sid
item.Value = []byte("") item.Value = []byte("")
@ -222,7 +235,7 @@ func (rp *MemProvider) connectInit() error {
return nil return nil
} }
// SessionGC Impelment method, no used. // SessionGC Implement method, no used.
func (rp *MemProvider) SessionGC(context.Context) { func (rp *MemProvider) SessionGC(context.Context) {
} }

View File

@ -109,7 +109,7 @@ func (st *SessionStore) SessionID(context.Context) string {
// SessionRelease save mysql session values to database. // SessionRelease save mysql session values to database.
// must call this method to save values to database. // must call this method to save values to database.
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (st *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) {
defer st.c.Close() defer st.c.Close()
st.lock.RLock() st.lock.RLock()
values := st.values values := st.values
@ -122,6 +122,11 @@ func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
b, time.Now().Unix(), st.sid) b, time.Now().Unix(), st.sid)
} }
// SessionReleaseIfPresent save mysql session values to database.
func (st *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) {
st.SessionRelease(ctx, w)
}
// Provider mysql session provider // Provider mysql session provider
type Provider struct { type Provider struct {
maxlifetime int64 maxlifetime int64

View File

@ -112,7 +112,7 @@ func (st *SessionStore) SessionID(context.Context) string {
// SessionRelease save postgresql session values to database. // SessionRelease save postgresql session values to database.
// must call this method to save values to database. // must call this method to save values to database.
func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (st *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) {
defer st.c.Close() defer st.c.Close()
st.lock.RLock() st.lock.RLock()
values := st.values values := st.values
@ -125,6 +125,11 @@ func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
b, time.Now().Format(time.RFC3339), st.sid) b, time.Now().Format(time.RFC3339), st.sid)
} }
// SessionReleaseIfPresent save postgresql session values to database when key is present
func (st *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) {
st.SessionRelease(ctx, w)
}
// Provider postgresql session provider // Provider postgresql session provider
type Provider struct { type Provider struct {
maxlifetime int64 maxlifetime int64

View File

@ -101,6 +101,15 @@ func (rs *SessionStore) SessionID(context.Context) string {
// SessionRelease save session values to redis // SessionRelease save session values to redis
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
rs.releaseSession(ctx, w, false)
}
// SessionReleaseIfPresent save session values to redis when key is present
func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) {
rs.releaseSession(ctx, w, true)
}
func (rs *SessionStore) releaseSession(ctx context.Context, _ http.ResponseWriter, requirePresent bool) {
rs.lock.RLock() rs.lock.RLock()
values := rs.values values := rs.values
rs.lock.RUnlock() rs.lock.RUnlock()
@ -109,7 +118,11 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
return return
} }
c := rs.p c := rs.p
c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) if requirePresent {
c.SetXX(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
} else {
c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
} }
// Provider redis session provider // Provider redis session provider
@ -158,12 +171,12 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr s
} }
rp.poollist = redis.NewClient(&redis.Options{ rp.poollist = redis.NewClient(&redis.Options{
Addr: rp.SavePath, Addr: rp.SavePath,
Password: rp.Password, Password: rp.Password,
PoolSize: rp.Poolsize, PoolSize: rp.Poolsize,
DB: rp.DbNum, DB: rp.DbNum,
ConnMaxIdleTime: rp.idleTimeout, ConnMaxIdleTime: rp.idleTimeout,
MaxRetries: rp.MaxRetries, MaxRetries: rp.MaxRetries,
}) })
return rp.poollist.Ping(ctx).Err() return rp.poollist.Ping(ctx).Err()

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"sync"
"testing" "testing"
"time" "time"
@ -15,25 +16,9 @@ import (
) )
func TestRedis(t *testing.T) { func TestRedis(t *testing.T) {
redisAddr := os.Getenv("REDIS_ADDR") globalSession, err := setupSessionManager(t)
if redisAddr == "" {
redisAddr = "127.0.0.1:6379"
}
redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr)
sessionConfig := session.NewManagerConfig(
session.CfgCookieName(`gosessionid`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
session.CfgProviderConfig(redisConfig),
)
globalSession, err := session.NewManager("redis", sessionConfig)
if err != nil { if err != nil {
t.Fatal("could not create manager:", err) t.Fatal(err)
} }
go globalSession.GC() go globalSession.GC()
@ -112,3 +97,65 @@ func TestProvider_SessionInit(t *testing.T) {
assert.Equal(t, 3*time.Second, cp.idleTimeout) assert.Equal(t, 3*time.Second, cp.idleTimeout)
assert.Equal(t, int64(12), cp.maxlifetime) assert.Equal(t, int64(12), cp.maxlifetime)
} }
func TestStoreSessionReleaseIfPresentAndSessionDestroy(t *testing.T) {
globalSessions, err := setupSessionManager(t)
if err != nil {
t.Fatal(err)
}
// todo test if e==nil
go globalSessions.GC()
ctx := context.Background()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
if err := globalSessions.GetProvider().SessionDestroy(ctx, sess.SessionID(ctx)); err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
sess.SessionReleaseIfPresent(ctx, httptest.NewRecorder())
}()
wg.Wait()
exist, err := globalSessions.GetProvider().SessionExist(ctx, sess.SessionID(ctx))
if err != nil {
t.Error(err)
}
if exist {
t.Fatalf("session %s should exist", sess.SessionID(ctx))
}
}
func setupSessionManager(t *testing.T) (*session.Manager, error) {
redisAddr := os.Getenv("REDIS_ADDR")
if redisAddr == "" {
redisAddr = "127.0.0.1:6379"
}
redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr)
sessionConfig := session.NewManagerConfig(
session.CfgCookieName(`gosessionid`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
session.CfgProviderConfig(redisConfig),
)
globalSessions, err := session.NewManager("redis", sessionConfig)
if err != nil {
t.Log("could not create manager: ", err)
return nil, err
}
return globalSessions, nil
}

View File

@ -101,6 +101,15 @@ func (rs *SessionStore) SessionID(context.Context) string {
// SessionRelease save session values to redis_cluster // SessionRelease save session values to redis_cluster
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
rs.releaseSession(ctx, w, false)
}
// SessionReleaseIfPresent save session values to redis_cluster when key is present
func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) {
rs.releaseSession(ctx, w, true)
}
func (rs *SessionStore) releaseSession(ctx context.Context, _ http.ResponseWriter, requirePresent bool) {
rs.lock.RLock() rs.lock.RLock()
values := rs.values values := rs.values
rs.lock.RUnlock() rs.lock.RUnlock()
@ -109,7 +118,11 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
return return
} }
c := rs.p c := rs.p
c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) if requirePresent {
c.SetXX(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
} else {
c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
} }
// Provider redis_cluster session provider // Provider redis_cluster session provider
@ -156,11 +169,11 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr s
} }
rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{
Addrs: strings.Split(rp.SavePath, ";"), Addrs: strings.Split(rp.SavePath, ";"),
Password: rp.Password, Password: rp.Password,
PoolSize: rp.Poolsize, PoolSize: rp.Poolsize,
ConnMaxIdleTime: rp.idleTimeout, ConnMaxIdleTime: rp.idleTimeout,
MaxRetries: rp.MaxRetries, MaxRetries: rp.MaxRetries,
}) })
return rp.poollist.Ping(ctx).Err() return rp.poollist.Ping(ctx).Err()
} }

View File

@ -103,6 +103,15 @@ func (rs *SessionStore) SessionID(context.Context) string {
// SessionRelease save session values to redis_sentinel // SessionRelease save session values to redis_sentinel
func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
rs.releaseSession(ctx, w, false)
}
// SessionReleaseIfPresent save session values to redis_sentinel when key is present
func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) {
rs.releaseSession(ctx, w, true)
}
func (rs *SessionStore) releaseSession(ctx context.Context, _ http.ResponseWriter, requirePresent bool) {
rs.lock.RLock() rs.lock.RLock()
values := rs.values values := rs.values
rs.lock.RUnlock() rs.lock.RUnlock()
@ -111,7 +120,11 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite
return return
} }
c := rs.p c := rs.p
c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) if requirePresent {
c.SetXX(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
} else {
c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second)
}
} }
// Provider redis_sentinel session provider // Provider redis_sentinel session provider
@ -159,13 +172,13 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr s
} }
rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{
SentinelAddrs: strings.Split(rp.SavePath, ";"), SentinelAddrs: strings.Split(rp.SavePath, ";"),
Password: rp.Password, Password: rp.Password,
PoolSize: rp.Poolsize, PoolSize: rp.Poolsize,
DB: rp.DbNum, DB: rp.DbNum,
MasterName: rp.MasterName, MasterName: rp.MasterName,
ConnMaxIdleTime: rp.idleTimeout, ConnMaxIdleTime: rp.idleTimeout,
MaxRetries: rp.MaxRetries, MaxRetries: rp.MaxRetries,
}) })
return rp.poollist.Ping(ctx).Err() return rp.poollist.Ping(ctx).Err()

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time" "time"
@ -13,18 +14,9 @@ import (
) )
func TestRedisSentinel(t *testing.T) { func TestRedisSentinel(t *testing.T) {
sessionConfig := session.NewManagerConfig( globalSessions, err := setupSessionManager(t)
session.CfgCookieName(`gosessionid`), if err != nil {
session.CfgSetCookie(true), t.Log(err)
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"),
)
globalSessions, e := session.NewManager("redis_sentinel", sessionConfig)
if e != nil {
t.Log(e)
return return
} }
// todo test if e==nil // todo test if e==nil
@ -104,3 +96,60 @@ func TestProvider_SessionInit(t *testing.T) {
assert.Equal(t, 3*time.Second, cp.idleTimeout) assert.Equal(t, 3*time.Second, cp.idleTimeout)
assert.Equal(t, int64(12), cp.maxlifetime) assert.Equal(t, int64(12), cp.maxlifetime)
} }
func TestStoreSessionReleaseIfPresentAndSessionDestroy(t *testing.T) {
globalSessions, e := setupSessionManager(t)
if e != nil {
t.Log(e)
return
}
// todo test if e==nil
go globalSessions.GC()
ctx := context.Background()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess, err := globalSessions.SessionStart(w, r)
if err != nil {
t.Fatal("session start failed:", err)
}
if err := globalSessions.GetProvider().SessionDestroy(ctx, sess.SessionID(ctx)); err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
sess.SessionReleaseIfPresent(context.Background(), httptest.NewRecorder())
}()
wg.Wait()
exist, err := globalSessions.GetProvider().SessionExist(ctx, sess.SessionID(ctx))
if err != nil {
t.Error(err)
}
if exist {
t.Fatalf("session %s should exist", sess.SessionID(ctx))
}
}
func setupSessionManager(t *testing.T) (*session.Manager, error) {
sessionConfig := session.NewManagerConfig(
session.CfgCookieName(`gosessionid`),
session.CfgSetCookie(true),
session.CfgGcLifeTime(3600),
session.CfgMaxLifeTime(3600),
session.CfgSecure(false),
session.CfgCookieLifeTime(3600),
session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"),
)
globalSessions, err := session.NewManager("redis_sentinel", sessionConfig)
if err != nil {
t.Log(err)
return nil, err
}
return globalSessions, nil
}

View File

@ -74,7 +74,7 @@ func (st *CookieSessionStore) SessionID(context.Context) string {
} }
// SessionRelease Write cookie session to http response cookie // SessionRelease Write cookie session to http response cookie
func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (st *CookieSessionStore) SessionRelease(_ context.Context, w http.ResponseWriter) {
st.lock.RLock() st.lock.RLock()
values := st.values values := st.values
st.lock.RUnlock() st.lock.RUnlock()
@ -93,6 +93,12 @@ func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.Respons
} }
} }
// SessionReleaseIfPresent Write cookie session to http response cookie when it is present
// This is a no-op for cookie sessions, because they are always present.
func (st *CookieSessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) {
st.SessionRelease(ctx, w)
}
type cookieConfig struct { type cookieConfig struct {
SecurityKey string `json:"securityKey"` SecurityKey string `json:"securityKey"`
BlockKey string `json:"blockKey"` BlockKey string `json:"blockKey"`

View File

@ -59,7 +59,7 @@ func TestCookie(t *testing.T) {
} }
} }
func TestDestorySessionCookie(t *testing.T) { func TestDestroySessionCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
conf := new(ManagerConfig) conf := new(ManagerConfig)
if err := json.Unmarshal([]byte(config), conf); err != nil { if err := json.Unmarshal([]byte(config), conf); err != nil {

View File

@ -80,6 +80,15 @@ func (fs *FileSessionStore) SessionID(context.Context) string {
// SessionRelease Write file session to local file with Gob string // SessionRelease Write file session to local file with Gob string
func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) {
fs.releaseSession(ctx, w, true)
}
// SessionReleaseIfPresent Write file session to local file with Gob string when session exists
func (fs *FileSessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) {
fs.releaseSession(ctx, w, false)
}
func (fs *FileSessionStore) releaseSession(_ context.Context, _ http.ResponseWriter, createIfNotExist bool) {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
b, err := EncodeGob(fs.values) b, err := EncodeGob(fs.values)
@ -95,7 +104,7 @@ func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseW
SLogger.Println(err) SLogger.Println(err)
return return
} }
} else if os.IsNotExist(err) { } else if os.IsNotExist(err) && createIfNotExist {
f, err = os.Create(filepath.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) f, err = os.Create(filepath.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid))
if err != nil { if err != nil {
SLogger.Println(err) SLogger.Println(err)
@ -228,7 +237,7 @@ func (fp *FileProvider) SessionAll(context.Context) int {
} }
// SessionRegenerate Generate new sid for file session. // SessionRegenerate Generate new sid for file session.
// it delete old file and create new file named from new sid. // it deletes old file and create new file named from new sid.
func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()

View File

@ -17,6 +17,7 @@ package session
import ( import (
"context" "context"
"fmt" "fmt"
"net/http/httptest"
"os" "os"
"sync" "sync"
"testing" "testing"
@ -334,15 +335,15 @@ func TestFileSessionStoreDelete(t *testing.T) {
_ = fp.SessionInit(context.Background(), 180, sessionPath) _ = fp.SessionInit(context.Background(), 180, sessionPath)
s, _ := fp.SessionRead(context.Background(), sid) s, _ := fp.SessionRead(context.Background(), sid)
s.Set(nil, "1", 1) s.Set(context.Background(), "1", 1)
if s.Get(nil, "1") == nil { if s.Get(context.Background(), "1") == nil {
t.Error() t.Error()
} }
s.Delete(nil, "1") s.Delete(context.Background(), "1")
if s.Get(nil, "1") != nil { if s.Get(context.Background(), "1") != nil {
t.Error() t.Error()
} }
} }
@ -387,13 +388,21 @@ func TestFileSessionStoreSessionID(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) { if s.SessionID(context.Background()) != fmt.Sprintf("%s_%d", sid, i) {
t.Error(err) t.Error(err)
} }
} }
} }
func TestFileSessionStoreSessionRelease(t *testing.T) { func TestFileSessionStoreSessionRelease(t *testing.T) {
releaseSession(t, false)
}
func TestFileSessionStoreSessionReleaseIfPresent(t *testing.T) {
releaseSession(t, true)
}
func releaseSession(t *testing.T, requirePresent bool) {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
os.RemoveAll(sessionPath) os.RemoveAll(sessionPath)
@ -410,8 +419,13 @@ func TestFileSessionStoreSessionRelease(t *testing.T) {
t.Error(err) t.Error(err)
} }
s.Set(nil, i, i) s.Set(context.Background(), i, i)
s.SessionRelease(nil, nil) if requirePresent {
s.SessionReleaseIfPresent(context.Background(), httptest.NewRecorder())
} else {
s.SessionRelease(context.Background(), httptest.NewRecorder())
}
} }
for i := 1; i <= sessionCount; i++ { for i := 1; i <= sessionCount; i++ {
@ -425,3 +439,36 @@ func TestFileSessionStoreSessionRelease(t *testing.T) {
} }
} }
} }
func TestFileSessionStoreSessionReleaseIfPresentAndSessionDestroy(t *testing.T) {
mutex.Lock()
defer mutex.Unlock()
os.RemoveAll(sessionPath)
defer os.RemoveAll(sessionPath)
fp := &FileProvider{}
s, err := fp.SessionRead(context.Background(), sid)
if err != nil {
return
}
_ = fp.SessionInit(context.Background(), 180, sessionPath)
filepder.savePath = sessionPath
if err := fp.SessionDestroy(context.Background(), sid); err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
s.SessionReleaseIfPresent(context.Background(), httptest.NewRecorder())
}()
wg.Wait()
exist, err := fp.SessionExist(context.Background(), sid)
if err != nil {
t.Error(err)
}
if exist {
t.Fatalf("session %s should exist", sid)
}
}

View File

@ -73,7 +73,11 @@ func (st *MemSessionStore) SessionID(context.Context) string {
} }
// SessionRelease Implement method, no used. // SessionRelease Implement method, no used.
func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (st *MemSessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) {
}
// SessionReleaseIfPresent Implement method, no used.
func (*MemSessionStore) SessionReleaseIfPresent(_ context.Context, _ http.ResponseWriter) {
} }
// MemProvider Implement the provider interface // MemProvider Implement the provider interface

View File

@ -44,12 +44,13 @@ import (
// Store contains all data for one session process with specific id. // Store contains all data for one session process with specific id.
type Store interface { type Store interface {
Set(ctx context.Context, key, value interface{}) error // set session value Set(ctx context.Context, key, value interface{}) error // Set set session value
Get(ctx context.Context, key interface{}) interface{} // get session value Get(ctx context.Context, key interface{}) interface{} // Get get session value
Delete(ctx context.Context, key interface{}) error // delete session value Delete(ctx context.Context, key interface{}) error // Delete delete session value
SessionID(ctx context.Context) string // back current sessionID SessionID(ctx context.Context) string // SessionID return current sessionID
SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) // SessionReleaseIfPresent release the resource & save data to provider & return the data when the session is present, not all implementation support this feature, you need to check if the specific implementation if support this feature.
Flush(ctx context.Context) error // delete all data SessionRelease(ctx context.Context, w http.ResponseWriter) // SessionRelease release the resource & save data to provider & return the data
Flush(ctx context.Context) error // Flush delete all data
} }
// Provider contains global session methods and saved SessionStores. // Provider contains global session methods and saved SessionStores.
@ -153,7 +154,7 @@ func (manager *Manager) GetProvider() Provider {
// //
// error is not nil when there is anything wrong. // error is not nil when there is anything wrong.
// sid is empty when need to generate a new session id // sid is empty when need to generate a new session id
// otherwise return an valid session id. // otherwise return a valid session id.
func (manager *Manager) getSid(r *http.Request) (string, error) { func (manager *Manager) getSid(r *http.Request) (string, error) {
cookie, errs := r.Cookie(manager.config.CookieName) cookie, errs := r.Cookie(manager.config.CookieName)
if errs != nil || cookie.Value == "" { if errs != nil || cookie.Value == "" {
@ -279,7 +280,7 @@ func (manager *Manager) GC() {
time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
} }
// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. // SessionRegenerateID Regenerate a session id for this SessionStore whose id is saving in http request.
func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (Store, error) { func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (Store, error) {
sid, err := manager.sessionID() sid, err := manager.sessionID()
if err != nil { if err != nil {

View File

@ -85,7 +85,7 @@ func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store,
return rs, nil return rs, nil
} }
// SessionExist judged whether sid is exist in session // SessionExist judged whether sid existed in session
func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) {
if p.client == nil { if p.client == nil {
if err := p.connectInit(); err != nil { if err := p.connectInit(); err != nil {
@ -204,7 +204,7 @@ func (s *SessionStore) SessionID(context.Context) string {
} }
// SessionRelease Store the keyvalues into ssdb // SessionRelease Store the keyvalues into ssdb
func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { func (s *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) {
s.lock.RLock() s.lock.RLock()
values := s.values values := s.values
s.lock.RUnlock() s.lock.RUnlock()
@ -215,6 +215,12 @@ func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter
s.client.Do("setx", s.sid, string(b), s.maxLifetime) s.client.Do("setx", s.sid, string(b), s.maxLifetime)
} }
// SessionReleaseIfPresent is not supported now
// Because ssdb does not support lua script or SETXX command
func (s *SessionStore) SessionReleaseIfPresent(c context.Context, w http.ResponseWriter) {
s.SessionRelease(c, w)
}
func init() { func init() {
session.Register("ssdb", ssdbProvider) session.Register("ssdb", ssdbProvider)
} }