iagopiimenta/activestorage_legacy

View on GitHub
lib/active_storage/verifier.rb

Summary

Maintainability
A
45 mins
Test Coverage
B
89%
# frozen_string_literal: true
require 'active_storage/messages_metadata'

class ActiveStorage::Verifier
  class InvalidSignature < StandardError; end

  def initialize(secret, options = {})
    raise ArgumentError, "Secret should not be nil." unless secret
    @secret = secret
    @digest = options[:digest] || "SHA1"
    @serializer = options[:serializer] || Marshal
  end

  # Checks if a signed message could have been generated by signing an object
  # with the +MessageVerifier+'s secret.
  #
  #   verifier = ActiveSupport::MessageVerifier.new 's3Krit'
  #   signed_message = verifier.generate 'a private message'
  #   verifier.valid_message?(signed_message) # => true
  #
  #   tampered_message = signed_message.chop # editing the message invalidates the signature
  #   verifier.valid_message?(tampered_message) # => false
  def valid_message?(signed_message)
    return if signed_message.nil? || !signed_message.valid_encoding? || signed_message.blank?

    data, digest = signed_message.split("--".freeze)
    data.present? && digest.present? && security_class.secure_compare(digest, generate_digest(data))
  end

  # Decodes the signed message using the +MessageVerifier+'s secret.
  #
  #   verifier = ActiveSupport::MessageVerifier.new 's3Krit'
  #
  #   signed_message = verifier.generate 'a private message'
  #   verifier.verified(signed_message) # => 'a private message'
  #
  # Returns +nil+ if the message was not signed with the same secret.
  #
  #   other_verifier = ActiveSupport::MessageVerifier.new 'd1ff3r3nt-s3Krit'
  #   other_verifier.verified(signed_message) # => nil
  #
  # Returns +nil+ if the message is not Base64-encoded.
  #
  #   invalid_message = "f--46a0120593880c733a53b6dad75b42ddc1c8996d"
  #   verifier.verified(invalid_message) # => nil
  #
  # Raises any error raised while decoding the signed message.
  #
  #   incompatible_message = "test--dad7b06c94abba8d46a15fafaef56c327665d5ff"
  #   verifier.verified(incompatible_message) # => TypeError: incompatible marshal file format
  def verified(signed_message, purpose: nil, **)
    if valid_message?(signed_message)
      begin
        data = signed_message.split("--".freeze)[0]
        message = ActiveSupport::MessagesMetadata.verify(decode(data), purpose)
        @serializer.load(message) if message
      rescue ArgumentError => argument_error
        return if argument_error.message.include?("invalid base64")
        raise
      end
    end
  end

  # Decodes the signed message using the +MessageVerifier+'s secret.
  #
  #   verifier = ActiveSupport::MessageVerifier.new 's3Krit'
  #   signed_message = verifier.generate 'a private message'
  #
  #   verifier.verify(signed_message) # => 'a private message'
  #
  # Raises +InvalidSignature+ if the message was not signed with the same
  # secret or was not Base64-encoded.
  #
  #   other_verifier = ActiveSupport::MessageVerifier.new 'd1ff3r3nt-s3Krit'
  #   other_verifier.verify(signed_message) # => ActiveSupport::MessageVerifier::InvalidSignature
  def verify(*args)
    verified(*args) || raise(InvalidSignature)
  end

  # Generates a signed message for the provided value.
  #
  # The message is signed with the +MessageVerifier+'s secret. Without knowing
  # the secret, the original value cannot be extracted from the message.
  #
  #   verifier = ActiveSupport::MessageVerifier.new 's3Krit'
  #   verifier.generate 'a private message' # => "BAhJIhRwcml2YXRlLW1lc3NhZ2UGOgZFVA==--e2d724331ebdee96a10fb99b089508d1c72bd772"
  def generate(value, expires_at: nil, expires_in: nil, purpose: nil)
    data = encode(ActiveSupport::MessagesMetadata.wrap(@serializer.dump(value), expires_at: expires_at, expires_in: expires_in, purpose: purpose))
    "#{data}--#{generate_digest(data)}"
  end

  private
  def encode(data)
    ::Base64.strict_encode64(data)
  end

  def security_class
    if defined?(ActiveSupport::SecurityUtils)
      ActiveSupport::SecurityUtils
    else
      Rack::Utils
    end
  end

  def decode(data)
    ::Base64.strict_decode64(data)
  end

  def generate_digest(data)
    require "openssl" unless defined?(OpenSSL)
    OpenSSL::HMAC.hexdigest(OpenSSL::Digest.const_get(@digest).new, @secret, data)
  end
end