ory-am/hydra

View on GitHub
oauth2/fosite_store_helpers.go

Summary

Maintainability
F
1 wk
Test Coverage
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package oauth2

import (
    "context"
    "crypto/sha256"
    "fmt"
    "net/url"
    "slices"
    "testing"
    "time"

    "github.com/ory/x/assertx"

    "github.com/ory/hydra/v2/flow"
    "github.com/ory/hydra/v2/jwk"

    "github.com/go-jose/go-jose/v3"
    "github.com/gobuffalo/pop/v6"
    "github.com/pborman/uuid"

    "github.com/ory/fosite/handler/rfc7523"

    "github.com/ory/hydra/v2/oauth2/trust"

    "github.com/ory/hydra/v2/x"

    "github.com/ory/fosite/storage"
    "github.com/ory/x/sqlxx"

    gofrsuuid "github.com/gofrs/uuid"
    "github.com/pkg/errors"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"

    "github.com/ory/fosite"
    "github.com/ory/fosite/handler/openid"
    "github.com/ory/x/sqlcon"

    "github.com/ory/hydra/v2/client"
)

func signatureFromJTI(jti string) string {
    return fmt.Sprintf("%x", sha256.Sum256([]byte(jti)))
}

type BlacklistedJTI struct {
    JTI    string         `db:"-"`
    ID     string         `db:"signature"`
    Expiry time.Time      `db:"expires_at"`
    NID    gofrsuuid.UUID `db:"nid"`
}

func (j *BlacklistedJTI) AfterFind(_ *pop.Connection) error {
    j.Expiry = j.Expiry.UTC()
    return nil
}

func (BlacklistedJTI) TableName() string {
    return "hydra_oauth2_jti_blacklist"
}

func NewBlacklistedJTI(jti string, exp time.Time) *BlacklistedJTI {
    return &BlacklistedJTI{
        JTI: jti,
        ID:  signatureFromJTI(jti),
        // because the database timestamp types are not as accurate as time.Time we truncate to seconds (which should always work)
        Expiry: exp.UTC().Truncate(time.Second),
    }
}

type AssertionJWTReader interface {
    x.FositeStorer

    GetClientAssertionJWT(ctx context.Context, jti string) (*BlacklistedJTI, error)

    SetClientAssertionJWTRaw(context.Context, *BlacklistedJTI) error
}

var defaultIgnoreKeys = []string{
    "id",
    "session",
    "requested_scope",
    "granted_scope",
    "form",
    "created_at",
    "updated_at",
    "client.created_at",
    "client.updated_at",
    "requestedAt",
    "client.client_secret",
}

var defaultRequest = fosite.Request{
    ID:          "blank",
    RequestedAt: time.Now().UTC().Round(time.Second),
    Client: &client.Client{
        ID:                 "foobar",
        Contacts:           []string{},
        RedirectURIs:       []string{},
        Audience:           []string{},
        AllowedCORSOrigins: []string{},
        ResponseTypes:      []string{},
        GrantTypes:         []string{},
        JSONWebKeys:        &x.JoseJSONWebKeySet{},
        Metadata:           sqlxx.JSONRawMessage("{}"),
    },
    RequestedScope:    fosite.Arguments{"fa", "ba"},
    GrantedScope:      fosite.Arguments{"fa", "ba"},
    RequestedAudience: fosite.Arguments{"ad1", "ad2"},
    GrantedAudience:   fosite.Arguments{"ad1", "ad2"},
    Form:              url.Values{"foo": []string{"bar", "baz"}},
    Session:           NewSession("bar"),
}

var lifespan = time.Hour
var flushRequests = []*fosite.Request{
    {
        ID:             "flush-1",
        RequestedAt:    time.Now().Round(time.Second),
        Client:         &client.Client{ID: "foobar"},
        RequestedScope: fosite.Arguments{"fa", "ba"},
        GrantedScope:   fosite.Arguments{"fa", "ba"},
        Form:           url.Values{"foo": []string{"bar", "baz"}},
        Session:        &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}},
    },
    {
        ID:             "flush-2",
        RequestedAt:    time.Now().Round(time.Second).Add(-(lifespan + time.Minute)),
        Client:         &client.Client{ID: "foobar"},
        RequestedScope: fosite.Arguments{"fa", "ba"},
        GrantedScope:   fosite.Arguments{"fa", "ba"},
        Form:           url.Values{"foo": []string{"bar", "baz"}},
        Session:        &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}},
    },
    {
        ID:             "flush-3",
        RequestedAt:    time.Now().Round(time.Second).Add(-(lifespan + time.Hour)),
        Client:         &client.Client{ID: "foobar"},
        RequestedScope: fosite.Arguments{"fa", "ba"},
        GrantedScope:   fosite.Arguments{"fa", "ba"},
        Form:           url.Values{"foo": []string{"bar", "baz"}},
        Session:        &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}},
    },
}

func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry) {
    cl := &client.Client{ID: "foobar"}
    cr := &flow.OAuth2ConsentRequest{
        Client:               cl,
        OpenIDConnectContext: new(flow.OAuth2ConsentRequestOpenIDConnectContext),
        LoginChallenge:       sqlxx.NullString(id),
        ID:                   id,
        Verifier:             id,
        CSRF:                 id,
        AuthenticatedAt:      sqlxx.NullTime(time.Now()),
        RequestedAt:          time.Now(),
    }

    ctx := context.Background()
    if _, err := x.ClientManager().GetClient(ctx, cl.ID); errors.Is(err, sqlcon.ErrNoRows) {
        require.NoError(t, x.ClientManager().CreateClient(ctx, cl))
    }

    f, err := x.ConsentManager().CreateLoginRequest(
        ctx, &flow.LoginRequest{
            Client:               cl,
            OpenIDConnectContext: new(flow.OAuth2ConsentRequestOpenIDConnectContext),
            ID:                   id,
            Verifier:             id,
            AuthenticatedAt:      sqlxx.NullTime(time.Now()),
            RequestedAt:          time.Now(),
        })
    require.NoError(t, err)
    err = x.ConsentManager().CreateConsentRequest(ctx, f, cr)
    require.NoError(t, err)

    encodedFlow, err := f.ToConsentVerifier(ctx, x)
    require.NoError(t, err)

    _, err = x.ConsentManager().HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{
        ConsentRequest:  cr,
        Session:         new(flow.AcceptOAuth2ConsentRequestSession),
        AuthenticatedAt: sqlxx.NullTime(time.Now()),
        ID:              encodedFlow,
        RequestedAt:     time.Now(),
        HandledAt:       sqlxx.NullTime(time.Now()),
    })

    require.NoError(t, err)
}

