ory-am/hydra

View on GitHub
hsm/manager_hsm_test.go

Summary

Maintainability
F
1 wk
Test Coverage
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0

//go:build hsm
// +build hsm

package hsm_test

import (
    "context"
    "crypto/ecdsa"
    "crypto/elliptic"
    "crypto/rand"
    "crypto/rsa"
    "crypto/x509"
    "fmt"
    "reflect"
    "testing"

    "github.com/ory/hydra/v2/jwk"
    "github.com/ory/x/contextx"

    "github.com/ory/hydra/v2/driver"
    "github.com/ory/hydra/v2/driver/config"
    "github.com/ory/hydra/v2/persistence/sql"
    "github.com/ory/x/configx"
    "github.com/ory/x/logrusx"

    "github.com/ThalesIgnite/crypto11"
    "github.com/go-jose/go-jose/v3"
    "github.com/go-jose/go-jose/v3/cryptosigner"
    "github.com/golang/mock/gomock"
    "github.com/miekg/pkcs11"
    "github.com/pborman/uuid"
    "github.com/pkg/errors"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"

    "github.com/ory/hydra/v2/hsm"
    "github.com/ory/hydra/v2/x"
)

func TestDefaultKeyManager_HSMEnabled(t *testing.T) {
    ctrl := gomock.NewController(t)
    mockHsmContext := NewMockContext(ctrl)
    defer ctrl.Finish()
    l := logrusx.New("", "")
    c := config.MustNew(context.Background(), l, configx.SkipValidation())
    c.MustSet(context.Background(), config.KeyDSN, "memory")
    c.MustSet(context.Background(), config.HSMEnabled, "true")
    reg, err := driver.NewRegistryWithoutInit(c, l)
    require.NoError(t, err)
    reg.WithHsmContext(mockHsmContext)
    assert.NoError(t, reg.Init(context.Background(), false, true, &contextx.TestContextualizer{}, nil, nil))
    assert.IsType(t, &jwk.ManagerStrategy{}, reg.KeyManager())
    assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager())
}

