session: Support SessionReleaseIfPresent to avoid concurrent problem (#5685)
This commit is contained in:
		
							parent
							
								
									bdb7e7a904
								
							
						
					
					
						commit
						06d869664a
					
				| @ -51,7 +51,7 @@ func TestConn(t *testing.T) { | ||||
| func TestReconnect(t *testing.T) { | ||||
| 	// Setup connection listener | ||||
| 	newConns := make(chan net.Conn) | ||||
| 	connNum := 2 | ||||
| 	connNum := 3 | ||||
| 	ln, err := net.Listen("tcp", ":6002") | ||||
| 	if err != nil { | ||||
| 		t.Log("Error listening:", err.Error()) | ||||
|  | ||||
| @ -16,13 +16,14 @@ package mock | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 
 | ||||
| 	beegoCtx "github.com/beego/beego/v2/server/web/context" | ||||
| ) | ||||
| 
 | ||||
| func NewMockContext(req *http.Request) (*beegoCtx.Context, *HttpResponse) { | ||||
| func NewMockContext(req *http.Request) (*beegoCtx.Context, *httptest.ResponseRecorder) { | ||||
| 	ctx := beegoCtx.NewContext() | ||||
| 	resp := NewMockHttpResponse() | ||||
| 	resp := httptest.NewRecorder() | ||||
| 	ctx.Reset(resp, req) | ||||
| 	return ctx, resp | ||||
| } | ||||
|  | ||||
| @ -31,7 +31,7 @@ type TestController struct { | ||||
| } | ||||
| 
 | ||||
| func TestMockContext(t *testing.T) { | ||||
| 	req, err := http.NewRequest("GET", "http://localhost:8080/hello?name=tom", bytes.NewReader([]byte{})) | ||||
| 	req, err := http.NewRequest("GET", "https://localhost:8080/hello?name=tom", bytes.NewReader([]byte{})) | ||||
| 	assert.Nil(t, err) | ||||
| 	ctx, resp := NewMockContext(req) | ||||
| 	ctrl := &TestController{ | ||||
| @ -40,7 +40,7 @@ func TestMockContext(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 	ctrl.HelloWorld() | ||||
| 	result := resp.BodyToString() | ||||
| 	result := resp.Body.String() | ||||
| 	assert.Equal(t, "name=tom", result) | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -1,69 +0,0 @@ | ||||
| // Copyright 2021 beego | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package mock | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"net/http" | ||||
| ) | ||||
| 
 | ||||
| // HttpResponse mock response, which should be used in tests | ||||
| type HttpResponse struct { | ||||
| 	body       []byte | ||||
| 	header     http.Header | ||||
| 	StatusCode int | ||||
| } | ||||
| 
 | ||||
| // NewMockHttpResponse you should only use this in your test code | ||||
| func NewMockHttpResponse() *HttpResponse { | ||||
| 	return &HttpResponse{ | ||||
| 		body:   make([]byte, 0), | ||||
| 		header: make(http.Header), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Header return headers | ||||
| func (m *HttpResponse) Header() http.Header { | ||||
| 	return m.header | ||||
| } | ||||
| 
 | ||||
| // Write append the body | ||||
| func (m *HttpResponse) Write(bytes []byte) (int, error) { | ||||
| 	m.body = append(m.body, bytes...) | ||||
| 	return len(bytes), nil | ||||
| } | ||||
| 
 | ||||
| // WriteHeader set the status code | ||||
| func (m *HttpResponse) WriteHeader(statusCode int) { | ||||
| 	m.StatusCode = statusCode | ||||
| } | ||||
| 
 | ||||
| // JsonUnmarshal convert the body to object | ||||
| func (m *HttpResponse) JsonUnmarshal(value interface{}) error { | ||||
| 	return json.Unmarshal(m.body, value) | ||||
| } | ||||
| 
 | ||||
| // BodyToString return the body as the string | ||||
| func (m *HttpResponse) BodyToString() string { | ||||
| 	return string(m.body) | ||||
| } | ||||
| 
 | ||||
| // Reset will reset the status to init status | ||||
| // Usually, you want to reuse this instance you may need to call Reset | ||||
| func (m *HttpResponse) Reset() { | ||||
| 	m.body = make([]byte, 0) | ||||
| 	m.header = make(http.Header) | ||||
| 	m.StatusCode = 0 | ||||
| } | ||||
| @ -106,7 +106,12 @@ func (s *SessionStore) SessionID(ctx context.Context) string { | ||||
| } | ||||
| 
 | ||||
| // SessionRelease do nothing | ||||
| func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| func (s *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { | ||||
| 	// Support in the future if necessary, now I think we don't need to implement this | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent do nothing | ||||
| func (*SessionStore) SessionReleaseIfPresent(_ context.Context, _ http.ResponseWriter) { | ||||
| 	// Support in the future if necessary, now I think we don't need to implement this | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -37,12 +37,12 @@ func TestSessionProvider(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 	ctrl.HelloSession() | ||||
| 	result := resp.BodyToString() | ||||
| 	result := resp.Body.String() | ||||
| 	assert.Equal(t, "set", result) | ||||
| 
 | ||||
| 	resp.Reset() | ||||
| 	resp.Body.Reset() | ||||
| 	ctrl.HelloSessionName() | ||||
| 	result = resp.BodyToString() | ||||
| 	result = resp.Body.String() | ||||
| 
 | ||||
| 	assert.Equal(t, "Tom", result) | ||||
| } | ||||
|  | ||||
| @ -64,7 +64,7 @@ type Provider struct { | ||||
| 	b           *couchbase.Bucket | ||||
| } | ||||
| 
 | ||||
| // Set value to couchabse session | ||||
| // Set value to couchbase session | ||||
| func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error { | ||||
| 	cs.lock.Lock() | ||||
| 	defer cs.lock.Unlock() | ||||
| @ -72,7 +72,7 @@ func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Get value from couchabse session | ||||
| // Get value from couchbase session | ||||
| func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { | ||||
| 	cs.lock.RLock() | ||||
| 	defer cs.lock.RUnlock() | ||||
| @ -104,7 +104,7 @@ func (cs *SessionStore) SessionID(context.Context) string { | ||||
| } | ||||
| 
 | ||||
| // SessionRelease Write couchbase session with Gob string | ||||
| func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| func (cs *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { | ||||
| 	defer cs.b.Close() | ||||
| 	cs.lock.RLock() | ||||
| 	values := cs.values | ||||
| @ -117,6 +117,12 @@ func (cs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite | ||||
| 	cs.b.Set(cs.sid, int(cs.maxlifetime), bo) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent is not supported now. | ||||
| // If we want to use couchbase, we may refactor the code to use couchbase collection. | ||||
| func (cs *SessionStore) SessionReleaseIfPresent(c context.Context, w http.ResponseWriter) { | ||||
| 	cs.SessionRelease(c, w) | ||||
| } | ||||
| 
 | ||||
| func (cp *Provider) getBucket() *couchbase.Bucket { | ||||
| 	c, err := couchbase.Connect(cp.SavePath) | ||||
| 	if err != nil { | ||||
| @ -195,7 +201,7 @@ func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, | ||||
| } | ||||
| 
 | ||||
| // SessionExist Check couchbase session exist. | ||||
| // it checkes sid exist or not. | ||||
| // it checks sid exist or not. | ||||
| func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { | ||||
| 	cp.b = cp.getBucket() | ||||
| 	defer cp.b.Close() | ||||
|  | ||||
| @ -68,7 +68,7 @@ func (ls *SessionStore) SessionID(context.Context) string { | ||||
| } | ||||
| 
 | ||||
| // SessionRelease save session values to ledis | ||||
| func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| func (ls *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { | ||||
| 	ls.lock.RLock() | ||||
| 	values := ls.values | ||||
| 	ls.lock.RUnlock() | ||||
| @ -80,6 +80,13 @@ func (ls *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite | ||||
| 	c.Expire([]byte(ls.sid), ls.maxlifetime) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent is not supported now, because ledis has no this feature like SETXX or atomic operation. | ||||
| // https://github.com/ledisdb/ledisdb/issues/251 | ||||
| // https://github.com/ledisdb/ledisdb/issues/351 | ||||
| func (ls *SessionStore) SessionReleaseIfPresent(c context.Context, w http.ResponseWriter) { | ||||
| 	ls.SessionRelease(c, w) | ||||
| } | ||||
| 
 | ||||
| // Provider ledis session provider | ||||
| type Provider struct { | ||||
| 	maxlifetime int64 | ||||
| @ -162,8 +169,8 @@ func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) | ||||
| func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { | ||||
| 	count, _ := c.Exists([]byte(sid)) | ||||
| 	if count == 0 { | ||||
| 		// oldsid doesn't exists, set the new sid directly | ||||
| 		// ignore error here, since if it return error | ||||
| 		// oldsid doesn't exist, set the new sid directly | ||||
| 		// ignore error here, since if it returns error | ||||
| 		// the existed value will be 0 | ||||
| 		c.Set([]byte(sid), []byte("")) | ||||
| 		c.Expire([]byte(sid), lp.maxlifetime) | ||||
| @ -181,7 +188,7 @@ func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // SessionGC Impelment method, no used. | ||||
| // SessionGC Implement method, no used. | ||||
| func (lp *Provider) SessionGC(context.Context) { | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -97,6 +97,15 @@ func (rs *SessionStore) SessionID(context.Context) string { | ||||
| 
 | ||||
| // SessionRelease save session values to memcache | ||||
| func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| 	rs.releaseSession(ctx, w, false) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent save session values to memcache when key is present | ||||
| func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { | ||||
| 	rs.releaseSession(ctx, w, true) | ||||
| } | ||||
| 
 | ||||
| func (rs *SessionStore) releaseSession(_ context.Context, _ http.ResponseWriter, requirePresent bool) { | ||||
| 	rs.lock.RLock() | ||||
| 	values := rs.values | ||||
| 	rs.lock.RUnlock() | ||||
| @ -105,7 +114,11 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite | ||||
| 		return | ||||
| 	} | ||||
| 	item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)} | ||||
| 	client.Set(&item) | ||||
| 	if requirePresent { | ||||
| 		client.Replace(&item) | ||||
| 	} else { | ||||
| 		client.Set(&item) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // MemProvider memcache session provider | ||||
| @ -176,8 +189,8 @@ func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string | ||||
| 	} | ||||
| 	var contain []byte | ||||
| 	if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { | ||||
| 		// oldsid doesn't exists, set the new sid directly | ||||
| 		// ignore error here, since if it return error | ||||
| 		// oldsid doesn't exist, set the new sid directly | ||||
| 		// ignore error here, since if it returns error | ||||
| 		// the existed value will be 0 | ||||
| 		item.Key = sid | ||||
| 		item.Value = []byte("") | ||||
| @ -222,7 +235,7 @@ func (rp *MemProvider) connectInit() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // SessionGC Impelment method, no used. | ||||
| // SessionGC Implement method, no used. | ||||
| func (rp *MemProvider) SessionGC(context.Context) { | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -109,7 +109,7 @@ func (st *SessionStore) SessionID(context.Context) string { | ||||
| 
 | ||||
| // SessionRelease save mysql session values to database. | ||||
| // must call this method to save values to database. | ||||
| func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| func (st *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { | ||||
| 	defer st.c.Close() | ||||
| 	st.lock.RLock() | ||||
| 	values := st.values | ||||
| @ -122,6 +122,11 @@ func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite | ||||
| 		b, time.Now().Unix(), st.sid) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent save mysql session values to database. | ||||
| func (st *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { | ||||
| 	st.SessionRelease(ctx, w) | ||||
| } | ||||
| 
 | ||||
| // Provider mysql session provider | ||||
| type Provider struct { | ||||
| 	maxlifetime int64 | ||||
|  | ||||
| @ -112,7 +112,7 @@ func (st *SessionStore) SessionID(context.Context) string { | ||||
| 
 | ||||
| // SessionRelease save postgresql session values to database. | ||||
| // must call this method to save values to database. | ||||
| func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| func (st *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { | ||||
| 	defer st.c.Close() | ||||
| 	st.lock.RLock() | ||||
| 	values := st.values | ||||
| @ -125,6 +125,11 @@ func (st *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite | ||||
| 		b, time.Now().Format(time.RFC3339), st.sid) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent save postgresql session values to database when key is present | ||||
| func (st *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { | ||||
| 	st.SessionRelease(ctx, w) | ||||
| } | ||||
| 
 | ||||
| // Provider postgresql session provider | ||||
| type Provider struct { | ||||
| 	maxlifetime int64 | ||||
|  | ||||
| @ -101,6 +101,15 @@ func (rs *SessionStore) SessionID(context.Context) string { | ||||
| 
 | ||||
| // SessionRelease save session values to redis | ||||
| func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| 	rs.releaseSession(ctx, w, false) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent save session values to redis when key is present | ||||
| func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { | ||||
| 	rs.releaseSession(ctx, w, true) | ||||
| } | ||||
| 
 | ||||
| func (rs *SessionStore) releaseSession(ctx context.Context, _ http.ResponseWriter, requirePresent bool) { | ||||
| 	rs.lock.RLock() | ||||
| 	values := rs.values | ||||
| 	rs.lock.RUnlock() | ||||
| @ -109,7 +118,11 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite | ||||
| 		return | ||||
| 	} | ||||
| 	c := rs.p | ||||
| 	c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) | ||||
| 	if requirePresent { | ||||
| 		c.SetXX(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) | ||||
| 	} else { | ||||
| 		c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Provider redis session provider | ||||
| @ -158,12 +171,12 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr s | ||||
| 	} | ||||
| 
 | ||||
| 	rp.poollist = redis.NewClient(&redis.Options{ | ||||
| 		Addr:               rp.SavePath, | ||||
| 		Password:           rp.Password, | ||||
| 		PoolSize:           rp.Poolsize, | ||||
| 		DB:                 rp.DbNum, | ||||
| 		ConnMaxIdleTime:    rp.idleTimeout, | ||||
| 		MaxRetries:         rp.MaxRetries, | ||||
| 		Addr:            rp.SavePath, | ||||
| 		Password:        rp.Password, | ||||
| 		PoolSize:        rp.Poolsize, | ||||
| 		DB:              rp.DbNum, | ||||
| 		ConnMaxIdleTime: rp.idleTimeout, | ||||
| 		MaxRetries:      rp.MaxRetries, | ||||
| 	}) | ||||
| 
 | ||||
| 	return rp.poollist.Ping(ctx).Err() | ||||
|  | ||||
| @ -6,6 +6,7 @@ import ( | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| @ -15,25 +16,9 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func TestRedis(t *testing.T) { | ||||
| 	redisAddr := os.Getenv("REDIS_ADDR") | ||||
| 	if redisAddr == "" { | ||||
| 		redisAddr = "127.0.0.1:6379" | ||||
| 	} | ||||
| 	redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr) | ||||
| 
 | ||||
| 	sessionConfig := session.NewManagerConfig( | ||||
| 		session.CfgCookieName(`gosessionid`), | ||||
| 		session.CfgSetCookie(true), | ||||
| 		session.CfgGcLifeTime(3600), | ||||
| 		session.CfgMaxLifeTime(3600), | ||||
| 		session.CfgSecure(false), | ||||
| 		session.CfgCookieLifeTime(3600), | ||||
| 		session.CfgProviderConfig(redisConfig), | ||||
| 	) | ||||
| 
 | ||||
| 	globalSession, err := session.NewManager("redis", sessionConfig) | ||||
| 	globalSession, err := setupSessionManager(t) | ||||
| 	if err != nil { | ||||
| 		t.Fatal("could not create manager:", err) | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 
 | ||||
| 	go globalSession.GC() | ||||
| @ -112,3 +97,65 @@ func TestProvider_SessionInit(t *testing.T) { | ||||
| 	assert.Equal(t, 3*time.Second, cp.idleTimeout) | ||||
| 	assert.Equal(t, int64(12), cp.maxlifetime) | ||||
| } | ||||
| 
 | ||||
| func TestStoreSessionReleaseIfPresentAndSessionDestroy(t *testing.T) { | ||||
| 	globalSessions, err := setupSessionManager(t) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	// todo test if e==nil | ||||
| 	go globalSessions.GC() | ||||
| 
 | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	r, _ := http.NewRequest("GET", "/", nil) | ||||
| 	w := httptest.NewRecorder() | ||||
| 
 | ||||
| 	sess, err := globalSessions.SessionStart(w, r) | ||||
| 	if err != nil { | ||||
| 		t.Fatal("session start failed:", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := globalSessions.GetProvider().SessionDestroy(ctx, sess.SessionID(ctx)); err != nil { | ||||
| 		t.Error(err) | ||||
| 		return | ||||
| 	} | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	wg.Add(1) | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		sess.SessionReleaseIfPresent(ctx, httptest.NewRecorder()) | ||||
| 	}() | ||||
| 	wg.Wait() | ||||
| 	exist, err := globalSessions.GetProvider().SessionExist(ctx, sess.SessionID(ctx)) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 	if exist { | ||||
| 		t.Fatalf("session %s should exist", sess.SessionID(ctx)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func setupSessionManager(t *testing.T) (*session.Manager, error) { | ||||
| 	redisAddr := os.Getenv("REDIS_ADDR") | ||||
| 	if redisAddr == "" { | ||||
| 		redisAddr = "127.0.0.1:6379" | ||||
| 	} | ||||
| 	redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr) | ||||
| 
 | ||||
| 	sessionConfig := session.NewManagerConfig( | ||||
| 		session.CfgCookieName(`gosessionid`), | ||||
| 		session.CfgSetCookie(true), | ||||
| 		session.CfgGcLifeTime(3600), | ||||
| 		session.CfgMaxLifeTime(3600), | ||||
| 		session.CfgSecure(false), | ||||
| 		session.CfgCookieLifeTime(3600), | ||||
| 		session.CfgProviderConfig(redisConfig), | ||||
| 	) | ||||
| 	globalSessions, err := session.NewManager("redis", sessionConfig) | ||||
| 	if err != nil { | ||||
| 		t.Log("could not create manager: ", err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return globalSessions, nil | ||||
| } | ||||
|  | ||||
| @ -101,6 +101,15 @@ func (rs *SessionStore) SessionID(context.Context) string { | ||||
| 
 | ||||
| // SessionRelease save session values to redis_cluster | ||||
| func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| 	rs.releaseSession(ctx, w, false) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent save session values to redis_cluster when key is present | ||||
| func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { | ||||
| 	rs.releaseSession(ctx, w, true) | ||||
| } | ||||
| 
 | ||||
| func (rs *SessionStore) releaseSession(ctx context.Context, _ http.ResponseWriter, requirePresent bool) { | ||||
| 	rs.lock.RLock() | ||||
| 	values := rs.values | ||||
| 	rs.lock.RUnlock() | ||||
| @ -109,7 +118,11 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite | ||||
| 		return | ||||
| 	} | ||||
| 	c := rs.p | ||||
| 	c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) | ||||
| 	if requirePresent { | ||||
| 		c.SetXX(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) | ||||
| 	} else { | ||||
| 		c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Provider redis_cluster session provider | ||||
| @ -156,11 +169,11 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr s | ||||
| 	} | ||||
| 
 | ||||
| 	rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ | ||||
| 		Addrs:              strings.Split(rp.SavePath, ";"), | ||||
| 		Password:           rp.Password, | ||||
| 		PoolSize:           rp.Poolsize, | ||||
| 		ConnMaxIdleTime:    rp.idleTimeout, | ||||
| 		MaxRetries:         rp.MaxRetries, | ||||
| 		Addrs:           strings.Split(rp.SavePath, ";"), | ||||
| 		Password:        rp.Password, | ||||
| 		PoolSize:        rp.Poolsize, | ||||
| 		ConnMaxIdleTime: rp.idleTimeout, | ||||
| 		MaxRetries:      rp.MaxRetries, | ||||
| 	}) | ||||
| 	return rp.poollist.Ping(ctx).Err() | ||||
| } | ||||
|  | ||||
| @ -103,6 +103,15 @@ func (rs *SessionStore) SessionID(context.Context) string { | ||||
| 
 | ||||
| // SessionRelease save session values to redis_sentinel | ||||
| func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| 	rs.releaseSession(ctx, w, false) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent save session values to redis_sentinel when key is present | ||||
| func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { | ||||
| 	rs.releaseSession(ctx, w, true) | ||||
| } | ||||
| 
 | ||||
| func (rs *SessionStore) releaseSession(ctx context.Context, _ http.ResponseWriter, requirePresent bool) { | ||||
| 	rs.lock.RLock() | ||||
| 	values := rs.values | ||||
| 	rs.lock.RUnlock() | ||||
| @ -111,7 +120,11 @@ func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWrite | ||||
| 		return | ||||
| 	} | ||||
| 	c := rs.p | ||||
| 	c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) | ||||
| 	if requirePresent { | ||||
| 		c.SetXX(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) | ||||
| 	} else { | ||||
| 		c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Provider redis_sentinel session provider | ||||
| @ -159,13 +172,13 @@ func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr s | ||||
| 	} | ||||
| 
 | ||||
| 	rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ | ||||
| 		SentinelAddrs:      strings.Split(rp.SavePath, ";"), | ||||
| 		Password:           rp.Password, | ||||
| 		PoolSize:           rp.Poolsize, | ||||
| 		DB:                 rp.DbNum, | ||||
| 		MasterName:         rp.MasterName, | ||||
| 		ConnMaxIdleTime:    rp.idleTimeout, | ||||
| 		MaxRetries:         rp.MaxRetries, | ||||
| 		SentinelAddrs:   strings.Split(rp.SavePath, ";"), | ||||
| 		Password:        rp.Password, | ||||
| 		PoolSize:        rp.Poolsize, | ||||
| 		DB:              rp.DbNum, | ||||
| 		MasterName:      rp.MasterName, | ||||
| 		ConnMaxIdleTime: rp.idleTimeout, | ||||
| 		MaxRetries:      rp.MaxRetries, | ||||
| 	}) | ||||
| 
 | ||||
| 	return rp.poollist.Ping(ctx).Err() | ||||
|  | ||||
| @ -4,6 +4,7 @@ import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| @ -13,18 +14,9 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| func TestRedisSentinel(t *testing.T) { | ||||
| 	sessionConfig := session.NewManagerConfig( | ||||
| 		session.CfgCookieName(`gosessionid`), | ||||
| 		session.CfgSetCookie(true), | ||||
| 		session.CfgGcLifeTime(3600), | ||||
| 		session.CfgMaxLifeTime(3600), | ||||
| 		session.CfgSecure(false), | ||||
| 		session.CfgCookieLifeTime(3600), | ||||
| 		session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"), | ||||
| 	) | ||||
| 	globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) | ||||
| 	if e != nil { | ||||
| 		t.Log(e) | ||||
| 	globalSessions, err := setupSessionManager(t) | ||||
| 	if err != nil { | ||||
| 		t.Log(err) | ||||
| 		return | ||||
| 	} | ||||
| 	// todo test if e==nil | ||||
| @ -104,3 +96,60 @@ func TestProvider_SessionInit(t *testing.T) { | ||||
| 	assert.Equal(t, 3*time.Second, cp.idleTimeout) | ||||
| 	assert.Equal(t, int64(12), cp.maxlifetime) | ||||
| } | ||||
| 
 | ||||
| func TestStoreSessionReleaseIfPresentAndSessionDestroy(t *testing.T) { | ||||
| 	globalSessions, e := setupSessionManager(t) | ||||
| 	if e != nil { | ||||
| 		t.Log(e) | ||||
| 		return | ||||
| 	} | ||||
| 	// todo test if e==nil | ||||
| 	go globalSessions.GC() | ||||
| 
 | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	r, _ := http.NewRequest("GET", "/", nil) | ||||
| 	w := httptest.NewRecorder() | ||||
| 
 | ||||
| 	sess, err := globalSessions.SessionStart(w, r) | ||||
| 	if err != nil { | ||||
| 		t.Fatal("session start failed:", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := globalSessions.GetProvider().SessionDestroy(ctx, sess.SessionID(ctx)); err != nil { | ||||
| 		t.Error(err) | ||||
| 		return | ||||
| 	} | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	wg.Add(1) | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		sess.SessionReleaseIfPresent(context.Background(), httptest.NewRecorder()) | ||||
| 	}() | ||||
| 	wg.Wait() | ||||
| 	exist, err := globalSessions.GetProvider().SessionExist(ctx, sess.SessionID(ctx)) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 	if exist { | ||||
| 		t.Fatalf("session %s should exist", sess.SessionID(ctx)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func setupSessionManager(t *testing.T) (*session.Manager, error) { | ||||
| 	sessionConfig := session.NewManagerConfig( | ||||
| 		session.CfgCookieName(`gosessionid`), | ||||
| 		session.CfgSetCookie(true), | ||||
| 		session.CfgGcLifeTime(3600), | ||||
| 		session.CfgMaxLifeTime(3600), | ||||
| 		session.CfgSecure(false), | ||||
| 		session.CfgCookieLifeTime(3600), | ||||
| 		session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"), | ||||
| 	) | ||||
| 	globalSessions, err := session.NewManager("redis_sentinel", sessionConfig) | ||||
| 	if err != nil { | ||||
| 		t.Log(err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return globalSessions, nil | ||||
| } | ||||
|  | ||||
| @ -74,7 +74,7 @@ func (st *CookieSessionStore) SessionID(context.Context) string { | ||||
| } | ||||
| 
 | ||||
| // SessionRelease Write cookie session to http response cookie | ||||
| func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| func (st *CookieSessionStore) SessionRelease(_ context.Context, w http.ResponseWriter) { | ||||
| 	st.lock.RLock() | ||||
| 	values := st.values | ||||
| 	st.lock.RUnlock() | ||||
| @ -93,6 +93,12 @@ func (st *CookieSessionStore) SessionRelease(ctx context.Context, w http.Respons | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent Write cookie session to http response cookie when it is present | ||||
| // This is a no-op for cookie sessions, because they are always present. | ||||
| func (st *CookieSessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { | ||||
| 	st.SessionRelease(ctx, w) | ||||
| } | ||||
| 
 | ||||
| type cookieConfig struct { | ||||
| 	SecurityKey  string `json:"securityKey"` | ||||
| 	BlockKey     string `json:"blockKey"` | ||||
|  | ||||
| @ -59,7 +59,7 @@ func TestCookie(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestDestorySessionCookie(t *testing.T) { | ||||
| func TestDestroySessionCookie(t *testing.T) { | ||||
| 	config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` | ||||
| 	conf := new(ManagerConfig) | ||||
| 	if err := json.Unmarshal([]byte(config), conf); err != nil { | ||||
|  | ||||
| @ -80,6 +80,15 @@ func (fs *FileSessionStore) SessionID(context.Context) string { | ||||
| 
 | ||||
| // SessionRelease Write file session to local file with Gob string | ||||
| func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| 	fs.releaseSession(ctx, w, true) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent Write file session to local file with Gob string when session exists | ||||
| func (fs *FileSessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { | ||||
| 	fs.releaseSession(ctx, w, false) | ||||
| } | ||||
| 
 | ||||
| func (fs *FileSessionStore) releaseSession(_ context.Context, _ http.ResponseWriter, createIfNotExist bool) { | ||||
| 	filepder.lock.Lock() | ||||
| 	defer filepder.lock.Unlock() | ||||
| 	b, err := EncodeGob(fs.values) | ||||
| @ -95,7 +104,7 @@ func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseW | ||||
| 			SLogger.Println(err) | ||||
| 			return | ||||
| 		} | ||||
| 	} else if os.IsNotExist(err) { | ||||
| 	} else if os.IsNotExist(err) && createIfNotExist { | ||||
| 		f, err = os.Create(filepath.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) | ||||
| 		if err != nil { | ||||
| 			SLogger.Println(err) | ||||
| @ -228,7 +237,7 @@ func (fp *FileProvider) SessionAll(context.Context) int { | ||||
| } | ||||
| 
 | ||||
| // SessionRegenerate Generate new sid for file session. | ||||
| // it delete old file and create new file named from new sid. | ||||
| // it deletes old file and create new file named from new sid. | ||||
| func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { | ||||
| 	filepder.lock.Lock() | ||||
| 	defer filepder.lock.Unlock() | ||||
|  | ||||
| @ -17,6 +17,7 @@ package session | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| @ -334,15 +335,15 @@ func TestFileSessionStoreDelete(t *testing.T) { | ||||
| 	_ = fp.SessionInit(context.Background(), 180, sessionPath) | ||||
| 
 | ||||
| 	s, _ := fp.SessionRead(context.Background(), sid) | ||||
| 	s.Set(nil, "1", 1) | ||||
| 	s.Set(context.Background(), "1", 1) | ||||
| 
 | ||||
| 	if s.Get(nil, "1") == nil { | ||||
| 	if s.Get(context.Background(), "1") == nil { | ||||
| 		t.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	s.Delete(nil, "1") | ||||
| 	s.Delete(context.Background(), "1") | ||||
| 
 | ||||
| 	if s.Get(nil, "1") != nil { | ||||
| 	if s.Get(context.Background(), "1") != nil { | ||||
| 		t.Error() | ||||
| 	} | ||||
| } | ||||
| @ -387,13 +388,21 @@ func TestFileSessionStoreSessionID(t *testing.T) { | ||||
| 		if err != nil { | ||||
| 			t.Error(err) | ||||
| 		} | ||||
| 		if s.SessionID(nil) != fmt.Sprintf("%s_%d", sid, i) { | ||||
| 		if s.SessionID(context.Background()) != fmt.Sprintf("%s_%d", sid, i) { | ||||
| 			t.Error(err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFileSessionStoreSessionRelease(t *testing.T) { | ||||
| 	releaseSession(t, false) | ||||
| } | ||||
| 
 | ||||
| func TestFileSessionStoreSessionReleaseIfPresent(t *testing.T) { | ||||
| 	releaseSession(t, true) | ||||
| } | ||||
| 
 | ||||
| func releaseSession(t *testing.T, requirePresent bool) { | ||||
| 	mutex.Lock() | ||||
| 	defer mutex.Unlock() | ||||
| 	os.RemoveAll(sessionPath) | ||||
| @ -410,8 +419,13 @@ func TestFileSessionStoreSessionRelease(t *testing.T) { | ||||
| 			t.Error(err) | ||||
| 		} | ||||
| 
 | ||||
| 		s.Set(nil, i, i) | ||||
| 		s.SessionRelease(nil, nil) | ||||
| 		s.Set(context.Background(), i, i) | ||||
| 		if requirePresent { | ||||
| 			s.SessionReleaseIfPresent(context.Background(), httptest.NewRecorder()) | ||||
| 		} else { | ||||
| 			s.SessionRelease(context.Background(), httptest.NewRecorder()) | ||||
| 		} | ||||
| 
 | ||||
| 	} | ||||
| 
 | ||||
| 	for i := 1; i <= sessionCount; i++ { | ||||
| @ -425,3 +439,36 @@ func TestFileSessionStoreSessionRelease(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestFileSessionStoreSessionReleaseIfPresentAndSessionDestroy(t *testing.T) { | ||||
| 	mutex.Lock() | ||||
| 	defer mutex.Unlock() | ||||
| 	os.RemoveAll(sessionPath) | ||||
| 	defer os.RemoveAll(sessionPath) | ||||
| 	fp := &FileProvider{} | ||||
| 	s, err := fp.SessionRead(context.Background(), sid) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	_ = fp.SessionInit(context.Background(), 180, sessionPath) | ||||
| 	filepder.savePath = sessionPath | ||||
| 	if err := fp.SessionDestroy(context.Background(), sid); err != nil { | ||||
| 		t.Error(err) | ||||
| 		return | ||||
| 	} | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	wg.Add(1) | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		s.SessionReleaseIfPresent(context.Background(), httptest.NewRecorder()) | ||||
| 	}() | ||||
| 	wg.Wait() | ||||
| 	exist, err := fp.SessionExist(context.Background(), sid) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 	if exist { | ||||
| 		t.Fatalf("session %s should exist", sid) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -73,7 +73,11 @@ func (st *MemSessionStore) SessionID(context.Context) string { | ||||
| } | ||||
| 
 | ||||
| // SessionRelease Implement method, no used. | ||||
| func (st *MemSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| func (st *MemSessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent Implement method, no used. | ||||
| func (*MemSessionStore) SessionReleaseIfPresent(_ context.Context, _ http.ResponseWriter) { | ||||
| } | ||||
| 
 | ||||
| // MemProvider Implement the provider interface | ||||
|  | ||||
| @ -44,12 +44,13 @@ import ( | ||||
| 
 | ||||
| // Store contains all data for one session process with specific id. | ||||
| type Store interface { | ||||
| 	Set(ctx context.Context, key, value interface{}) error     // set session value | ||||
| 	Get(ctx context.Context, key interface{}) interface{}      // get session value | ||||
| 	Delete(ctx context.Context, key interface{}) error         // delete session value | ||||
| 	SessionID(ctx context.Context) string                      // back current sessionID | ||||
| 	SessionRelease(ctx context.Context, w http.ResponseWriter) // release the resource & save data to provider & return the data | ||||
| 	Flush(ctx context.Context) error                           // delete all data | ||||
| 	Set(ctx context.Context, key, value interface{}) error              // Set set session value | ||||
| 	Get(ctx context.Context, key interface{}) interface{}               // Get get session value | ||||
| 	Delete(ctx context.Context, key interface{}) error                  // Delete delete session value | ||||
| 	SessionID(ctx context.Context) string                               // SessionID return current sessionID | ||||
| 	SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) // SessionReleaseIfPresent release the resource & save data to provider & return the data when the session is present, not all implementation support this feature, you need to check if the specific implementation if support this feature. | ||||
| 	SessionRelease(ctx context.Context, w http.ResponseWriter)          // SessionRelease release the resource & save data to provider & return the data | ||||
| 	Flush(ctx context.Context) error                                    // Flush delete all data | ||||
| } | ||||
| 
 | ||||
| // Provider contains global session methods and saved SessionStores. | ||||
| @ -153,7 +154,7 @@ func (manager *Manager) GetProvider() Provider { | ||||
| // | ||||
| // error is not nil when there is anything wrong. | ||||
| // sid is empty when need to generate a new session id | ||||
| // otherwise return an valid session id. | ||||
| // otherwise return a valid session id. | ||||
| func (manager *Manager) getSid(r *http.Request) (string, error) { | ||||
| 	cookie, errs := r.Cookie(manager.config.CookieName) | ||||
| 	if errs != nil || cookie.Value == "" { | ||||
| @ -279,7 +280,7 @@ func (manager *Manager) GC() { | ||||
| 	time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) | ||||
| } | ||||
| 
 | ||||
| // SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. | ||||
| // SessionRegenerateID Regenerate a session id for this SessionStore whose id is saving in http request. | ||||
| func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (Store, error) { | ||||
| 	sid, err := manager.sessionID() | ||||
| 	if err != nil { | ||||
|  | ||||
| @ -85,7 +85,7 @@ func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, | ||||
| 	return rs, nil | ||||
| } | ||||
| 
 | ||||
| // SessionExist judged whether sid is exist in session | ||||
| // SessionExist judged whether sid existed in session | ||||
| func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { | ||||
| 	if p.client == nil { | ||||
| 		if err := p.connectInit(); err != nil { | ||||
| @ -204,7 +204,7 @@ func (s *SessionStore) SessionID(context.Context) string { | ||||
| } | ||||
| 
 | ||||
| // SessionRelease Store the keyvalues into ssdb | ||||
| func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { | ||||
| func (s *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { | ||||
| 	s.lock.RLock() | ||||
| 	values := s.values | ||||
| 	s.lock.RUnlock() | ||||
| @ -215,6 +215,12 @@ func (s *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter | ||||
| 	s.client.Do("setx", s.sid, string(b), s.maxLifetime) | ||||
| } | ||||
| 
 | ||||
| // SessionReleaseIfPresent is not supported now | ||||
| // Because ssdb does not support lua script or SETXX command | ||||
| func (s *SessionStore) SessionReleaseIfPresent(c context.Context, w http.ResponseWriter) { | ||||
| 	s.SessionRelease(c, w) | ||||
| } | ||||
| 
 | ||||
| func init() { | ||||
| 	session.Register("ssdb", ssdbProvider) | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user