nuts-foundation/nuts-node

View on GitHub
crypto/storage/azure/keyvault.go

Summary

Maintainability
A
0 mins
Test Coverage
D
63%
/*
 * Copyright (C) 2024 Nuts community
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 *
 */

package azure

import (
    "context"
    "crypto"
    "encoding/json"
    "errors"
    "fmt"
    "github.com/Azure/azure-sdk-for-go/sdk/azcore"
    "github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
    "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
    "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys"
    "github.com/lestrrat-go/jwx/v2/jwk"
    "github.com/nuts-foundation/nuts-node/core"
    "github.com/nuts-foundation/nuts-node/crypto/log"
    "github.com/nuts-foundation/nuts-node/crypto/storage/spi"
    "golang.org/x/crypto/cryptobyte"
    "golang.org/x/crypto/cryptobyte/asn1"
    "io"
    "net/http"
    "regexp"
    "time"
)

const (
    DefaultChainCredentialType    string = "default"
    ManagedIdentityCredentialType string = "managed_identity"
)

// New creates a new Azure Key Vault storage backend.
// If useHSM is true, the key type will be azkeys.KeyTypeECHSM, otherwise azkeys.KeyTypeEC.
func New(config Config) (spi.Storage, error) {
    if config.URL == "" {
        return nil, errors.New("missing Azure Key Vault URL")
    }
    credential, err := createCredential(config.Auth.Type)
    if err != nil {
        return nil, err
    }
    client, err := azkeys.NewClient(config.URL, credential, nil)
    if err != nil {
        return nil, fmt.Errorf("unable to create Azure Key Vault client: %w", err)
    }
    return &keyvault{client: client, timeOut: config.Timeout, useHSM: config.UseHSM}, nil
}

func createCredential(credentialType string) (azcore.TokenCredential, error) {
    switch credentialType {
    case DefaultChainCredentialType:
        return azidentity.NewDefaultAzureCredential(nil)
    case ManagedIdentityCredentialType:
        return azidentity.NewManagedIdentityCredential(nil)
    default:
        return nil, fmt.Errorf("unsupported Azure Key Vault credential type: %s", credentialType)
    }
}

// StorageType is the name of this storage type, used in health check reports and configuration.
const StorageType = "azure-keyvault"

type keyvault struct {
    client  keyVaultClient
    timeOut time.Duration
    useHSM  bool
}

func (a keyvault) Name() string {
    return StorageType
}

func (a keyvault) CheckHealth() map[string]core.Health {
    return nil
}

func (a keyvault) NewPrivateKey(ctx context.Context, namingFunc func(crypto.PublicKey) (string, error)) (crypto.PublicKey, string, error) {
    keyID, err := namingFunc(nil)
    if err != nil {
        return nil, "", err
    }
    // Make sure it doesn't already exist: Azure Key Vault otherwise creates a new version for the same key.
    exists, err := a.PrivateKeyExists(ctx, keyID)
    if err != nil {
        return nil, "", err
    }
    if exists {
        return nil, "", spi.ErrKeyAlreadyExists
    }

    keyName := keyIDToKeyName(keyID)

    var keyType azkeys.KeyType
    if a.useHSM {
        keyType = azkeys.KeyTypeECHSM
    } else {
        keyType = azkeys.KeyTypeEC
    }

    response, err := a.client.CreateKey(ctx, keyName, azkeys.CreateKeyParameters{
        Kty:   to.Ptr(keyType),
        Curve: to.Ptr(azkeys.CurveNameP256),
        KeyAttributes: &azkeys.KeyAttributes{
            Enabled:    to.Ptr(true),
            Exportable: to.Ptr(false),
        },
        Tags: map[string]*string{
            "originalKID": to.Ptr(keyID),
        },
    }, nil)
    if err != nil {
        return nil, "", fmt.Errorf("unable to create key in Azure Key Vault (name=%s): %w", keyName, err)
    }
    publicKey, _, err := parseKey(response.Key)
    if err != nil {
        return nil, "", err
    }
    return publicKey, keyID, nil
}

func (a keyvault) GetPrivateKey(ctx context.Context, kid string) (crypto.Signer, error) {
    keyName := keyIDToKeyName(kid)
    response, err := a.getPrivateKey(ctx, keyName)
    if err != nil {
        return nil, err
    }
    publicKey, signingAlgorithm, err := parseKey(response.Key)
    if err != nil {
        return nil, err
    }
    return &azureSigningKey{
        client:           a.client,
        timeOut:          a.timeOut,
        keyName:          keyName,
        publicKey:        publicKey,
        signingAlgorithm: signingAlgorithm,
    }, nil
}

func (a keyvault) PrivateKeyExists(ctx context.Context, kid string) (bool, error) {
    _, err := a.getPrivateKey(ctx, keyIDToKeyName(kid))
    if errors.Is(err, spi.ErrNotFound) {
        return false, nil
    }
    if err != nil {
        return false, err
    }
    return true, nil
}