func TestKeyManager_HsmKeySetPrefix(t *testing.T) {
    ctrl := gomock.NewController(t)
    hsmContext := NewMockContext(ctrl)
    defer ctrl.Finish()
    l := logrusx.New("", "")
    c := config.MustNew(context.Background(), l, configx.SkipValidation())
    keySetPrefix := "application_specific_prefix."
    c.MustSet(context.Background(), config.HSMKeySetPrefix, keySetPrefix)
    m := hsm.NewKeyManager(hsmContext, c)

    rsaKey3072, err := rsa.GenerateKey(rand.Reader, 3072)
    require.NoError(t, err)
    rsaKey4096, err := rsa.GenerateKey(rand.Reader, 4096)
    require.NoError(t, err)

    ecdsaKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
    require.NoError(t, err)

    rsaKeyPair3072 := NewMockSignerDecrypter(ctrl)
    rsaKeyPair3072.EXPECT().Public().Return(&rsaKey3072.PublicKey).AnyTimes()

    rsaKeyPair4096 := NewMockSignerDecrypter(ctrl)
    rsaKeyPair4096.EXPECT().Public().Return(&rsaKey4096.PublicKey).AnyTimes()

    ecdsaKeyPair := NewMockSignerDecrypter(ctrl)
    ecdsaKeyPair.EXPECT().Public().Return(&ecdsaKey.PublicKey).AnyTimes()

    var kid = uuid.New()

    expectedPrefixedOpenIDConnectKeyName := fmt.Sprintf("%s%s", keySetPrefix, x.OpenIDConnectKeyName)

    t.Run("case=GenerateAndPersistKeySet", func(t *testing.T) {
        privateAttrSet, publicAttrSet := expectedKeyAttributes(t, expectedPrefixedOpenIDConnectKeyName, kid)
        hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(nil, nil)
        hsmContext.EXPECT().GenerateRSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(4096)).Return(rsaKeyPair4096, nil)

        got, err := m.GenerateAndPersistKeySet(context.TODO(), x.OpenIDConnectKeyName, kid, "RS256", "sig")

        assert.NoError(t, err)
        expectedKeySet := expectedKeySet(rsaKeyPair4096, kid, "RS256", "sig")
        if !reflect.DeepEqual(got, expectedKeySet) {
            t.Errorf("GenerateAndPersistKeySet() got = %v, want %v", got, expectedKeySet)
        }
    })
    t.Run("case=GetKey", func(t *testing.T) {
        hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(rsaKeyPair4096, nil)
        hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair4096), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)

        got, err := m.GetKey(context.TODO(), x.OpenIDConnectKeyName, kid)

        assert.NoError(t, err)
        expectedKeySet := expectedKeySet(rsaKeyPair4096, kid, "RS256", "sig")
        if !reflect.DeepEqual(got, expectedKeySet) {
            t.Errorf("GetKey() got = %v, want %v", got, expectedKeySet)
        }
    })
    t.Run("case=GetKeyMinimalRsaKeyLengthError", func(t *testing.T) {
        hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(rsaKeyPair3072, nil)

        _, err := m.GetKey(context.TODO(), x.OpenIDConnectKeyName, kid)

        assert.ErrorIs(t, err, jwk.ErrMinimalRsaKeyLength)
    })
    t.Run("case=GetKeySet", func(t *testing.T) {
        hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return([]crypto11.Signer{rsaKeyPair4096}, nil)
        hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair4096), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(kid)), nil)
        hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair4096), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)

        got, err := m.GetKeySet(context.TODO(), x.OpenIDConnectKeyName)

        assert.NoError(t, err)
        expectedKeySet := expectedKeySet(rsaKeyPair4096, kid, "RS256", "sig")
        if !reflect.DeepEqual(got, expectedKeySet) {
            t.Errorf("GetKey() got = %v, want %v", got, expectedKeySet)
        }
    })
    t.Run("case=GetKeySetMinimalRsaKeyLengthError", func(t *testing.T) {
        hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return([]crypto11.Signer{rsaKeyPair3072}, nil)
        hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair3072), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(kid)), nil)

        _, err := m.GetKeySet(context.TODO(), x.OpenIDConnectKeyName)

        assert.ErrorIs(t, err, jwk.ErrMinimalRsaKeyLength)
    })
    t.Run("case=DeleteKey", func(t *testing.T) {
        hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return(rsaKeyPair4096, nil)
        rsaKeyPair4096.EXPECT().Delete().Return(nil)

        err := m.DeleteKey(context.TODO(), x.OpenIDConnectKeyName, kid)

        assert.NoError(t, err)
    })
    t.Run("case=DeleteKeySet", func(t *testing.T) {
        hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(expectedPrefixedOpenIDConnectKeyName))).Return([]crypto11.Signer{rsaKeyPair4096}, nil)
        rsaKeyPair4096.EXPECT().Delete().Return(nil)

        err := m.DeleteKeySet(context.TODO(), x.OpenIDConnectKeyName)

        assert.NoError(t, err)
    })
}

