CMSgov/dpc-app

View on GitHub
dpc-api/src/main/java/gov/cms/dpc/api/auth/jwt/PublicKeyHandler.java

Summary

Maintainability
A
0 mins
Test Coverage
B
83%
package gov.cms.dpc.api.auth.jwt;

import gov.cms.dpc.api.entities.PublicKeyEntity;
import gov.cms.dpc.api.exceptions.PublicKeyException;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.sec.SECObjectIdentifiers;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.openssl.PEMException;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.util.io.pem.PemObject;
import org.bouncycastle.util.io.pem.PemWriter;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;

public class PublicKeyHandler {

    // ECC Curve names are defined here: https://tools.ietf.org/search/rfc4492#section-5.1.1
    // They're in two separate namespaces, certicom and ansi-x962
    public static final ASN1ObjectIdentifier RSA_PARENT = new ASN1ObjectIdentifier("1.2.840.113549");
    public static final ASN1ObjectIdentifier ECC_KEY = new ASN1ObjectIdentifier("1.2.840.10045.2.1");

    private PublicKeyHandler() {
        // Not used
    }

    /**
     * Parse and validate PEM encoded Public Key
     *
     * @param pem - {@link String} PEM encoded public key
     * @return - {@link SubjectPublicKeyInfo}
     */
    public static SubjectPublicKeyInfo parsePEMString(String pem) {
        final ByteArrayInputStream bas = new ByteArrayInputStream(pem.getBytes(StandardCharsets.ISO_8859_1));
        try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(bas, StandardCharsets.UTF_8))) {
            try (PEMParser pemParser = new PEMParser(bufferedReader)) {
                try {
                    final Object object = pemParser.readObject();
                    if (object == null) {
                        throw new PublicKeyException("Cannot parse public key, returned value is null");
                    }
                    if (!(object instanceof SubjectPublicKeyInfo)) {
                        throw new PublicKeyException(String.format("Cannot convert %s to %s.", object.getClass().getName(), SubjectPublicKeyInfo.class.getName()));
                    }
                    return (SubjectPublicKeyInfo) object;
                } catch (PEMException e) {
                    throw new PublicKeyException("Not a valid public key", e);
                }
            }
        } catch (IOException e) {
            throw new PublicKeyException("Cannot parse Public Key input", e);
        }
    }

    /**
     * Convert the {@link SubjectPublicKeyInfo} to a PEM encoded string value
     *
     * @param value - {@link SubjectPublicKeyInfo} to encode as PEM
     * @return - {@link String} PEM encoded public key
     */
    public static String pemEncodePublicKey(SubjectPublicKeyInfo value) {
        try {
            final PemObject object = new PemObject("PUBLIC KEY", value.getEncoded());
            try (StringWriter stringWriter = new StringWriter()) {
                try (PemWriter pemWriter = new PemWriter(stringWriter)) {
                    pemWriter.writeObject(object);
                }
                return stringWriter.toString();
            }
        } catch (IOException e) {
            throw new PublicKeyException("Cannot convert public key to PEM", e);
        }
    }

    public static void validatePublicKey(SubjectPublicKeyInfo value) {
        // If RSA, do some other validations
        final ASN1ObjectIdentifier algorithmID = value.getAlgorithm().getAlgorithm();
        if (algorithmID.on(RSA_PARENT)) {
            validateRSAKey(value);
        } else if (algorithmID.equals(ECC_KEY)) {
            validateECCKey(value);
        } else {
            throw new PublicKeyException(String.format("Unsupported key type `%s`.", algorithmID.getId()));
        }

    }

    public static void verifySignature(String publicKeyPem, String snippet, String sigStr) {
        String keyStr = publicKeyPem
                .replace("-----BEGIN PUBLIC KEY-----", "")
                .replace("-----END PUBLIC KEY-----", "")
                .replaceAll("[\n\r]+", "");
        sigStr = sigStr.replaceAll("[\n\r]+", "");

        try {
            byte[] keyBytes = Base64.getDecoder().decode(keyStr);
            X509EncodedKeySpec keySpec = new X509EncodedKeySpec(keyBytes);
            KeyFactory keyFactory = KeyFactory.getInstance("RSA");
            PublicKey publicKey = keyFactory.generatePublic(keySpec);

            Signature signature = Signature.getInstance("SHA256withRSA");
            signature.initVerify(publicKey);
            signature.update(snippet.getBytes(StandardCharsets.UTF_8));
            if (!signature.verify(Base64.getDecoder().decode(sigStr))) {
                throw new PublicKeyException("Key and signature do not match");
            }
        } catch (NoSuchAlgorithmException e) {
            throw new PublicKeyException("Invalid algorithm", e);
        } catch (GeneralSecurityException e) {
            throw new PublicKeyException("Signature could not be verified.", e);
        }
    }

    private static void validateRSAKey(SubjectPublicKeyInfo value) {
        // Should have a minimum length
        try {
            // Verifies the key is at least 4096 bits, which is 550 bytes of encoded data
            if (value.getEncoded().length < 550) {
                throw new PublicKeyException("Public key must be at least 4096 bits.");
            }
        } catch (IOException e) {
            throw new PublicKeyException("Cannot read public key.", e);
        }

    }

    private static void validateECCKey(SubjectPublicKeyInfo value) {
        // Verify we have a supported curve, which is currently secp256r1 or secp384r1
        final ASN1Primitive curveName = value.getAlgorithm().getParameters().toASN1Primitive();
        if (!(curveName.equals(SECObjectIdentifiers.secp256r1) || curveName.equals(SECObjectIdentifiers.secp384r1))) {
            throw new PublicKeyException(String.format("ECC curve `%s` is not supported.", curveName.toString()));
        }
    }

    /**
     * Convert the given {@link PublicKeyEntity} to a {@link PublicKey}
     *
     * @param entity - {@link PublicKeyEntity} to convert
     * @return - {@link PublicKey}
     * @throws PublicKeyException - throws if the conversion fails
     */
    static PublicKey publicKeyFromEntity(PublicKeyEntity entity) {
        X509EncodedKeySpec spec;
        try {
            final SubjectPublicKeyInfo publicKeySpec = entity.getPublicKey();
            final String keyType = publicKeySpec.getAlgorithm().getAlgorithm().on(RSA_PARENT) ? "RSA" : "EC";
            spec = new X509EncodedKeySpec(publicKeySpec.getEncoded());
            return KeyFactory.getInstance(keyType).generatePublic(spec);
        } catch (IOException | NoSuchAlgorithmException | InvalidKeySpecException e) {
            throw new PublicKeyException("Cannot convert Key Spec to Public Key", e);
        }
    }
}