cmd/server/helper_cert_test.go
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package server_test
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"os"
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/google/uuid"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"github.com/ory/x/configx"
"github.com/ory/x/logrusx"
"github.com/ory/x/tlsx"
"github.com/ory/hydra/v2/cmd/server"
"github.com/ory/hydra/v2/driver"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/internal/testhelpers"
"github.com/ory/hydra/v2/jwk"
)
func TestGetOrCreateTLSCertificate(t *testing.T) {
certPath, keyPath, cert, priv := testhelpers.GenerateTLSCertificateFilesForTests(t)
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { t.Fatalf("Logger called os.Exit(%v)", code) }
hook := test.NewLocal(logger.Logger)
cfg := config.MustNew(
context.Background(),
logger,
configx.WithValues(map[string]interface{}{
"dsn": config.DSNMemory,
"serve.tls.enabled": true,
"serve.tls.cert.path": certPath,
"serve.tls.key.path": keyPath,
}),
)
d, err := driver.NewRegistryWithoutInit(cfg, logger)
require.NoError(t, err)
getCert := server.GetOrCreateTLSCertificate(context.Background(), d, config.AdminInterface, nil)
require.NotNil(t, getCert)
tlsCert, err := getCert(nil)
require.NoError(t, err)
require.NotNil(t, tlsCert)
if tlsCert.Leaf == nil {
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
require.NoError(t, err)
}
require.True(t, tlsCert.Leaf.Equal(cert))
require.True(t, priv.Equal(tlsCert.PrivateKey))
// generate new cert+key
newCertPath, newKeyPath, newCert, newPriv := testhelpers.GenerateTLSCertificateFilesForTests(t)
require.False(t, cert.Equal(newCert))
require.False(t, priv.Equal(newPriv))
require.NotEqual(t, certPath, newCertPath)
require.NotEqual(t, keyPath, newKeyPath)
// move them into place
require.NoError(t, os.Rename(newKeyPath, keyPath))
require.NoError(t, os.Rename(newCertPath, certPath))
// give it some time and check we're reloaded
time.Sleep(150 * time.Millisecond)
require.Nil(t, hook.LastEntry())
// request another certificate: it should be the new one
tlsCert, err = getCert(nil)
require.NoError(t, err)
if tlsCert.Leaf == nil {
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
require.NoError(t, err)
}
require.True(t, tlsCert.Leaf.Equal(newCert))
require.True(t, newPriv.Equal(tlsCert.PrivateKey))
require.NoError(t, os.WriteFile(certPath, []byte{'j', 'u', 'n', 'k'}, 0))
timeout := time.After(500 * time.Millisecond)
for {
if hook.LastEntry() != nil {
break
}
select {
case <-timeout:
require.FailNow(t, "expected error log entry")
default:
}
}
require.Contains(t, hook.LastEntry().Message, "Failed to reload TLS certificates. Using the previously loaded certificates.")
}
func TestGetOrCreateTLSCertificateBase64(t *testing.T) {
certPath, keyPath, cert, priv := testhelpers.GenerateTLSCertificateFilesForTests(t)
certPEM, err := os.ReadFile(certPath)
require.NoError(t, err)
certBase64 := base64.StdEncoding.EncodeToString(certPEM)
keyPEM, err := os.ReadFile(keyPath)
require.NoError(t, err)
keyBase64 := base64.StdEncoding.EncodeToString(keyPEM)
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { t.Fatalf("Logger called os.Exit(%v)", code) }
hook := test.NewLocal(logger.Logger)
_ = hook
cfg := config.MustNew(
context.Background(),
logger,
configx.WithValues(map[string]interface{}{
"dsn": config.DSNMemory,
"serve.tls.enabled": true,
"serve.tls.cert.base64": certBase64,
"serve.tls.key.base64": keyBase64,
}),
)
d, err := driver.NewRegistryWithoutInit(cfg, logger)
require.NoError(t, err)
getCert := server.GetOrCreateTLSCertificate(context.Background(), d, config.AdminInterface, nil)
require.NotNil(t, getCert)
tlsCert, err := getCert(nil)
require.NoError(t, err)
require.NotNil(t, tlsCert)
if tlsCert.Leaf == nil {
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
require.NoError(t, err)
}
require.True(t, tlsCert.Leaf.Equal(cert))
require.True(t, priv.Equal(tlsCert.PrivateKey))
}
func TestCreateSelfSignedCertificate(t *testing.T) {
keys, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New().String(), "sig")
require.NoError(t, err)
private := keys.Keys[0]
cert, err := tlsx.CreateSelfSignedCertificate(private.Key)
require.NoError(t, err)
server.AttachCertificate(&private, cert)
var actual jose.JSONWebKeySet
var b bytes.Buffer
require.NoError(t, json.NewEncoder(&b).Encode(keys))
require.NoError(t, json.NewDecoder(&b).Decode(&actual))
}