func TestKeyManager_GenerateAndPersistKeySet(t *testing.T) {
    ctrl := gomock.NewController(t)
    hsmContext := NewMockContext(ctrl)
    defer ctrl.Finish()
    l := logrusx.New("", "")
    c := config.MustNew(context.Background(), l, configx.SkipValidation())
    m := hsm.NewKeyManager(hsmContext, c)

    rsaKey, err := rsa.GenerateKey(rand.Reader, 4096)
    require.NoError(t, err)

    ecdsaKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
    require.NoError(t, err)

    rsaKeyPair := NewMockSignerDecrypter(ctrl)
    rsaKeyPair.EXPECT().Public().Return(&rsaKey.PublicKey).AnyTimes()

    ecdsaKeyPair := NewMockSignerDecrypter(ctrl)
    ecdsaKeyPair.EXPECT().Public().Return(&ecdsaKey.PublicKey).AnyTimes()

    var kid = uuid.New()

    type args struct {
        ctx context.Context
        set string
        kid string
        alg string
        use string
    }
    tests := []struct {
        name       string
        setup      func(t *testing.T)
        args       args
        want       *jose.JSONWebKeySet
        wantErrMsg string
        wantErr    error
    }{
        {
            name: "Generate RS256",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
                alg: "RS256",
                use: "sig",
            },
            setup: func(t *testing.T) {
                privateAttrSet, publicAttrSet := expectedKeyAttributes(t, x.OpenIDConnectKeyName, kid)
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
                hsmContext.EXPECT().GenerateRSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(4096)).Return(rsaKeyPair, nil)
            },
            want: expectedKeySet(rsaKeyPair, kid, "RS256", "sig"),
        },
        {
            name: "Generate RS256 with GenerateRSAKeyPairWithAttributes Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
                alg: "RS256",
                use: "sig",
            },
            setup: func(t *testing.T) {
                privateAttrSet, publicAttrSet := expectedKeyAttributes(t, x.OpenIDConnectKeyName, kid)
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
                hsmContext.EXPECT().GenerateRSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(4096)).Return(nil, errors.New("GenerateRSAKeyPairWithAttributesError"))
            },
            wantErrMsg: "GenerateRSAKeyPairWithAttributesError",
        },
        {
            name: "Generate ES256",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
                alg: "ES256",
                use: "sig",
            },
            setup: func(t *testing.T) {
                privateAttrSet, publicAttrSet := expectedKeyAttributes(t, x.OpenIDConnectKeyName, kid)
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
                hsmContext.EXPECT().GenerateECDSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(elliptic.P256())).Return(ecdsaKeyPair, nil)
            },
            want: expectedKeySet(ecdsaKeyPair, kid, "ES256", "sig"),
        },
        {
            name: "Generate ES256 with GenerateECDSAKeyPairWithAttributes Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
                alg: "ES256",
                use: "sig",
            },
            setup: func(t *testing.T) {
                privateAttrSet, publicAttrSet := expectedKeyAttributes(t, x.OpenIDConnectKeyName, kid)
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
                hsmContext.EXPECT().GenerateECDSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(elliptic.P256())).Return(nil, errors.New("GenerateECDSAKeyPairWithAttributesError"))
            },
            wantErrMsg: "GenerateECDSAKeyPairWithAttributesError",
        },
        {
            name: "Generate ES512",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
                alg: "ES512",
                use: "sig",
            },
            setup: func(t *testing.T) {
                privateAttrSet, publicAttrSet := expectedKeyAttributes(t, x.OpenIDConnectKeyName, kid)
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
                hsmContext.EXPECT().GenerateECDSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(elliptic.P521())).Return(ecdsaKeyPair, nil)
            },
            want: expectedKeySet(ecdsaKeyPair, kid, "ES512", "sig"),
        },
        {
            name: "Generate ES512 GenerateECDSAKeyPairWithAttributes Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
                alg: "ES512",
                use: "sig",
            },
            setup: func(t *testing.T) {
                privateAttrSet, publicAttrSet := expectedKeyAttributes(t, x.OpenIDConnectKeyName, kid)
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
                hsmContext.EXPECT().GenerateECDSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(elliptic.P521())).Return(nil, errors.New("GenerateECDSAKeyPairWithAttributesError"))
            },
            wantErrMsg: "GenerateECDSAKeyPairWithAttributesError",
        },
        {
            name: "Generate unsupported",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
                alg: "ES384",
                use: "sig",
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
            },
            wantErr: errors.WithStack(jwk.ErrUnsupportedKeyAlgorithm),
        },
        {
            name: "Generate with FindKeyPair Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
                alg: "RS256",
                use: "sig",
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairError"))
            },
            wantErrMsg: "FindKeyPairError",
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            tt.setup(t)
            got, err := m.GenerateAndPersistKeySet(tt.args.ctx, tt.args.set, tt.args.kid, tt.args.alg, tt.args.use)
            if tt.wantErr != nil {
                require.Nil(t, got)
                require.IsType(t, tt.wantErr, err)
            } else if len(tt.wantErrMsg) != 0 {
                require.Nil(t, got)
                require.EqualError(t, err, tt.wantErrMsg)
                return
            }
            if !reflect.DeepEqual(got, tt.want) {
                t.Errorf("GenerateAndPersistKeySet() got = %v, want %v", got, tt.want)
            }
        })
    }
}

