driver/registry_sql.go
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package driver
import (
"context"
"io/fs"
"strings"
"time"
"github.com/gobuffalo/pop/v6"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/luna-duclos/instrumentedsql"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/consent"
"github.com/ory/hydra/v2/hsm"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/oauth2/trust"
"github.com/ory/hydra/v2/persistence/sql"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
"github.com/ory/x/dbal"
"github.com/ory/x/errorsx"
otelsql "github.com/ory/x/otelx/sql"
"github.com/ory/x/popx"
"github.com/ory/x/resilience"
"github.com/ory/x/sqlcon"
)
type RegistrySQL struct {
*RegistryBase
defaultKeyManager jwk.Manager
initialPing func(r *RegistrySQL) error
}
var _ Registry = new(RegistrySQL)
// defaultInitialPing is the default function that will be called within RegistrySQL.Init to make sure
// the database is reachable. It can be injected for test purposes by changing the value
// of RegistrySQL.initialPing.
var defaultInitialPing = func(m *RegistrySQL) error {
if err := resilience.Retry(m.l, 5*time.Second, 5*time.Minute, m.Ping); err != nil {
m.Logger().Print("Could not ping database: ", err)
return errorsx.WithStack(err)
}
return nil
}
func init() {
dbal.RegisterDriver(
func() dbal.Driver {
return NewRegistrySQL()
},
)
}
func NewRegistrySQL() *RegistrySQL {
r := &RegistrySQL{
RegistryBase: new(RegistryBase),
initialPing: defaultInitialPing,
}
r.RegistryBase.with(r)
return r
}
func (m *RegistrySQL) Init(
ctx context.Context,
skipNetworkInit bool,
migrate bool,
ctxer contextx.Contextualizer,
extraMigrations []fs.FS,
goMigrations []popx.Migration,
) error {
if m.persister == nil {
m.WithContextualizer(ctxer)
var opts []instrumentedsql.Opt
if m.Tracer(ctx).IsLoaded() {
opts = []instrumentedsql.Opt{
instrumentedsql.WithTracer(otelsql.NewTracer()),
instrumentedsql.WithOmitArgs(), // don't risk leaking PII or secrets
instrumentedsql.WithOpsExcluded(instrumentedsql.OpSQLRowsNext),
}
}
// new db connection
pool, idlePool, connMaxLifetime, connMaxIdleTime, cleanedDSN := sqlcon.ParseConnectionOptions(
m.l, m.Config().DSN(),
)
c, err := pop.NewConnection(
&pop.ConnectionDetails{
URL: sqlcon.FinalizeDSN(m.l, cleanedDSN),
IdlePool: idlePool,
ConnMaxLifetime: connMaxLifetime,
ConnMaxIdleTime: connMaxIdleTime,
Pool: pool,
UseInstrumentedDriver: m.Tracer(ctx).IsLoaded(),
InstrumentedDriverOptions: opts,
Unsafe: m.Config().DbIgnoreUnknownTableColumns(),
},
)
if err != nil {
return errorsx.WithStack(err)
}
if err := resilience.Retry(m.l, 5*time.Second, 5*time.Minute, c.Open); err != nil {
return errorsx.WithStack(err)
}
p, err := sql.NewPersister(ctx, c, m, m.Config(), extraMigrations, goMigrations)
if err != nil {
return err
}
m.persister = p
if err := m.initialPing(m); err != nil {
return err
}
if m.Config().HSMEnabled() {
hardwareKeyManager := hsm.NewKeyManager(m.HSMContext(), m.Config())
m.defaultKeyManager = jwk.NewManagerStrategy(hardwareKeyManager, m.persister)
} else {
m.defaultKeyManager = m.persister
}
// if dsn is memory we have to run the migrations on every start
// use case - such as
// - just in memory
// - shared connection
// - shared but unique in the same process
// see: https://sqlite.org/inmemorydb.html
if dbal.IsMemorySQLite(m.Config().DSN()) {
m.Logger().Print("Hydra is running migrations on every startup as DSN is memory.\n")
m.Logger().Print("This means your data is lost when Hydra terminates.\n")
if err := p.MigrateUp(context.Background()); err != nil {
return err
}
} else if migrate {
if err := p.MigrateUp(context.Background()); err != nil {
return err
}
}
if skipNetworkInit {
m.persister = p
} else {
net, err := p.DetermineNetwork(ctx)
if err != nil {
m.Logger().WithError(err).Warnf("Unable to determine network, retrying.")
return err
}
m.persister = p.WithFallbackNetworkID(net.ID)
}
if m.Config().HSMEnabled() {
hardwareKeyManager := hsm.NewKeyManager(m.HSMContext(), m.Config())
m.defaultKeyManager = jwk.NewManagerStrategy(hardwareKeyManager, m.persister)
} else {
m.defaultKeyManager = m.persister
}
}
return nil
}
func (m *RegistrySQL) alwaysCanHandle(dsn string) bool {
scheme := strings.Split(dsn, "://")[0]
s := dbal.Canonicalize(scheme)
return s == dbal.DriverMySQL || s == dbal.DriverPostgreSQL || s == dbal.DriverCockroachDB
}
func (m *RegistrySQL) Ping() error {
return m.Persister().Ping()
}
func (m *RegistrySQL) ClientManager() client.Manager {
return m.Persister()
}
func (m *RegistrySQL) ConsentManager() consent.Manager {
return m.Persister()
}
func (m *RegistrySQL) OAuth2Storage() x.FositeStorer {
return m.Persister()
}
func (m *RegistrySQL) KeyManager() jwk.Manager {
return m.defaultKeyManager
}
func (m *RegistrySQL) SoftwareKeyManager() jwk.Manager {
return m.Persister()
}
func (m *RegistrySQL) GrantManager() trust.GrantManager {
return m.Persister()
}