CMSgov/dpc-app

View on GitHub
dpc-testing/src/main/java/gov/cms/dpc/testing/APIAuthHelpers.java

Summary

Maintainability
B
5 hrs
Test Coverage
package gov.cms.dpc.testing;

import ca.uhn.fhir.context.FhirContext;
import ca.uhn.fhir.rest.client.api.IClientInterceptor;
import ca.uhn.fhir.rest.client.api.IGenericClient;
import ca.uhn.fhir.rest.client.api.IHttpRequest;
import ca.uhn.fhir.rest.client.api.IHttpResponse;
import ca.uhn.fhir.rest.client.interceptor.LoggingInterceptor;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.nitram509.jmacaroons.MacaroonVersion;
import com.github.nitram509.jmacaroons.MacaroonsBuilder;
import com.google.common.net.HttpHeaders;
import gov.cms.dpc.testing.models.KeyView;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.http.HttpRequest;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.NameValuePair;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.protocol.HttpContext;
import org.apache.http.util.EntityUtils;
import org.eclipse.jetty.http.HttpStatus;

import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import javax.ws.rs.core.MediaType;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.security.spec.ECGenParameterSpec;
import java.sql.Date;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.*;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;

public class APIAuthHelpers {
    public static final String TASK_URL = "http://localhost:9900/tasks/";
    public static final String KEY_VERIFICATION_SNIPPET = "This is the snippet used to verify a key pair in DPC.";
    private static final String CLIENT_ASSERTION_TYPE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
    private static final ObjectMapper mapper = new ObjectMapper();

    private APIAuthHelpers() {
        // Not used
    }

    public static IGenericClient buildAuthenticatedClient(FhirContext ctx, String baseURL, String macaroon, UUID keyID, PrivateKey privateKey) {
        return buildAuthenticatedClient(ctx, baseURL, macaroon, keyID, privateKey, false);
    }

    public static IGenericClient buildAuthenticatedClient(FhirContext ctx, String baseURL, String macaroon, UUID keyID, PrivateKey privateKey, boolean disableSSLCheck) {
        return buildAuthenticatedClient(ctx, baseURL, macaroon, keyID, privateKey, disableSSLCheck, false);
    }

    public static IGenericClient buildAuthenticatedClient(FhirContext ctx, String baseURL, String macaroon, UUID keyID, PrivateKey privateKey, boolean disableSSLCheck, boolean enableRequestLog) {
        final IGenericClient client = createBaseFHIRClient(ctx, baseURL, disableSSLCheck, enableRequestLog);
        client.registerInterceptor(new HAPISmartInterceptor(baseURL, macaroon, keyID, privateKey));

        // Add the async header the hard way
        final var addPreferInterceptor = new IClientInterceptor() {
            @Override
            public void interceptRequest(IHttpRequest iHttpRequest) {
                // Manually set these values, rather than pulling a dependency on dpc-common, where the constants are defined
                iHttpRequest.addHeader("Prefer", "respond-async");
            }

            @Override
            public void interceptResponse(IHttpResponse iHttpResponse) {
                // Not used
            }
        };
        client.registerInterceptor(addPreferInterceptor);

        return client;
    }

    public static IGenericClient buildAdminClient(FhirContext ctx, String baseURL, String macaroon, boolean disableSSLCheck) {
        return buildAdminClient(ctx, baseURL, macaroon, disableSSLCheck, false);
    }

    public static IGenericClient buildAdminClient(FhirContext ctx, String baseURL, String macaroon, boolean disableSSLCheck, boolean enableRequestLog) {
        final IGenericClient client = createBaseFHIRClient(ctx, baseURL, disableSSLCheck, enableRequestLog);
        client.registerInterceptor(new MacaroonsInterceptor(macaroon));
        return client;
    }

