Skip to content

Commit

Permalink
fix session mw
Browse files Browse the repository at this point in the history
  • Loading branch information
seefs001 committed Sep 26, 2024
1 parent 5ffb2dc commit e3a12cc
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 23 deletions.
69 changes: 52 additions & 17 deletions examples/xmw_example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,74 @@ import (
func main() {
// Create multiple handler functions
helloHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sm := xmw.GetSessionManager(r)
if sm != nil {
if name, ok := sm.Get(r, "name"); ok {
fmt.Fprintf(w, "Hello, %s!", name)
return
}
}
fmt.Fprintf(w, "Hello, World!")
})

timeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Current time: %s", time.Now().Format(time.RFC3339))
})

setNameHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sm := xmw.GetSessionManager(r)
if sm != nil {
name := r.URL.Query().Get("name")
if name != "" {
sm.Set(r, "name", name)
fmt.Fprintf(w, "Name set to: %s", name)
} else {
fmt.Fprintf(w, "Please provide a name using the 'name' query parameter")
return
}
}
})

clearSessionHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sm := xmw.GetSessionManager(r)
if sm != nil {
sm.Clear(r)
fmt.Fprintf(w, "Session cleared")
}
})

// Create a mux (multiplexer) to route requests
mux := http.NewServeMux()
mux.Handle("/", helloHandler)
mux.Handle("/time", timeHandler)
mux.Handle("/setname", setNameHandler)
mux.Handle("/clear", clearSessionHandler)

// Define middleware stack
middlewareStack := []xmw.Middleware{
xmw.Logger(),
xmw.Recover(),
xmw.Timeout(xmw.TimeoutConfig{
Timeout: 5 * time.Second,
}),
xmw.CORS(xmw.CORSConfig{
AllowOrigins: []string{"*"},
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
}),
xmw.Compress(),
xmw.BasicAuth(xmw.BasicAuthConfig{
Users: map[string]string{"user": "password"},
Realm: "Restricted",
}),
}
// middlewareStack := []xmw.Middleware{
// xmw.Logger(),
// xmw.Recover(),
// xmw.Timeout(xmw.TimeoutConfig{
// Timeout: 5 * time.Second,
// }),
// xmw.CORS(xmw.CORSConfig{
// AllowOrigins: []string{"*"},
// AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
// }),
// xmw.Compress(),
// xmw.Session(xmw.SessionConfig{
// Store: xmw.NewMemoryStore(),
// CookieName: "session_id",
// MaxAge: 3600, // 1 hour
// SessionName: "custom_session", // You can specify a custom session name here
// }),
// }

// Apply middleware to the mux
finalHandler := xmw.Use(mux, middlewareStack...)
// finalHandler := xmw.Use(mux, middlewareStack...)
finalHandler := xmw.Use(mux, xmw.DefaultMiddlewareSet...)

// Start the server
fmt.Println("Server is running on http://localhost:8080")
http.ListenAndServe(":8080", finalHandler)
}
80 changes: 74 additions & 6 deletions xmw/xmw.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type LoggerConfig struct {
func Logger(config ...LoggerConfig) Middleware {
cfg := LoggerConfig{
Next: nil,
Format: "[${time}] ${status} - ${latency} ${method} ${path}?${query} ${ip} ${user_agent} ${body_size}\n",
Format: "[${time}] ${status} - ${latency} ${method} ${path} ${query} ${ip} ${user_agent} ${body_size}\n",
TimeFormat: "2006-01-02 15:04:05",
TimeZone: "Local",
TimeInterval: 500 * time.Millisecond,
Expand Down Expand Up @@ -600,6 +600,9 @@ func (m *MemoryStore) Delete(sessionID string) error {
return nil
}

// DefaultSessionName is the default name for the session in the context
const DefaultSessionName = "ctx_session"

// SessionConfig defines the config for Session middleware
type SessionConfig struct {
Next func(c *http.Request) bool
Expand All @@ -609,14 +612,76 @@ type SessionConfig struct {
SessionName string
}

// SessionManager handles session operations
type SessionManager struct {
store SessionStore
cookieName string
maxAge int
sessionName string
}

// NewSessionManager creates a new SessionManager
func NewSessionManager(store SessionStore, cookieName string, maxAge int, sessionName string) *SessionManager {
return &SessionManager{
store: store,
cookieName: cookieName,
maxAge: maxAge,
sessionName: sessionName,
}
}

// Get retrieves a value from the session
func (sm *SessionManager) Get(r *http.Request, key string) (interface{}, bool) {
session := sm.getSession(r)
if session == nil {
return nil, false
}
value, ok := session[key]
return value, ok
}

// Set sets a value in the session
func (sm *SessionManager) Set(r *http.Request, key string, value interface{}) {
session := sm.getSession(r)
if session != nil {
session[key] = value
}
}

// Delete removes a value from the session
func (sm *SessionManager) Delete(r *http.Request, key string) {
session := sm.getSession(r)
if session != nil {
delete(session, key)
}
}

// Clear removes all values from the session
func (sm *SessionManager) Clear(r *http.Request) {
session := sm.getSession(r)
if session != nil {
for key := range session {
delete(session, key)
}
}
}

// getSession retrieves the session from the request context
func (sm *SessionManager) getSession(r *http.Request) map[string]interface{} {
if session, ok := r.Context().Value(sm.sessionName).(map[string]interface{}); ok {
return session
}
return nil
}

// Session returns a middleware that handles session management
func Session(config ...SessionConfig) Middleware {
cfg := SessionConfig{
Next: nil,
Store: NewMemoryStore(),
CookieName: "session_id",
MaxAge: 86400, // 1 day
SessionName: "session",
SessionName: DefaultSessionName,
}

if len(config) > 0 {
Expand All @@ -637,6 +702,8 @@ func Session(config ...SessionConfig) Middleware {
}
}

sessionManager := NewSessionManager(cfg.Store, cfg.CookieName, cfg.MaxAge, cfg.SessionName)

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if cfg.Next != nil && cfg.Next(r) {
Expand Down Expand Up @@ -665,6 +732,7 @@ func Session(config ...SessionConfig) Middleware {
}

ctx := context.WithValue(r.Context(), cfg.SessionName, session)
ctx = context.WithValue(ctx, "sessionManager", sessionManager)
r = r.WithContext(ctx)

next.ServeHTTP(w, r)
Expand All @@ -678,10 +746,10 @@ func generateSessionID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}

// GetSession retrieves the session from the request context
func GetSession(r *http.Request, sessionName string) map[string]interface{} {
if session, ok := r.Context().Value(sessionName).(map[string]interface{}); ok {
return session
// GetSessionManager retrieves the SessionManager from the request context
func GetSessionManager(r *http.Request) *SessionManager {
if sm, ok := r.Context().Value("sessionManager").(*SessionManager); ok {
return sm
}
return nil
}
Expand Down

0 comments on commit e3a12cc

Please sign in to comment.