client/manager_test_helpers.go
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package client
import (
"context"
"crypto/x509"
"fmt"
"testing"
"time"
"github.com/go-faker/faker/v4"
"github.com/go-jose/go-jose/v3"
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/fosite"
testhelpersuuid "github.com/ory/hydra/v2/internal/testhelpers/uuid"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/assertx"
"github.com/ory/x/contextx"
"github.com/ory/x/sqlcon"
)
func TestHelperClientAutoGenerateKey(k string, m Storage) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.TODO()
c := &Client{
Secret: "secret",
RedirectURIs: []string{"http://redirect"},
TermsOfServiceURI: "foo",
}
require.NoError(t, m.CreateClient(ctx, c))
dbClient, err := m.GetClient(ctx, c.GetID())
require.NoError(t, err)
dbClientConcrete, ok := dbClient.(*Client)
require.True(t, ok)
testhelpersuuid.AssertUUID(t, dbClientConcrete.ID)
assert.NoError(t, m.DeleteClient(ctx, c.GetID()))
}
}
func TestHelperClientAuthenticate(k string, m Manager) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.TODO()
require.NoError(t, m.CreateClient(ctx, &Client{
ID: "1234321",
Secret: "secret",
RedirectURIs: []string{"http://redirect"},
}))
c, err := m.AuthenticateClient(ctx, "1234321", []byte("secret1"))
require.Error(t, err)
require.Nil(t, c)
c, err = m.AuthenticateClient(ctx, "1234321", []byte("secret"))
require.NoError(t, err)
assert.Equal(t, "1234321", c.GetID())
}
}
func TestHelperUpdateTwoClients(_ string, m Manager) func(t *testing.T) {
return func(t *testing.T) {
c1, c2 := &Client{Name: "test client 1"}, &Client{Name: "test client 2"}
require.NoError(t, m.CreateClient(context.Background(), c1))
require.NoError(t, m.CreateClient(context.Background(), c2))
c1.Name, c2.Name = "updated client 1", "updated client 2"
assert.NoError(t, m.UpdateClient(context.Background(), c1))
assert.NoError(t, m.UpdateClient(context.Background(), c2))
}
}
func testHelperUpdateClient(t *testing.T, ctx context.Context, network Storage, k string) {
d, err := network.GetClient(ctx, "1234")
assert.NoError(t, err)
err = network.UpdateClient(ctx, &Client{
ID: "2-1234",
Name: "name-new",
Secret: "secret-new",
RedirectURIs: []string{"http://redirect/new"},
TermsOfServiceURI: "bar",
JSONWebKeys: new(x.JoseJSONWebKeySet),
})
require.NoError(t, err)
nc, err := network.GetConcreteClient(ctx, "2-1234")
require.NoError(t, err)
if k != "http" {
// http always returns an empty secret
assert.NotEqual(t, d.GetHashedSecret(), nc.GetHashedSecret())
}
assert.Equal(t, "bar", nc.TermsOfServiceURI)
assert.Equal(t, "name-new", nc.Name)
assert.EqualValues(t, []string{"http://redirect/new"}, nc.GetRedirectURIs())
assert.Zero(t, len(nc.Contacts))
}
func TestHelperCreateGetUpdateDeleteClientNext(t *testing.T, m Storage, networks []uuid.UUID) {
ctx := context.Background()
resources := map[uuid.UUID][]Client{}
for k := range networks {
nid := networks[k]
resources[nid] = []Client{}
ctx := contextx.SetNIDContext(ctx, nid)
t.Run(fmt.Sprintf("nid=%s", nid), func(t *testing.T) {
var client Client
require.NoError(t, faker.FakeData(&client))
client.CreatedAt = time.Now().Truncate(time.Second).UTC()
t.Run("lifecycle=does not exist", func(t *testing.T) {
_, err := m.GetClient(ctx, "1234")
require.Error(t, err)
})
t.Run("lifecycle=exists", func(t *testing.T) {
require.NoError(t, m.CreateClient(ctx, &client))
c, err := m.GetClient(ctx, client.GetID())
require.NoError(t, err)
assertx.EqualAsJSONExcept(t, &client, c, []string{
"registration_access_token",
"registration_client_uri",
"updated_at",
})
n, err := m.CountClients(ctx)
assert.NoError(t, err)
assert.Equal(t, 1, n)
copy := client
require.Error(t, m.CreateClient(ctx, ©))
})
t.Run("lifecycle=update", func(t *testing.T) {
client.Name = "updated" + nid.String()
require.NoError(t, m.UpdateClient(ctx, &client))
c, err := m.GetClient(ctx, client.GetID())
require.NoError(t, err)
assertx.EqualAsJSONExcept(t, &client, c, []string{
"registration_access_token",
"registration_client_uri",
"updated_at",
})
resources[nid] = append(resources[nid], client)
})
})
}
for k := range resources {
original := k
clients := resources[k]
for i := range networks {
check := networks[i]
t.Run("network="+original.String(), func(t *testing.T) {
ctx := contextx.SetNIDContext(ctx, check)
for _, expected := range clients {
c, err := m.GetClient(ctx, expected.GetID())
if check != original {
t.Run(fmt.Sprintf("case=must not find client %s", expected.GetID()), func(t *testing.T) {
require.ErrorIs(t, err, sqlcon.ErrNoRows)
})
} else {
t.Run("case=updates must not override each other", func(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "updated"+original.String(), c.(*Client).Name)
})
}
}
})
}
}
for k := range resources {
clients := resources[k]
ctx := contextx.SetNIDContext(ctx, k)
t.Run("network="+k.String(), func(t *testing.T) {
for _, client := range clients {
t.Run("lifecycle=cleanup", func(t *testing.T) {
assert.NoError(t, m.DeleteClient(ctx, client.GetID()))
_, err := m.GetClient(ctx, client.GetID())
assert.ErrorIs(t, err, sqlcon.ErrNoRows)
n, err := m.CountClients(ctx)
assert.NoError(t, err)
assert.Equal(t, 0, n)
assert.Error(t, m.DeleteClient(ctx, client.GetID()))
})
}
})
}
}
func TestHelperCreateGetUpdateDeleteClient(k string, connection *pop.Connection, t1 Storage, t2 Storage) func(t *testing.T) {
return func(t *testing.T) {
ctx := context.Background()
_, err := t1.GetClient(ctx, "1234")
require.Error(t, err)
t1c1 := &Client{
ID: "1234",
Name: "name",
Secret: "secret",
RedirectURIs: []string{"http://redirect", "http://redirect1"},
GrantTypes: []string{"implicit", "refresh_token"},
ResponseTypes: []string{"code token", "token id_token", "code"},
Scope: "scope-a scope-b",
Owner: "aeneas",
PolicyURI: "http://policy",
TermsOfServiceURI: "http://tos",
ClientURI: "http://client",
LogoURI: "http://logo",
Contacts: []string{"aeneas1", "aeneas2"},
SecretExpiresAt: 0,
SectorIdentifierURI: "https://sector",
JSONWebKeys: &x.JoseJSONWebKeySet{JSONWebKeySet: &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{KeyID: "foo", Key: []byte("asdf"), Certificates: []*x509.Certificate{}, CertificateThumbprintSHA1: []uint8{}, CertificateThumbprintSHA256: []uint8{}}}}},
JSONWebKeysURI: "https://...",
TokenEndpointAuthMethod: "none",
TokenEndpointAuthSigningAlgorithm: "RS256",
RequestURIs: []string{"foo", "bar"},
AllowedCORSOrigins: []string{"foo", "bar"},
RequestObjectSigningAlgorithm: "rs256",
UserinfoSignedResponseAlg: "RS256",
CreatedAt: time.Now().Add(-time.Hour).Round(time.Second).UTC(),
UpdatedAt: time.Now().Add(-time.Minute).Round(time.Second).UTC(),
FrontChannelLogoutURI: "http://fc-logout",
FrontChannelLogoutSessionRequired: true,
PostLogoutRedirectURIs: []string{"hello", "mister"},
BackChannelLogoutURI: "http://bc-logout",
BackChannelLogoutSessionRequired: true,
}
require.NoError(t, t1.CreateClient(ctx, t1c1))
{
t2c1 := *t1c1
require.Error(t, connection.Create(&t2c1), "should not be able to create the same client in other manager/network; are they backed by the same database?")
require.NoError(t, t2.CreateClient(ctx, &t2c1), "we should be able to create a client with the same ID in other network")
}
t2c3 := *t1c1
{
t2c3.ID = "t2c2-1234"
require.NoError(t, t2.CreateClient(ctx, &t2c3))
require.Error(t, t2.CreateClient(ctx, &t2c3))
}
assert.Equal(t, t1c1.GetID(), "1234")
if k != "http" {
assert.NotEmpty(t, t1c1.GetHashedSecret())
}
c2Template := &Client{
ID: "2-1234",
Name: "name2",
Secret: "secret",
RedirectURIs: []string{"http://redirect"},
TermsOfServiceURI: "foo",
SecretExpiresAt: 1,
}
assert.NoError(t, t1.CreateClient(ctx, c2Template))
assert.NoError(t, t2.CreateClient(ctx, c2Template))
d, err := t1.GetClient(ctx, "1234")
require.NoError(t, err)
cc := d.(*Client)
testhelpersuuid.AssertUUID(t, cc.NID)
compare(t, t1c1, d, k)
ds, err := t1.GetClients(ctx, Filter{Limit: 100, Offset: 0})
assert.NoError(t, err)
assert.Len(t, ds, 2)
assert.NotEqual(t, ds[0].GetID(), ds[1].GetID())
assert.NotEqual(t, ds[0].GetID(), ds[1].GetID())
// test if SecretExpiresAt was set properly
assert.Equal(t, ds[0].SecretExpiresAt, 0)
assert.Equal(t, ds[1].SecretExpiresAt, 1)
ds, err = t1.GetClients(ctx, Filter{Limit: 1, Offset: 0})
assert.NoError(t, err)
assert.Len(t, ds, 1)
ds, err = t1.GetClients(ctx, Filter{Limit: 100, Offset: 100})
assert.NoError(t, err)
assert.Empty(t, ds)
// get by name
ds, err = t1.GetClients(ctx, Filter{Limit: 100, Offset: 0, Name: "name"})
assert.NoError(t, err)
assert.Len(t, ds, 1)
assert.Equal(t, ds[0].Name, "name")
// get by name not exist
ds, err = t1.GetClients(ctx, Filter{Limit: 100, Offset: 0, Name: "bad name"})
assert.NoError(t, err)
assert.Len(t, ds, 0)
// get by owner
ds, err = t1.GetClients(ctx, Filter{Limit: 100, Offset: 0, Owner: "aeneas"})
assert.NoError(t, err)
assert.Len(t, ds, 1)
assert.Equal(t, ds[0].Owner, "aeneas")
testHelperUpdateClient(t, ctx, t1, k)
testHelperUpdateClient(t, ctx, t2, k)
err = t1.DeleteClient(ctx, "1234")
assert.NoError(t, err)
err = t1.DeleteClient(ctx, t2c3.GetID())
assert.Error(t, err)
_, err = t1.GetClient(ctx, "1234")
assert.NotNil(t, err)
n, err := t1.CountClients(ctx)
assert.NoError(t, err)
assert.Equal(t, 1, n)
}
}
func compare(t *testing.T, expected *Client, actual fosite.Client, k string) {
assert.EqualValues(t, expected.GetID(), actual.GetID())
if k != "http" {
assert.EqualValues(t, expected.GetHashedSecret(), actual.GetHashedSecret())
}
assert.EqualValues(t, expected.GetRedirectURIs(), actual.GetRedirectURIs())
assert.EqualValues(t, expected.GetGrantTypes(), actual.GetGrantTypes())
assert.EqualValues(t, expected.GetResponseTypes(), actual.GetResponseTypes())
assert.EqualValues(t, expected.GetScopes(), actual.GetScopes())
assert.EqualValues(t, expected.IsPublic(), actual.IsPublic())
if actual, ok := actual.(*Client); ok {
assert.EqualValues(t, expected.Owner, actual.Owner)
assert.EqualValues(t, expected.Name, actual.Name)
assert.EqualValues(t, expected.PolicyURI, actual.PolicyURI)
assert.EqualValues(t, expected.TermsOfServiceURI, actual.TermsOfServiceURI)
assert.EqualValues(t, expected.ClientURI, actual.ClientURI)
assert.EqualValues(t, expected.LogoURI, actual.LogoURI)
assert.EqualValues(t, expected.Contacts, actual.Contacts)
assert.EqualValues(t, expected.SecretExpiresAt, actual.SecretExpiresAt)
assert.EqualValues(t, expected.SectorIdentifierURI, actual.SectorIdentifierURI)
assert.EqualValues(t, expected.UserinfoSignedResponseAlg, actual.UserinfoSignedResponseAlg)
assert.EqualValues(t, expected.CreatedAt.UTC().Unix(), actual.CreatedAt.UTC().Unix())
// these values are not the same because of https://github.com/gobuffalo/pop/issues/591
//assert.EqualValues(t, expected.UpdatedAt.UTC().Unix(), actual.UpdatedAt.UTC().Unix(), "%s\n%s", expected.UpdatedAt.String(), actual.UpdatedAt.String())
assert.EqualValues(t, expected.FrontChannelLogoutURI, actual.FrontChannelLogoutURI)
assert.EqualValues(t, expected.FrontChannelLogoutSessionRequired, actual.FrontChannelLogoutSessionRequired)
assert.EqualValues(t, expected.PostLogoutRedirectURIs, actual.PostLogoutRedirectURIs)
assert.EqualValues(t, expected.BackChannelLogoutURI, actual.BackChannelLogoutURI)
assert.EqualValues(t, expected.BackChannelLogoutSessionRequired, actual.BackChannelLogoutSessionRequired)
}
if actual, ok := actual.(fosite.OpenIDConnectClient); ok {
require.NotNil(t, expected.JSONWebKeys)
for k, v := range expected.JSONWebKeys.JSONWebKeySet.Keys {
if v.CertificateThumbprintSHA1 == nil {
v.CertificateThumbprintSHA1 = make([]byte, 0)
}
if v.CertificateThumbprintSHA256 == nil {
v.CertificateThumbprintSHA256 = make([]byte, 0)
}
expected.JSONWebKeys.JSONWebKeySet.Keys[k] = v
}
assert.EqualValues(t, expected.JSONWebKeys.JSONWebKeySet, actual.GetJSONWebKeys())
assert.EqualValues(t, expected.JSONWebKeysURI, actual.GetJSONWebKeysURI())
assert.EqualValues(t, expected.TokenEndpointAuthMethod, actual.GetTokenEndpointAuthMethod())
assert.EqualValues(t, expected.RequestURIs, actual.GetRequestURIs())
assert.EqualValues(t, expected.RequestObjectSigningAlgorithm, actual.GetRequestObjectSigningAlgorithm())
}
}