// TestHelperRunner is used to run the database suite of tests in this package.
// KEEP EXPORTED AND AVAILABLE FOR THIRD PARTIES TO TEST PLUGINS!
func TestHelperRunner(t *testing.T, store InternalRegistry, k string) {
    t.Helper()
    if k != "memory" {
        t.Run(fmt.Sprintf("case=testHelperUniqueConstraints/db=%s", k), testHelperRequestIDMultiples(store, k))
        t.Run("case=testFositeSqlStoreTransactionsCommitAccessToken", testFositeSqlStoreTransactionCommitAccessToken(store))
        t.Run("case=testFositeSqlStoreTransactionsRollbackAccessToken", testFositeSqlStoreTransactionRollbackAccessToken(store))
        t.Run("case=testFositeSqlStoreTransactionCommitRefreshToken", testFositeSqlStoreTransactionCommitRefreshToken(store))
        t.Run("case=testFositeSqlStoreTransactionRollbackRefreshToken", testFositeSqlStoreTransactionRollbackRefreshToken(store))
        t.Run("case=testFositeSqlStoreTransactionCommitAuthorizeCode", testFositeSqlStoreTransactionCommitAuthorizeCode(store))
        t.Run("case=testFositeSqlStoreTransactionRollbackAuthorizeCode", testFositeSqlStoreTransactionRollbackAuthorizeCode(store))
        t.Run("case=testFositeSqlStoreTransactionCommitPKCERequest", testFositeSqlStoreTransactionCommitPKCERequest(store))
        t.Run("case=testFositeSqlStoreTransactionRollbackPKCERequest", testFositeSqlStoreTransactionRollbackPKCERequest(store))
        t.Run("case=testFositeSqlStoreTransactionCommitOpenIdConnectSession", testFositeSqlStoreTransactionCommitOpenIdConnectSession(store))
        t.Run("case=testFositeSqlStoreTransactionRollbackOpenIdConnectSession", testFositeSqlStoreTransactionRollbackOpenIdConnectSession(store))

    }
    t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteAuthorizeCodes/db=%s", k), testHelperCreateGetDeleteAuthorizeCodes(store))
    t.Run(fmt.Sprintf("case=testHelperExpiryFields/db=%s", k), testHelperExpiryFields(store))
    t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteAccessTokenSession/db=%s", k), testHelperCreateGetDeleteAccessTokenSession(store))
    t.Run(fmt.Sprintf("case=testHelperNilAccessToken/db=%s", k), testHelperNilAccessToken(store))
    t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteOpenIDConnectSession/db=%s", k), testHelperCreateGetDeleteOpenIDConnectSession(store))
    t.Run(fmt.Sprintf("case=testHelperCreateGetDeleteRefreshTokenSession/db=%s", k), testHelperCreateGetDeleteRefreshTokenSession(store))
    t.Run(fmt.Sprintf("case=testHelperRevokeRefreshToken/db=%s", k), testHelperRevokeRefreshToken(store))
    t.Run(fmt.Sprintf("case=testHelperCreateGetDeletePKCERequestSession/db=%s", k), testHelperCreateGetDeletePKCERequestSession(store))
    t.Run(fmt.Sprintf("case=testHelperFlushTokens/db=%s", k), testHelperFlushTokens(store, time.Hour))
    t.Run(fmt.Sprintf("case=testHelperFlushTokensWithLimitAndBatchSize/db=%s", k), testHelperFlushTokensWithLimitAndBatchSize(store, 3, 2))
    t.Run(fmt.Sprintf("case=testFositeStoreSetClientAssertionJWT/db=%s", k), testFositeStoreSetClientAssertionJWT(store))
    t.Run(fmt.Sprintf("case=testFositeStoreClientAssertionJWTValid/db=%s", k), testFositeStoreClientAssertionJWTValid(store))
    t.Run(fmt.Sprintf("case=testHelperDeleteAccessTokens/db=%s", k), testHelperDeleteAccessTokens(store))
    t.Run(fmt.Sprintf("case=testHelperRevokeAccessToken/db=%s", k), testHelperRevokeAccessToken(store))
    t.Run(fmt.Sprintf("case=testFositeJWTBearerGrantStorage/db=%s", k), testFositeJWTBearerGrantStorage(store))
}

func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) {
    return func(t *testing.T) {
        requestId := uuid.New()
        mockRequestForeignKey(t, requestId, m)
        cl := &client.Client{ID: "foobar"}

        fositeRequest := &fosite.Request{
            ID:          requestId,
            Client:      cl,
            RequestedAt: time.Now().UTC().Round(time.Second),
            Session:     NewSession("bar"),
        }

        for i := 0; i < 4; i++ {
            signature := uuid.New()
            err := m.OAuth2Storage().CreateRefreshTokenSession(context.TODO(), signature, fositeRequest)
            assert.NoError(t, err)
            err = m.OAuth2Storage().CreateAccessTokenSession(context.TODO(), signature, fositeRequest)
            assert.NoError(t, err)
            err = m.OAuth2Storage().CreateOpenIDConnectSession(context.TODO(), signature, fositeRequest)
            assert.NoError(t, err)
            err = m.OAuth2Storage().CreatePKCERequestSession(context.TODO(), signature, fositeRequest)
            assert.NoError(t, err)
            err = m.OAuth2Storage().CreateAuthorizeCodeSession(context.TODO(), signature, fositeRequest)
            assert.NoError(t, err)
        }
    }
}

func testHelperCreateGetDeleteOpenIDConnectSession(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := x.OAuth2Storage()

        ctx := context.Background()
        _, err := m.GetOpenIDConnectSession(ctx, "4321", &fosite.Request{Session: NewSession("bar")})
        assert.NotNil(t, err)

        err = m.CreateOpenIDConnectSession(ctx, "4321", &defaultRequest)
        require.NoError(t, err)

        res, err := m.GetOpenIDConnectSession(ctx, "4321", &fosite.Request{Session: NewSession("bar")})
        require.NoError(t, err)
        AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session")

        err = m.DeleteOpenIDConnectSession(ctx, "4321")
        require.NoError(t, err)

        _, err = m.GetOpenIDConnectSession(ctx, "4321", &fosite.Request{Session: NewSession("bar")})
        assert.NotNil(t, err)
    }
}

