ory-am/hydra

View on GitHub
persistence/sql/persister_grant_jwk.go

Summary

Maintainability
A
1 hr
Test Coverage
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package sql

import (
    "context"
    "strings"
    "time"

    "github.com/ory/hydra/v2/jwk"

    "github.com/pkg/errors"

    "github.com/go-jose/go-jose/v3"
    "github.com/gobuffalo/pop/v6"

    "github.com/ory/hydra/v2/oauth2/trust"
    "github.com/ory/x/otelx"
    "github.com/ory/x/stringsx"

    "github.com/ory/x/sqlcon"
)

var _ trust.GrantManager = &Persister{}

func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jose.JSONWebKey) (err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateGrant")
    defer otelx.End(span, &err)

    return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
        // add key, if it doesn't exist
        if _, err := p.GetKey(ctx, g.PublicKey.Set, g.PublicKey.KeyID); err != nil {
            if !errors.Is(err, sqlcon.ErrNoRows) {
                return sqlcon.HandleError(err)
            }

            if err = p.AddKey(ctx, g.PublicKey.Set, &publicKey); err != nil {
                return sqlcon.HandleError(err)
            }
        }

        data := p.sqlDataFromJWTGrant(g)
        return sqlcon.HandleError(p.CreateWithNetwork(ctx, &data))
    })
}

func (p *Persister) GetConcreteGrant(ctx context.Context, id string) (_ trust.Grant, err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteGrant")
    defer otelx.End(span, &err)

    var data trust.SQLData
    if err := p.QueryWithNetwork(ctx).Where("id = ?", id).First(&data); err != nil {
        return trust.Grant{}, sqlcon.HandleError(err)
    }

    return p.jwtGrantFromSQlData(data), nil
}

func (p *Persister) DeleteGrant(ctx context.Context, id string) (err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteGrant")
    defer otelx.End(span, &err)

    return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
        grant, err := p.GetConcreteGrant(ctx, id)
        if err != nil {
            return sqlcon.HandleError(err)
        }

        if err := p.QueryWithNetwork(ctx).Where("id = ?", grant.ID).Delete(&trust.SQLData{}); err != nil {
            return sqlcon.HandleError(err)
        }

        return p.DeleteKey(ctx, grant.PublicKey.Set, grant.PublicKey.KeyID)
    })
}

func (p *Persister) GetGrants(ctx context.Context, limit, offset int, optionalIssuer string) (_ []trust.Grant, err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetGrants")
    defer otelx.End(span, &err)

    grantsData := make([]trust.SQLData, 0)

    query := p.QueryWithNetwork(ctx).
        Paginate(offset/limit+1, limit).
        Order("id")
    if optionalIssuer != "" {
        query = query.Where("issuer = ?", optionalIssuer)
    }

    if err := query.All(&grantsData); err != nil {
        return nil, sqlcon.HandleError(err)
    }

    grants := make([]trust.Grant, 0, len(grantsData))
    for _, data := range grantsData {
        grants = append(grants, p.jwtGrantFromSQlData(data))
    }

    return grants, nil
}

func (p *Persister) CountGrants(ctx context.Context) (n int, err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CountGrants")
    defer otelx.End(span, &err)

    n, err = p.QueryWithNetwork(ctx).
        Count(&trust.SQLData{})
    return n, sqlcon.HandleError(err)
}

func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject string, keyId string) (_ *jose.JSONWebKey, err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKey")
    defer otelx.End(span, &err)

    var data trust.SQLData
    query := p.QueryWithNetwork(ctx).
        Where("issuer = ?", issuer).
        Where("(subject = ? OR allow_any_subject IS TRUE)", subject).
        Where("key_id = ?", keyId).
        Where("nid = ?", p.NetworkID(ctx))
    if err := query.First(&data); err != nil {
        return nil, sqlcon.HandleError(err)
    }

    keySet, err := p.GetKey(ctx, data.KeySet, keyId)
    if err != nil {
        return nil, err
    }

    return &keySet.Keys[0], nil
}

