oauth2-proxy/oauth2-proxy

View on GitHub
pkg/middleware/stored_session.go

Summary

Maintainability
A
50 mins
Test Coverage
A
90%
package middleware

import (
    "context"
    "errors"
    "fmt"
    "net/http"
    "time"

    "github.com/justinas/alice"
    middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
    sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
    "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
    "github.com/oauth2-proxy/oauth2-proxy/v7/providers"
)

const (
    // When attempting to obtain the lock, if it's not done before this timeout
    // then exit and fail the refresh attempt.
    // TODO: This should probably be configurable by the end user.
    sessionRefreshObtainTimeout = 5 * time.Second

    // Maximum time allowed for a session refresh attempt.
    // If the refresh request isn't finished within this time, the lock will be
    // released.
    // TODO: This should probably be configurable by the end user.
    sessionRefreshLockDuration = 2 * time.Second

    // How long to wait after failing to obtain the lock before trying again.
    // TODO: This should probably be configurable by the end user.
    sessionRefreshRetryPeriod = 10 * time.Millisecond
)

// StoredSessionLoaderOptions contains all of the requirements to construct
// a stored session loader.
// All options must be provided.
type StoredSessionLoaderOptions struct {
    // Session storage backend
    SessionStore sessionsapi.SessionStore

    // How often should sessions be refreshed
    RefreshPeriod time.Duration

    // Provider based session refreshing
    RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)

    // Provider based session validation.
    // If the sesssion is older than `RefreshPeriod` but the provider doesn't
    // refresh it, we must re-validate using this validation.
    ValidateSession func(context.Context, *sessionsapi.SessionState) bool
}

// NewStoredSessionLoader creates a new storedSessionLoader which loads
// sessions from the session store.
// If no session is found, the request will be passed to the nex handler.
// If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor {
    ss := &storedSessionLoader{
        store:            opts.SessionStore,
        refreshPeriod:    opts.RefreshPeriod,
        sessionRefresher: opts.RefreshSession,
        sessionValidator: opts.ValidateSession,
    }
    return ss.loadSession
}

// storedSessionLoader is responsible for loading sessions from cookie
// identified sessions in the session store.
type storedSessionLoader struct {
    store            sessionsapi.SessionStore
    refreshPeriod    time.Duration
    sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error)
    sessionValidator func(context.Context, *sessionsapi.SessionState) bool
}

// loadSession attempts to load a session as identified by the request cookies.
// If no session is found, the request will be passed to the next handler.
// If a session was loader by a previous handler, it will not be replaced.
func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler {
    return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
        scope := middlewareapi.GetRequestScope(req)
        // If scope is nil, this will panic.
        // A scope should always be injected before this handler is called.
        if scope.Session != nil {
            // The session was already loaded, pass to the next handler
            next.ServeHTTP(rw, req)
            return
        }

        session, err := s.getValidatedSession(rw, req)
        if err != nil && !errors.Is(err, http.ErrNoCookie) {
            // In the case when there was an error loading the session,
            // we should clear the session
            logger.Errorf("Error loading cookied session: %v, removing session", err)
            err = s.store.Clear(rw, req)
            if err != nil {
                logger.Errorf("Error removing session: %v", err)
            }
        }

        // Add the session to the scope if it was found
        scope.Session = session
        next.ServeHTTP(rw, req)
    })
}

// getValidatedSession is responsible for loading a session and making sure
// that is is valid.
func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
    session, err := s.store.Load(req)
    if err != nil || session == nil {
        // No session was found in the storage or error occurred, nothing more to do
        return nil, err
    }

    err = s.refreshSessionIfNeeded(rw, req, session)
    if err != nil {
        return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err)
    }

    return session, nil
}