func (a keyvault) DeletePrivateKey(ctx context.Context, kid string) error {
    _, err := a.client.DeleteKey(ctx, keyIDToKeyName(kid), nil)
    responseError := new(azcore.ResponseError)
    if errors.As(err, &responseError) && responseError.StatusCode == http.StatusNotFound {
        return spi.ErrNotFound
    } else if err != nil {
        return fmt.Errorf("unable to delete key from Azure Key Vault (name=%s): %w", keyIDToKeyName(kid), err)
    }
    return nil
}

func (a keyvault) getPrivateKey(ctx context.Context, keyName string) (*azkeys.GetKeyResponse, error) {
    response, err := a.client.GetKey(ctx, keyName, "", nil)
    responseError := new(azcore.ResponseError)
    if errors.As(err, &responseError) && responseError.StatusCode == http.StatusNotFound {
        return nil, spi.ErrNotFound
    } else if err != nil {
        // other error
        return nil, fmt.Errorf("unable to get key from Azure Key Vault (name=%s): %w", keyName, err)
    }
    return &response, nil
}

func (a keyvault) SavePrivateKey(ctx context.Context, kid string, key crypto.PrivateKey) error {
    // Only used for migrating to a new storage backend, which is not implemented yet for Azure Key Vault
    return errors.New("SavePrivateKey() is not supported for Azure Key Vault")
}

func (a keyvault) ListPrivateKeys(ctx context.Context) []string {
    pager := a.client.NewListKeyPropertiesPager(nil)
    result := make([]string, 0)
    for pager.More() {
        page, err := pager.NextPage(ctx)
        if err != nil {
            log.Logger().WithError(err).Error("unable to list keys from Azure Key Vault")
            return nil
        }
        for _, keyProperties := range page.Value {
            kid, ok := keyProperties.Tags["originalKID"]
            if ok {
                result = append(result, *kid)
            }
        }
    }
    return result
}

// parseKey parses an Azure Key Vault key into a crypto.PublicKey and selects the azkeys.SignatureAlgorithm.
func parseKey(key *azkeys.JSONWebKey) (crypto.PublicKey, azkeys.SignatureAlgorithm, error) {
    jwkData, _ := json.Marshal(key)
    keyAsJWK, err := jwk.ParseKey(jwkData)
    if err != nil {
        return nil, "", fmt.Errorf("unable to parse key from Azure Key Vault as JWK: %w", err)
    }
    var publicKey crypto.PublicKey
    if err := keyAsJWK.Raw(&publicKey); err != nil {
        return nil, "", fmt.Errorf("unable to convert key from Azure Key Vault Key to crypto.PublicKey: %w", err)
    }
    if !(*key.Kty == azkeys.KeyTypeEC || *key.Kty == azkeys.KeyTypeECHSM) || *key.Crv != azkeys.CurveNameP256 {
        return nil, "", errors.New("only ES256 keys are supported")
    }
    return publicKey, azkeys.SignatureAlgorithmES256, nil
}

var allowedKeyNameRegex = regexp.MustCompile("[^0-9a-zA-Z-]+")

func keyIDToKeyName(keyID string) string {
    return allowedKeyNameRegex.ReplaceAllString(keyID, "-")
}

var _ crypto.Signer = &azureSigningKey{}

type azureSigningKey struct {
    client           keyVaultClient
    timeOut          time.Duration
    keyName          string
    publicKey        crypto.PublicKey
    signingAlgorithm azkeys.SignatureAlgorithm
}

func (a azureSigningKey) Public() crypto.PublicKey {
    return a.publicKey
}

func (a azureSigningKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
    ctx, cancel := context.WithTimeout(context.Background(), a.timeOut)
    defer cancel()
    // Sanity check
    if opts != nil && opts.HashFunc() == 0 {
        return nil, errors.New("hashing should've been done")
    }
    response, err := a.client.Sign(ctx, a.keyName, "", azkeys.SignParameters{
        Algorithm: to.Ptr(a.signingAlgorithm),
        Value:     digest,
    }, nil)
    if err != nil {
        return nil, fmt.Errorf("unable to sign with Azure Key Vault: %w", err)
    }
    // Azure Key Vault returns the signature in a []byte with r and s components concatenated.
    // We need to convert it to an ASN.1-encoded signature. The first half of the signature is r, the second half is s.
    return encodeSignature(response.Result[:len(response.Result)/2], response.Result[len(response.Result)/2:])
}

// encodeSignature was copied from ecdsa/ecdsa.go#encodeSignature()
func encodeSignature(r, s []byte) ([]byte, error) {
    var b cryptobyte.Builder
    b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
        addASN1IntBytes(b, r)
        addASN1IntBytes(b, s)
    })
    return b.Bytes()
}

// addASN1IntBytes was copied from ecdsa/ecdsa.go#addASN1IntBytes()
func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
    for len(bytes) > 0 && bytes[0] == 0 {
        bytes = bytes[1:]
    }
    if len(bytes) == 0 {
        b.SetError(errors.New("invalid integer"))
        return
    }
    b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) {
        if bytes[0]&0x80 != 0 { // note: this has to do with signed/unsigned requiring leading zero
            c.AddUint8(0)
        }
        c.AddBytes(bytes)
    })
}