    public static AuthResponse jwtAuthFlow(String baseURL, String macaroon, UUID keyID, PrivateKey privateKey) throws IOException, URISyntaxException {
        /* TODO revert this workaround to previous version of code
         * - git diff f2d3abe1f23e4d1ad2f2a01 5d799c57712418de674 <<< green is good
         * see also https://github.com/CMSgov/dpc-app/pull/849
         */
        String audience = baseURL;
        if (baseURL.startsWith("http://internal-dpc-prod-")) {
            audience = "https://prod.dpc.cms.gov/api/v1";
        }
        final String jwt = Jwts.builder()
                .setHeaderParam("kid", keyID)
                .setAudience(String.format("%s/Token/auth", audience))
                .setIssuer(macaroon)
                .setSubject(macaroon)
                .setId(UUID.randomUUID().toString())
                .setExpiration(Date.from(Instant.now().plus(5, ChronoUnit.MINUTES).minus(30, ChronoUnit.SECONDS)))
                .signWith(privateKey, getSigningAlgorithm(KeyType.RSA))
                .compact();

        // Verify JWT with /validate endpoint
        try (final CloseableHttpClient client = createCustomHttpClient().trusting().build()) {
            final URIBuilder builder = new URIBuilder(String.format("%s/Token/validate", baseURL));
            final HttpPost post = new HttpPost(builder.build());
            post.setEntity(new StringEntity(jwt));
            post.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN);

            try (CloseableHttpResponse response = client.execute(post)) {
                assertEquals(HttpStatus.OK_200, response.getStatusLine().getStatusCode(), "Token validation should have succeeded");
            }
        }

        // Submit JWT to /auth endpoint
        final List<NameValuePair> formData = new ArrayList<>();
        formData.add(new BasicNameValuePair("scope", "system/*.*"));
        formData.add(new BasicNameValuePair("grant_type", "client_credentials"));
        formData.add(new BasicNameValuePair("client_assertion_type", CLIENT_ASSERTION_TYPE));
        formData.add(new BasicNameValuePair("client_assertion", jwt));

        final UrlEncodedFormEntity entity = new UrlEncodedFormEntity(formData);
        final AuthResponse authResponse;
        try (final CloseableHttpClient client = createCustomHttpClient().trusting().build()) {
            final URIBuilder builder = new URIBuilder(String.format("%s/Token/auth", baseURL));
            final HttpPost post = new HttpPost(builder.build());
            post.setEntity(entity);
            post.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED);