func TestKeyManager_GetKey(t *testing.T) {
    ctrl := gomock.NewController(t)
    hsmContext := NewMockContext(ctrl)
    defer ctrl.Finish()
    l := logrusx.New("", "")
    c := config.MustNew(context.Background(), l, configx.SkipValidation())
    m := hsm.NewKeyManager(hsmContext, c)

    rsaKey, err := rsa.GenerateKey(rand.Reader, 4096)
    require.NoError(t, err)
    rsaKeyPair := NewMockSignerDecrypter(ctrl)
    rsaKeyPair.EXPECT().Public().Return(&rsaKey.PublicKey).AnyTimes()

    ecdsaP256Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    require.NoError(t, err)
    ecdsaP256KeyPair := NewMockSignerDecrypter(ctrl)
    ecdsaP256KeyPair.EXPECT().Public().Return(&ecdsaP256Key.PublicKey).AnyTimes()

    ecdsaP521Key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
    require.NoError(t, err)
    ecdsaP521KeyPair := NewMockSignerDecrypter(ctrl)
    ecdsaP521KeyPair.EXPECT().Public().Return(&ecdsaP521Key.PublicKey).AnyTimes()

    ecdsaP224Key, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
    require.NoError(t, err)
    ecdsaP224KeyPair := NewMockSignerDecrypter(ctrl)
    ecdsaP224KeyPair.EXPECT().Public().Return(&ecdsaP224Key.PublicKey).AnyTimes()

    var kid = uuid.New()

    type args struct {
        ctx context.Context
        set string
        kid string
    }
    tests := []struct {
        name       string
        setup      func(t *testing.T)
        args       args
        want       *jose.JSONWebKeySet
        wantErrMsg string
        wantErr    error
    }{
        {
            name: "Get RS256 sig",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)
            },
            want: expectedKeySet(rsaKeyPair, kid, "RS256", "sig"),
        },
        {
            name: "Get RS256 enc",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), nil)
            },
            want: expectedKeySet(rsaKeyPair, kid, "RS256", "enc"),
        },
        {
            name: "Key usage attribute error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, errors.New("GetAttributeError"))
            },
            want: expectedKeySet(rsaKeyPair, kid, "RS256", "sig"),
        },
        {
            name: "Get ES256 sig",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP256KeyPair, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP256KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)
            },
            want: expectedKeySet(ecdsaP256KeyPair, kid, "ES256", "sig"),
        },
        {
            name: "Get ES256 enc",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP256KeyPair, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP256KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), nil)
            },
            want: expectedKeySet(ecdsaP256KeyPair, kid, "ES256", "enc"),
        },
        {
            name: "Get ES512 sig",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP521KeyPair, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP521KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)
            },
            want: expectedKeySet(ecdsaP521KeyPair, kid, "ES512", "sig"),
        },
        {
            name: "Get ES512 enc",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP521KeyPair, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP521KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), nil)
            },
            want: expectedKeySet(ecdsaP521KeyPair, kid, "ES512", "enc"),
        },
        {
            name: "Key not found",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
            },
            wantErrMsg: "Not Found",
        },
        {
            name: "FindKeyPair Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairError"))
            },
            wantErrMsg: "FindKeyPairError",
        },
        {
            name: "Unsupported elliptic curve",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP224KeyPair, nil)
            },
            wantErr: errors.WithStack(jwk.ErrUnsupportedEllipticCurve),
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            tt.setup(t)
            got, err := m.GetKey(tt.args.ctx, tt.args.set, tt.args.kid)
            if tt.wantErr != nil {
                require.Nil(t, got)
                require.IsType(t, tt.wantErr, err)
            } else if len(tt.wantErrMsg) != 0 {
                require.Nil(t, got)
                require.EqualError(t, err, tt.wantErrMsg)
                return
            }
            if !reflect.DeepEqual(got, tt.want) {
                t.Errorf("GetKey() got = %v, want %v", got, tt.want)
            }
        })
    }
}