func testHelperCreateGetDeleteRefreshTokenSession(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := x.OAuth2Storage()

        ctx := context.Background()
        _, err := m.GetRefreshTokenSession(ctx, "4321", NewSession("bar"))
        assert.NotNil(t, err)

        err = m.CreateRefreshTokenSession(ctx, "4321", &defaultRequest)
        require.NoError(t, err)

        res, err := m.GetRefreshTokenSession(ctx, "4321", NewSession("bar"))
        require.NoError(t, err)
        AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session")

        err = m.DeleteRefreshTokenSession(ctx, "4321")
        require.NoError(t, err)

        _, err = m.GetRefreshTokenSession(ctx, "4321", NewSession("bar"))
        assert.NotNil(t, err)
    }
}

func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := x.OAuth2Storage()

        ctx := context.Background()
        _, err := m.GetRefreshTokenSession(ctx, "1111", NewSession("bar"))
        assert.Error(t, err)

        reqIdOne := uuid.New()
        reqIdTwo := uuid.New()

        mockRequestForeignKey(t, reqIdOne, x)
        mockRequestForeignKey(t, reqIdTwo, x)

        err = m.CreateRefreshTokenSession(ctx, "1111", &fosite.Request{
            ID:          reqIdOne,
            Client:      &client.Client{ID: "foobar"},
            RequestedAt: time.Now().UTC().Round(time.Second),
            Session:     NewSession("user"),
        })
        require.NoError(t, err)

        err = m.CreateRefreshTokenSession(ctx, "1122", &fosite.Request{
            ID:          reqIdTwo,
            Client:      &client.Client{ID: "foobar"},
            RequestedAt: time.Now().UTC().Round(time.Second),
            Session:     NewSession("user"),
        })
        require.NoError(t, err)

        _, err = m.GetRefreshTokenSession(ctx, "1111", NewSession("bar"))
        require.NoError(t, err)

        err = m.RevokeRefreshToken(ctx, reqIdOne)
        require.NoError(t, err)

        err = m.RevokeRefreshToken(ctx, reqIdTwo)
        require.NoError(t, err)

        req, err := m.GetRefreshTokenSession(ctx, "1111", NewSession("bar"))
        assert.NotNil(t, req)
        assert.EqualError(t, err, fosite.ErrInactiveToken.Error())

        req, err = m.GetRefreshTokenSession(ctx, "1122", NewSession("bar"))
        assert.NotNil(t, req)
        assert.EqualError(t, err, fosite.ErrInactiveToken.Error())

    }
}

func testHelperCreateGetDeleteAuthorizeCodes(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := x.OAuth2Storage()

        mockRequestForeignKey(t, "blank", x)

        ctx := context.Background()
        res, err := m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar"))
        assert.Error(t, err)
        assert.Nil(t, res)

        err = m.CreateAuthorizeCodeSession(ctx, "4321", &defaultRequest)
        require.NoError(t, err)

        res, err = m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar"))
        require.NoError(t, err)
        AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session")

        err = m.InvalidateAuthorizeCodeSession(ctx, "4321")
        require.NoError(t, err)

        res, err = m.GetAuthorizeCodeSession(ctx, "4321", NewSession("bar"))
        require.Error(t, err)
        assert.EqualError(t, err, fosite.ErrInvalidatedAuthorizeCode.Error())
        assert.NotNil(t, res)
    }
}

type testHelperExpiryFieldsResult struct {
    ExpiresAt time.Time `db:"expires_at"`
    name      string
}

func (r testHelperExpiryFieldsResult) TableName() string {
    return "hydra_oauth2_" + r.name
}

func testHelperExpiryFields(reg InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := reg.OAuth2Storage()
        t.Parallel()

        mockRequestForeignKey(t, "blank", reg)

        ctx := context.Background()

        s := NewSession("bar")
        s.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour).Round(time.Minute))
        s.SetExpiresAt(fosite.RefreshToken, time.Now().Add(time.Hour*2).Round(time.Minute))
        s.SetExpiresAt(fosite.AuthorizeCode, time.Now().Add(time.Hour*3).Round(time.Minute))
        request := fosite.Request{
            ID:          uuid.New(),
            RequestedAt: time.Now().UTC().Round(time.Second),
            Client: &client.Client{
                ID:       "foobar",
                Metadata: sqlxx.JSONRawMessage("{}"),
            },
            RequestedScope:    fosite.Arguments{"fa", "ba"},
            GrantedScope:      fosite.Arguments{"fa", "ba"},
            RequestedAudience: fosite.Arguments{"ad1", "ad2"},
            GrantedAudience:   fosite.Arguments{"ad1", "ad2"},
            Form:              url.Values{"foo": []string{"bar", "baz"}},
            Session:           s,
        }

        t.Run("case=CreateAccessTokenSession", func(t *testing.T) {
            id := uuid.New()
            err := m.CreateAccessTokenSession(ctx, id, &request)
            require.NoError(t, err)

            r := testHelperExpiryFieldsResult{name: "access"}
            require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", x.SignatureHash(id)).First(&r))

            assert.EqualValues(t, s.GetExpiresAt(fosite.AccessToken).UTC(), r.ExpiresAt.UTC())
        })

        t.Run("case=CreateRefreshTokenSession", func(t *testing.T) {
            id := uuid.New()
            err := m.CreateRefreshTokenSession(ctx, id, &request)
            require.NoError(t, err)

            r := testHelperExpiryFieldsResult{name: "refresh"}
            require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", id).First(&r))
            assert.EqualValues(t, s.GetExpiresAt(fosite.RefreshToken).UTC(), r.ExpiresAt.UTC())
        })

        t.Run("case=CreateAuthorizeCodeSession", func(t *testing.T) {
            id := uuid.New()
            err := m.CreateAuthorizeCodeSession(ctx, id, &request)
            require.NoError(t, err)

            r := testHelperExpiryFieldsResult{name: "code"}
            require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", id).First(&r))
            assert.EqualValues(t, s.GetExpiresAt(fosite.AuthorizeCode).UTC(), r.ExpiresAt.UTC())
        })

        t.Run("case=CreatePKCERequestSession", func(t *testing.T) {
            id := uuid.New()
            err := m.CreatePKCERequestSession(ctx, id, &request)
            require.NoError(t, err)

            r := testHelperExpiryFieldsResult{name: "pkce"}
            require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", id).First(&r))
            assert.EqualValues(t, s.GetExpiresAt(fosite.AuthorizeCode).UTC(), r.ExpiresAt.UTC())
        })

        t.Run("case=CreateOpenIDConnectSession", func(t *testing.T) {
            id := uuid.New()
            err := m.CreateOpenIDConnectSession(ctx, id, &request)
            require.NoError(t, err)

            r := testHelperExpiryFieldsResult{name: "oidc"}
            require.NoError(t, reg.Persister().Connection(ctx).Select("expires_at").Where("signature = ?", id).First(&r))
            assert.EqualValues(t, s.GetExpiresAt(fosite.AuthorizeCode).UTC(), r.ExpiresAt.UTC())
        })
    }
}

