18F/identity-idp

View on GitHub
app/services/encryption/kms_client.rb

Summary

Maintainability
A
0 mins
Test Coverage
A
95%
# frozen_string_literal: true

require 'base64'

module Encryption
  class KmsClient
    include Encodable
    include ::NewRelic::Agent::MethodTracer

    KEY_TYPE = {
      KMS: 'KMSc',
      LOCAL_KEY: 'LOCc',
    }.freeze
    KMS_KEY_REGEX = /\A#{KEY_TYPE[:KMS]}/
    LOCAL_KEY_REGEX = /\A#{KEY_TYPE[:LOCAL_KEY]}/

    # rubocop:disable Layout/LineLength
    # Lazily-loaded per-region client factory
    KMS_CLIENT_POOL = ConnectionPool.new(size: IdentityConfig.store.aws_kms_client_multi_pool_size) do
      Aws::KMS::Client.new(
        instance_profile_credentials_timeout: 1, # defaults to 1 second
        instance_profile_credentials_retries: 5, # defaults to 0 retries
        region: IdentityConfig.store.aws_region, # The region in which the client is being instantiated
      )
    end.freeze
    # rubocop:enable Layout/LineLength

    attr_reader :kms_key_id

    def initialize(kms_key_id: IdentityConfig.store.aws_kms_key_id)
      @kms_key_id = kms_key_id
    end

    def encrypt(plaintext, encryption_context)
      KmsLogger.log(:encrypt, context: encryption_context, key_id: kms_key_id)
      return encrypt_kms(plaintext, encryption_context) if FeatureManagement.use_kms?
      encrypt_local(plaintext, encryption_context)
    end

    def decrypt(ciphertext, encryption_context)
      if self.class.looks_like_contextless?(ciphertext)
        return decrypt_contextless_kms(ciphertext, encryption_context)
      end
      KmsLogger.log(:decrypt, context: encryption_context, key_id: kms_key_id)
      return decrypt_kms(ciphertext, encryption_context) if use_kms?(ciphertext)
      decrypt_local(ciphertext, encryption_context)
    end

    def self.looks_like_kms?(ciphertext)
      ciphertext.start_with?(KEY_TYPE[:KMS])
    end

    def self.looks_like_local_key?(ciphertext)
      ciphertext.start_with?(KEY_TYPE[:LOCAL_KEY])
    end

    def self.looks_like_contextless?(ciphertext)
      !looks_like_kms?(ciphertext) && !looks_like_local_key?(ciphertext)
    end

    private

    def use_kms?(ciphertext)
      FeatureManagement.use_kms? && self.class.looks_like_kms?(ciphertext)
    end

    def encrypt_kms(plaintext, encryption_context)
      KEY_TYPE[:KMS] + chunk_plaintext(plaintext).map do |chunk|
        Base64.strict_encode64(
          encrypt_raw_kms(chunk, encryption_context),
        )
      end.to_json
    end

    def encrypt_raw_kms(plaintext, encryption_context)
      raise ArgumentError, 'kms plaintext exceeds 4096 bytes' if plaintext.bytesize > 4096

      KMS_CLIENT_POOL.with do |client|
        client.encrypt(
          key_id: kms_key_id,
          plaintext: plaintext,
          encryption_context: encryption_context,
        ).ciphertext_blob
      end
    end

    def decrypt_kms(ciphertext, encryption_context)
      clipped_ciphertext = ciphertext.gsub(KMS_KEY_REGEX, '')
      ciphertext_chunks = JSON.parse(clipped_ciphertext)
      ciphertext_chunks.map do |chunk|
        decrypt_raw_kms(
          Base64.strict_decode64(chunk),
          encryption_context,
        )
      end.join('')
    rescue JSON::ParserError, ArgumentError => error
      raise EncryptionError, "Failed to parse KMS ciphertext: #{error}"
    end

    def decrypt_raw_kms(ciphertext, encryption_context)
      KMS_CLIENT_POOL.with do |client|
        client.decrypt(
          ciphertext_blob: ciphertext,
          encryption_context: encryption_context,
        ).plaintext
      end
    rescue Aws::KMS::Errors::InvalidCiphertextException
      raise EncryptionError, 'Aws::KMS::Errors::InvalidCiphertextException'
    end

    def encrypt_local(plaintext, encryption_context)
      KEY_TYPE[:LOCAL_KEY] + chunk_plaintext(plaintext).map do |chunk|
        Base64.strict_encode64(
          encryptor.encrypt(chunk, local_encryption_key(encryption_context)),
        )
      end.to_json
    end

    def decrypt_local(ciphertext, encryption_context)
      clipped_ciphertext = ciphertext.gsub(LOCAL_KEY_REGEX, '')
      ciphertext_chunks = JSON.parse(clipped_ciphertext)
      ciphertext_chunks.map do |chunk|
        encryptor.decrypt(
          Base64.strict_decode64(chunk),
          local_encryption_key(encryption_context),
        )
      end.join('')
    rescue JSON::ParserError, ArgumentError => error
      raise EncryptionError, "Failed to parse local ciphertext: #{error}"
    end

    def local_encryption_key(encryption_context)
      OpenSSL::HMAC.digest(
        'sha256',
        IdentityConfig.store.password_pepper,
        (encryption_context.keys + encryption_context.values).sort.join(''),
      )
    end

    def decrypt_contextless_kms(ciphertext, encryption_context)
      ContextlessKmsClient.new.decrypt(ciphertext, log_context: encryption_context)
    end

    # chunk plaintext into ~4096 byte chunks, but not less than 1024 bytes in a chunk if chunking.
    # we do this by counting how many chunks we have and adding one.
    def chunk_plaintext(plaintext)
      plain_size = plaintext.bytesize
      number_chunks = plain_size / 4096
      chunk_size = plain_size / (1 + number_chunks)
      plaintext.scan(/.{1,#{chunk_size}}/m)
    end

    def encryptor
      @encryptor ||= Encryptors::AesEncryptor.new
    end

    add_method_tracer :decrypt, "Custom/#{name}/decrypt"
    add_method_tracer :encrypt, "Custom/#{name}/encrypt"
  end
end