func TestKeyManager_GetKeySet(t *testing.T) {
    ctrl := gomock.NewController(t)
    hsmContext := NewMockContext(ctrl)
    defer ctrl.Finish()
    l := logrusx.New("", "")
    c := config.MustNew(context.Background(), l, configx.SkipValidation())
    m := hsm.NewKeyManager(hsmContext, c)

    rsaKey, err := rsa.GenerateKey(rand.Reader, 4096)
    require.NoError(t, err)
    rsaKid := uuid.New()
    rsaKeyPair := NewMockSignerDecrypter(ctrl)
    rsaKeyPair.EXPECT().Public().Return(&rsaKey.PublicKey).AnyTimes()

    ecdsaP256Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    require.NoError(t, err)
    ecdsaP256Kid := uuid.New()
    ecdsaP256KeyPair := NewMockSignerDecrypter(ctrl)
    ecdsaP256KeyPair.EXPECT().Public().Return(&ecdsaP256Key.PublicKey).AnyTimes()

    ecdsaP521Key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
    require.NoError(t, err)
    ecdsaP521Kid := uuid.New()
    ecdsaP521KeyPair := NewMockSignerDecrypter(ctrl)
    ecdsaP521KeyPair.EXPECT().Public().Return(&ecdsaP521Key.PublicKey).AnyTimes()

    ecdsaP224Key, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
    require.NoError(t, err)
    ecdsaP224Kid := uuid.New()
    ecdsaP224KeyPair := NewMockSignerDecrypter(ctrl)
    ecdsaP224KeyPair.EXPECT().Public().Return(&ecdsaP224Key.PublicKey).AnyTimes()

    allKeys := []crypto11.Signer{rsaKeyPair, ecdsaP256KeyPair, ecdsaP521KeyPair}

    var keys []jose.JSONWebKey
    keys = append(keys, createJSONWebKeys(rsaKeyPair, rsaKid, "RS256", "sig")...)
    keys = append(keys, createJSONWebKeys(ecdsaP256KeyPair, ecdsaP256Kid, "ES256", "sig")...)
    keys = append(keys, createJSONWebKeys(ecdsaP521KeyPair, ecdsaP521Kid, "ES512", "sig")...)

    type args struct {
        ctx context.Context
        set string
    }
    tests := []struct {
        name       string
        setup      func(t *testing.T)
        args       args
        want       *jose.JSONWebKeySet
        wantErrMsg string
        wantErr    error
    }{
        {
            name: "With multiple keys per set",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(allKeys, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(rsaKid)), nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP256KeyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(ecdsaP256Kid)), nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP256KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP521KeyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(ecdsaP521Kid)), nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP521KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil)
            },
            want: &jose.JSONWebKeySet{Keys: keys},
        },
        {
            name: "GetCkaIdAttributeError Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(allKeys, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaId)).Return(nil, errors.New("GetCkaIdAttributeError"))
            },
            wantErrMsg: "GetCkaIdAttributeError",
        },
        {
            name: "Key set not found",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
            },
            wantErrMsg: "Not Found",
        },
        {
            name: "FindKeyPairs Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairsError"))
            },
            wantErrMsg: "FindKeyPairsError",
        },
        {
            name: "Unsupported elliptic curve",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return([]crypto11.Signer{ecdsaP224KeyPair}, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP224KeyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(ecdsaP224Kid)), nil)
            },
            wantErr: errors.WithStack(jwk.ErrUnsupportedEllipticCurve),
        },
        {
            name: "Invalid key type Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                keyPair := NewMockSignerDecrypter(ctrl)
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return([]crypto11.Signer{keyPair}, nil)
                hsmContext.EXPECT().GetAttribute(gomock.Eq(keyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(rsaKid)), nil)
                keyPair.EXPECT().Public().Return(nil).Times(1)
            },
            wantErr: errors.WithStack(jwk.ErrUnsupportedKeyAlgorithm),
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            tt.setup(t)
            got, err := m.GetKeySet(tt.args.ctx, tt.args.set)
            if tt.wantErr != nil {
                require.Nil(t, got)
                require.IsType(t, tt.wantErr, err)
            } else if len(tt.wantErrMsg) != 0 {
                require.Nil(t, got)
                require.EqualError(t, err, tt.wantErrMsg)
                return
            }
            if !reflect.DeepEqual(got, tt.want) {
                t.Errorf("GetKey() got = %v, want %v", got, tt.want)
            }
        })
    }
}

func TestKeyManager_DeleteKey(t *testing.T) {
    ctrl := gomock.NewController(t)
    hsmContext := NewMockContext(ctrl)
    defer ctrl.Finish()
    l := logrusx.New("", "")
    c := config.MustNew(context.Background(), l, configx.SkipValidation())
    m := hsm.NewKeyManager(hsmContext, c)

    rsaKeyPair := NewMockSignerDecrypter(ctrl)

    kid := uuid.New()

    type args struct {
        ctx context.Context
        set string
        kid string
    }
    tests := []struct {
        name       string
        setup      func(t *testing.T)
        args       args
        wantErrMsg string
    }{
        {
            name: "Existing key",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil)
                rsaKeyPair.EXPECT().Delete().Return(nil)
            },
        },
        {
            name: "Key not found",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
            },
            wantErrMsg: "Not Found",
        },
        {
            name: "FindKeyPair Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairError"))
            },
            wantErrMsg: "FindKeyPairError",
        },
        {
            name: "Delete Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
                kid: kid,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil)
                rsaKeyPair.EXPECT().Delete().Return(errors.New("DeleteError"))
            },
            wantErrMsg: "DeleteError",
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            tt.setup(t)
            if err := m.DeleteKey(tt.args.ctx, tt.args.set, tt.args.kid); len(tt.wantErrMsg) != 0 {
                require.EqualError(t, err, tt.wantErrMsg)
            }
        })
    }
}