func testHelperNilAccessToken(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := x.OAuth2Storage()
        c := &client.Client{ID: "nil-request-client-id-123"}
        require.NoError(t, x.ClientManager().CreateClient(context.Background(), c))
        err := m.CreateAccessTokenSession(context.TODO(), "nil-request-id", &fosite.Request{
            ID:                "",
            RequestedAt:       time.Now().UTC().Round(time.Second),
            Client:            c,
            RequestedScope:    fosite.Arguments{"fa", "ba"},
            GrantedScope:      fosite.Arguments{"fa", "ba"},
            RequestedAudience: fosite.Arguments{"ad1", "ad2"},
            GrantedAudience:   fosite.Arguments{"ad1", "ad2"},
            Form:              url.Values{"foo": []string{"bar", "baz"}},
            Session:           NewSession("bar"),
        })
        require.NoError(t, err)
    }
}

func testHelperCreateGetDeleteAccessTokenSession(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := x.OAuth2Storage()

        ctx := context.Background()
        _, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar"))
        assert.Error(t, err)

        err = m.CreateAccessTokenSession(ctx, "4321", &defaultRequest)
        require.NoError(t, err)

        res, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar"))
        require.NoError(t, err)
        AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session")

        err = m.DeleteAccessTokenSession(ctx, "4321")
        require.NoError(t, err)

        _, err = m.GetAccessTokenSession(ctx, "4321", NewSession("bar"))
        assert.Error(t, err)
    }
}

func testHelperDeleteAccessTokens(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := x.OAuth2Storage()
        ctx := context.Background()

        err := m.CreateAccessTokenSession(ctx, "4321", &defaultRequest)
        require.NoError(t, err)

        _, err = m.GetAccessTokenSession(ctx, "4321", NewSession("bar"))
        require.NoError(t, err)

        err = m.DeleteAccessTokens(ctx, defaultRequest.Client.GetID())
        require.NoError(t, err)

        req, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar"))
        assert.Nil(t, req)
        assert.EqualError(t, err, fosite.ErrNotFound.Error())
    }
}

func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := x.OAuth2Storage()
        ctx := context.Background()

        err := m.CreateAccessTokenSession(ctx, "4321", &defaultRequest)
        require.NoError(t, err)

        _, err = m.GetAccessTokenSession(ctx, "4321", NewSession("bar"))
        require.NoError(t, err)

        err = m.RevokeAccessToken(ctx, defaultRequest.GetID())
        require.NoError(t, err)

        req, err := m.GetAccessTokenSession(ctx, "4321", NewSession("bar"))
        assert.Nil(t, req)
        assert.EqualError(t, err, fosite.ErrNotFound.Error())
    }
}

func testHelperCreateGetDeletePKCERequestSession(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        m := x.OAuth2Storage()

        ctx := context.Background()
        _, err := m.GetPKCERequestSession(ctx, "4321", NewSession("bar"))
        assert.NotNil(t, err)

        err = m.CreatePKCERequestSession(ctx, "4321", &defaultRequest)
        require.NoError(t, err)

        res, err := m.GetPKCERequestSession(ctx, "4321", NewSession("bar"))
        require.NoError(t, err)
        AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session")

        err = m.DeletePKCERequestSession(ctx, "4321")
        require.NoError(t, err)

        _, err = m.GetPKCERequestSession(ctx, "4321", NewSession("bar"))
        assert.NotNil(t, err)
    }
}

func testHelperFlushTokens(x InternalRegistry, lifespan time.Duration) func(t *testing.T) {
    m := x.OAuth2Storage()
    ds := &Session{}

    return func(t *testing.T) {
        ctx := context.Background()
        for _, r := range flushRequests {
            mockRequestForeignKey(t, r.ID, x)
            require.NoError(t, m.CreateAccessTokenSession(ctx, r.ID, r))
            _, err := m.GetAccessTokenSession(ctx, r.ID, ds)
            require.NoError(t, err)
        }

        require.NoError(t, m.FlushInactiveAccessTokens(ctx, time.Now().Add(-time.Hour*24), 100, 10))
        _, err := m.GetAccessTokenSession(ctx, "flush-1", ds)
        require.NoError(t, err)
        _, err = m.GetAccessTokenSession(ctx, "flush-2", ds)
        require.NoError(t, err)
        _, err = m.GetAccessTokenSession(ctx, "flush-3", ds)
        require.NoError(t, err)

        require.NoError(t, m.FlushInactiveAccessTokens(ctx, time.Now().Add(-(lifespan+time.Hour/2)), 100, 10))
        _, err = m.GetAccessTokenSession(ctx, "flush-1", ds)
        require.NoError(t, err)
        _, err = m.GetAccessTokenSession(ctx, "flush-2", ds)
        require.NoError(t, err)
        _, err = m.GetAccessTokenSession(ctx, "flush-3", ds)
        require.Error(t, err)

        require.NoError(t, m.FlushInactiveAccessTokens(ctx, time.Now(), 100, 10))
        _, err = m.GetAccessTokenSession(ctx, "flush-1", ds)
        require.NoError(t, err)
        _, err = m.GetAccessTokenSession(ctx, "flush-2", ds)
        require.Error(t, err)
        _, err = m.GetAccessTokenSession(ctx, "flush-3", ds)
        require.Error(t, err)
        require.NoError(t, m.DeleteAccessTokens(ctx, "foobar"))
    }
}

