diff --git a/go.mod b/go.mod index 89baa406..ce67b7d2 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-cmp v0.5.0 // indirect - github.com/google/uuid v1.1.1 // indirect + github.com/google/uuid v1.1.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/hashicorp/golang-lru v0.5.4 github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go new file mode 100644 index 00000000..b37a9f51 --- /dev/null +++ b/server/web/filter/session/filter.go @@ -0,0 +1,65 @@ +package session + +import ( + "context" + "errors" + "fmt" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" + "github.com/google/uuid" + "sync" +) + +var ( + sessionKey string + sessionKeyOnce sync.Once +) + +func getSessionKey() string { + + sessionKeyOnce.Do(func() { + //generate an unique session store key + sessionKey = fmt.Sprintf(`sess_store:%d`, uuid.New().ID()) + }) + + return sessionKey +} + +func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { + sessionConfig := session.NewManagerConfig(options...) + sessionManager, _ := session.NewManager(string(providerType), sessionConfig) + go sessionManager.GC() + + return func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + + if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil { + logs.Warning(`init session error:%s`, err.Error()) + } else { + //release session at the end of request + defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) + ctx.Input.SetData(getSessionKey(), sess) + } + + next(ctx) + + } + } +} + +func GetStore(ctx *webContext.Context) (store session.Store, err error) { + if ctx == nil { + err = errors.New(`ctx is nil`) + return + } + + if s, ok := ctx.Input.GetData(getSessionKey()).(session.Store); ok { + store = s + return + } else { + err = errors.New(`can not get a valid session store`) + return + } +} diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go new file mode 100644 index 00000000..7b24f6ad --- /dev/null +++ b/server/web/filter/session/filter_test.go @@ -0,0 +1,112 @@ +package session + +import ( + "context" + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" + "net/http" + "net/http/httptest" + "testing" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code) + } +} + +func TestSession(t *testing.T) { + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store := ctx.Input.GetData(getSessionKey()); store == nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} + +func TestGetStore(t *testing.T) { + handler := web.NewControllerRegister() + + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + var ( + checkKey = `asodiuasdk1j)AS(87` + checkValue = `ASsd-09812-3` + + store session.Store + err error + + c = context.Background() + ) + + if store, err = GetStore(ctx); err == nil { + if store == nil { + t.Error(`store should not be nil`) + } else { + _ = store.Set(c, checkKey, checkValue) + } + } else { + t.Error(err) + } + + next(ctx) + + if store != nil { + if v := store.Get(c, checkKey); v != checkValue { + t.Error(v, `is not equals to`, checkValue) + } + }else{ + t.Error(`store should not be nil`) + } + + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} diff --git a/server/web/session/session_provider_type.go b/server/web/session/session_provider_type.go index c14a3ecc..78dc116d 100644 --- a/server/web/session/session_provider_type.go +++ b/server/web/session/session_provider_type.go @@ -1,16 +1,18 @@ package session +type ProviderType string + const ( - ProviderCookie = `cookie` - ProviderFile = `file` - ProviderMemory = `memory` - ProviderCouchbase = `couchbase` - ProviderLedis = `ledis` - ProviderMemcache = `memcache` - ProviderMysql = `mysql` - ProviderPostgresql = `postgresql` - ProviderRedis = `redis` - ProviderRedisCluster = `redis_cluster` - ProviderRedisSentinel = `redis_sentinel` - ProviderSsdb = `ssdb` + ProviderCookie ProviderType = `cookie` + ProviderFile ProviderType = `file` + ProviderMemory ProviderType = `memory` + ProviderCouchbase ProviderType = `couchbase` + ProviderLedis ProviderType = `ledis` + ProviderMemcache ProviderType = `memcache` + ProviderMysql ProviderType = `mysql` + ProviderPostgresql ProviderType = `postgresql` + ProviderRedis ProviderType = `redis` + ProviderRedisCluster ProviderType = `redis_cluster` + ProviderRedisSentinel ProviderType = `redis_sentinel` + ProviderSsdb ProviderType = `ssdb` )