func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject string) (_ *jose.JSONWebKeySet, err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeys")
    defer otelx.End(span, &err)

    expiresAt := "expires_at > NOW()"
    if p.conn.Dialect.Name() == "sqlite3" {
        expiresAt = "expires_at > datetime('now')"
    }

    grantsData := make([]trust.SQLData, 0)
    query := p.QueryWithNetwork(ctx).
        Select("key_id").
        Where(expiresAt).
        Where("issuer = ?", issuer).
        Where("(subject = ? OR allow_any_subject IS TRUE)", subject).
        Order("created_at DESC").
        Limit(100) // Load maximum of 100 keys

    if err := query.All(&grantsData); err != nil {
        return nil, sqlcon.HandleError(err)
    }

    if len(grantsData) == 0 {
        return &jose.JSONWebKeySet{}, nil
    }

    keyIDs := make([]interface{}, len(grantsData))
    for k, d := range grantsData {
        keyIDs[k] = d.KeyID
    }

    var js jwk.SQLDataRows
    if err := p.QueryWithNetwork(ctx).
        // key_set and issuer are set to the same value on creation:
        //
        //    grant := Grant{
        //        ID:              uuid.New().String(),
        //        Issuer:          grantRequest.Issuer,
        //        Subject:         grantRequest.Subject,
        //        AllowAnySubject: grantRequest.AllowAnySubject,
        //        Scope:           grantRequest.Scope,
        //        PublicKey: PublicKey{
        //            Set:   grantRequest.Issuer, // group all keys by issuer, so set=issuer
        //            KeyID: grantRequest.PublicKeyJWK.KeyID,
        //        },
        //        CreatedAt: time.Now().UTC().Round(time.Second),
        //        ExpiresAt: grantRequest.ExpiresAt.UTC().Round(time.Second),
        //    }
        //
        // Therefore it is fine if we only look for the issuer here instead of the key set id.
        Where("sid = ?", issuer).
        Where("kid IN (?)", keyIDs).
        Order("created_at DESC").
        All(&js); err != nil {
        return nil, sqlcon.HandleError(err)
    }

    return js.ToJWK(ctx, p.r)
}

func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) (_ []string, err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeyScopes")
    defer otelx.End(span, &err)

    var data trust.SQLData
    query := p.QueryWithNetwork(ctx).
        Where("issuer = ?", issuer).
        Where("(subject = ? OR allow_any_subject IS TRUE)", subject).
        Where("key_id = ?", keyId).
        Where("nid = ?", p.NetworkID(ctx))

    if err := query.First(&data); err != nil {
        return nil, sqlcon.HandleError(err)
    }

    return p.jwtGrantFromSQlData(data).Scope, nil
}

func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (ok bool, err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.IsJWTUsed")
    defer otelx.End(span, &err)

    err = p.ClientAssertionJWTValid(ctx, jti)
    if err != nil {
        return true, nil
    }

    return false, nil
}

func (p *Persister) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) (err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MarkJWTUsedForTime")
    defer otelx.End(span, &err)

    return p.SetClientAssertionJWT(ctx, jti, exp)
}

func (p *Persister) sqlDataFromJWTGrant(g trust.Grant) trust.SQLData {
    return trust.SQLData{
        ID:              g.ID,
        Issuer:          g.Issuer,
        Subject:         g.Subject,
        AllowAnySubject: g.AllowAnySubject,
        Scope:           strings.Join(g.Scope, "|"),
        KeySet:          g.PublicKey.Set,
        KeyID:           g.PublicKey.KeyID,
        CreatedAt:       g.CreatedAt,
        ExpiresAt:       g.ExpiresAt,
    }
}

func (p *Persister) jwtGrantFromSQlData(data trust.SQLData) trust.Grant {
    return trust.Grant{
        ID:              data.ID,
        Issuer:          data.Issuer,
        Subject:         data.Subject,
        AllowAnySubject: data.AllowAnySubject,
        Scope:           stringsx.Splitx(data.Scope, "|"),
        PublicKey: trust.PublicKey{
            Set:   data.KeySet,
            KeyID: data.KeyID,
        },
        CreatedAt: data.CreatedAt,
        ExpiresAt: data.ExpiresAt,
    }
}

func (p *Persister) FlushInactiveGrants(ctx context.Context, notAfter time.Time, _ int, _ int) (err error) {
    ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveGrants")
    defer otelx.End(span, &err)

    deleteUntil := time.Now().UTC()
    if deleteUntil.After(notAfter) {
        deleteUntil = notAfter
    }
    return sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("expires_at < ?", deleteUntil).Delete(&trust.SQLData{}))
}