func testHelperFlushTokensWithLimitAndBatchSize(x InternalRegistry, limit int, batchSize int) func(t *testing.T) {
    m := x.OAuth2Storage()
    ds := &Session{}

    return func(t *testing.T) {
        ctx := context.Background()
        var requests []*fosite.Request

        // create five expired requests
        id := uuid.New()
        totalCount := 5
        for i := 0; i < totalCount; i++ {
            r := createTestRequest(fmt.Sprintf("%s-%d", id, i+1))
            r.RequestedAt = time.Now().Add(-2 * time.Hour)
            mockRequestForeignKey(t, r.ID, x)
            require.NoError(t, m.CreateAccessTokenSession(ctx, r.ID, r))
            _, err := m.GetAccessTokenSession(ctx, r.ID, ds)
            require.NoError(t, err)
            requests = append(requests, r)
        }

        require.NoError(t, m.FlushInactiveAccessTokens(ctx, time.Now(), limit, batchSize))
        var notFoundCount, foundCount int
        for i := range requests {
            if _, err := m.GetAccessTokenSession(ctx, requests[i].ID, ds); err == nil {
                foundCount++
            } else {
                require.ErrorIs(t, err, fosite.ErrNotFound)
                notFoundCount++
            }
        }
        assert.Equal(t, limit, notFoundCount, "should have deleted %d tokens", limit)
        assert.Equal(t, totalCount-limit, foundCount, "should have found %d tokens", totalCount-limit)
    }
}

func testFositeSqlStoreTransactionCommitAccessToken(m InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        {
            doTestCommit(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().RevokeAccessToken)
            doTestCommit(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().DeleteAccessTokenSession)
        }
    }
}

func testFositeSqlStoreTransactionRollbackAccessToken(m InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        {
            doTestRollback(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().RevokeAccessToken)
            doTestRollback(m, t, m.OAuth2Storage().CreateAccessTokenSession, m.OAuth2Storage().GetAccessTokenSession, m.OAuth2Storage().DeleteAccessTokenSession)
        }
    }
}

func testFositeSqlStoreTransactionCommitRefreshToken(m InternalRegistry) func(t *testing.T) {

    return func(t *testing.T) {
        doTestCommit(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken)
        doTestCommit(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession)
    }
}

func testFositeSqlStoreTransactionRollbackRefreshToken(m InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        doTestRollback(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().RevokeRefreshToken)
        doTestRollback(m, t, m.OAuth2Storage().CreateRefreshTokenSession, m.OAuth2Storage().GetRefreshTokenSession, m.OAuth2Storage().DeleteRefreshTokenSession)
    }
}

func testFositeSqlStoreTransactionCommitAuthorizeCode(m InternalRegistry) func(t *testing.T) {

    return func(t *testing.T) {
        doTestCommit(m, t, m.OAuth2Storage().CreateAuthorizeCodeSession, m.OAuth2Storage().GetAuthorizeCodeSession, m.OAuth2Storage().InvalidateAuthorizeCodeSession)
    }
}

func testFositeSqlStoreTransactionRollbackAuthorizeCode(m InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        doTestRollback(m, t, m.OAuth2Storage().CreateAuthorizeCodeSession, m.OAuth2Storage().GetAuthorizeCodeSession, m.OAuth2Storage().InvalidateAuthorizeCodeSession)
    }
}

func testFositeSqlStoreTransactionCommitPKCERequest(m InternalRegistry) func(t *testing.T) {

    return func(t *testing.T) {
        doTestCommit(m, t, m.OAuth2Storage().CreatePKCERequestSession, m.OAuth2Storage().GetPKCERequestSession, m.OAuth2Storage().DeletePKCERequestSession)
    }
}

func testFositeSqlStoreTransactionRollbackPKCERequest(m InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        doTestRollback(m, t, m.OAuth2Storage().CreatePKCERequestSession, m.OAuth2Storage().GetPKCERequestSession, m.OAuth2Storage().DeletePKCERequestSession)
    }
}

// OpenIdConnect tests can't use the helper functions, due to the signature of GetOpenIdConnectSession being
// different from the other getter methods
func testFositeSqlStoreTransactionCommitOpenIdConnectSession(m InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        txnStore, ok := m.OAuth2Storage().(storage.Transactional)
        require.True(t, ok)
        ctx := context.Background()
        ctx, err := txnStore.BeginTX(ctx)
        require.NoError(t, err)
        signature := uuid.New()
        testRequest := createTestRequest(signature)
        err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature, testRequest)
        require.NoError(t, err)
        err = txnStore.Commit(ctx)
        require.NoError(t, err)

        // Require a new context, since the old one contains the transaction.
        res, err := m.OAuth2Storage().GetOpenIDConnectSession(context.Background(), signature, testRequest)
        // session should have been created successfully because Commit did not return an error
        require.NoError(t, err)
        assertx.EqualAsJSONExcept(t, &defaultRequest, res, defaultIgnoreKeys)

        // test delete within a transaction
        ctx, err = txnStore.BeginTX(context.Background())
        require.NoError(t, err)
        err = m.OAuth2Storage().DeleteOpenIDConnectSession(ctx, signature)
        require.NoError(t, err)
        err = txnStore.Commit(ctx)
        require.NoError(t, err)

        // Require a new context, since the old one contains the transaction.
        _, err = m.OAuth2Storage().GetOpenIDConnectSession(context.Background(), signature, testRequest)
        // Since commit worked for delete, we should get an error here.
        require.Error(t, err)
    }
}

func testFositeSqlStoreTransactionRollbackOpenIdConnectSession(m InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        txnStore, ok := m.OAuth2Storage().(storage.Transactional)
        require.True(t, ok)
        ctx := context.Background()
        ctx, err := txnStore.BeginTX(ctx)
        require.NoError(t, err)

        signature := uuid.New()
        testRequest := createTestRequest(signature)
        err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature, testRequest)
        require.NoError(t, err)
        err = txnStore.Rollback(ctx)
        require.NoError(t, err)

        // Require a new context, since the old one contains the transaction.
        ctx = context.Background()
        _, err = m.OAuth2Storage().GetOpenIDConnectSession(ctx, signature, testRequest)
        // Since we rolled back above, the session should not exist and getting it should result in an error
        require.Error(t, err)

        // create a new session, delete it, then rollback the delete. We should be able to then get it.
        signature2 := uuid.New()
        testRequest2 := createTestRequest(signature2)
        err = m.OAuth2Storage().CreateOpenIDConnectSession(ctx, signature2, testRequest2)
        require.NoError(t, err)
        _, err = m.OAuth2Storage().GetOpenIDConnectSession(ctx, signature2, testRequest2)
        require.NoError(t, err)

        ctx, err = txnStore.BeginTX(context.Background())
        require.NoError(t, err)
        err = m.OAuth2Storage().DeleteOpenIDConnectSession(ctx, signature2)
        require.NoError(t, err)
        err = txnStore.Rollback(ctx)

        require.NoError(t, err)
        _, err = m.OAuth2Storage().GetOpenIDConnectSession(context.Background(), signature2, testRequest2)
        require.NoError(t, err)
    }
}

