 * Nuts node
 * Copyright (C) 2021 Nuts community
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * GNU General Public License for more details.
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <>.

package crypto

import (


// SignJWT creates a JWT from the given claims and signs it with the given key.
func (client *Crypto) SignJWT(ctx context.Context, claims map[string]interface{}, headers map[string]interface{}, key interface{}) (string, error) {
    // copy headers so we don't change the input
    headersLocal := make(map[string]interface{})
    maps.Copy(headersLocal, headers)

    privateKey, kid, err := client.getPrivateKey(ctx, key)
    if err != nil {
        return "", err

    audit.Log(ctx, log.Logger(), audit.CryptoSignJWTEvent).Infof("Signing a JWT with key: %s (issuer: %s, subject: %s)", kid, claims["iss"], claims["sub"])

    alg, err := signingAlg(privateKey.Public())
    if err != nil {
        return "", err

    headersLocal["kid"] = kid
    return signJWT(privateKey, alg, claims, headersLocal)

// SignJWS creates a signed JWS using the indicated key and map of headers and payload as bytes.
func (client *Crypto) SignJWS(ctx context.Context, payload []byte, headers map[string]interface{}, key interface{}, detached bool) (string, error) {
    privateKey, kid, err := client.getPrivateKey(ctx, key)
    if err != nil {
        return "", err
    alg, err := SignatureAlgorithm(privateKey.Public())
    if err != nil {
        return "", err

    audit.Log(ctx, log.Logger(), audit.CryptoSignJWSEvent).Infof("Signing a JWS with key: %s", kid)

    return signJWS(payload, headers, privateKey, alg, detached)

// EncryptJWE encrypts a payload using the provided public key and key identifier.
func (client *Crypto) EncryptJWE(ctx context.Context, payload []byte, headers map[string]interface{}, publicKey interface{}) (string, error) {
    audit.Log(ctx, log.Logger(), audit.CryptoEncryptJWEEvent).Info("Encrypting a JWE")
    return EncryptJWE(payload, headers, publicKey)

// DecryptJWE decrypts a message using the associated private key from the kid header.
func (client *Crypto) DecryptJWE(ctx context.Context, message string) (body []byte, headers map[string]interface{}, err error) {
    msg, err := jwe.Parse([]byte(message))
    if err != nil {
        return nil, nil, err

    protectedHeaders := msg.ProtectedHeaders()
    kid := protectedHeaders.KeyID()
    if len(kid) == 0 {
        return nil, nil, errors.New("kid header not found")
    privateKey, kid, err := client.getPrivateKey(ctx, kid)
    if err != nil {
        return nil, nil, err

    audit.Log(ctx, log.Logger(), audit.CryptoDecryptJWEEvent).Infof("Decrypting a JWE with kid: %s", kid)

    keyJWK, err := jwk.FromRaw(privateKey)
    if err != nil {
        return nil, nil, fmt.Errorf("keys stored in '%s' do not support JWE decryption",
    body, err = jwe.Decrypt([]byte(message), jwe.WithKey(protectedHeaders.Algorithm(), keyJWK))
    if err != nil {
        return nil, nil, err
    headers, err = msg.ProtectedHeaders().AsMap(ctx)
    if err != nil {
        return nil, nil, err
    return body, headers, err

// signJWT signs claims with the signer and returns the compacted token. The headers param can be used to add additional headers
func signJWT(key crypto.Signer, alg jwa.SignatureAlgorithm, claims map[string]interface{}, headers map[string]interface{}) (token string, err error) {
    var sig []byte
    t := jwt.New()

    for k, v := range claims {
        if err := t.Set(k, v); err != nil {
            return "", err
    hdr, err := convertHeaders(headers)
    if err != nil {
        return "", fmt.Errorf("invalid JWT headers: %w", err)

    sig, err = jwt.Sign(t, jwt.WithKey(jwa.SignatureAlgorithm(alg.String()), key, jws.WithProtectedHeaders(hdr)))
    token = string(sig)


// JWTKidAlg parses a JWT, does not validate it and returns the 'kid' and 'alg' headers
func JWTKidAlg(tokenString string) (string, jwa.SignatureAlgorithm, error) {
    j, err := jws.ParseString(tokenString)
    if err != nil {
        return "", "", err

    if len(j.Signatures()) != 1 {
        return "", "", errors.New("incorrect number of signatures in JWT")

    sig := j.Signatures()[0]
    hdrs := sig.ProtectedHeaders()
    return hdrs.KeyID(), hdrs.Algorithm(), nil

// PublicKeyFunc defines a function that resolves a public key based on a kid
type PublicKeyFunc func(kid string) (crypto.PublicKey, error)

// ParseJWT parses a token, validates and verifies it.
func ParseJWT(tokenString string, f PublicKeyFunc, options ...jwt.ParseOption) (jwt.Token, error) {
    kid, alg, err := JWTKidAlg(tokenString)
    if err != nil {
        return nil, err

    key, err := f(kid)
    if err != nil {
        return nil, err

    if !jwx.IsAlgorithmSupported(alg) {
        return nil, fmt.Errorf("token signing algorithm is not supported: %s", alg)

    options = append(options, jwt.WithKey(alg, key))
    options = append(options, jwt.WithVerify(true))

    return jwt.ParseString(tokenString, options...)

// ParseJWS parses a JWS byte array object, validates and verifies it.
// This method returns the value of the payload as byte array, or an error if
// the parsing fails at any level.
func ParseJWS(token []byte, f PublicKeyFunc) (payload []byte, err error) {
    message, err := jws.Parse(token)
    if err != nil {
        return nil, err
    headers, body, _, err := jws.SplitCompact(token)
    if err != nil {
        return nil, err
    signatures := message.Signatures()
    for i := range signatures {
        signature := signatures[i]
        // Get and check the algorithm
        alg := signature.ProtectedHeaders().Algorithm()
        if !jwx.IsAlgorithmSupported(alg) {
            return nil, fmt.Errorf("token signing algorithm is not supported: %s", alg)
        // Get the verifier for the algorithm
        verifier, err := jws.NewVerifier(alg)
        if err != nil {
            return nil, err
        // Get the key id, and get the associated key
        kid := signature.ProtectedHeaders().KeyID()
        key, err := f(kid)
        if err != nil {
            return nil, err
        // This seems an awkward way of appending 3 arrays.
        var payload []byte
        parts := [][]byte{headers, []byte("."), body}
        for _, part := range parts {
            payload = append(payload, part...)
        err = verifier.Verify(payload, signature.Signature(), key)
        if err != nil {
            return nil, err

    body = message.Payload()
    return body, nil

func signJWS(payload []byte, protectedHeaders map[string]interface{}, privateKey crypto.Signer, alg jwa.SignatureAlgorithm, detachedPayload bool) (string, error) {
    headers := jws.NewHeaders()
    for key, value := range protectedHeaders {
        if err := headers.Set(key, value); err != nil {
            return "", fmt.Errorf("unable to set header %s: %w", key, err)
    // The JWX library is fine with creating a JWK for a private key (including the private exponents), so
    // we want to make sure the `jwk` header (if present) does not (accidentally) contain a private key.
    // That would lead to the node leaking its private key material in the resulting JWS which would be very, very bad.
    if headers.JWK() != nil {
        var jwkAsPrivateKey crypto.Signer
        if err := headers.JWK().Raw(&jwkAsPrivateKey); err == nil {
            // `err != nil` is good in this case, because that means the key is not assignable to crypto.Signer,
            // which is the interface implemented by all private key types.
            return "", errors.New("refusing to sign JWS with private key in JWK header")

    var (
        data []byte
        err  error
    if detachedPayload {
        // Sign JWS with detached payload
        data, err = jws.Sign(nil, jws.WithKey(alg, privateKey, jws.WithProtectedHeaders(headers)), jws.WithDetachedPayload(payload))
    } else {
        // Sign normal JWS
        data, err = jws.Sign(payload, jws.WithKey(alg, privateKey, jws.WithProtectedHeaders(headers)))
    if err != nil {
        return "", fmt.Errorf("unable to sign JWS %w", err)
    return string(data), nil

func EncryptJWE(payload []byte, protectedHeaders map[string]interface{}, publicKey interface{}) (message string, err error) {
    if publicKey == nil {
        return "", errors.New("no publicKey provided")
    json, err := json.Marshal(protectedHeaders)
    if err != nil {
        return "", err
    headers := jwe.NewHeaders()
    err = headers.UnmarshalJSON(json)
    if err != nil {
        return "", err
    // Figure out the KeyEncryptionAlgorithm, give prevalence to the headers
    var alg jwa.KeyEncryptionAlgorithm
    if len(headers.Algorithm().String()) > 0 {
        alg = headers.Algorithm()
    } else {
        alg, err = encryptionAlgorithm(publicKey)
        if err != nil {
            return "", err

    // Figure out the KeyEncryptionAlgorithm, give prevalence to the headers
    enc := jwx.DefaultContentEncryptionAlgorithm
    if len(headers.ContentEncryption().String()) > 0 {
        enc = headers.ContentEncryption()
    options := []jwe.EncryptOption{
        jwe.WithKey(alg, publicKey),
        jwe.WithCompress(headers.Compression()), // "" means no compression

    encoded, err := jwe.Encrypt(payload, options...)
    return string(encoded), err

func (client *Crypto) getPrivateKey(ctx context.Context, key interface{}) (crypto.Signer, string, error) {
    var kid string
    switch k := key.(type) {
    case exportableKey:
        return k.Signer(), k.KID(), nil
    case Key:
        kid = k.KID()
    case string:
        kid = k
        return nil, "", errors.New("provided key must be either string or Key")

    privateKey, err :=, kid)
    if err != nil {
        if errors.Is(err, spi.ErrNotFound) {
            return nil, "", ErrPrivateKeyNotFound
        return nil, "", err
    return privateKey, kid, nil

func convertHeaders(headers map[string]interface{}) (jws.Headers, error) {
    hdr := jws.NewHeaders()

    for k, v := range headers {
        if err := hdr.Set(k, v); err != nil {
            return nil, err
    return hdr, nil

func signingAlg(key crypto.PublicKey) (jwa.SignatureAlgorithm, error) {
    switch k := key.(type) {
    case *rsa.PublicKey:
        return jwa.PS256, nil
    case *ecdsa.PublicKey:
        return ecAlgUsingPublicKey(*k)
    case ed25519.PublicKey:
        return jwa.EdDSA, nil
        return "", fmt.Errorf(`could not determine signature algorithm for key type '%T'`, key)

func ecAlgUsingPublicKey(key ecdsa.PublicKey) (alg jwa.SignatureAlgorithm, err error) {
    switch key.Params().BitSize {
    case 256:
        alg = jwa.ES256
    case 384:
        alg = jwa.ES384
    case 521:
        alg = jwa.ES512
        err = jwx.ErrUnsupportedSigningKey

// SignatureAlgorithm determines the jwa.SigningAlgorithm for ec/rsa/ed25519 keys.
func SignatureAlgorithm(key crypto.PublicKey) (jwa.SignatureAlgorithm, error) {
    if key == nil {
        return "", errors.New("no key provided")

    var ptr interface{}
    switch v := key.(type) {
    case rsa.PrivateKey:
        ptr = &v
    case rsa.PublicKey:
        ptr = &v
    case ecdsa.PrivateKey:
        ptr = &v
    case ecdsa.PublicKey:
        ptr = &v
        ptr = v

    switch k := ptr.(type) {
    case *rsa.PrivateKey:
        return jwa.PS256, nil
    case *rsa.PublicKey:
        return jwa.PS256, nil
    case *ecdsa.PrivateKey:
        return ecAlgUsingPublicKey(k.PublicKey)
    case *ecdsa.PublicKey:
        return ecAlgUsingPublicKey(*k)
    case ed25519.PrivateKey:
        return jwa.EdDSA, nil
    case ed25519.PublicKey:
        return jwa.EdDSA, nil
        return "", fmt.Errorf(`could not determine signature algorithm for key type '%T'`, key)

func encryptionAlgorithm(key crypto.PublicKey) (jwa.KeyEncryptionAlgorithm, error) {

    switch key.(type) {
    case *rsa.PublicKey:
        return jwx.DefaultRsaEncryptionAlgorithm, nil
    case *ecdsa.PublicKey:
        return jwx.DefaultEcEncryptionAlgorithm, nil
        return "", fmt.Errorf("could not determine encryption algorithm for key type '%T'", key)

// Thumbprint generates a Nuts compatible thumbprint: Base58(SHA256(rfc7638-json))
func Thumbprint(key jwk.Key) (string, error) {
    if key == nil {
        panic("Thumbprint(): key is nil")
    pkHash, err := key.Thumbprint(crypto.SHA256)
    if err != nil {
        return "", err
    return base58.EncodeAlphabet(pkHash[:], base58.BTCAlphabet), nil