persistence/sql/persister_consent.go
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package sql
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"github.com/ory/fosite"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/flow"
"github.com/ory/hydra/v2/oauth2/flowctx"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/errorsx"
"github.com/ory/x/otelx"
"github.com/ory/x/sqlcon"
"github.com/ory/x/sqlxx"
)
var _ consent.Manager = &Persister{}
func (p *Persister) RevokeSubjectConsentSession(ctx context.Context, user string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectConsentSession")
defer otelx.End(span, &err)
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ?", user))
}
func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user, client string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectClientConsentSession", trace.WithAttributes(attribute.String("client", client)))
defer otelx.End(span, &err)
return p.Transaction(ctx, p.revokeConsentSession("consent_challenge_id IS NOT NULL AND subject = ? AND client_id = ?", user, client))
}
func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error {
return func(ctx context.Context, c *pop.Connection) error {
fs := make([]*flow.Flow, 0)
if err := p.QueryWithNetwork(ctx).
Where(whereStmt, whereArgs...).
Select("consent_challenge_id").
All(&fs); errors.Is(err, sql.ErrNoRows) {
return errorsx.WithStack(x.ErrNotFound)
} else if err != nil {
return sqlcon.HandleError(err)
}
ids := make([]interface{}, 0, len(fs))
nid := p.NetworkID(ctx)
for _, f := range fs {
ids = append(ids, f.ConsentChallengeID.String())
}
if len(ids) == 0 {
return nil
}
if err := p.QueryWithNetwork(ctx).
Where("nid = ?", nid).
Where("request_id IN (?)", ids...).
Delete(&OAuth2RequestSQL{Table: sqlTableAccess}); errors.Is(err, fosite.ErrNotFound) {
// do nothing
} else if err != nil {
return err
}
if err := p.QueryWithNetwork(ctx).
Where("nid = ?", nid).
Where("request_id IN (?)", ids...).
Delete(&OAuth2RequestSQL{Table: sqlTableRefresh}); errors.Is(err, fosite.ErrNotFound) {
// do nothing
} else if err != nil {
return err
}
if err := p.QueryWithNetwork(ctx).
Where("nid = ?", nid).
Where("consent_challenge_id IN (?)", ids...).
Delete(new(flow.Flow)); errors.Is(err, sql.ErrNoRows) {
return errorsx.WithStack(x.ErrNotFound)
} else if err != nil {
return sqlcon.HandleError(err)
}
return nil
}
}
func (p *Persister) RevokeSubjectLoginSession(ctx context.Context, subject string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectLoginSession")
defer otelx.End(span, &err)
err = p.QueryWithNetwork(ctx).Where("subject = ?", subject).Delete(&flow.LoginSession{})
if err != nil {
return sqlcon.HandleError(err)
}
// This confuses people, see https://github.com/ory/hydra/issues/1168
//
// count, _ := rows.RowsAffected()
// if count == 0 {
// return errorsx.WithStack(x.ErrNotFound)
// }
return nil
}
func (p *Persister) CreateForcedObfuscatedLoginSession(ctx context.Context, session *consent.ForcedObfuscatedLoginSession) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateForcedObfuscatedLoginSession")
defer otelx.End(span, &err)
return p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
nid := p.NetworkID(ctx)
if err := c.RawQuery(
"DELETE FROM hydra_oauth2_obfuscated_authentication_session WHERE nid = ? AND client_id = ? AND subject = ?",
nid,
session.ClientID,
session.Subject,
).Exec(); err != nil {
return sqlcon.HandleError(err)
}
return sqlcon.HandleError(c.RawQuery(
"INSERT INTO hydra_oauth2_obfuscated_authentication_session (nid, subject, client_id, subject_obfuscated) VALUES (?, ?, ?, ?)",
nid,
session.Subject,
session.ClientID,
session.SubjectObfuscated,
).Exec())
})
}
func (p *Persister) GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (_ *consent.ForcedObfuscatedLoginSession, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetForcedObfuscatedLoginSession", trace.WithAttributes(attribute.String("client", client)))
defer otelx.End(span, &err)
var s consent.ForcedObfuscatedLoginSession
if err := p.Connection(ctx).Where(
"client_id = ? AND subject_obfuscated = ? AND nid = ?",
client,
obfuscated,
p.NetworkID(ctx),
).First(&s); errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(x.ErrNotFound)
} else if err != nil {
return nil, sqlcon.HandleError(err)
}
return &s, nil
}
// CreateConsentRequest configures fields that are introduced or changed in the
// consent request. It doesn't touch fields that would be copied from the login
// request.
func (p *Persister) CreateConsentRequest(ctx context.Context, f *flow.Flow, req *flow.OAuth2ConsentRequest) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateConsentRequest")
defer otelx.End(span, &err)
if f == nil {
return errorsx.WithStack(x.ErrNotFound.WithDebug("Flow is nil"))
}
if f.ID != req.LoginChallenge.String() || f.NID != p.NetworkID(ctx) {
return errorsx.WithStack(x.ErrNotFound)
}
f.State = flow.FlowStateConsentInitialized
f.ConsentChallengeID = sqlxx.NullString(req.ID)
f.ConsentSkip = req.Skip
f.ConsentVerifier = sqlxx.NullString(req.Verifier)
f.ConsentCSRF = sqlxx.NullString(req.CSRF)
return nil
}
func (p *Persister) GetFlowByConsentChallenge(ctx context.Context, challenge string) (_ *flow.Flow, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetFlowByConsentChallenge")
defer otelx.End(span, &err)
// challenge contains the flow.
f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), challenge, flowctx.AsConsentChallenge)
if err != nil {
return nil, errorsx.WithStack(x.ErrNotFound)
}
if f.NID != p.NetworkID(ctx) {
return nil, errorsx.WithStack(x.ErrNotFound)
}
if f.RequestedAt.Add(p.config.ConsentRequestMaxAge(ctx)).Before(time.Now()) {
return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The consent request has expired, please try again."))
}
f.Client, err = p.GetConcreteClient(ctx, f.ClientID)
if err != nil {
return nil, err
}
return f, nil
}
func (p *Persister) GetConsentRequest(ctx context.Context, challenge string) (_ *flow.OAuth2ConsentRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConsentRequest")
defer otelx.End(span, &err)
f, err := p.GetFlowByConsentChallenge(ctx, challenge)
if err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
return nil, errorsx.WithStack(x.ErrNotFound)
}
return nil, err
}
// We need to overwrite the ID with the encoded flow (challenge) so that the client is not confused.
f.ConsentChallengeID = sqlxx.NullString(challenge)
return f.GetConsentRequest(), nil
}
func (p *Persister) CreateLoginRequest(ctx context.Context, req *flow.LoginRequest) (_ *flow.Flow, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginRequest")
defer otelx.End(span, &err)
f := flow.NewFlow(req)
nid := p.NetworkID(ctx)
if nid == uuid.Nil {
return nil, errorsx.WithStack(x.ErrNotFound)
}
f.NID = nid
return f, nil
}
func (p *Persister) GetFlow(ctx context.Context, loginChallenge string) (_ *flow.Flow, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetFlow")
defer otelx.End(span, &err)
var f flow.Flow
if err := p.QueryWithNetwork(ctx).Where("login_challenge = ?", loginChallenge).First(&f); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(x.ErrNotFound)
}
return nil, sqlcon.HandleError(err)
}
return &f, nil
}
func (p *Persister) GetLoginRequest(ctx context.Context, loginChallenge string) (_ *flow.LoginRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetLoginRequest")
defer otelx.End(span, &err)
f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), loginChallenge, flowctx.AsLoginChallenge)
if err != nil {
return nil, errorsx.WithStack(x.ErrNotFound.WithWrap(err))
}
if f.NID != p.NetworkID(ctx) {
return nil, errorsx.WithStack(x.ErrNotFound)
}
if f.RequestedAt.Add(p.config.ConsentRequestMaxAge(ctx)).Before(time.Now()) {
return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The login request has expired, please try again."))
}
f.Client, err = p.GetConcreteClient(ctx, f.ClientID)
if err != nil {
return nil, err
}
lr := f.GetLoginRequest()
// Restore the short challenge ID, which was previously sent to the encoded flow,
// to make sure that the challenge ID in the returned flow matches the param.
lr.ID = loginChallenge
return lr, nil
}
func (p *Persister) HandleConsentRequest(ctx context.Context, f *flow.Flow, r *flow.AcceptOAuth2ConsentRequest) (_ *flow.OAuth2ConsentRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HandleConsentRequest")
defer otelx.End(span, &err)
if f == nil {
return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil"))
}
if f.NID != p.NetworkID(ctx) {
return nil, errorsx.WithStack(x.ErrNotFound)
}
// Restore the short challenge ID, which was previously sent to the encoded flow,
// to make sure that the challenge ID in the returned flow matches the param.
r.ID = f.ConsentChallengeID.String()
if err := f.HandleConsentRequest(r); err != nil {
return nil, errorsx.WithStack(err)
}
return f.GetConsentRequest(), nil
}
func (p *Persister) VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (_ *flow.AcceptOAuth2ConsentRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateConsentRequest")
defer otelx.End(span, &err)
f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), verifier, flowctx.AsConsentVerifier)
if err != nil {
return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The consent verifier has already been used, has not been granted, or is invalid."))
}
if f.NID != p.NetworkID(ctx) {
return nil, errorsx.WithStack(sqlcon.ErrNoRows)
}
f.Client, err = p.GetConcreteClient(ctx, f.ClientID)
if err != nil {
return nil, err
}
if err = f.InvalidateConsentRequest(); err != nil {
return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error()))
}
// We set the consent challenge ID to a new UUID that we can use as a foreign key in the database
// without encoding the whole flow.
f.ConsentChallengeID = sqlxx.NullString(uuid.Must(uuid.NewV4()).String())
if err = p.Connection(ctx).Create(f); err != nil {
return nil, sqlcon.HandleError(err)
}
return f.GetHandledConsentRequest(), nil
}
func (p *Persister) HandleLoginRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledLoginRequest) (lr *flow.LoginRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HandleLoginRequest")
defer otelx.End(span, &err)
if f == nil {
return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil"))
}
if f.NID != p.NetworkID(ctx) {
return nil, errorsx.WithStack(x.ErrNotFound)
}
r.ID = f.ID
err = f.HandleLoginRequest(r)
if err != nil {
return nil, err
}
return p.GetLoginRequest(ctx, challenge)
}
func (p *Persister) VerifyAndInvalidateLoginRequest(ctx context.Context, verifier string) (_ *flow.HandledLoginRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateLoginRequest")
defer otelx.End(span, &err)
f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), verifier, flowctx.AsLoginVerifier)
if err != nil {
return nil, errorsx.WithStack(sqlcon.ErrNoRows)
}
if f.NID != p.NetworkID(ctx) {
return nil, errorsx.WithStack(sqlcon.ErrNoRows)
}
f.Client, err = p.GetConcreteClient(ctx, f.ClientID)
if err != nil {
return nil, err
}
if err := f.InvalidateLoginRequest(); err != nil {
return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error()))
}
d := f.GetHandledLoginRequest()
return &d, nil
}
func (p *Persister) GetRememberedLoginSession(ctx context.Context, loginSessionFromCookie *flow.LoginSession, id string) (_ *flow.LoginSession, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRememberedLoginSession")
defer otelx.End(span, &err)
if s := loginSessionFromCookie; s != nil && s.NID == p.NetworkID(ctx) && s.ID == id && s.Remember {
return s, nil
}
var s flow.LoginSession
if err := p.QueryWithNetwork(ctx).Where("remember = TRUE").Find(&s, id); errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(x.ErrNotFound)
} else if err != nil {
return nil, sqlcon.HandleError(err)
}
return &s, nil
}
// ConfirmLoginSession creates or updates the login session. The NID will be set to the network ID of the context.
func (p *Persister) ConfirmLoginSession(ctx context.Context, loginSession *flow.LoginSession) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ConfirmLoginSession")
defer otelx.End(span, &err)
loginSession.NID = p.NetworkID(ctx)
loginSession.AuthenticatedAt = sqlxx.NullTime(time.Time(loginSession.AuthenticatedAt).Truncate(time.Second))
if p.Connection(ctx).Dialect.Name() == "mysql" {
// MySQL does not support UPSERT.
return p.mySQLConfirmLoginSession(ctx, loginSession)
}
err = p.Connection(ctx).Transaction(func(tx *pop.Connection) error {
res, err := tx.TX.NamedExec(`
INSERT INTO hydra_oauth2_authentication_session (id, nid, authenticated_at, subject, remember, identity_provider_session_id)
VALUES (:id, :nid, :authenticated_at, :subject, :remember, :identity_provider_session_id)
ON CONFLICT(id) DO
UPDATE SET
authenticated_at = :authenticated_at,
subject = :subject,
remember = :remember,
identity_provider_session_id = :identity_provider_session_id
WHERE hydra_oauth2_authentication_session.id = :id AND hydra_oauth2_authentication_session.nid = :nid
`, loginSession)
if err != nil {
return sqlcon.HandleError(err)
}
n, err := res.RowsAffected()
if err != nil {
return sqlcon.HandleError(err)
}
if n == 0 {
return errorsx.WithStack(x.ErrNotFound)
}
return nil
})
if err != nil {
return errors.WithStack(err)
}
return nil
}
func (p *Persister) CreateLoginSession(ctx context.Context, session *flow.LoginSession) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginSession")
defer otelx.End(span, &err)
nid := p.NetworkID(ctx)
if nid == uuid.Nil {
return errorsx.WithStack(x.ErrNotFound)
}
session.NID = nid
return nil
}
func (p *Persister) DeleteLoginSession(ctx context.Context, id string) (deletedSession *flow.LoginSession, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginSession")
defer otelx.End(span, &err)
if p.Connection(ctx).Dialect.Name() == "mysql" {
// MySQL does not support RETURNING.
return p.mySQLDeleteLoginSession(ctx, id)
}
var session flow.LoginSession
err = p.Connection(ctx).RawQuery(
`DELETE FROM hydra_oauth2_authentication_session
WHERE id = ? AND nid = ?
RETURNING *`,
id,
p.NetworkID(ctx),
).First(&session)
if err != nil {
return nil, sqlcon.HandleError(err)
}
return &session, nil
}
func (p *Persister) mySQLDeleteLoginSession(ctx context.Context, id string) (*flow.LoginSession, error) {
var session flow.LoginSession
err := p.Connection(ctx).Transaction(func(tx *pop.Connection) error {
err := tx.RawQuery(`
SELECT * FROM hydra_oauth2_authentication_session
WHERE id = ? AND nid = ?`,
id,
p.NetworkID(ctx),
).First(&session)
if err != nil {
return err
}
return p.Connection(ctx).RawQuery(`
DELETE FROM hydra_oauth2_authentication_session
WHERE id = ? AND nid = ?`,
id,
p.NetworkID(ctx),
).Exec()
})
if err != nil {
return nil, sqlcon.HandleError(err)
}
return &session, nil
}
func (p *Persister) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, subject string) (rs []flow.AcceptOAuth2ConsentRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequests")
defer otelx.End(span, &err)
var f flow.Flow
if err = p.Connection(ctx).
Where(
strings.TrimSpace(fmt.Sprintf(`
(state = %d OR state = %d) AND
subject = ? AND
client_id = ? AND
consent_skip=FALSE AND
consent_error='{}' AND
consent_remember=TRUE AND
nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
)),
subject, client, p.NetworkID(ctx)).
Order("requested_at DESC").
Limit(1).
First(&f); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(consent.ErrNoPreviousConsentFound)
}
return nil, sqlcon.HandleError(err)
}
return p.filterExpiredConsentRequests(ctx, []flow.AcceptOAuth2ConsentRequest{*f.GetHandledConsentRequest()})
}
func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, limit, offset int) (_ []flow.AcceptOAuth2ConsentRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsGrantedConsentRequests",
trace.WithAttributes(attribute.Int("limit", limit), attribute.Int("offset", offset)))
defer otelx.End(span, &err)
var fs []flow.Flow
c := p.Connection(ctx)
if err := c.
Where(
strings.TrimSpace(fmt.Sprintf(`
(state = %d OR state = %d) AND
subject = ? AND
consent_skip=FALSE AND
consent_error='{}' AND
nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
)),
subject, p.NetworkID(ctx)).
Order("requested_at DESC").
Paginate(offset/limit+1, limit).
All(&fs); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(consent.ErrNoPreviousConsentFound)
}
return nil, sqlcon.HandleError(err)
}
var rs []flow.AcceptOAuth2ConsentRequest
for _, f := range fs {
rs = append(rs, *f.GetHandledConsentRequest())
}
return p.filterExpiredConsentRequests(ctx, rs)
}
func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Context, subject, sid string, limit, offset int) (_ []flow.AcceptOAuth2ConsentRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsSessionGrantedConsentRequests",
trace.WithAttributes(attribute.String("sid", sid), attribute.Int("limit", limit), attribute.Int("offset", offset)))
defer otelx.End(span, &err)
var fs []flow.Flow
c := p.Connection(ctx)
if err := c.
Where(
strings.TrimSpace(fmt.Sprintf(`
(state = %d OR state = %d) AND
subject = ? AND
login_session_id = ? AND
consent_skip=FALSE AND
consent_error='{}' AND
nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
)),
subject, sid, p.NetworkID(ctx)).
Order("requested_at DESC").
Paginate(offset/limit+1, limit).
All(&fs); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(consent.ErrNoPreviousConsentFound)
}
return nil, sqlcon.HandleError(err)
}
var rs []flow.AcceptOAuth2ConsentRequest
for _, f := range fs {
rs = append(rs, *f.GetHandledConsentRequest())
}
return p.filterExpiredConsentRequests(ctx, rs)
}
func (p *Persister) CountSubjectsGrantedConsentRequests(ctx context.Context, subject string) (n int, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CountSubjectsGrantedConsentRequests")
defer otelx.End(span, &err)
defer func() {
span.SetAttributes(attribute.Int("count", n))
}()
n, err = p.Connection(ctx).
Where(
strings.TrimSpace(fmt.Sprintf(`
(state = %d OR state = %d) AND
subject = ? AND
consent_skip=FALSE AND
consent_error='{}' AND
nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused,
)),
subject, p.NetworkID(ctx)).
Count(&flow.Flow{})
return n, sqlcon.HandleError(err)
}
func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests []flow.AcceptOAuth2ConsentRequest) (_ []flow.AcceptOAuth2ConsentRequest, err error) {
_, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.filterExpiredConsentRequests")
defer otelx.End(span, &err)
var result []flow.AcceptOAuth2ConsentRequest
for _, v := range requests {
if v.RememberFor > 0 && v.RequestedAt.Add(time.Duration(v.RememberFor)*time.Second).Before(time.Now().UTC()) {
continue
}
result = append(result, v)
}
if len(result) == 0 {
return nil, errorsx.WithStack(consent.ErrNoPreviousConsentFound)
}
return result, nil
}
func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithFrontChannelLogout")
defer otelx.End(span, &err)
return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
}
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) (_ []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListUserAuthenticatedClientsWithBackChannelLogout")
defer otelx.End(span, &err)
return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
}
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) (cs []client.Client, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.listUserAuthenticatedClients",
trace.WithAttributes(attribute.String("sid", sid)))
defer otelx.End(span, &err)
if err := p.Connection(ctx).RawQuery(
/* #nosec G201 - channel can either be "front" or "back" */
fmt.Sprintf(`
SELECT DISTINCT c.* FROM hydra_client as c
JOIN hydra_oauth2_flow as f ON (c.id = f.client_id AND c.nid = f.nid)
WHERE
f.subject=? AND
c.%schannel_logout_uri != '' AND
c.%schannel_logout_uri IS NOT NULL AND
f.login_session_id = ? AND
f.nid = ? AND
c.nid = ?`,
channel,
channel,
),
subject,
sid,
p.NetworkID(ctx),
p.NetworkID(ctx),
).All(&cs); err != nil {
return nil, sqlcon.HandleError(err)
}
return cs, nil
}
func (p *Persister) CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLogoutRequest")
defer otelx.End(span, &err)
return errorsx.WithStack(p.CreateWithNetwork(ctx, request))
}
func (p *Persister) AcceptLogoutRequest(ctx context.Context, challenge string) (_ *flow.LogoutRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AcceptLogoutRequest")
defer otelx.End(span, &err)
if err := p.Connection(ctx).RawQuery("UPDATE hydra_oauth2_logout_request SET accepted=true, rejected=false WHERE challenge=? AND nid = ?", challenge, p.NetworkID(ctx)).Exec(); err != nil {
return nil, sqlcon.HandleError(err)
}
return p.GetLogoutRequest(ctx, challenge)
}
func (p *Persister) RejectLogoutRequest(ctx context.Context, challenge string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RejectLogoutRequest")
defer otelx.End(span, &err)
count, err := p.Connection(ctx).
RawQuery("UPDATE hydra_oauth2_logout_request SET rejected=true, accepted=false WHERE challenge=? AND nid = ?", challenge, p.NetworkID(ctx)).
ExecWithCount()
if count == 0 {
return errorsx.WithStack(x.ErrNotFound)
} else {
return errorsx.WithStack(err)
}
}
func (p *Persister) GetLogoutRequest(ctx context.Context, challenge string) (_ *flow.LogoutRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetLogoutRequest")
defer otelx.End(span, &err)
var lr flow.LogoutRequest
return &lr, sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("challenge = ? AND rejected = FALSE", challenge).First(&lr))
}
func (p *Persister) VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (_ *flow.LogoutRequest, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateLogoutRequest")
defer otelx.End(span, &err)
var lr flow.LogoutRequest
if count, err := p.Connection(ctx).RawQuery(`
UPDATE hydra_oauth2_logout_request
SET was_used = TRUE
WHERE nid = ?
AND verifier = ?
AND accepted = TRUE
AND rejected = FALSE`,
p.NetworkID(ctx),
verifier,
).ExecWithCount(); count == 0 && err == nil {
return nil, errorsx.WithStack(x.ErrNotFound)
} else if err != nil {
return nil, sqlcon.HandleError(err)
}
err = sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("verifier = ?", verifier).First(&lr))
if err != nil {
return nil, err
}
if expiry := time.Time(lr.ExpiresAt);
// If the expiry is unset, we are in a legacy use case (allow logout).
// TODO: Remove this in the future.
!expiry.IsZero() && expiry.Before(time.Now().UTC()) {
return nil, errorsx.WithStack(flow.ErrorLogoutFlowExpired)
}
return &lr, nil
}
func (p *Persister) FlushInactiveLoginConsentRequests(ctx context.Context, notAfter time.Time, limit int, batchSize int) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveLoginConsentRequests")
defer otelx.End(span, &err)
/* #nosec G201 table is static */
var f flow.Flow
// The value of notAfter should be the minimum between input parameter and request max expire based on its configured age
requestMaxExpire := time.Now().Add(-p.config.ConsentRequestMaxAge(ctx))
if requestMaxExpire.Before(notAfter) {
notAfter = requestMaxExpire
}
challenges := []string{}
queryFormat := `
SELECT login_challenge
FROM hydra_oauth2_flow
WHERE (
(state != ?)
OR (login_error IS NOT NULL AND login_error <> '{}' AND login_error <> '')
OR (consent_error IS NOT NULL AND consent_error <> '{}' AND consent_error <> '')
)
AND requested_at < ?
AND nid = ?
ORDER BY login_challenge
LIMIT %[1]d
`
// Select up to [limit] flows that can be safely deleted, i.e. flows that meet
// the following criteria:
// - flow.state is anything between FlowStateLoginInitialized and FlowStateConsentUnused (unhandled)
// - flow.login_error has valid error (login rejected)
// - flow.consent_error has valid error (consent rejected)
// AND timed-out
// - flow.requested_at < minimum of ttl.login_consent_request and notAfter
q := p.Connection(ctx).RawQuery(fmt.Sprintf(queryFormat, limit), flow.FlowStateConsentUsed, notAfter, p.NetworkID(ctx))
if err := q.All(&challenges); err == sql.ErrNoRows {
return errors.Wrap(fosite.ErrNotFound, "")
}
// Delete in batch consent requests and their references in cascade
for i := 0; i < len(challenges); i += batchSize {
j := i + batchSize
if j > len(challenges) {
j = len(challenges)
}
q := p.Connection(ctx).RawQuery(
fmt.Sprintf("DELETE FROM %s WHERE login_challenge in (?) AND nid = ?", (&f).TableName()),
challenges[i:j],
p.NetworkID(ctx),
)
if err := q.Exec(); err != nil {
return sqlcon.HandleError(err)
}
}
return nil
}
func (p *Persister) mySQLConfirmLoginSession(ctx context.Context, session *flow.LoginSession) error {
err := sqlcon.HandleError(p.Connection(ctx).Create(session))
if err == nil {
return nil
}
if !errors.Is(err, sqlcon.ErrUniqueViolation) {
return err
}
n, err := p.Connection(ctx).
Where("id = ? and nid = ?", session.ID, session.NID).
UpdateQuery(session, "authenticated_at", "subject", "identity_provider_session_id", "remember")
if err != nil {
return errors.WithStack(sqlcon.HandleError(err))
}
if n == 0 {
return errorsx.WithStack(x.ErrNotFound)
}
return nil
}