func testFositeStoreSetClientAssertionJWT(m InternalRegistry) func(*testing.T) {
    return func(t *testing.T) {
        t.Run("case=basic setting works", func(t *testing.T) {
            store, ok := m.OAuth2Storage().(AssertionJWTReader)
            require.True(t, ok)
            jti := NewBlacklistedJTI("basic jti", time.Now().Add(time.Minute))

            require.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry))

            cmp, err := store.GetClientAssertionJWT(context.Background(), jti.JTI)
            require.NotEqual(t, cmp.NID, gofrsuuid.Nil)
            cmp.NID = gofrsuuid.Nil
            require.NoError(t, err)
            assert.Equal(t, jti, cmp)
        })

        t.Run("case=errors when the JTI is blacklisted", func(t *testing.T) {
            store, ok := m.OAuth2Storage().(AssertionJWTReader)
            require.True(t, ok)
            jti := NewBlacklistedJTI("already set jti", time.Now().Add(time.Minute))
            require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti))

            assert.ErrorIs(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry), fosite.ErrJTIKnown)
        })

        t.Run("case=deletes expired JTIs", func(t *testing.T) {
            store, ok := m.OAuth2Storage().(AssertionJWTReader)
            require.True(t, ok)
            expiredJTI := NewBlacklistedJTI("expired jti", time.Now().Add(-time.Minute))
            require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), expiredJTI))
            newJTI := NewBlacklistedJTI("some new jti", time.Now().Add(time.Minute))

            require.NoError(t, store.SetClientAssertionJWT(context.Background(), newJTI.JTI, newJTI.Expiry))

            _, err := store.GetClientAssertionJWT(context.Background(), expiredJTI.JTI)
            assert.True(t, errors.Is(err, sqlcon.ErrNoRows))
            cmp, err := store.GetClientAssertionJWT(context.Background(), newJTI.JTI)
            require.NoError(t, err)
            require.NotEqual(t, cmp.NID, gofrsuuid.Nil)
            cmp.NID = gofrsuuid.Nil
            assert.Equal(t, newJTI, cmp)
        })

        t.Run("case=inserts same JTI if expired", func(t *testing.T) {
            store, ok := m.OAuth2Storage().(AssertionJWTReader)
            require.True(t, ok)
            jti := NewBlacklistedJTI("going to be reused jti", time.Now().Add(-time.Minute))
            require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti))

            jti.Expiry = jti.Expiry.Add(2 * time.Minute)
            assert.NoError(t, store.SetClientAssertionJWT(context.Background(), jti.JTI, jti.Expiry))
            cmp, err := store.GetClientAssertionJWT(context.Background(), jti.JTI)
            assert.NoError(t, err)
            assert.Equal(t, jti, cmp)
        })
    }
}

func testFositeStoreClientAssertionJWTValid(m InternalRegistry) func(*testing.T) {
    return func(t *testing.T) {
        t.Run("case=returns valid on unknown JTI", func(t *testing.T) {
            store, ok := m.OAuth2Storage().(AssertionJWTReader)
            require.True(t, ok)

            assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), "unknown jti"))
        })

        t.Run("case=returns invalid on known JTI", func(t *testing.T) {
            store, ok := m.OAuth2Storage().(AssertionJWTReader)
            require.True(t, ok)
            jti := NewBlacklistedJTI("known jti", time.Now().Add(time.Minute))

            require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti))

            assert.True(t, errors.Is(store.ClientAssertionJWTValid(context.Background(), jti.JTI), fosite.ErrJTIKnown))
        })

        t.Run("case=returns valid on expired JTI", func(t *testing.T) {
            store, ok := m.OAuth2Storage().(AssertionJWTReader)
            require.True(t, ok)
            jti := NewBlacklistedJTI("expired jti 2", time.Now().Add(-time.Minute))

            require.NoError(t, store.SetClientAssertionJWTRaw(context.Background(), jti))

            assert.NoError(t, store.ClientAssertionJWTValid(context.Background(), jti.JTI))
        })
    }
}