func TestKeyManager_DeleteKeySet(t *testing.T) {
    ctrl := gomock.NewController(t)
    hsmContext := NewMockContext(ctrl)
    defer ctrl.Finish()
    l := logrusx.New("", "")
    c := config.MustNew(context.Background(), l, configx.SkipValidation())
    m := hsm.NewKeyManager(hsmContext, c)

    rsaKeyPair1 := NewMockSignerDecrypter(ctrl)
    rsaKeyPair2 := NewMockSignerDecrypter(ctrl)
    allKeys := []crypto11.Signer{rsaKeyPair1, rsaKeyPair2}

    type args struct {
        ctx context.Context
        set string
    }
    tests := []struct {
        name       string
        setup      func(t *testing.T)
        args       args
        wantErrMsg string
    }{
        {
            name: "Existing key",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(allKeys, nil)
                rsaKeyPair1.EXPECT().Delete().Return(nil)
                rsaKeyPair2.EXPECT().Delete().Return(nil)
            },
        },
        {
            name: "Key not found",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil)
            },
            wantErrMsg: "Not Found",
        },
        {
            name: "FindKeyPairs Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairsError"))
            },
            wantErrMsg: "FindKeyPairsError",
        },
        {
            name: "Delete Error",
            args: args{
                ctx: context.TODO(),
                set: x.OpenIDConnectKeyName,
            },
            setup: func(t *testing.T) {
                hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(allKeys, nil)
                rsaKeyPair1.EXPECT().Delete().Return(errors.New("DeleteError"))
            },
            wantErrMsg: "DeleteError",
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            tt.setup(t)
            if err := m.DeleteKeySet(tt.args.ctx, tt.args.set); len(tt.wantErrMsg) != 0 {
                require.EqualError(t, err, tt.wantErrMsg)
            }
        })
    }
}

func TestKeyManager_AddKey(t *testing.T) {
    m := &hsm.KeyManager{
        Context: nil,
    }
    err := m.AddKey(context.TODO(), x.OpenIDConnectKeyName, &jose.JSONWebKey{})
    assert.ErrorIs(t, err, hsm.ErrPreGeneratedKeys)
}

func TestKeyManager_AddKeySet(t *testing.T) {
    m := &hsm.KeyManager{
        Context: nil,
    }
    err := m.AddKeySet(context.TODO(), x.OpenIDConnectKeyName, &jose.JSONWebKeySet{})
    assert.ErrorIs(t, err, hsm.ErrPreGeneratedKeys)
}

func TestKeyManager_UpdateKey(t *testing.T) {
    m := &hsm.KeyManager{
        Context: nil,
    }
    err := m.UpdateKey(context.TODO(), x.OpenIDConnectKeyName, &jose.JSONWebKey{})
    assert.ErrorIs(t, err, hsm.ErrPreGeneratedKeys)
}

func TestKeyManager_UpdateKeySet(t *testing.T) {
    m := &hsm.KeyManager{
        Context: nil,
    }
    err := m.UpdateKeySet(context.TODO(), x.OpenIDConnectKeyName, &jose.JSONWebKeySet{})
    assert.ErrorIs(t, err, hsm.ErrPreGeneratedKeys)
}

func expectedKeyAttributes(t *testing.T, set, kid string) (crypto11.AttributeSet, crypto11.AttributeSet) {
    privateAttrSet, err := crypto11.NewAttributeSetWithIDAndLabel([]byte(kid), []byte(set))
    require.NoError(t, err)
    publicAttrSet, err := crypto11.NewAttributeSetWithIDAndLabel([]byte(kid), []byte(set))
    require.NoError(t, err)
    publicAttrSet.AddIfNotPresent([]*pkcs11.Attribute{
        pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true),
        pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, false),
    })
    privateAttrSet.AddIfNotPresent([]*pkcs11.Attribute{
        pkcs11.NewAttribute(pkcs11.CKA_SIGN, true),
        pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, false),
    })
    return privateAttrSet, publicAttrSet
}

func expectedKeySet(keyPair *MockSignerDecrypter, kid, alg, use string) *jose.JSONWebKeySet {
    return &jose.JSONWebKeySet{Keys: createJSONWebKeys(keyPair, kid, alg, use)}
}

func createJSONWebKeys(keyPair *MockSignerDecrypter, kid string, alg string, use string) []jose.JSONWebKey {
    return []jose.JSONWebKey{{
        Algorithm:                   alg,
        Use:                         use,
        Key:                         cryptosigner.Opaque(keyPair),
        KeyID:                       kid,
        Certificates:                []*x509.Certificate{},
        CertificateThumbprintSHA1:   []uint8{},
        CertificateThumbprintSHA256: []uint8{},
    }}
}