oauth2/oauth2_auth_code_bench_test.go
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package oauth2_test
import (
"context"
"flag"
"net/http"
"os"
"runtime"
"runtime/pprof"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
"go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
semconv "go.opentelemetry.io/otel/semconv/v1.12.0"
"golang.org/x/oauth2"
hydra "github.com/ory/hydra-client-go/v2"
hc "github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/internal/testhelpers"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
"github.com/ory/x/pointerx"
"github.com/ory/x/stringsx"
)
var (
prof = flag.String("profile", "", "write a CPU profile to this filename")
conc = flag.Int("conc", 100, "dispatch this many requests concurrently")
tracing = flag.Bool("tracing", false, "send OpenTelemetry traces to localhost:4318")
)
func BenchmarkAuthCode(b *testing.B) {
flag.Parse()
ctx := context.Background()
spans := tracetest.NewSpanRecorder()
opts := []trace.TracerProviderOption{
trace.WithSpanProcessor(spans),
trace.WithResource(resource.NewWithAttributes(
semconv.SchemaURL, attribute.String(string(semconv.ServiceNameKey), "BenchmarkAuthCode"),
)),
}
if *tracing {
exporter, err := otlptracehttp.New(ctx, otlptracehttp.WithInsecure(), otlptracehttp.WithEndpoint("localhost:4318"))
require.NoError(b, err)
opts = append(opts, trace.WithSpanProcessor(trace.NewSimpleSpanProcessor(exporter)))
}
provider := trace.NewTracerProvider(opts...)
tracer := provider.Tracer("BenchmarkAuthCode")
otel.SetTextMapPropagator(propagation.TraceContext{})
otel.SetTracerProvider(provider)
ctx, span := tracer.Start(ctx, "BenchmarkAuthCode")
defer span.End()
ctx = context.WithValue(ctx, oauth2.HTTPClient, otelhttp.DefaultClient)
dsn := stringsx.Coalesce(os.Getenv("DSN"), "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable&max_conns=20&max_idle_conns=20")
// dsn := "mysql://root:secret@tcp(localhost:3444)/mysql?max_conns=16&max_idle_conns=16"
// dsn := "cockroach://root@localhost:3446/defaultdb?sslmode=disable&max_conns=16&max_idle_conns=16"
reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer)
reg.Config().MustSet(ctx, config.KeyLogLevel, "error")
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyRefreshTokenHook, "")
oauth2Keys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OAuth2JWTKeyName, "sig")
require.NoError(b, err)
oidcKeys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OpenIDConnectKeyName, "sig")
require.NoError(b, err)
_, _ = oauth2Keys, oidcKeys
require.NoError(b, reg.KeyManager().UpdateKeySet(ctx, x.OAuth2JWTKeyName, oauth2Keys))
require.NoError(b, reg.KeyManager().UpdateKeySet(ctx, x.OpenIDConnectKeyName, oidcKeys))
_, adminTS := testhelpers.NewOAuth2Server(ctx, b, reg)
var (
authURL = reg.Config().OAuth2AuthURL(ctx).String()
tokenURL = reg.Config().OAuth2TokenURL(ctx).String()
nonce = uuid.New()
)
newOAuth2Client := func(b *testing.B, cb string) (*hc.Client, *oauth2.Config) {
secret := uuid.New()
c := &hc.Client{
Secret: secret,
RedirectURIs: []string{cb},
ResponseTypes: []string{"id_token", "code", "token"},
GrantTypes: []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"},
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
}
require.NoError(b, reg.ClientManager().CreateClient(ctx, c))
return c, &oauth2.Config{
ClientID: c.GetID(),
ClientSecret: secret,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenURL,
AuthStyle: oauth2.AuthStyleInHeader,
},
Scopes: strings.Split(c.Scope, " "),
}
}
cfg := hydra.NewConfiguration()
cfg.HTTPClient = otelhttp.DefaultClient
adminClient := hydra.NewAPIClient(cfg)
adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}}
getAuthorizeCode := func(ctx context.Context, b *testing.B, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) {
if c == nil {
c = testhelpers.NewEmptyJarClient(b)
}
state := uuid.New()
req, err := http.NewRequestWithContext(ctx, "GET", conf.AuthCodeURL(state, params...), nil)
require.NoError(b, err)
resp, err := c.Do(req)
require.NoError(b, err)
defer resp.Body.Close()
q := resp.Request.URL.Query()
require.EqualValues(b, state, q.Get("state"))
return q.Get("code"), resp
}
acceptLoginHandler := func(b *testing.B, c *hc.Client, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc {
return otelhttp.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
rr, _, err := adminClient.OAuth2API.GetOAuth2LoginRequest(ctx).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute()
require.NoError(b, err)
assert.EqualValues(b, c.GetID(), pointerx.Deref(rr.Client.ClientId))
assert.Empty(b, pointerx.Deref(rr.Client.ClientSecret))
assert.EqualValues(b, c.GrantTypes, rr.Client.GrantTypes)
assert.EqualValues(b, c.LogoURI, pointerx.Deref(rr.Client.LogoUri))
assert.EqualValues(b, c.RedirectURIs, rr.Client.RedirectUris)
assert.EqualValues(b, r.URL.Query().Get("login_challenge"), rr.Challenge)
assert.EqualValues(b, []string{"hydra", "offline", "openid"}, rr.RequestedScope)
assert.Contains(b, rr.RequestUrl, authURL)
acceptBody := hydra.AcceptOAuth2LoginRequest{
Subject: uuid.New(),
Remember: pointerx.Ptr(!rr.Skip),
Acr: pointerx.Ptr("1"),
Amr: []string{"pwd"},
Context: map[string]interface{}{"context": "bar"},
}
if checkRequestPayload != nil {
if b := checkRequestPayload(rr); b != nil {
acceptBody = *b
}
}
v, _, err := adminClient.OAuth2API.AcceptOAuth2LoginRequest(ctx).
LoginChallenge(r.URL.Query().Get("login_challenge")).
AcceptOAuth2LoginRequest(acceptBody).
Execute()
require.NoError(b, err)
require.NotEmpty(b, v.RedirectTo)
http.Redirect(w, r, v.RedirectTo, http.StatusFound)
}), "acceptLoginHandler").ServeHTTP
}
acceptConsentHandler := func(b *testing.B, c *hc.Client, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc {
return otelhttp.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
rr, _, err := adminClient.OAuth2API.GetOAuth2ConsentRequest(ctx).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute()
require.NoError(b, err)
assert.EqualValues(b, c.GetID(), pointerx.Deref(rr.Client.ClientId))
assert.Empty(b, pointerx.Deref(rr.Client.ClientSecret))
assert.EqualValues(b, c.GrantTypes, rr.Client.GrantTypes)
assert.EqualValues(b, c.LogoURI, pointerx.Deref(rr.Client.LogoUri))
assert.EqualValues(b, c.RedirectURIs, rr.Client.RedirectUris)
// assert.EqualValues(b, subject, pointerx.Deref(rr.Subject))
assert.EqualValues(b, []string{"hydra", "offline", "openid"}, rr.RequestedScope)
assert.EqualValues(b, r.URL.Query().Get("consent_challenge"), rr.Challenge)
assert.Contains(b, *rr.RequestUrl, authURL)
if checkRequestPayload != nil {
checkRequestPayload(rr)
}
assert.Equal(b, map[string]interface{}{"context": "bar"}, rr.Context)
v, _, err := adminClient.OAuth2API.AcceptOAuth2ConsentRequest(ctx).
ConsentChallenge(r.URL.Query().Get("consent_challenge")).
AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{
GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0),
GrantAccessTokenAudience: rr.RequestedAccessTokenAudience,
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]interface{}{"foo": "bar"},
IdToken: map[string]interface{}{"bar": "baz"},
},
}).
Execute()
require.NoError(b, err)
require.NotEmpty(b, v.RedirectTo)
http.Redirect(w, r, v.RedirectTo, http.StatusFound)
}), "acceptConsentHandler").ServeHTTP
}
run := func(b *testing.B, strategy string) func(*testing.B) {
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
c, conf := newOAuth2Client(b, testhelpers.NewCallbackURL(b, "callback", testhelpers.HTTPServerNotImplementedHandler))
testhelpers.NewLoginConsentUI(b, reg.Config(),
acceptLoginHandler(b, c, nil),
acceptConsentHandler(b, c, nil),
)
return func(b *testing.B) {
//pop.Debug = true
code, _ := getAuthorizeCode(ctx, b, conf, nil, oauth2.SetAuthURLParam("nonce", nonce))
require.NotEmpty(b, code)
_, err := conf.Exchange(ctx, code)
//pop.Debug = false
require.NoError(b, err)
}
}
b.ResetTimer()
b.SetParallelism(*conc / runtime.GOMAXPROCS(0))
b.Run("strategy=jwt", func(b *testing.B) {
initialDBSpans := dbSpans(spans)
B := run(b, "jwt")
stop := profile(b)
defer stop()
var totalMS int64 = 0
b.RunParallel(func(p *testing.PB) {
defer func(t0 time.Time) {
atomic.AddInt64(&totalMS, int64(time.Since(t0).Milliseconds()))
}(time.Now())
for p.Next() {
B(b)
}
})
b.ReportMetric(0, "ns/op")
b.ReportMetric(float64(atomic.LoadInt64(&totalMS))/float64(b.N), "ms/op")
b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op")
b.ReportMetric(float64(b.N)/b.Elapsed().Seconds(), "ops/s")
})
b.Run("strategy=opaque", func(b *testing.B) {
initialDBSpans := dbSpans(spans)
B := run(b, "opaque")
stop := profile(b)
defer stop()
var totalMS int64 = 0
b.RunParallel(func(p *testing.PB) {
defer func(t0 time.Time) {
atomic.AddInt64(&totalMS, int64(time.Since(t0).Milliseconds()))
}(time.Now())
for p.Next() {
B(b)
}
})
b.ReportMetric(0, "ns/op")
b.ReportMetric(float64(atomic.LoadInt64(&totalMS))/float64(b.N), "ms/op")
b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op")
b.ReportMetric(float64(b.N)/b.Elapsed().Seconds(), "ops/s")
})
}
func profile(t testing.TB) (stop func()) {
t.Helper()
if *prof == "" {
return func() {} // noop
}
f, err := os.Create(*prof)
require.NoError(t, err)
require.NoError(t, pprof.StartCPUProfile(f))
return func() {
pprof.StopCPUProfile()
require.NoError(t, f.Close())
t.Log("Wrote profile to", f.Name())
}
}