func testFositeJWTBearerGrantStorage(x InternalRegistry) func(t *testing.T) {
    return func(t *testing.T) {
        grantManager := x.GrantManager()
        keyManager := x.KeyManager()
        grantStorage := x.OAuth2Storage().(rfc7523.RFC7523KeyStorage)

        t.Run("case=associated key added with grant", func(t *testing.T) {
            keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "token-service-key", "sig")
            require.NoError(t, err)

            publicKey := keySet.Keys[0].Public()
            issuer := "token-service"
            subject := "bob@example.com"
            grant := trust.Grant{
                ID:              uuid.New(),
                Issuer:          issuer,
                Subject:         subject,
                AllowAnySubject: false,
                Scope:           []string{"openid", "offline"},
                PublicKey:       trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID},
                CreatedAt:       time.Now().UTC().Round(time.Second),
                ExpiresAt:       time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
            }

            storedKeySet, err := grantStorage.GetPublicKeys(context.TODO(), issuer, subject)
            require.NoError(t, err)
            require.Len(t, storedKeySet.Keys, 0)

            err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
            require.NoError(t, err)

            storedKeySet, err = grantStorage.GetPublicKeys(context.TODO(), issuer, subject)
            require.NoError(t, err)
            assert.Len(t, storedKeySet.Keys, 1)

            storedKey, err := grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID)
            require.NoError(t, err)
            assert.Equal(t, publicKey.KeyID, storedKey.KeyID)
            assert.Equal(t, publicKey.Use, storedKey.Use)
            assert.Equal(t, publicKey.Key, storedKey.Key)

            storedScopes, err := grantStorage.GetPublicKeyScopes(context.TODO(), issuer, subject, publicKey.KeyID)
            require.NoError(t, err)
            assert.Equal(t, grant.Scope, storedScopes)

            storedKeySet, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
            require.NoError(t, err)
            assert.Equal(t, publicKey.KeyID, storedKeySet.Keys[0].KeyID)
            assert.Equal(t, publicKey.Use, storedKeySet.Keys[0].Use)
            assert.Equal(t, publicKey.Key, storedKeySet.Keys[0].Key)
        })

        t.Run("case=only associated key returns", func(t *testing.T) {
            keySetToNotReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "some-key", "sig")
            require.NoError(t, err)
            require.NoError(t, keyManager.AddKeySet(context.Background(), "some-set", keySetToNotReturn), "adding a random key should not fail")

            issuer := "maria"
            subject := "maria@example.com"

            keySet1ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-1", "sig")
            require.NoError(t, err)
            require.NoError(t, grantManager.CreateGrant(context.Background(), trust.Grant{
                ID:              uuid.New(),
                Issuer:          issuer,
                Subject:         subject,
                AllowAnySubject: false,
                Scope:           []string{"openid"},
                PublicKey:       trust.PublicKey{Set: issuer, KeyID: keySet1ToReturn.Keys[0].Public().KeyID},
                CreatedAt:       time.Now().UTC().Round(time.Second),
                ExpiresAt:       time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
            }, keySet1ToReturn.Keys[0].Public()))

            keySet2ToReturn, err := jwk.GenerateJWK(context.Background(), jose.ES256, "maria-key-2", "sig")
            require.NoError(t, err)
            require.NoError(t, grantManager.CreateGrant(context.TODO(), trust.Grant{
                ID:              uuid.New(),
                Issuer:          issuer,
                Subject:         subject,
                AllowAnySubject: false,
                Scope:           []string{"openid"},
                PublicKey:       trust.PublicKey{Set: issuer, KeyID: keySet2ToReturn.Keys[0].Public().KeyID},
                CreatedAt:       time.Now().UTC().Round(time.Second),
                ExpiresAt:       time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
            }, keySet2ToReturn.Keys[0].Public()))

            storedKeySet, err := grantStorage.GetPublicKeys(context.Background(), issuer, subject)
            require.NoError(t, err)
            require.Len(t, storedKeySet.Keys, 2)

            // Cannot rely on sort order because the created_at timestamps may alias.
            idx1 := slices.IndexFunc(storedKeySet.Keys, func(k jose.JSONWebKey) bool {
                return k.KeyID == keySet1ToReturn.Keys[0].Public().KeyID
            })
            require.GreaterOrEqual(t, idx1, 0)
            idx2 := slices.IndexFunc(storedKeySet.Keys, func(k jose.JSONWebKey) bool {
                return k.KeyID == keySet2ToReturn.Keys[0].Public().KeyID
            })
            require.GreaterOrEqual(t, idx2, 0)

            assert.Equal(t, keySet1ToReturn.Keys[0].Public().KeyID, storedKeySet.Keys[idx1].KeyID)
            assert.Equal(t, keySet1ToReturn.Keys[0].Public().Use, storedKeySet.Keys[idx1].Use)
            assert.Equal(t, keySet1ToReturn.Keys[0].Public().Key, storedKeySet.Keys[idx1].Key)
            assert.Equal(t, keySet2ToReturn.Keys[0].Public().KeyID, storedKeySet.Keys[idx2].KeyID)
            assert.Equal(t, keySet2ToReturn.Keys[0].Public().Use, storedKeySet.Keys[idx2].Use)
            assert.Equal(t, keySet2ToReturn.Keys[0].Public().Key, storedKeySet.Keys[idx2].Key)

            storedKeySet, err = grantStorage.GetPublicKeys(context.Background(), issuer, "non-existing-subject")
            require.NoError(t, err)
            assert.Len(t, storedKeySet.Keys, 0)

            _, err = grantStorage.GetPublicKeyScopes(context.Background(), issuer, "non-existing-subject", keySet2ToReturn.Keys[0].Public().KeyID)
            require.Error(t, err)
        })

        t.Run("case=associated key is deleted, when granted is deleted", func(t *testing.T) {
            keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "hackerman-key", "sig")
            require.NoError(t, err)

            publicKey := keySet.Keys[0].Public()
            issuer := "aeneas"
            subject := "aeneas@example.com"
            grant := trust.Grant{
                ID:              uuid.New(),
                Issuer:          issuer,
                Subject:         subject,
                AllowAnySubject: false,
                Scope:           []string{"openid", "offline"},
                PublicKey:       trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID},
                CreatedAt:       time.Now().UTC().Round(time.Second),
                ExpiresAt:       time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
            }

            err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
            require.NoError(t, err)

            _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, grant.PublicKey.KeyID)
            require.NoError(t, err)

            _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
            require.NoError(t, err)

            err = grantManager.DeleteGrant(context.TODO(), grant.ID)
            require.NoError(t, err)

            _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID)
            assert.Error(t, err)

            _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
            assert.Error(t, err)
        })

        t.Run("case=associated grant is deleted, when key is deleted", func(t *testing.T) {
            keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "vladimir-key", "sig")
            require.NoError(t, err)

            publicKey := keySet.Keys[0].Public()
            issuer := "vladimir"
            subject := "vladimir@example.com"
            grant := trust.Grant{
                ID:              uuid.New(),
                Issuer:          issuer,
                Subject:         subject,
                AllowAnySubject: false,
                Scope:           []string{"openid", "offline"},
                PublicKey:       trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID},
                CreatedAt:       time.Now().UTC().Round(time.Second),
                ExpiresAt:       time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
            }

            err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
            require.NoError(t, err)

            _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID)
            require.NoError(t, err)

            _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
            require.NoError(t, err)

            err = keyManager.DeleteKey(context.TODO(), issuer, publicKey.KeyID)
            require.NoError(t, err)

            _, err = keyManager.GetKey(context.TODO(), issuer, publicKey.KeyID)
            assert.Error(t, err)

            _, err = grantManager.GetConcreteGrant(context.TODO(), grant.ID)
            assert.Error(t, err)
        })

        t.Run("case=only returns the key when subject matches", func(t *testing.T) {
            keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "issuer-key", "sig")
            require.NoError(t, err)

            publicKey := keySet.Keys[0].Public()
            issuer := "limited-issuer"
            subject := "jagoba"
            grant := trust.Grant{
                ID:              uuid.New(),
                Issuer:          issuer,
                Subject:         subject,
                AllowAnySubject: false,
                Scope:           []string{"openid", "offline"},
                PublicKey:       trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID},
                CreatedAt:       time.Now().UTC().Round(time.Second),
                ExpiresAt:       time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
            }

            err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
            require.NoError(t, err)

            // All three get methods should only return the public key when using the valid subject
            _, err = grantStorage.GetPublicKey(context.TODO(), issuer, "any-subject-1", publicKey.KeyID)
            require.Error(t, err)
            _, err = grantStorage.GetPublicKey(context.TODO(), issuer, subject, publicKey.KeyID)
            require.NoError(t, err)

            _, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, "any-subject-2", publicKey.KeyID)
            require.Error(t, err)
            _, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, subject, publicKey.KeyID)
            require.NoError(t, err)

            jwks, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3")
            require.NoError(t, err)
            require.NotNil(t, jwks)
            require.Empty(t, jwks.Keys)
            jwks, err = grantStorage.GetPublicKeys(context.TODO(), issuer, subject)
            require.NoError(t, err)
            require.NotNil(t, jwks)
            require.NotEmpty(t, jwks.Keys)
        })

        t.Run("case=returns the key when any subject is allowed", func(t *testing.T) {
            keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "issuer-key", "sig")
            require.NoError(t, err)

            publicKey := keySet.Keys[0].Public()
            issuer := "unlimited-issuer"
            grant := trust.Grant{
                ID:              uuid.New(),
                Issuer:          issuer,
                Subject:         "",
                AllowAnySubject: true,
                Scope:           []string{"openid", "offline"},
                PublicKey:       trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID},
                CreatedAt:       time.Now().UTC().Round(time.Second),
                ExpiresAt:       time.Now().UTC().Round(time.Second).AddDate(1, 0, 0),
            }

            err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
            require.NoError(t, err)

            // All three get methods should always return the public key
            _, err = grantStorage.GetPublicKey(context.TODO(), issuer, "any-subject-1", publicKey.KeyID)
            require.NoError(t, err)

            _, err = grantStorage.GetPublicKeyScopes(context.TODO(), issuer, "any-subject-2", publicKey.KeyID)
            require.NoError(t, err)

            jwks, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3")
            require.NoError(t, err)
            require.NotNil(t, jwks)
            require.NotEmpty(t, jwks.Keys)
        })

        t.Run("case=does not return expired values", func(t *testing.T) {
            keySet, err := jwk.GenerateJWK(context.Background(), jose.RS256, "issuer-expired-key", "sig")
            require.NoError(t, err)

            publicKey := keySet.Keys[0].Public()
            issuer := "expired-issuer"
            grant := trust.Grant{
                ID:              uuid.New(),
                Issuer:          issuer,
                Subject:         "",
                AllowAnySubject: true,
                Scope:           []string{"openid", "offline"},
                PublicKey:       trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID},
                CreatedAt:       time.Now().UTC().Round(time.Second),
                ExpiresAt:       time.Now().UTC().Round(time.Second).AddDate(-1, 0, 0),
            }

            err = grantManager.CreateGrant(context.TODO(), grant, publicKey)
            require.NoError(t, err)

            keys, err := grantStorage.GetPublicKeys(context.TODO(), issuer, "any-subject-3")
            require.NoError(t, err)
            assert.Len(t, keys.Keys, 0)
        })
    }
}