// refreshSessionIfNeeded will attempt to refresh a session if the session
// is older than the refresh period.
// Success or fail, we will then validate the session.
func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
    if !needsRefresh(s.refreshPeriod, session) {
        // Refresh is disabled or the session is not old enough, do nothing
        return nil
    }

    var lockObtained bool
    ctx, cancel := context.WithTimeout(context.Background(), sessionRefreshObtainTimeout)
    defer cancel()

    for !lockObtained {
        select {
        case <-ctx.Done():
            return errors.New("timeout obtaining session lock")
        default:
            err := session.ObtainLock(req.Context(), sessionRefreshLockDuration)
            if err != nil && !errors.Is(err, sessionsapi.ErrLockNotObtained) {
                return fmt.Errorf("error occurred while trying to obtain lock: %v", err)
            } else if errors.Is(err, sessionsapi.ErrLockNotObtained) {
                time.Sleep(sessionRefreshRetryPeriod)
                continue
            }
            // No error means we obtained the lock
            lockObtained = true
        }
    }

    // The rest of this function is carried out under lock, but we must release it
    // wherever we exit from this function.
    defer func() {
        if session == nil {
            return
        }
        if err := session.ReleaseLock(req.Context()); err != nil {
            logger.Errorf("unable to release lock: %v", err)
        }
    }()

    // Reload the session in case it was changed underneath us.
    freshSession, err := s.store.Load(req)
    if err != nil {
        return fmt.Errorf("could not load session: %v", err)
    }
    if freshSession == nil {
        return errors.New("session no longer exists, it may have been removed by another request")
    }
    // Restore the state of the fresh session into the original pointer.
    // This is important so that changes are passed up the to the parent scope.
    lock := session.Lock
    *session = *freshSession

    // Ensure we maintain the session lock after we have refreshed the session.
    // Loading from the session store creates a new lock in the session.
    session.Lock = lock

    if !needsRefresh(s.refreshPeriod, session) {
        // The session must have already been refreshed while we were waiting to
        // obtain the lock.
        return nil
    }

    // We are holding the lock and the session needs a refresh
    logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
    if err := s.refreshSession(rw, req, session); err != nil {
        // If a preemptive refresh fails, we still keep the session
        // if validateSession succeeds.
        logger.Errorf("Unable to refresh session: %v", err)
    }

    // Validate all sessions after any Redeem/Refresh operation (fail or success)
    return s.validateSession(req.Context(), session)
}

// needsRefresh determines whether we should attempt to refresh a session or not.
func needsRefresh(refreshPeriod time.Duration, session *sessionsapi.SessionState) bool {
    return refreshPeriod > time.Duration(0) && session.Age() > refreshPeriod
}

// refreshSession attempts to refresh the session with the provider
// and will save the session if it was updated.
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
    refreshed, err := s.sessionRefresher(req.Context(), session)
    if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
        return fmt.Errorf("error refreshing tokens: %v", err)
    }

    // HACK:
    // Providers that don't implement `RefreshSession` use the default
    // implementation which returns `ErrNotImplemented`.
    // Pretend it refreshed to reset the refresh timer so that `ValidateSession`
    // isn't triggered every subsequent request and is only called once during
    // this request.
    if errors.Is(err, providers.ErrNotImplemented) {
        refreshed = true
    }

    // Session not refreshed, nothing to persist.
    if !refreshed {
        return nil
    }

    // If we refreshed, update the `CreatedAt` time to reset the refresh timer
    // (In case underlying provider implementations forget)
    session.CreatedAtNow()

    // Because the session was refreshed, make sure to save it
    err = s.store.Save(rw, req, session)
    if err != nil {
        logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err)
        return fmt.Errorf("error saving session: %v", err)
    }
    return nil
}

// validateSession checks whether the session has expired and performs
// provider validation on the session.
// An error implies the session is not longer valid.
func (s *storedSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState) error {
    if session.IsExpired() {
        return errors.New("session is expired")
    }

    if !s.sessionValidator(ctx, session) {
        return errors.New("session is invalid")
    }

    return nil
}