            try (CloseableHttpResponse response = client.execute(post)) {
                assertEquals(HttpStatus.OK_200, response.getStatusLine().getStatusCode(), "Token request should have succeeded");
                authResponse = mapper.readValue(response.getEntity().getContent(), AuthResponse.class);
                assertNotEquals(macaroon, authResponse.accessToken, "New Macaroon should not be identical");
                assertEquals(300, authResponse.expiresIn, "Should expire in 300 seconds");
            }
        }
        return authResponse;
    }

    public static String createGoldenMacaroon() throws IOException {
        return createGoldenMacaroon(TASK_URL);
    }

    public static String createGoldenMacaroon(String taskURL) throws IOException {

        try (CloseableHttpClient client = HttpClients.createDefault()) {
            final HttpPost post = new HttpPost(String.format("%s/generate-token", taskURL));

            try (CloseableHttpResponse execute = client.execute(post)) {
                assertEquals(HttpStatus.OK_200, execute.getStatusLine().getStatusCode(), "Generated macaroon");
                return EntityUtils.toString(execute.getEntity());
            }
        }
    }

    public static KeyPair generateKeyPair() throws NoSuchAlgorithmException {
        // TODO: Possibly revert to ECC type.
        return generateKeyPair(KeyType.RSA);
    }

    public static KeyPair generateKeyPair(KeyType keyType) throws NoSuchAlgorithmException {
        final KeyPairGenerator kpg = KeyPairGenerator.getInstance(keyType.getName());
        if (keyType == KeyType.RSA) {
            kpg.initialize(keyType.getKeySize());
        } else {
            ECGenParameterSpec spec = new ECGenParameterSpec("secp256r1");
            try {
                kpg.initialize(spec);
            } catch (InvalidAlgorithmParameterException e) {
                throw new IllegalArgumentException("Cannot generate key", e);
            }
        }
        return kpg.generateKeyPair();
    }

    public static String generatePublicKey(PublicKey key) {
        final String encoded = Base64.getMimeEncoder().encodeToString(key.getEncoded());
        return String.format("-----BEGIN PUBLIC KEY-----\n%s\n-----END PUBLIC KEY-----\n", encoded);
    }

    /**
     * Generate a new {@link KeyPair} and submit the {@link PublicKey} to the API service, along with the given label
     *
     * @param keyLabel       - {@link String} identifier (kid) of the public key
     * @param organizationID - {@link String} organization ID to register key with
     * @param goldenMacaroon - {@link String} admin Macaroon that can upload keys
     * @param baseURL        - {@link String} baseURl to submit Key to
     * @return - {@link Pair}  of {@link UUID} (public key ID) and {@link PrivateKey} which matches the uploaded {@link PublicKey}
     * @throws IOException              - throws if something bad happens
     * @throws URISyntaxException       - throws if the URI is no good
     * @throws NoSuchAlgorithmException - throws if security breaks
     */
    public static Pair<UUID, PrivateKey> generateAndUploadKey(String keyLabel, String organizationID, String goldenMacaroon, String baseURL) throws IOException, URISyntaxException, GeneralSecurityException {
        final KeyPair keyPair = generateKeyPair();
        final String key = generatePublicKey(keyPair.getPublic());
        final String signature = signString(keyPair.getPrivate(), KEY_VERIFICATION_SNIPPET);

        // Create org specific macaroon from Golden Macaroon
        final String macaroon = MacaroonsBuilder
                .modify(MacaroonsBuilder.deserialize(goldenMacaroon).get(0))
                .add_first_party_caveat(String.format("organization_id = %s", organizationID))
                .getMacaroon().serialize(MacaroonVersion.SerializationVersion.V2_JSON);

        final KeyView keyEntity;
        final URIBuilder builder = new URIBuilder(String.format("%s/Key", baseURL));
        builder.addParameter("label", keyLabel);
        final HttpPost post = new HttpPost(builder.build());
        Map<String, String> body = Map.of("key", key, "signature", signature);
        post.setEntity(new StringEntity(new ObjectMapper().writeValueAsString(body)));
        post.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + macaroon);
        post.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON);
        post.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON);
        try (CloseableHttpClient client = createCustomHttpClient().trusting().build()) {
            try (CloseableHttpResponse response = client.execute(post)) {
                keyEntity = mapper.readValue(response.getEntity().getContent(), KeyView.class);
                assertEquals(HttpStatus.OK_200, response.getStatusLine().getStatusCode(), "Key should be valid");
            }
        }

        return Pair.of(keyEntity.id, keyPair.getPrivate());
    }

    public static CustomHttpBuilder createCustomHttpClient() {
        return new CustomHttpBuilder();
    }

    public static String signString(PrivateKey privateKey, String str) throws GeneralSecurityException {
        Signature signature = Signature.getInstance("SHA256withRSA");
        signature.initSign(privateKey);
        signature.update(str.getBytes(StandardCharsets.UTF_8));
        byte[] sigBytes = signature.sign();
        return Base64.getEncoder().encodeToString(sigBytes);
    }

    private static IGenericClient createBaseFHIRClient(FhirContext ctx, String baseURL, boolean disableSSLCheck, boolean enableRequestLog) {
        final HttpClientBuilder clientBuilder = HttpClients.custom();
        if (disableSSLCheck) {
            try {
                clientBuilder.setSSLContext(createTrustingSSLContext());
                clientBuilder.setSSLHostnameVerifier((s, sslSession) -> s.equalsIgnoreCase(sslSession.getPeerHost()));
            } catch (NoSuchAlgorithmException | KeyManagementException e) {
                throw new RuntimeException("Cannot create custom SSL context", e);
            }
        }

        ctx.getRestfulClientFactory().setHttpClient(clientBuilder.build());

        IGenericClient client = ctx.newRestfulGenericClient(baseURL);

        // Disable logging for tests
        LoggingInterceptor loggingInterceptor = new LoggingInterceptor();
        loggingInterceptor.setLogRequestSummary(enableRequestLog);
        loggingInterceptor.setLogResponseSummary(enableRequestLog);
        client.registerInterceptor(loggingInterceptor);

        return client;
    }

    private static SSLContext createTrustingSSLContext() throws KeyManagementException, NoSuchAlgorithmException {
        final SSLContext tls = SSLContext.getInstance("TLSv1.2");
        tls.init(null, getTrustingManager(), new SecureRandom());
        return tls;
    }

    private static TrustManager[] getTrustingManager() {
        return new TrustManager[]{new X509TrustManager() {
            @Override
            public X509Certificate[] getAcceptedIssuers() {
                return null;
            }

            @Override
            public void checkClientTrusted(X509Certificate[] certs, String authType) throws CertificateException {
                // only used for testing, so no certificates expected
                if (certs.length == 0) {
                    // do nothing
                } else if (certs.length > 0) {
                    // still do nothing
                } else {
                    throw new CertificateException();
                }
            }

            @Override
            public void checkServerTrusted(X509Certificate[] certs, String authType) throws CertificateException {
                // only used for testing, so no certificates expected
                if (certs.length == 0) {
                    // do nothing
                } else if (certs.length > 0) {
                    // still do nothing
                } else {
                    throw new CertificateException();
                }
            }

        }
        };
    }

    /**
     * Get the correct {@link SignatureAlgorithm} for the given {@link KeyType}
     *
     * @param keyType - {@link KeyType} to get algorithm for
     * @return - {@link SignatureAlgorithm} to use for signing JWT
     */
    public static SignatureAlgorithm getSigningAlgorithm(KeyType keyType) {
        return keyType == KeyType.ECC ? SignatureAlgorithm.ES256 : SignatureAlgorithm.RS384;
    }


    public static class MacaroonsInterceptor implements IClientInterceptor {

        private String macaroon;

        public MacaroonsInterceptor(String macaroon) {
            this.macaroon = macaroon;
        }

        @Override
        public void interceptRequest(IHttpRequest theRequest) {
            theRequest.addHeader(HttpHeaders.AUTHORIZATION, "Bearer " + this.macaroon);
        }

        @Override
        public void interceptResponse(IHttpResponse theResponse) {
            // Not used
        }

        public String getMacaroon() {
            return macaroon;
        }

        public void setMacaroon(String macaroon) {
            this.macaroon = macaroon;
        }
    }

    @SuppressWarnings("WeakerAccess")
    public static class AuthResponse {

        @JsonProperty(value = "access_token")
        public String accessToken;
        @JsonProperty(value = "token_type")
        public String tokenType;
        @JsonProperty(value = "expires_in")
        public Long expiresIn;
        public String scope;
        @JsonIgnore
        public OffsetDateTime expiresAt;

        public AuthResponse() {
            // Set the expiration time, so we can track it later
            this.expiresAt = OffsetDateTime.now(ZoneOffset.UTC);
        }
    }

    public static class HAPISmartInterceptor implements IClientInterceptor {

        private final String baseURL;
        private final String clientToken;
        private final UUID keyID;
        private final PrivateKey privateKey;

        private OffsetDateTime shouldRefreshToken;
        private AuthResponse response;

        HAPISmartInterceptor(String baseURL, String clientToken, UUID keyID, PrivateKey privateKey) {
            this.baseURL = baseURL;
            this.clientToken = clientToken;
            this.keyID = keyID;
            this.privateKey = privateKey;

            // Do the initial JWT Auth flow
            refreshAuthToken();
        }

        @Override
        public void interceptRequest(IHttpRequest theRequest) {
            if (OffsetDateTime.now(ZoneOffset.UTC).isAfter(shouldRefreshToken)) {
                refreshAuthToken();
            }

            theRequest.addHeader(HttpHeaders.AUTHORIZATION, "Bearer " + this.response.accessToken);
        }

        @Override
        public void interceptResponse(IHttpResponse theResponse) {
            // We don't need this
        }

        public AuthResponse getAuthResponse() {
            return this.response;
        }

        private void refreshAuthToken() {
            System.out.println("Refreshing access token");
            try {
                final AuthResponse authResponse = jwtAuthFlow(this.baseURL, this.clientToken, this.keyID, this.privateKey);
                // Set the refresh time to be 30 seconds before expiration
                this.shouldRefreshToken = OffsetDateTime.now(ZoneOffset.UTC)
                        .plus(authResponse.expiresIn, ChronoUnit.SECONDS)
                        .minus(30, ChronoUnit.SECONDS);
                this.response = authResponse;
            } catch (IOException | URISyntaxException e) {
                throw new IllegalStateException("Cannot perform auth flow", e);
            }
        }
    }

    public static class HttpClientAuthInterceptor implements HttpRequestInterceptor {

        private final String baseURL;
        private final String clientToken;
        private final UUID keyID;
        private final PrivateKey privateKey;

        private OffsetDateTime shouldRefreshToken;
        private AuthResponse response;

        HttpClientAuthInterceptor(String baseURL, String clientToken, UUID keyID, PrivateKey privateKey) {
            this.baseURL = baseURL;
            this.clientToken = clientToken;
            this.keyID = keyID;
            this.privateKey = privateKey;

            // Do the initial refresh
            refreshAuthToken();
        }

        @Override
        public void process(HttpRequest request, HttpContext context) {
            if (OffsetDateTime.now(ZoneOffset.UTC).isAfter(this.shouldRefreshToken)) {
                refreshAuthToken();
            }
            request.addHeader(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", this.response.accessToken));
        }

        private void refreshAuthToken() {
            System.out.println("Refreshing access token");
            try {
                final AuthResponse authResponse = jwtAuthFlow(this.baseURL, this.clientToken, this.keyID, this.privateKey);
                // Set the refresh time to be 30 seconds before expiration
                this.shouldRefreshToken = OffsetDateTime.now(ZoneOffset.UTC)
                        .plus(authResponse.expiresIn, ChronoUnit.SECONDS)
                        .minus(30, ChronoUnit.SECONDS);
                this.response = authResponse;
            } catch (IOException | URISyntaxException e) {
                throw new IllegalStateException("Cannot perform auth flow", e);
            }
        }
    }

    public static class CustomHttpBuilder {

        private final org.apache.http.impl.client.HttpClientBuilder builder;

        CustomHttpBuilder() {
            this.builder = HttpClients.custom();
        }


        public CustomHttpBuilder trusting() {
            try {
                builder
                        .setSSLContext(createTrustingSSLContext())
                        .setSSLHostnameVerifier((s, sslSession) -> s.equalsIgnoreCase(sslSession.getPeerHost()));
            } catch (KeyManagementException | NoSuchAlgorithmException e) {
                throw new IllegalStateException("Cannot create trusting http context");
            }

            return this;
        }

        public CustomHttpBuilder isAuthed(String baseURL, String clientToken, UUID keyID, PrivateKey privateKey) {
            this.builder.addInterceptorFirst(new HttpClientAuthInterceptor(baseURL, clientToken, keyID, privateKey));
            return this;
        }

        public CloseableHttpClient build() {
            return this.builder.build();
        }
    }
}