func doTestCommit(m InternalRegistry, t *testing.T,
    createFn func(context.Context, string, fosite.Requester) error,
    getFn func(context.Context, string, fosite.Session) (fosite.Requester, error),
    revokeFn func(context.Context, string) error,
) {

    txnStore, ok := m.OAuth2Storage().(storage.Transactional)
    require.True(t, ok)
    ctx := context.Background()
    ctx, err := txnStore.BeginTX(ctx)
    require.NoError(t, err)
    signature := uuid.New()
    err = createFn(ctx, signature, createTestRequest(signature))
    require.NoError(t, err)
    err = txnStore.Commit(ctx)
    require.NoError(t, err)

    // Require a new context, since the old one contains the transaction.
    res, err := getFn(context.Background(), signature, NewSession("bar"))
    // token should have been created successfully because Commit did not return an error
    require.NoError(t, err)
    assertx.EqualAsJSONExcept(t, &defaultRequest, res, defaultIgnoreKeys)
    // AssertObjectKeysEqual(t, &defaultRequest, res, "RequestedScope", "GrantedScope", "Form", "Session")

    // testrevoke within a transaction
    ctx, err = txnStore.BeginTX(context.Background())
    require.NoError(t, err)
    err = revokeFn(ctx, signature)
    require.NoError(t, err)
    err = txnStore.Commit(ctx)
    require.NoError(t, err)

    // Require a new context, since the old one contains the transaction.
    _, err = getFn(context.Background(), signature, NewSession("bar"))
    // Since commit worked for revoke, we should get an error here.
    require.Error(t, err)
}

func doTestRollback(m InternalRegistry, t *testing.T,
    createFn func(context.Context, string, fosite.Requester) error,
    getFn func(context.Context, string, fosite.Session) (fosite.Requester, error),
    revokeFn func(context.Context, string) error,
) {
    txnStore, ok := m.OAuth2Storage().(storage.Transactional)
    require.True(t, ok)

    ctx := context.Background()
    ctx, err := txnStore.BeginTX(ctx)
    require.NoError(t, err)
    signature := uuid.New()
    err = createFn(ctx, signature, createTestRequest(signature))
    require.NoError(t, err)
    err = txnStore.Rollback(ctx)
    require.NoError(t, err)

    // Require a new context, since the old one contains the transaction.
    ctx = context.Background()
    _, err = getFn(ctx, signature, NewSession("bar"))
    // Since we rolled back above, the token should not exist and getting it should result in an error
    require.Error(t, err)

    // create a new token, revoke it, then rollback the revoke. We should be able to then get it successfully.
    signature2 := uuid.New()
    err = createFn(ctx, signature2, createTestRequest(signature2))
    require.NoError(t, err)
    _, err = getFn(ctx, signature2, NewSession("bar"))
    require.NoError(t, err)

    ctx, err = txnStore.BeginTX(context.Background())
    require.NoError(t, err)
    err = revokeFn(ctx, signature2)
    require.NoError(t, err)
    err = txnStore.Rollback(ctx)
    require.NoError(t, err)

    _, err = getFn(context.Background(), signature2, NewSession("bar"))
    require.NoError(t, err)
}

func createTestRequest(id string) *fosite.Request {
    return &fosite.Request{
        ID:                id,
        RequestedAt:       time.Now().UTC().Round(time.Second),
        Client:            &client.Client{ID: "foobar"},
        RequestedScope:    fosite.Arguments{"fa", "ba"},
        GrantedScope:      fosite.Arguments{"fa", "ba"},
        RequestedAudience: fosite.Arguments{"ad1", "ad2"},
        GrantedAudience:   fosite.Arguments{"ad1", "ad2"},
        Form:              url.Values{"foo": []string{"bar", "baz"}},
        Session:           &Session{DefaultSession: &openid.DefaultSession{Subject: "bar"}},
    }
}