oauth2/oauth2_refresh_token_test.go
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package oauth2_test
import (
"context"
"errors"
"fmt"
"math/rand"
"net/url"
"strings"
"sync"
"testing"
"time"
"github.com/gofrs/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/fosite"
hc "github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/driver"
"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/oauth2"
"github.com/ory/x/contextx"
"github.com/ory/x/dbal"
"github.com/ory/x/networkx"
)
// TestCreateRefreshTokenSessionStress is a sanity test to verify the fix for https://github.com/ory/hydra/issues/1719 &
// https://github.com/ory/hydra/issues/1735.
// It currently only deals with Postgres as that was what the issue was based on due to the default isolation level used
// by the storage engine.
func TestCreateRefreshTokenSessionStress(t *testing.T) {
if testing.Short() {
return
}
// number of iterations this test will make to ensure everything is working as expected. This test is aiming to
// prove correct behaviour when the handler is getting hit with the same refresh token in concurrent requests. Given
// that problems that may occur in this scenario are "racey" in nature, it is important to run this test several times
// so to minimize the probability were we pass due to sheer luck.
testRuns := 5
// number of workers that will concurrently hit the 'CreateRefreshTokenSession' method using the same refresh token.
// don't set this value to be too high as it will result in connection failures to the DB instance. The test is designed such that
// it will retry in the event we get unlucky and a transaction completes successfully prior to other requests getting past the
// first read.
workers := 10
token := "234c678fed33c1d2025537ae464a1ebf7d23fc4a" //nolint:gosec
tokenSignature := "4c7c7e8b3a77ad0c3ec846a21653c48b45dbfa31" //nolint:gosec
testClient := hc.Client{
ID: uuid.Must(uuid.NewV4()).String(),
Secret: "secret",
ResponseTypes: []string{"id_token", "code", "token"},
GrantTypes: []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"},
Scope: "hydra offline openid",
Audience: []string{"https://api.ory.sh/"},
}
request := &fosite.AccessRequest{
GrantTypes: []string{
"refresh_token",
},
Request: fosite.Request{
RequestedAt: time.Now(),
ID: uuid.Must(uuid.NewV4()).String(),
Client: &hc.Client{
ID: testClient.GetID(),
},
RequestedScope: []string{"offline"},
GrantedScope: []string{"offline"},
Session: oauth2.NewSession(""),
Form: url.Values{
"refresh_token": []string{fmt.Sprintf("%s.%s", token, tokenSignature)},
},
},
}
setupRegistries(t)
for dbName, dbRegistry := range registries {
if dbName == "memory" {
// todo check why sqlite fails with "no such table: hydra_oauth2_refresh \n sqlite create"
// should be fine though as nobody should use sqlite in production
continue
}
net := &networkx.Network{}
require.NoError(t, dbRegistry.Persister().Connection(context.Background()).First(net))
dbRegistry.WithContextualizer(&contextx.Static{NID: net.ID, C: internal.NewConfigurationWithDefaults().Source(context.Background())})
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
t.Cleanup(cancel)
require.NoError(t, dbRegistry.OAuth2Storage().(clientCreator).CreateClient(ctx, &testClient))
require.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, request))
_, err := dbRegistry.OAuth2Storage().GetRefreshTokenSession(ctx, tokenSignature, nil)
require.NoError(t, err)
provider := dbRegistry.OAuth2Provider()
storageVersion := dbVersion(t, ctx, dbRegistry)
var wg sync.WaitGroup
for run := 0; run < testRuns; run++ {
barrier := make(chan struct{})
errorsCh := make(chan error, workers)
go func() {
for w := 0; w < workers; w++ {
wg.Add(1)
go func(run, worker int) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) //nolint:gosec
// all workers will block here until the for loop above has launched all the worker go-routines
// this is to ensure we fire all the workers off at the same
<-barrier
_, err := provider.NewAccessResponse(ctx, request)
errorsCh <- err
}(run, w)
}
// wait until all workers have completed their work
wg.Wait()
close(errorsCh)
}()
// let the race begin!
// all worker go-routines will now attempt to hit the "NewAccessResponse" method
close(barrier)
// process worker results
// successCount is the number of workers that were able to call "NewAccessResponse" without receiving an error.
// if the successCount at the end of a test run is bigger than one, it means that multiple access/refresh tokens
// were issued using the same refresh token! - https://knowyourmeme.com/memes/scared-hamster
var successCount int
for err := range errorsCh {
if err != nil {
if e := (&fosite.RFC6749Error{}); errors.As(err, &e) {
switch e.ErrorField {
// change logic below when the refresh handler starts returning 'fosite.ErrInvalidRequest' for other reasons.
// as of now, this error is only returned due to concurrent transactions competing to refresh using the same token.
case fosite.ErrInvalidRequest.ErrorField, fosite.ErrServerError.ErrorField:
// the error description copy is defined by RFC 6749 and should not be different regardless of
// the underlying transactional aware storage backend used by hydra
assert.Contains(t, []string{fosite.ErrInvalidRequest.DescriptionField, fosite.ErrServerError.DescriptionField}, e.DescriptionField)
// the database error debug copy will be different depending on the underlying database used
switch dbName {
case dbal.DriverMySQL:
case dbal.DriverPostgreSQL, dbal.DriverCockroachDB:
var matched bool
for _, errSubstr := range []string{
// both postgreSQL & cockroachDB return error code 40001 for consistency errors as a result of
// using the REPEATABLE_READ isolation level
"SQLSTATE 40001",
// possible if one worker starts the transaction AFTER another worker has successfully
// refreshed the token and committed the transaction
"not_found",
// postgres: duplicate key value violates unique constraint "hydra_oauth2_access_request_id_idx": Unable to insert or update resource because a resource with that value exists already: The request could not be completed due to concurrent access
"duplicate key",
// cockroach: restart transaction: TransactionRetryWithProtoRefreshError: TransactionRetryError: retry txn (RETRY_WRITE_TOO_OLD - WriteTooOld flag converted to WriteTooOldError): "sql txn" meta={id=7f069400 key=/Table/62/2/"02a55d6e-509b-4d7a-8458-5828b2f831a1"/0 pri=0.00598277 epo=0 ts=1600955431.566576173,2 min=1600955431.566576173,0 seq=6} lock=true stat=PENDING rts=1600955431.566576173,2 wto=false max=1600955431.566576173,0: Unable to serialize access due to a concurrent update in another session: The request could not be completed due to concurrent access
"RETRY_WRITE_TOO_OLD",
// postgres: pq: deadlock detected
"deadlock detected",
// postgres: pq: could not serialize access due to concurrent update: Unable to serialize access due to a concurrent update in another session: The request could not be completed due to concurrent access
"concurrent update",
// cockroach: this happens when there is an error with the storage
"RETRY_WRITE_TOO_OLD",
// refresh token reuse detection
"token_inactive",
} {
if strings.Contains(e.DebugField, errSubstr) {
matched = true
break
}
}
assert.True(t, matched, "received an unexpected kind of `%s`\n"+
"DB version: %s\n"+
"Error description: %s\n"+
"Error debug: %s\n"+
"Error hint: %s\n"+
"Raw error: %T %+v\n"+
"Raw cause: %T %+v",
e.ErrorField,
storageVersion,
e.DescriptionField,
e.DebugField,
e.HintField,
err, err,
e, e)
}
default:
// unfortunately, MySQL does not offer the same behaviour under the "REPEATABLE_READ" isolation
// level so we have to relax this assertion just for MySQL for the time being as server_errors
// resembling the following can be returned:
//
// Error 1213: Deadlock found when trying to get lock; try restarting transaction
if dbName != dbal.DriverMySQL {
t.Errorf("an unexpected RFC6749 error with the name %q was returned.\n"+
"Hint: has the refresh token error handling changed in fosite? If so, you need to add further "+
"assertions here to cover the additional errors that are being returned by the handler.\n"+
"DB version: %s\n"+
"Error description: %s\n"+
"Error debug: %s\n"+
"Error hint: %s\n"+
"Raw error: %+v",
e.ErrorField,
storageVersion,
e.DescriptionField,
e.DebugField,
e.HintField,
err)
}
}
} else {
t.Errorf("expected underlying error to be of type '*fosite.RFC6749Error', but it was "+
"actually of type %T: %+v - DB version: %s", err, err, storageVersion)
}
} else {
successCount++
}
}
// IMPORTANT - skip consistency check for MySQL :(
//
// different DBMS's provide different consistency guarantees when using the "REPEATABLE_READ" isolation level
// Currently, MySQL's implementation of "REPEATABLE_READ" makes it possible for multiple concurrent requests
// to successfully utilize the same refresh token. Therefore, we skip the assertion below.
//
// TODO: this needs to be addressed by making it possible to use different isolation levels for various authorization
// flows depending on the underlying hydra storage backend. For example, if using MySQL, hydra should force
// the transaction isolation level to be "Serializable" when a request to the token handler is received.
switch dbName {
case dbal.DriverMySQL:
case dbal.DriverPostgreSQL, dbal.DriverCockroachDB:
require.Equal(t, 1, successCount, "CRITICAL: in test iteration %d, %d out of %d workers "+
"were able to use the refresh token. Exactly ONE was expected to be have been successful.",
run,
successCount,
workers)
}
// reset state for the next test iteration
assert.NoError(t, dbRegistry.OAuth2Storage().DeleteRefreshTokenSession(ctx, tokenSignature))
assert.NoError(t, dbRegistry.OAuth2Storage().CreateRefreshTokenSession(ctx, tokenSignature, request))
}
}
}
type version struct {
Version string `db:"version"`
}
func dbVersion(t *testing.T, ctx context.Context, registry driver.Registry) string {
var v version
versionFunc := "version()"
c := registry.Persister().Connection(ctx)
if c.Dialect.Name() == "sqlite3" {
versionFunc = "sqlite_version()"
}
/* #nosec G201 - versionFunc is an enum */
require.NoError(t, registry.Persister().Connection(ctx).RawQuery(fmt.Sprintf("select %s as version", versionFunc)).First(&v))
return v.Version
}